From 795e183273a9810f95f93b731e8bf7dc85cea742 Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Tue, 20 May 2014 16:33:10 -0700 Subject: [PATCH] Fixed merkle implementation to fit official algorithm. --- merkle/iavl.go | 181 +++++++++++++++++++++--------------- merkle/iavl_test.go | 121 +++++++++++++++++++----- merkle/types.go | 3 + merkle/{tree.go => util.go} | 24 +++++ 4 files changed, 233 insertions(+), 96 deletions(-) rename merkle/{tree.go => util.go} (58%) diff --git a/merkle/iavl.go b/merkle/iavl.go index d360339b7..02043462e 100644 --- a/merkle/iavl.go +++ b/merkle/iavl.go @@ -1,6 +1,7 @@ package merkle import ( + //"fmt" "hash" "crypto/sha256" ) @@ -100,18 +101,19 @@ func (self *IAVLNode) Get(key Sortable) (value interface{}, err error) { } } -// Copies and pops node from the tree. -// Returns a new tree (unless node is the root) & new (popped) node. +// Returns a new tree (unless node is the root) & a copy of the popped node. +// Can only pop nodes that have one or no children. func (self *IAVLNode) pop_node(node *IAVLNode) (new_self, new_node *IAVLNode) { - if node == nil { + if self == nil { + panic("self can't be nil") + } else if node == nil { panic("node can't be nil") } else if node.left != nil && node.right != nil { panic("node must not have both left and right") } - if self == nil { - return nil, node.Copy(true) - } else if self == node { + if self == node { + var n *IAVLNode if node.left != nil { n = node.left @@ -124,18 +126,19 @@ func (self *IAVLNode) pop_node(node *IAVLNode) (new_self, new_node *IAVLNode) { node.left = nil node.right = nil return n, node - } - - self = self.Copy(false) - if node.key.Less(self.key) { - self.left, node = self.left.pop_node(node) } else { - self.right, node = self.right.pop_node(node) - } - self.height = max(self.left.Height(), self.right.Height()) + 1 - return self, node + self = self.Copy(false) + if node.key.Less(self.key) { + self.left, node = self.left.pop_node(node) + } else { + self.right, node = self.right.pop_node(node) + } + self.calc_height() + return self, node + + } } // Pushes the node to the tree, returns a new tree @@ -156,51 +159,74 @@ func (self *IAVLNode) push_node(node *IAVLNode) *IAVLNode { } else { self.right = self.right.push_node(node) } - self.height = max(self.left.Height(), self.right.Height()) + 1 + self.calc_height() return self } func (self *IAVLNode) rotate_right() *IAVLNode { - if self == nil { - return self - } - if self.left == nil { - return self - } - return self.rotate(self.left.rmd) + self = self.Copy(false) + sl := self.left.Copy(false) + slr := sl.right + + sl.right = self + self.left = slr + + self.calc_height() + sl.calc_height() + + return sl } func (self *IAVLNode) rotate_left() *IAVLNode { - if self == nil { - return self - } - if self.right == nil { - return self - } - return self.rotate(self.right.lmd) + self = self.Copy(false) + sr := self.right.Copy(false) + srl := sr.left + + sr.left = self + self.right = srl + + self.calc_height() + sr.calc_height() + + return sr } -func (self *IAVLNode) rotate(get_new_root func() *IAVLNode) *IAVLNode { - self, new_root := self.pop_node(get_new_root()) - new_root.left = self.left - new_root.right = self.right - self.hash = nil - self.left = nil - self.right = nil - return new_root.push_node(self) +func (self *IAVLNode) calc_height() { + self.height = max(self.left.Height(), self.right.Height()) + 1 } -func (self *IAVLNode) balance() *IAVLNode { +func (self *IAVLNode) calc_balance() int { if self == nil { - return self + return 0 } - for abs(self.left.Height() - self.right.Height()) > 2 { - if self.left.Height() > self.right.Height() { - self = self.rotate_right() + return self.left.Height() - self.right.Height() +} + +func (self *IAVLNode) balance() (new_self *IAVLNode) { + balance := self.calc_balance() + if (balance > 1) { + if (self.left.calc_balance() >= 0) { + // Left Left Case + return self.rotate_right() } else { - self = self.rotate_left() + // Left Right Case + self = self.Copy(false) + self.left = self.left.rotate_left() + return self.rotate_right() } } + if (balance < -1) { + if (self.right.calc_balance() <= 0) { + // Right Right Case + return self.rotate_left() + } else { + // Right Left Case + self = self.Copy(false) + self.right = self.right.rotate_right() + return self.rotate_left() + } + } + // Nothing changed return self } @@ -222,31 +248,29 @@ func (self *IAVLNode) Put(key Sortable, value interface{}) (_ *IAVLNode, updated } else { self.right, updated = self.right.Put(key, value) } - self.height = max(self.left.Height(), self.right.Height()) + 1 - - if !updated { - self.height += 1 + if updated { + return self, updated + } else { + self.calc_height() return self.balance(), updated } - return self, updated } -func (self *IAVLNode) Remove(key Sortable) (_ *IAVLNode, value interface{}, err error) { +func (self *IAVLNode) Remove(key Sortable) (new_self *IAVLNode, value interface{}, err error) { if self == nil { return nil, nil, NotFound(key) } if self.key.Equals(key) { if self.left != nil && self.right != nil { - var new_root *IAVLNode if self.left.Size() < self.right.Size() { - self, new_root = self.pop_node(self.right.lmd()) + self, new_self = self.pop_node(self.right.lmd()) } else { - self, new_root = self.pop_node(self.left.rmd()) + self, new_self = self.pop_node(self.left.rmd()) } - new_root.left = self.left - new_root.right = self.right - return new_root, self.value, nil + new_self.left = self.left + new_self.right = self.right + return new_self, self.value, nil } else if self.left == nil { return self.right, self.value, nil } else if self.right == nil { @@ -256,20 +280,35 @@ func (self *IAVLNode) Remove(key Sortable) (_ *IAVLNode, value interface{}, err } } - self = self.Copy(true) - if key.Less(self.key) { - self.left, value, err = self.left.Remove(key) - } else { - self.right, value, err = self.right.Remove(key) - } - if err == nil { - self.hash = nil - self.height = max(self.left.Height(), self.right.Height()) + 1 - return self.balance(), value, err + if self.left == nil { + return self, nil, NotFound(key) + } + var new_left *IAVLNode + new_left, value, err = self.left.Remove(key) + if new_left == self.left { // not found + return self, nil, err + } else if err != nil { // some other error + return self, value, err + } + self = self.Copy(false) + self.left = new_left } else { - return self, value, err + if self.right == nil { + return self, nil, NotFound(key) + } + var new_right *IAVLNode + new_right, value, err = self.right.Remove(key) + if new_right == self.right { // not found + return self, nil, err + } else if err != nil { // some other error + return self, value, err + } + self = self.Copy(false) + self.right = new_right } + self.calc_height() + return self.balance(), value, err } func (self *IAVLNode) Height() int { @@ -296,16 +335,12 @@ func (self *IAVLNode) Value() interface{} { } func (self *IAVLNode) Left() Node { - if self.left == nil { - return nil - } + if self.left == nil { return nil } return self.left } func (self *IAVLNode) Right() Node { - if self.right == nil { - return nil - } + if self.right == nil { return nil } return self.right } diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index 47f149c06..94fce5434 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -3,6 +3,7 @@ package merkle import "testing" import ( + "fmt" "os" "bytes" "math/rand" @@ -150,14 +151,6 @@ func TestTraversals(t *testing.T) { var order []int = []int{ 6, 1, 8, 2, 4 , 9 , 5 , 7 , 0 , 3 , } - /* - var preorder []int = []int { - 17, 7, 5, 1, 12, 9, 13, 19, 18, 20, - } - var postorder []int = []int { - 1, 5, 9, 13, 12, 7, 18, 20, 19, 17, - } - */ test := func(T Tree) { t.Logf("%T", T) @@ -177,24 +170,106 @@ func TestTraversals(t *testing.T) { } j += 1 } + } + test(NewIAVLTree()) +} - /* - j = 0 - for tn, next := tree.TraverseTreePreOrder(T.Root())(); next != nil; tn, next = next () { - if int(tn.Key().(Int)) != preorder[j] { - t.Error("key in wrong spot pre-order") - } - j += 1 +func BenchmarkSha256(b *testing.B) { + b.StopTimer() + + str := []byte(randstr(32)) + + ransha256 := func() []byte { + return CalcSha256(str) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + ransha256() + } +} + +// from http://stackoverflow.com/questions/3955680/how-to-check-if-my-avl-tree-implementation-is-correct +func TestGriffin(t *testing.T) { + N := func(l *IAVLNode, i int, r *IAVLNode) *IAVLNode { + n := &IAVLNode{Int32(i), nil, -1, nil, l, r} + n.calc_height() + return n + } + var P func(*IAVLNode) string + P = func(n *IAVLNode) string { + if n.left == nil && n.right == nil { + return fmt.Sprintf("%v", n.key) + } else if n.left == nil { + return fmt.Sprintf("(- %v %v)", n.key, P(n.right)) + } else if n.right == nil { + return fmt.Sprintf("(%v %v -)", P(n.left), n.key) + } else { + return fmt.Sprintf("(%v %v %v)", P(n.left), n.key, P(n.right)) } + } - j = 0 - for tn, next := tree.TraverseTreePostOrder(T.Root())(); next != nil; tn, next = next () { - if int(tn.Key().(Int)) != postorder[j] { - t.Error("key in wrong spot post-order") - } - j += 1 + expectPut := func(n *IAVLNode, i int, repr string) { + n2, updated := n.Put(Int32(i), nil) + if updated == true || P(n2) != repr { + t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v", + i, P(n), repr, P(n2), updated) } - */ } - test(NewIAVLTree()) + + expectRemove := func(n *IAVLNode, i int, repr string) { + n2, value, err := n.Remove(Int32(i)) + if value != nil || err != nil || P(n2) != repr { + t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v", + i, P(n), repr, P(n2), value, err) + } + } + + //////// Test Put cases: + + // Case 1: + n1 := N(N(nil, 4, nil), 20, nil) + if P(n1) != "(4 20 -)" { t.Fatalf("Got %v", P(n1)) } + + expectPut(n1, 15, "(4 15 20)") + expectPut(n1, 8, "(4 8 20)") + + // Case 2: + n2 := N(N(N(nil, 3, nil), 4, N(nil, 9, nil)), 20, N(nil, 26, nil)) + if P(n2) != "((3 4 9) 20 26)" { t.Fatalf("Got %v", P(n2)) } + + expectPut(n2, 15, "((3 4 -) 9 (15 20 26))") + expectPut(n2, 8, "((3 4 8) 9 (- 20 26))") + + // Case 2: + n3 := N(N(N(N(nil, 2, nil), 3, nil), 4, N(N(nil, 7, nil), 9, N(nil, 11, nil))), + 20, N(N(nil, 21, nil), 26, N(nil, 30, nil))) + if P(n3) != "(((2 3 -) 4 (7 9 11)) 20 (21 26 30))" { t.Fatalf("Got %v", P(n3)) } + + expectPut(n3, 15, "(((2 3 -) 4 7) 9 ((- 11 15) 20 (21 26 30)))") + expectPut(n3, 8, "(((2 3 -) 4 (- 7 8)) 9 (11 20 (21 26 30)))") + + + //////// Test Remove cases: + + // Case 4: + n4 := N(N(nil, 1, nil), 2, N(N(nil, 3, nil), 4, N(nil, 5, nil))) + if P(n4) != "(1 2 (3 4 5))" { t.Fatalf("Got %v", P(n4)) } + + expectRemove(n4, 1, "((- 2 3) 4 5)") + + // Case 5: + n5 := N(N(N(nil, 1, nil), 2, N(N(nil, 3, nil), 4, N(nil, 5, nil))), 6, + N(N(N(nil, 7, nil), 8, nil), 9, N(N(nil, 10, nil), 11, N(nil, 12, N(nil, 13, nil))))) + if P(n5) != "((1 2 (3 4 5)) 6 ((7 8 -) 9 (10 11 (- 12 13))))" { t.Fatalf("Got %v", P(n5)) } + + expectRemove(n5, 1, "(((- 2 3) 4 5) 6 ((7 8 -) 9 (10 11 (- 12 13))))") + + // Case 6: + n6 := N(N(N(nil, 1, nil), 2, N(nil, 3, N(nil, 4, nil))), 5, + N(N(N(nil, 6, nil), 7, nil), 8, N(N(nil, 9, nil), 10, N(nil, 11, N(nil, 12, nil))))) + if P(n6) != "((1 2 (- 3 4)) 5 ((6 7 -) 8 (9 10 (- 11 12))))" { t.Fatalf("Got %v", P(n6)) } + + expectRemove(n6, 1, "(((2 3 4) 5 (6 7 -)) 8 (9 10 (- 11 12)))") + } diff --git a/merkle/types.go b/merkle/types.go index 606b0f037..a7bb33ee9 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -29,6 +29,9 @@ type Node interface { Size() int Has(key Sortable) bool Get(key Sortable) (value interface{}, err error) + + Put(key Sortable, value interface{}) (_ *IAVLNode, updated bool) + Remove(key Sortable) (_ *IAVLNode, value interface{}, err error) } type NodeIterator func() (node Node, next NodeIterator) diff --git a/merkle/tree.go b/merkle/util.go similarity index 58% rename from merkle/tree.go rename to merkle/util.go index e694e6f02..b4871645d 100644 --- a/merkle/tree.go +++ b/merkle/util.go @@ -1,5 +1,9 @@ package merkle +import ( + "fmt" +) + func Iterator(node Node) NodeIterator { stack := make([]Node, 0, 10) var cur Node = node @@ -28,3 +32,23 @@ func pop(stack []Node) ([]Node, Node) { return stack[0:len(stack)-1], stack[len(stack)-1] } } + +func PrintIAVLNode(node *IAVLNode) { + fmt.Println("==== NODE") + printIAVLNode(node, 0) + fmt.Println("==== END") +} + +func printIAVLNode(node *IAVLNode, indent int) { + indentPrefix := "" + for i:=0; i