commit 35b352afb0d8b9e0e3ad89afdacc231e1fbfe3b5 Author: Jae Kwon Date: Mon May 19 20:46:41 2014 -0700 First iteration of the immutable AVL tree diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..1377554eb --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.swp diff --git a/README.md b/README.md new file mode 100644 index 000000000..398ccfb63 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +TenderMint - proof of concept diff --git a/merkle/iavl.go b/merkle/iavl.go new file mode 100644 index 000000000..d360339b7 --- /dev/null +++ b/merkle/iavl.go @@ -0,0 +1,355 @@ +package merkle + +import ( + "hash" + "crypto/sha256" +) + +// Immutable AVL Tree (wraps the Node root) + +type IAVLTree struct { + root *IAVLNode +} + +func NewIAVLTree() *IAVLTree { + return &IAVLTree{} +} + +func (self *IAVLTree) Root() Node { + return self.root.Copy(true) +} + +func (self *IAVLTree) Size() int { + return self.root.Size() +} + +func (self *IAVLTree) Has(key Sortable) bool { + return self.root.Has(key) +} + +func (self *IAVLTree) Put(key Sortable, value interface{}) (err error) { + self.root, _ = self.root.Put(key, value) + return nil +} + +func (self *IAVLTree) Get(key Sortable) (value interface{}, err error) { + return self.root.Get(key) +} + +func (self *IAVLTree) Remove(key Sortable) (value interface{}, err error) { + new_root, value, err := self.root.Remove(key) + if err != nil { + return nil, err + } + self.root = new_root + return value, nil +} + +// Node + +type IAVLNode struct { + key Sortable + value interface{} + height int + hash []byte + left *IAVLNode + right *IAVLNode +} + +func (self *IAVLNode) Copy(copyHash bool) *IAVLNode { + if self == nil { + return nil + } + var hash []byte + if copyHash { + hash = self.hash + } + return &IAVLNode{ + key: self.key, + value: self.value, + height: self.height, + hash: hash, + left: self.left, + right: self.right, + } +} + +func (self *IAVLNode) Has(key Sortable) (has bool) { + if self == nil { + return false + } + if self.key.Equals(key) { + return true + } else if key.Less(self.key) { + return self.left.Has(key) + } else { + return self.right.Has(key) + } +} + +func (self *IAVLNode) Get(key Sortable) (value interface{}, err error) { + if self == nil { + return nil, NotFound(key) + } + if self.key.Equals(key) { + return self.value, nil + } else if key.Less(self.key) { + return self.left.Get(key) + } else { + return self.right.Get(key) + } +} + +// Copies and pops node from the tree. +// Returns a new tree (unless node is the root) & new (popped) node. +func (self *IAVLNode) pop_node(node *IAVLNode) (new_self, new_node *IAVLNode) { + if node == nil { + panic("node can't be nil") + } else if node.left != nil && node.right != nil { + panic("node must not have both left and right") + } + + if self == nil { + return nil, node.Copy(true) + } else if self == node { + var n *IAVLNode + if node.left != nil { + n = node.left + } else if node.right != nil { + n = node.right + } else { + n = nil + } + node = node.Copy(false) + node.left = nil + node.right = nil + return n, node + } + + self = self.Copy(false) + + if node.key.Less(self.key) { + self.left, node = self.left.pop_node(node) + } else { + self.right, node = self.right.pop_node(node) + } + + self.height = max(self.left.Height(), self.right.Height()) + 1 + return self, node +} + +// Pushes the node to the tree, returns a new tree +func (self *IAVLNode) push_node(node *IAVLNode) *IAVLNode { + if node == nil { + panic("node can't be nil") + } else if node.left != nil || node.right != nil { + panic("node must now be a leaf") + } + + self = self.Copy(false) + + if self == nil { + node.height = 1 + return node + } else if node.key.Less(self.key) { + self.left = self.left.push_node(node) + } else { + self.right = self.right.push_node(node) + } + self.height = max(self.left.Height(), self.right.Height()) + 1 + return self +} + +func (self *IAVLNode) rotate_right() *IAVLNode { + if self == nil { + return self + } + if self.left == nil { + return self + } + return self.rotate(self.left.rmd) +} + +func (self *IAVLNode) rotate_left() *IAVLNode { + if self == nil { + return self + } + if self.right == nil { + return self + } + return self.rotate(self.right.lmd) +} + +func (self *IAVLNode) rotate(get_new_root func() *IAVLNode) *IAVLNode { + self, new_root := self.pop_node(get_new_root()) + new_root.left = self.left + new_root.right = self.right + self.hash = nil + self.left = nil + self.right = nil + return new_root.push_node(self) +} + +func (self *IAVLNode) balance() *IAVLNode { + if self == nil { + return self + } + for abs(self.left.Height() - self.right.Height()) > 2 { + if self.left.Height() > self.right.Height() { + self = self.rotate_right() + } else { + self = self.rotate_left() + } + } + return self +} + +// TODO: don't clear the hash if the value hasn't changed. +func (self *IAVLNode) Put(key Sortable, value interface{}) (_ *IAVLNode, updated bool) { + if self == nil { + return &IAVLNode{key: key, value: value, height: 1, hash: nil}, false + } + + self = self.Copy(false) + + if self.key.Equals(key) { + self.value = value + return self, true + } + + if key.Less(self.key) { + self.left, updated = self.left.Put(key, value) + } else { + self.right, updated = self.right.Put(key, value) + } + self.height = max(self.left.Height(), self.right.Height()) + 1 + + if !updated { + self.height += 1 + return self.balance(), updated + } + return self, updated +} + +func (self *IAVLNode) Remove(key Sortable) (_ *IAVLNode, value interface{}, err error) { + if self == nil { + return nil, nil, NotFound(key) + } + + if self.key.Equals(key) { + if self.left != nil && self.right != nil { + var new_root *IAVLNode + if self.left.Size() < self.right.Size() { + self, new_root = self.pop_node(self.right.lmd()) + } else { + self, new_root = self.pop_node(self.left.rmd()) + } + new_root.left = self.left + new_root.right = self.right + return new_root, self.value, nil + } else if self.left == nil { + return self.right, self.value, nil + } else if self.right == nil { + return self.left, self.value, nil + } else { + return nil, self.value, nil + } + } + + self = self.Copy(true) + + if key.Less(self.key) { + self.left, value, err = self.left.Remove(key) + } else { + self.right, value, err = self.right.Remove(key) + } + if err == nil { + self.hash = nil + self.height = max(self.left.Height(), self.right.Height()) + 1 + return self.balance(), value, err + } else { + return self, value, err + } +} + +func (self *IAVLNode) Height() int { + if self == nil { + return 0 + } + return self.height +} + +func (self *IAVLNode) Size() int { + if self == nil { + return 0 + } + return 1 + self.left.Size() + self.right.Size() +} + + +func (self *IAVLNode) Key() Sortable { + return self.key +} + +func (self *IAVLNode) Value() interface{} { + return self.value +} + +func (self *IAVLNode) Left() Node { + if self.left == nil { + return nil + } + return self.left +} + +func (self *IAVLNode) Right() Node { + if self.right == nil { + return nil + } + return self.right +} + +// ... + +func (self *IAVLNode) _md(side func(*IAVLNode)*IAVLNode) (*IAVLNode) { + if self == nil { + return nil + } else if side(self) != nil { + return side(self)._md(side) + } else { + return self + } +} + +func (self *IAVLNode) lmd() (*IAVLNode) { + return self._md(func(node *IAVLNode)*IAVLNode { return node.left }) +} + +func (self *IAVLNode) rmd() (*IAVLNode) { + return self._md(func(node *IAVLNode)*IAVLNode { return node.right }) +} + +func abs(i int) int { + if i < 0 { + return -i + } + return i +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +// Calculate the hash of hasher over buf. +func CalcHash(buf []byte, hasher hash.Hash) []byte { + hasher.Write(buf) + return hasher.Sum(nil) +} + +// calculate hash256 which is sha256(sha256(data)) +func CalcSha256(buf []byte) []byte { + return CalcHash(buf, sha256.New()) +} diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go new file mode 100644 index 000000000..47f149c06 --- /dev/null +++ b/merkle/iavl_test.go @@ -0,0 +1,200 @@ +package merkle + +import "testing" + +import ( + "os" + "bytes" + "math/rand" + "encoding/binary" +) + + +func init() { + if urandom, err := os.Open("/dev/urandom"); err != nil { + return + } else { + buf := make([]byte, 8) + if _, err := urandom.Read(buf); err == nil { + buf_reader := bytes.NewReader(buf) + if seed, err := binary.ReadVarint(buf_reader); err == nil { + rand.Seed(seed) + } + } + urandom.Close() + } +} + +func randstr(length int) String { + if urandom, err := os.Open("/dev/urandom"); err != nil { + panic(err) + } else { + slice := make([]byte, length) + if _, err := urandom.Read(slice); err != nil { + panic(err) + } + urandom.Close() + return String(slice) + } + panic("unreachable") +} + +func TestImmutableAvlPutHasGetRemove(t *testing.T) { + + type record struct { + key String + value String + } + + records := make([]*record, 400) + var tree *IAVLNode + var err error + var val interface{} + var updated bool + + ranrec := func() *record { + return &record{ randstr(20), randstr(20) } + } + + for i := range records { + r := ranrec() + records[i] = r + tree, updated = tree.Put(r.key, String("")) + if updated { + t.Error("should have not been updated") + } + tree, updated = tree.Put(r.key, r.value) + if !updated { + t.Error("should have been updated") + } + if tree.Size() != (i+1) { + t.Error("size was wrong", tree.Size(), i+1) + } + } + + for _, r := range records { + if has := tree.Has(r.key); !has { + t.Error("Missing key") + } + if has := tree.Has(randstr(12)); has { + t.Error("Table has extra key") + } + if val, err := tree.Get(r.key); err != nil { + t.Error(err, val.(String), r.value) + } else if !(val.(String)).Equals(r.value) { + t.Error("wrong value") + } + } + + for i, x := range records { + if tree, val, err = tree.Remove(x.key); err != nil { + t.Error(err) + } else if !(val.(String)).Equals(x.value) { + t.Error("wrong value") + } + for _, r := range records[i+1:] { + if has := tree.Has(r.key); !has { + t.Error("Missing key") + } + if has := tree.Has(randstr(12)); has { + t.Error("Table has extra key") + } + if val, err := tree.Get(r.key); err != nil { + t.Error(err) + } else if !(val.(String)).Equals(r.value) { + t.Error("wrong value") + } + } + if tree.Size() != (len(records) - (i+1)) { + t.Error("size was wrong", tree.Size(), (len(records) - (i+1))) + } + } +} + + +func BenchmarkImmutableAvlTree(b *testing.B) { + b.StopTimer() + + type record struct { + key String + value String + } + + records := make([]*record, 100) + + ranrec := func() *record { + return &record{ randstr(20), randstr(20) } + } + + for i := range records { + records[i] = ranrec() + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + t := NewIAVLTree() + for _, r := range records { + t.Put(r.key, r.value) + } + for _, r := range records { + t.Remove(r.key) + } + } +} + + +func TestTraversals(t *testing.T) { + var data []int = []int{ + 1, 5, 7, 9, 12, 13, 17, 18, 19, 20, + } + var order []int = []int{ + 6, 1, 8, 2, 4 , 9 , 5 , 7 , 0 , 3 , + } + /* + var preorder []int = []int { + 17, 7, 5, 1, 12, 9, 13, 19, 18, 20, + } + var postorder []int = []int { + 1, 5, 9, 13, 12, 7, 18, 20, 19, 17, + } + */ + + test := func(T Tree) { + t.Logf("%T", T) + for j := range order { + if err := T.Put(Int(data[order[j]]), order[j]); err != nil { + t.Error(err) + } + } + + j := 0 + for + tn, next := Iterator(T.Root())(); + next != nil; + tn, next = next () { + if int(tn.Key().(Int)) != data[j] { + t.Error("key in wrong spot in-order") + } + j += 1 + } + + /* + j = 0 + for tn, next := tree.TraverseTreePreOrder(T.Root())(); next != nil; tn, next = next () { + if int(tn.Key().(Int)) != preorder[j] { + t.Error("key in wrong spot pre-order") + } + j += 1 + } + + j = 0 + for tn, next := tree.TraverseTreePostOrder(T.Root())(); next != nil; tn, next = next () { + if int(tn.Key().(Int)) != postorder[j] { + t.Error("key in wrong spot post-order") + } + j += 1 + } + */ + } + test(NewIAVLTree()) +} diff --git a/merkle/int.go b/merkle/int.go new file mode 100644 index 000000000..26115a96c --- /dev/null +++ b/merkle/int.go @@ -0,0 +1,225 @@ +package merkle + + +type Int8 int8 +type UInt8 uint8 +type Int16 int16 +type UInt16 uint16 +type Int32 int32 +type UInt32 uint32 +type Int64 int64 +type UInt64 uint64 +type Int int +type UInt uint + + +func (self Int8) Equals(other Sortable) bool { + if o, ok := other.(Int8); ok { + return self == o + } else { + return false + } +} + +func (self Int8) Less(other Sortable) bool { + if o, ok := other.(Int8); ok { + return self < o + } else { + return false + } +} + +func (self Int8) Hash() int { + return int(self) +} + + +func (self UInt8) Equals(other Sortable) bool { + if o, ok := other.(UInt8); ok { + return self == o + } else { + return false + } +} + +func (self UInt8) Less(other Sortable) bool { + if o, ok := other.(UInt8); ok { + return self < o + } else { + return false + } +} + +func (self UInt8) Hash() int { + return int(self) +} + + +func (self Int16) Equals(other Sortable) bool { + if o, ok := other.(Int16); ok { + return self == o + } else { + return false + } +} + +func (self Int16) Less(other Sortable) bool { + if o, ok := other.(Int16); ok { + return self < o + } else { + return false + } +} + +func (self Int16) Hash() int { + return int(self) +} + + +func (self UInt16) Equals(other Sortable) bool { + if o, ok := other.(UInt16); ok { + return self == o + } else { + return false + } +} + +func (self UInt16) Less(other Sortable) bool { + if o, ok := other.(UInt16); ok { + return self < o + } else { + return false + } +} + +func (self UInt16) Hash() int { + return int(self) +} + + +func (self Int32) Equals(other Sortable) bool { + if o, ok := other.(Int32); ok { + return self == o + } else { + return false + } +} + +func (self Int32) Less(other Sortable) bool { + if o, ok := other.(Int32); ok { + return self < o + } else { + return false + } +} + +func (self Int32) Hash() int { + return int(self) +} + + +func (self UInt32) Equals(other Sortable) bool { + if o, ok := other.(UInt32); ok { + return self == o + } else { + return false + } +} + +func (self UInt32) Less(other Sortable) bool { + if o, ok := other.(UInt32); ok { + return self < o + } else { + return false + } +} + +func (self UInt32) Hash() int { + return int(self) +} + + +func (self Int64) Equals(other Sortable) bool { + if o, ok := other.(Int64); ok { + return self == o + } else { + return false + } +} + +func (self Int64) Less(other Sortable) bool { + if o, ok := other.(Int64); ok { + return self < o + } else { + return false + } +} + +func (self Int64) Hash() int { + return int(self>>32) ^ int(self) +} + + +func (self UInt64) Equals(other Sortable) bool { + if o, ok := other.(UInt64); ok { + return self == o + } else { + return false + } +} + +func (self UInt64) Less(other Sortable) bool { + if o, ok := other.(UInt64); ok { + return self < o + } else { + return false + } +} + +func (self UInt64) Hash() int { + return int(self>>32) ^ int(self) +} + + +func (self Int) Equals(other Sortable) bool { + if o, ok := other.(Int); ok { + return self == o + } else { + return false + } +} + +func (self Int) Less(other Sortable) bool { + if o, ok := other.(Int); ok { + return self < o + } else { + return false + } +} + +func (self Int) Hash() int { + return int(self) +} + + +func (self UInt) Equals(other Sortable) bool { + if o, ok := other.(UInt); ok { + return self == o + } else { + return false + } +} + +func (self UInt) Less(other Sortable) bool { + if o, ok := other.(UInt); ok { + return self < o + } else { + return false + } +} + +func (self UInt) Hash() int { + return int(self) +} + + diff --git a/merkle/string.go b/merkle/string.go new file mode 100644 index 000000000..e497f0282 --- /dev/null +++ b/merkle/string.go @@ -0,0 +1,57 @@ +package merkle + +import "bytes" + +type String string +type ByteSlice []byte + +func (self String) Equals(other Sortable) bool { + if o, ok := other.(String); ok { + return self == o + } else { + return false + } +} + +func (self String) Less(other Sortable) bool { + if o, ok := other.(String); ok { + return self < o + } else { + return false + } +} + +func (self String) Hash() int { + bytes := []byte(self) + hash := 0 + for i, c := range bytes { + hash += (i+1)*int(c) + } + return hash +} + +func (self ByteSlice) Equals(other Sortable) bool { + if o, ok := other.(ByteSlice); ok { + return bytes.Equal(self, o) + } else { + return false + } +} + +func (self ByteSlice) Less(other Sortable) bool { + if o, ok := other.(ByteSlice); ok { + return bytes.Compare(self, o) < 0 // -1 if a < b + } else { + return false + } +} + +func (self ByteSlice) Hash() int { + hash := 0 + for i, c := range self { + hash += (i+1)*int(c) + } + return hash +} + + diff --git a/merkle/tree.go b/merkle/tree.go new file mode 100644 index 000000000..e694e6f02 --- /dev/null +++ b/merkle/tree.go @@ -0,0 +1,30 @@ +package merkle + +func Iterator(node Node) NodeIterator { + stack := make([]Node, 0, 10) + var cur Node = node + var tn_iterator NodeIterator + tn_iterator = func()(tn Node, next NodeIterator) { + if len(stack) > 0 || cur != nil { + for cur != nil { + stack = append(stack, cur) + cur = cur.Left() + } + stack, cur = pop(stack) + tn = cur + cur = cur.Right() + return tn, tn_iterator + } else { + return nil, nil + } + } + return tn_iterator +} + +func pop(stack []Node) ([]Node, Node) { + if len(stack) <= 0 { + return stack, nil + } else { + return stack[0:len(stack)-1], stack[len(stack)-1] + } +} diff --git a/merkle/types.go b/merkle/types.go new file mode 100644 index 000000000..606b0f037 --- /dev/null +++ b/merkle/types.go @@ -0,0 +1,38 @@ +package merkle + +import ( + "fmt" +) + +type Sortable interface { + Equals(b Sortable) bool + Less(b Sortable) bool +} + +type Tree interface { + Root() Node + + Size() int + Has(key Sortable) bool + Get(key Sortable) (value interface{}, err error) + + Put(key Sortable, value interface{}) (err error) + Remove(key Sortable) (value interface{}, err error) +} + +type Node interface { + Key() Sortable + Value() interface{} + Left() Node + Right() Node + + Size() int + Has(key Sortable) bool + Get(key Sortable) (value interface{}, err error) +} + +type NodeIterator func() (node Node, next NodeIterator) + +func NotFound(key Sortable) error { + return fmt.Errorf("Key was not found.") +}