Browse Source

First iteration of the immutable AVL tree

pull/9/head
Jae Kwon 10 years ago
commit
35b352afb0
8 changed files with 907 additions and 0 deletions
  1. +1
    -0
      .gitignore
  2. +1
    -0
      README.md
  3. +355
    -0
      merkle/iavl.go
  4. +200
    -0
      merkle/iavl_test.go
  5. +225
    -0
      merkle/int.go
  6. +57
    -0
      merkle/string.go
  7. +30
    -0
      merkle/tree.go
  8. +38
    -0
      merkle/types.go

+ 1
- 0
.gitignore View File

@ -0,0 +1 @@
*.swp

+ 1
- 0
README.md View File

@ -0,0 +1 @@
TenderMint - proof of concept

+ 355
- 0
merkle/iavl.go View File

@ -0,0 +1,355 @@
package merkle
import (
"hash"
"crypto/sha256"
)
// Immutable AVL Tree (wraps the Node root)
type IAVLTree struct {
root *IAVLNode
}
func NewIAVLTree() *IAVLTree {
return &IAVLTree{}
}
func (self *IAVLTree) Root() Node {
return self.root.Copy(true)
}
func (self *IAVLTree) Size() int {
return self.root.Size()
}
func (self *IAVLTree) Has(key Sortable) bool {
return self.root.Has(key)
}
func (self *IAVLTree) Put(key Sortable, value interface{}) (err error) {
self.root, _ = self.root.Put(key, value)
return nil
}
func (self *IAVLTree) Get(key Sortable) (value interface{}, err error) {
return self.root.Get(key)
}
func (self *IAVLTree) Remove(key Sortable) (value interface{}, err error) {
new_root, value, err := self.root.Remove(key)
if err != nil {
return nil, err
}
self.root = new_root
return value, nil
}
// Node
type IAVLNode struct {
key Sortable
value interface{}
height int
hash []byte
left *IAVLNode
right *IAVLNode
}
func (self *IAVLNode) Copy(copyHash bool) *IAVLNode {
if self == nil {
return nil
}
var hash []byte
if copyHash {
hash = self.hash
}
return &IAVLNode{
key: self.key,
value: self.value,
height: self.height,
hash: hash,
left: self.left,
right: self.right,
}
}
func (self *IAVLNode) Has(key Sortable) (has bool) {
if self == nil {
return false
}
if self.key.Equals(key) {
return true
} else if key.Less(self.key) {
return self.left.Has(key)
} else {
return self.right.Has(key)
}
}
func (self *IAVLNode) Get(key Sortable) (value interface{}, err error) {
if self == nil {
return nil, NotFound(key)
}
if self.key.Equals(key) {
return self.value, nil
} else if key.Less(self.key) {
return self.left.Get(key)
} else {
return self.right.Get(key)
}
}
// Copies and pops node from the tree.
// Returns a new tree (unless node is the root) & new (popped) node.
func (self *IAVLNode) pop_node(node *IAVLNode) (new_self, new_node *IAVLNode) {
if node == nil {
panic("node can't be nil")
} else if node.left != nil && node.right != nil {
panic("node must not have both left and right")
}
if self == nil {
return nil, node.Copy(true)
} else if self == node {
var n *IAVLNode
if node.left != nil {
n = node.left
} else if node.right != nil {
n = node.right
} else {
n = nil
}
node = node.Copy(false)
node.left = nil
node.right = nil
return n, node
}
self = self.Copy(false)
if node.key.Less(self.key) {
self.left, node = self.left.pop_node(node)
} else {
self.right, node = self.right.pop_node(node)
}
self.height = max(self.left.Height(), self.right.Height()) + 1
return self, node
}
// Pushes the node to the tree, returns a new tree
func (self *IAVLNode) push_node(node *IAVLNode) *IAVLNode {
if node == nil {
panic("node can't be nil")
} else if node.left != nil || node.right != nil {
panic("node must now be a leaf")
}
self = self.Copy(false)
if self == nil {
node.height = 1
return node
} else if node.key.Less(self.key) {
self.left = self.left.push_node(node)
} else {
self.right = self.right.push_node(node)
}
self.height = max(self.left.Height(), self.right.Height()) + 1
return self
}
func (self *IAVLNode) rotate_right() *IAVLNode {
if self == nil {
return self
}
if self.left == nil {
return self
}
return self.rotate(self.left.rmd)
}
func (self *IAVLNode) rotate_left() *IAVLNode {
if self == nil {
return self
}
if self.right == nil {
return self
}
return self.rotate(self.right.lmd)
}
func (self *IAVLNode) rotate(get_new_root func() *IAVLNode) *IAVLNode {
self, new_root := self.pop_node(get_new_root())
new_root.left = self.left
new_root.right = self.right
self.hash = nil
self.left = nil
self.right = nil
return new_root.push_node(self)
}
func (self *IAVLNode) balance() *IAVLNode {
if self == nil {
return self
}
for abs(self.left.Height() - self.right.Height()) > 2 {
if self.left.Height() > self.right.Height() {
self = self.rotate_right()
} else {
self = self.rotate_left()
}
}
return self
}
// TODO: don't clear the hash if the value hasn't changed.
func (self *IAVLNode) Put(key Sortable, value interface{}) (_ *IAVLNode, updated bool) {
if self == nil {
return &IAVLNode{key: key, value: value, height: 1, hash: nil}, false
}
self = self.Copy(false)
if self.key.Equals(key) {
self.value = value
return self, true
}
if key.Less(self.key) {
self.left, updated = self.left.Put(key, value)
} else {
self.right, updated = self.right.Put(key, value)
}
self.height = max(self.left.Height(), self.right.Height()) + 1
if !updated {
self.height += 1
return self.balance(), updated
}
return self, updated
}
func (self *IAVLNode) Remove(key Sortable) (_ *IAVLNode, value interface{}, err error) {
if self == nil {
return nil, nil, NotFound(key)
}
if self.key.Equals(key) {
if self.left != nil && self.right != nil {
var new_root *IAVLNode
if self.left.Size() < self.right.Size() {
self, new_root = self.pop_node(self.right.lmd())
} else {
self, new_root = self.pop_node(self.left.rmd())
}
new_root.left = self.left
new_root.right = self.right
return new_root, self.value, nil
} else if self.left == nil {
return self.right, self.value, nil
} else if self.right == nil {
return self.left, self.value, nil
} else {
return nil, self.value, nil
}
}
self = self.Copy(true)
if key.Less(self.key) {
self.left, value, err = self.left.Remove(key)
} else {
self.right, value, err = self.right.Remove(key)
}
if err == nil {
self.hash = nil
self.height = max(self.left.Height(), self.right.Height()) + 1
return self.balance(), value, err
} else {
return self, value, err
}
}
func (self *IAVLNode) Height() int {
if self == nil {
return 0
}
return self.height
}
func (self *IAVLNode) Size() int {
if self == nil {
return 0
}
return 1 + self.left.Size() + self.right.Size()
}
func (self *IAVLNode) Key() Sortable {
return self.key
}
func (self *IAVLNode) Value() interface{} {
return self.value
}
func (self *IAVLNode) Left() Node {
if self.left == nil {
return nil
}
return self.left
}
func (self *IAVLNode) Right() Node {
if self.right == nil {
return nil
}
return self.right
}
// ...
func (self *IAVLNode) _md(side func(*IAVLNode)*IAVLNode) (*IAVLNode) {
if self == nil {
return nil
} else if side(self) != nil {
return side(self)._md(side)
} else {
return self
}
}
func (self *IAVLNode) lmd() (*IAVLNode) {
return self._md(func(node *IAVLNode)*IAVLNode { return node.left })
}
func (self *IAVLNode) rmd() (*IAVLNode) {
return self._md(func(node *IAVLNode)*IAVLNode { return node.right })
}
func abs(i int) int {
if i < 0 {
return -i
}
return i
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
// Calculate the hash of hasher over buf.
func CalcHash(buf []byte, hasher hash.Hash) []byte {
hasher.Write(buf)
return hasher.Sum(nil)
}
// calculate hash256 which is sha256(sha256(data))
func CalcSha256(buf []byte) []byte {
return CalcHash(buf, sha256.New())
}

+ 200
- 0
merkle/iavl_test.go View File

@ -0,0 +1,200 @@
package merkle
import "testing"
import (
"os"
"bytes"
"math/rand"
"encoding/binary"
)
func init() {
if urandom, err := os.Open("/dev/urandom"); err != nil {
return
} else {
buf := make([]byte, 8)
if _, err := urandom.Read(buf); err == nil {
buf_reader := bytes.NewReader(buf)
if seed, err := binary.ReadVarint(buf_reader); err == nil {
rand.Seed(seed)
}
}
urandom.Close()
}
}
func randstr(length int) String {
if urandom, err := os.Open("/dev/urandom"); err != nil {
panic(err)
} else {
slice := make([]byte, length)
if _, err := urandom.Read(slice); err != nil {
panic(err)
}
urandom.Close()
return String(slice)
}
panic("unreachable")
}
func TestImmutableAvlPutHasGetRemove(t *testing.T) {
type record struct {
key String
value String
}
records := make([]*record, 400)
var tree *IAVLNode
var err error
var val interface{}
var updated bool
ranrec := func() *record {
return &record{ randstr(20), randstr(20) }
}
for i := range records {
r := ranrec()
records[i] = r
tree, updated = tree.Put(r.key, String(""))
if updated {
t.Error("should have not been updated")
}
tree, updated = tree.Put(r.key, r.value)
if !updated {
t.Error("should have been updated")
}
if tree.Size() != (i+1) {
t.Error("size was wrong", tree.Size(), i+1)
}
}
for _, r := range records {
if has := tree.Has(r.key); !has {
t.Error("Missing key")
}
if has := tree.Has(randstr(12)); has {
t.Error("Table has extra key")
}
if val, err := tree.Get(r.key); err != nil {
t.Error(err, val.(String), r.value)
} else if !(val.(String)).Equals(r.value) {
t.Error("wrong value")
}
}
for i, x := range records {
if tree, val, err = tree.Remove(x.key); err != nil {
t.Error(err)
} else if !(val.(String)).Equals(x.value) {
t.Error("wrong value")
}
for _, r := range records[i+1:] {
if has := tree.Has(r.key); !has {
t.Error("Missing key")
}
if has := tree.Has(randstr(12)); has {
t.Error("Table has extra key")
}
if val, err := tree.Get(r.key); err != nil {
t.Error(err)
} else if !(val.(String)).Equals(r.value) {
t.Error("wrong value")
}
}
if tree.Size() != (len(records) - (i+1)) {
t.Error("size was wrong", tree.Size(), (len(records) - (i+1)))
}
}
}
func BenchmarkImmutableAvlTree(b *testing.B) {
b.StopTimer()
type record struct {
key String
value String
}
records := make([]*record, 100)
ranrec := func() *record {
return &record{ randstr(20), randstr(20) }
}
for i := range records {
records[i] = ranrec()
}
b.StartTimer()
for i := 0; i < b.N; i++ {
t := NewIAVLTree()
for _, r := range records {
t.Put(r.key, r.value)
}
for _, r := range records {
t.Remove(r.key)
}
}
}
func TestTraversals(t *testing.T) {
var data []int = []int{
1, 5, 7, 9, 12, 13, 17, 18, 19, 20,
}
var order []int = []int{
6, 1, 8, 2, 4 , 9 , 5 , 7 , 0 , 3 ,
}
/*
var preorder []int = []int {
17, 7, 5, 1, 12, 9, 13, 19, 18, 20,
}
var postorder []int = []int {
1, 5, 9, 13, 12, 7, 18, 20, 19, 17,
}
*/
test := func(T Tree) {
t.Logf("%T", T)
for j := range order {
if err := T.Put(Int(data[order[j]]), order[j]); err != nil {
t.Error(err)
}
}
j := 0
for
tn, next := Iterator(T.Root())();
next != nil;
tn, next = next () {
if int(tn.Key().(Int)) != data[j] {
t.Error("key in wrong spot in-order")
}
j += 1
}
/*
j = 0
for tn, next := tree.TraverseTreePreOrder(T.Root())(); next != nil; tn, next = next () {
if int(tn.Key().(Int)) != preorder[j] {
t.Error("key in wrong spot pre-order")
}
j += 1
}
j = 0
for tn, next := tree.TraverseTreePostOrder(T.Root())(); next != nil; tn, next = next () {
if int(tn.Key().(Int)) != postorder[j] {
t.Error("key in wrong spot post-order")
}
j += 1
}
*/
}
test(NewIAVLTree())
}

+ 225
- 0
merkle/int.go View File

@ -0,0 +1,225 @@
package merkle
type Int8 int8
type UInt8 uint8
type Int16 int16
type UInt16 uint16
type Int32 int32
type UInt32 uint32
type Int64 int64
type UInt64 uint64
type Int int
type UInt uint
func (self Int8) Equals(other Sortable) bool {
if o, ok := other.(Int8); ok {
return self == o
} else {
return false
}
}
func (self Int8) Less(other Sortable) bool {
if o, ok := other.(Int8); ok {
return self < o
} else {
return false
}
}
func (self Int8) Hash() int {
return int(self)
}
func (self UInt8) Equals(other Sortable) bool {
if o, ok := other.(UInt8); ok {
return self == o
} else {
return false
}
}
func (self UInt8) Less(other Sortable) bool {
if o, ok := other.(UInt8); ok {
return self < o
} else {
return false
}
}
func (self UInt8) Hash() int {
return int(self)
}
func (self Int16) Equals(other Sortable) bool {
if o, ok := other.(Int16); ok {
return self == o
} else {
return false
}
}
func (self Int16) Less(other Sortable) bool {
if o, ok := other.(Int16); ok {
return self < o
} else {
return false
}
}
func (self Int16) Hash() int {
return int(self)
}
func (self UInt16) Equals(other Sortable) bool {
if o, ok := other.(UInt16); ok {
return self == o
} else {
return false
}
}
func (self UInt16) Less(other Sortable) bool {
if o, ok := other.(UInt16); ok {
return self < o
} else {
return false
}
}
func (self UInt16) Hash() int {
return int(self)
}
func (self Int32) Equals(other Sortable) bool {
if o, ok := other.(Int32); ok {
return self == o
} else {
return false
}
}
func (self Int32) Less(other Sortable) bool {
if o, ok := other.(Int32); ok {
return self < o
} else {
return false
}
}
func (self Int32) Hash() int {
return int(self)
}
func (self UInt32) Equals(other Sortable) bool {
if o, ok := other.(UInt32); ok {
return self == o
} else {
return false
}
}
func (self UInt32) Less(other Sortable) bool {
if o, ok := other.(UInt32); ok {
return self < o
} else {
return false
}
}
func (self UInt32) Hash() int {
return int(self)
}
func (self Int64) Equals(other Sortable) bool {
if o, ok := other.(Int64); ok {
return self == o
} else {
return false
}
}
func (self Int64) Less(other Sortable) bool {
if o, ok := other.(Int64); ok {
return self < o
} else {
return false
}
}
func (self Int64) Hash() int {
return int(self>>32) ^ int(self)
}
func (self UInt64) Equals(other Sortable) bool {
if o, ok := other.(UInt64); ok {
return self == o
} else {
return false
}
}
func (self UInt64) Less(other Sortable) bool {
if o, ok := other.(UInt64); ok {
return self < o
} else {
return false
}
}
func (self UInt64) Hash() int {
return int(self>>32) ^ int(self)
}
func (self Int) Equals(other Sortable) bool {
if o, ok := other.(Int); ok {
return self == o
} else {
return false
}
}
func (self Int) Less(other Sortable) bool {
if o, ok := other.(Int); ok {
return self < o
} else {
return false
}
}
func (self Int) Hash() int {
return int(self)
}
func (self UInt) Equals(other Sortable) bool {
if o, ok := other.(UInt); ok {
return self == o
} else {
return false
}
}
func (self UInt) Less(other Sortable) bool {
if o, ok := other.(UInt); ok {
return self < o
} else {
return false
}
}
func (self UInt) Hash() int {
return int(self)
}

+ 57
- 0
merkle/string.go View File

@ -0,0 +1,57 @@
package merkle
import "bytes"
type String string
type ByteSlice []byte
func (self String) Equals(other Sortable) bool {
if o, ok := other.(String); ok {
return self == o
} else {
return false
}
}
func (self String) Less(other Sortable) bool {
if o, ok := other.(String); ok {
return self < o
} else {
return false
}
}
func (self String) Hash() int {
bytes := []byte(self)
hash := 0
for i, c := range bytes {
hash += (i+1)*int(c)
}
return hash
}
func (self ByteSlice) Equals(other Sortable) bool {
if o, ok := other.(ByteSlice); ok {
return bytes.Equal(self, o)
} else {
return false
}
}
func (self ByteSlice) Less(other Sortable) bool {
if o, ok := other.(ByteSlice); ok {
return bytes.Compare(self, o) < 0 // -1 if a < b
} else {
return false
}
}
func (self ByteSlice) Hash() int {
hash := 0
for i, c := range self {
hash += (i+1)*int(c)
}
return hash
}

+ 30
- 0
merkle/tree.go View File

@ -0,0 +1,30 @@
package merkle
func Iterator(node Node) NodeIterator {
stack := make([]Node, 0, 10)
var cur Node = node
var tn_iterator NodeIterator
tn_iterator = func()(tn Node, next NodeIterator) {
if len(stack) > 0 || cur != nil {
for cur != nil {
stack = append(stack, cur)
cur = cur.Left()
}
stack, cur = pop(stack)
tn = cur
cur = cur.Right()
return tn, tn_iterator
} else {
return nil, nil
}
}
return tn_iterator
}
func pop(stack []Node) ([]Node, Node) {
if len(stack) <= 0 {
return stack, nil
} else {
return stack[0:len(stack)-1], stack[len(stack)-1]
}
}

+ 38
- 0
merkle/types.go View File

@ -0,0 +1,38 @@
package merkle
import (
"fmt"
)
type Sortable interface {
Equals(b Sortable) bool
Less(b Sortable) bool
}
type Tree interface {
Root() Node
Size() int
Has(key Sortable) bool
Get(key Sortable) (value interface{}, err error)
Put(key Sortable, value interface{}) (err error)
Remove(key Sortable) (value interface{}, err error)
}
type Node interface {
Key() Sortable
Value() interface{}
Left() Node
Right() Node
Size() int
Has(key Sortable) bool
Get(key Sortable) (value interface{}, err error)
}
type NodeIterator func() (node Node, next NodeIterator)
func NotFound(key Sortable) error {
return fmt.Errorf("Key was not found.")
}

Loading…
Cancel
Save