Browse Source

Fixed merkle implementation to fit official algorithm.

pull/9/head
Jae Kwon 10 years ago
parent
commit
795e183273
4 changed files with 233 additions and 96 deletions
  1. +108
    -73
      merkle/iavl.go
  2. +98
    -23
      merkle/iavl_test.go
  3. +3
    -0
      merkle/types.go
  4. +24
    -0
      merkle/util.go

+ 108
- 73
merkle/iavl.go View File

@ -1,6 +1,7 @@
package merkle
import (
//"fmt"
"hash"
"crypto/sha256"
)
@ -100,18 +101,19 @@ func (self *IAVLNode) Get(key Sortable) (value interface{}, err error) {
}
}
// Copies and pops node from the tree.
// Returns a new tree (unless node is the root) & new (popped) node.
// Returns a new tree (unless node is the root) & a copy of the popped node.
// Can only pop nodes that have one or no children.
func (self *IAVLNode) pop_node(node *IAVLNode) (new_self, new_node *IAVLNode) {
if node == nil {
if self == nil {
panic("self can't be nil")
} else if node == nil {
panic("node can't be nil")
} else if node.left != nil && node.right != nil {
panic("node must not have both left and right")
}
if self == nil {
return nil, node.Copy(true)
} else if self == node {
if self == node {
var n *IAVLNode
if node.left != nil {
n = node.left
@ -124,18 +126,19 @@ func (self *IAVLNode) pop_node(node *IAVLNode) (new_self, new_node *IAVLNode) {
node.left = nil
node.right = nil
return n, node
}
self = self.Copy(false)
if node.key.Less(self.key) {
self.left, node = self.left.pop_node(node)
} else {
self.right, node = self.right.pop_node(node)
}
self.height = max(self.left.Height(), self.right.Height()) + 1
return self, node
self = self.Copy(false)
if node.key.Less(self.key) {
self.left, node = self.left.pop_node(node)
} else {
self.right, node = self.right.pop_node(node)
}
self.calc_height()
return self, node
}
}
// Pushes the node to the tree, returns a new tree
@ -156,51 +159,74 @@ func (self *IAVLNode) push_node(node *IAVLNode) *IAVLNode {
} else {
self.right = self.right.push_node(node)
}
self.height = max(self.left.Height(), self.right.Height()) + 1
self.calc_height()
return self
}
func (self *IAVLNode) rotate_right() *IAVLNode {
if self == nil {
return self
}
if self.left == nil {
return self
}
return self.rotate(self.left.rmd)
self = self.Copy(false)
sl := self.left.Copy(false)
slr := sl.right
sl.right = self
self.left = slr
self.calc_height()
sl.calc_height()
return sl
}
func (self *IAVLNode) rotate_left() *IAVLNode {
if self == nil {
return self
}
if self.right == nil {
return self
}
return self.rotate(self.right.lmd)
self = self.Copy(false)
sr := self.right.Copy(false)
srl := sr.left
sr.left = self
self.right = srl
self.calc_height()
sr.calc_height()
return sr
}
func (self *IAVLNode) rotate(get_new_root func() *IAVLNode) *IAVLNode {
self, new_root := self.pop_node(get_new_root())
new_root.left = self.left
new_root.right = self.right
self.hash = nil
self.left = nil
self.right = nil
return new_root.push_node(self)
func (self *IAVLNode) calc_height() {
self.height = max(self.left.Height(), self.right.Height()) + 1
}
func (self *IAVLNode) balance() *IAVLNode {
func (self *IAVLNode) calc_balance() int {
if self == nil {
return self
return 0
}
for abs(self.left.Height() - self.right.Height()) > 2 {
if self.left.Height() > self.right.Height() {
self = self.rotate_right()
return self.left.Height() - self.right.Height()
}
func (self *IAVLNode) balance() (new_self *IAVLNode) {
balance := self.calc_balance()
if (balance > 1) {
if (self.left.calc_balance() >= 0) {
// Left Left Case
return self.rotate_right()
} else {
self = self.rotate_left()
// Left Right Case
self = self.Copy(false)
self.left = self.left.rotate_left()
return self.rotate_right()
}
}
if (balance < -1) {
if (self.right.calc_balance() <= 0) {
// Right Right Case
return self.rotate_left()
} else {
// Right Left Case
self = self.Copy(false)
self.right = self.right.rotate_right()
return self.rotate_left()
}
}
// Nothing changed
return self
}
@ -222,31 +248,29 @@ func (self *IAVLNode) Put(key Sortable, value interface{}) (_ *IAVLNode, updated
} else {
self.right, updated = self.right.Put(key, value)
}
self.height = max(self.left.Height(), self.right.Height()) + 1
if !updated {
self.height += 1
if updated {
return self, updated
} else {
self.calc_height()
return self.balance(), updated
}
return self, updated
}
func (self *IAVLNode) Remove(key Sortable) (_ *IAVLNode, value interface{}, err error) {
func (self *IAVLNode) Remove(key Sortable) (new_self *IAVLNode, value interface{}, err error) {
if self == nil {
return nil, nil, NotFound(key)
}
if self.key.Equals(key) {
if self.left != nil && self.right != nil {
var new_root *IAVLNode
if self.left.Size() < self.right.Size() {
self, new_root = self.pop_node(self.right.lmd())
self, new_self = self.pop_node(self.right.lmd())
} else {
self, new_root = self.pop_node(self.left.rmd())
self, new_self = self.pop_node(self.left.rmd())
}
new_root.left = self.left
new_root.right = self.right
return new_root, self.value, nil
new_self.left = self.left
new_self.right = self.right
return new_self, self.value, nil
} else if self.left == nil {
return self.right, self.value, nil
} else if self.right == nil {
@ -256,20 +280,35 @@ func (self *IAVLNode) Remove(key Sortable) (_ *IAVLNode, value interface{}, err
}
}
self = self.Copy(true)
if key.Less(self.key) {
self.left, value, err = self.left.Remove(key)
} else {
self.right, value, err = self.right.Remove(key)
}
if err == nil {
self.hash = nil
self.height = max(self.left.Height(), self.right.Height()) + 1
return self.balance(), value, err
if self.left == nil {
return self, nil, NotFound(key)
}
var new_left *IAVLNode
new_left, value, err = self.left.Remove(key)
if new_left == self.left { // not found
return self, nil, err
} else if err != nil { // some other error
return self, value, err
}
self = self.Copy(false)
self.left = new_left
} else {
return self, value, err
if self.right == nil {
return self, nil, NotFound(key)
}
var new_right *IAVLNode
new_right, value, err = self.right.Remove(key)
if new_right == self.right { // not found
return self, nil, err
} else if err != nil { // some other error
return self, value, err
}
self = self.Copy(false)
self.right = new_right
}
self.calc_height()
return self.balance(), value, err
}
func (self *IAVLNode) Height() int {
@ -296,16 +335,12 @@ func (self *IAVLNode) Value() interface{} {
}
func (self *IAVLNode) Left() Node {
if self.left == nil {
return nil
}
if self.left == nil { return nil }
return self.left
}
func (self *IAVLNode) Right() Node {
if self.right == nil {
return nil
}
if self.right == nil { return nil }
return self.right
}


+ 98
- 23
merkle/iavl_test.go View File

@ -3,6 +3,7 @@ package merkle
import "testing"
import (
"fmt"
"os"
"bytes"
"math/rand"
@ -150,14 +151,6 @@ func TestTraversals(t *testing.T) {
var order []int = []int{
6, 1, 8, 2, 4 , 9 , 5 , 7 , 0 , 3 ,
}
/*
var preorder []int = []int {
17, 7, 5, 1, 12, 9, 13, 19, 18, 20,
}
var postorder []int = []int {
1, 5, 9, 13, 12, 7, 18, 20, 19, 17,
}
*/
test := func(T Tree) {
t.Logf("%T", T)
@ -177,24 +170,106 @@ func TestTraversals(t *testing.T) {
}
j += 1
}
}
test(NewIAVLTree())
}
/*
j = 0
for tn, next := tree.TraverseTreePreOrder(T.Root())(); next != nil; tn, next = next () {
if int(tn.Key().(Int)) != preorder[j] {
t.Error("key in wrong spot pre-order")
}
j += 1
func BenchmarkSha256(b *testing.B) {
b.StopTimer()
str := []byte(randstr(32))
ransha256 := func() []byte {
return CalcSha256(str)
}
b.StartTimer()
for i := 0; i < b.N; i++ {
ransha256()
}
}
// from http://stackoverflow.com/questions/3955680/how-to-check-if-my-avl-tree-implementation-is-correct
func TestGriffin(t *testing.T) {
N := func(l *IAVLNode, i int, r *IAVLNode) *IAVLNode {
n := &IAVLNode{Int32(i), nil, -1, nil, l, r}
n.calc_height()
return n
}
var P func(*IAVLNode) string
P = func(n *IAVLNode) string {
if n.left == nil && n.right == nil {
return fmt.Sprintf("%v", n.key)
} else if n.left == nil {
return fmt.Sprintf("(- %v %v)", n.key, P(n.right))
} else if n.right == nil {
return fmt.Sprintf("(%v %v -)", P(n.left), n.key)
} else {
return fmt.Sprintf("(%v %v %v)", P(n.left), n.key, P(n.right))
}
}
j = 0
for tn, next := tree.TraverseTreePostOrder(T.Root())(); next != nil; tn, next = next () {
if int(tn.Key().(Int)) != postorder[j] {
t.Error("key in wrong spot post-order")
}
j += 1
expectPut := func(n *IAVLNode, i int, repr string) {
n2, updated := n.Put(Int32(i), nil)
if updated == true || P(n2) != repr {
t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v",
i, P(n), repr, P(n2), updated)
}
*/
}
test(NewIAVLTree())
expectRemove := func(n *IAVLNode, i int, repr string) {
n2, value, err := n.Remove(Int32(i))
if value != nil || err != nil || P(n2) != repr {
t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v",
i, P(n), repr, P(n2), value, err)
}
}
//////// Test Put cases:
// Case 1:
n1 := N(N(nil, 4, nil), 20, nil)
if P(n1) != "(4 20 -)" { t.Fatalf("Got %v", P(n1)) }
expectPut(n1, 15, "(4 15 20)")
expectPut(n1, 8, "(4 8 20)")
// Case 2:
n2 := N(N(N(nil, 3, nil), 4, N(nil, 9, nil)), 20, N(nil, 26, nil))
if P(n2) != "((3 4 9) 20 26)" { t.Fatalf("Got %v", P(n2)) }
expectPut(n2, 15, "((3 4 -) 9 (15 20 26))")
expectPut(n2, 8, "((3 4 8) 9 (- 20 26))")
// Case 2:
n3 := N(N(N(N(nil, 2, nil), 3, nil), 4, N(N(nil, 7, nil), 9, N(nil, 11, nil))),
20, N(N(nil, 21, nil), 26, N(nil, 30, nil)))
if P(n3) != "(((2 3 -) 4 (7 9 11)) 20 (21 26 30))" { t.Fatalf("Got %v", P(n3)) }
expectPut(n3, 15, "(((2 3 -) 4 7) 9 ((- 11 15) 20 (21 26 30)))")
expectPut(n3, 8, "(((2 3 -) 4 (- 7 8)) 9 (11 20 (21 26 30)))")
//////// Test Remove cases:
// Case 4:
n4 := N(N(nil, 1, nil), 2, N(N(nil, 3, nil), 4, N(nil, 5, nil)))
if P(n4) != "(1 2 (3 4 5))" { t.Fatalf("Got %v", P(n4)) }
expectRemove(n4, 1, "((- 2 3) 4 5)")
// Case 5:
n5 := N(N(N(nil, 1, nil), 2, N(N(nil, 3, nil), 4, N(nil, 5, nil))), 6,
N(N(N(nil, 7, nil), 8, nil), 9, N(N(nil, 10, nil), 11, N(nil, 12, N(nil, 13, nil)))))
if P(n5) != "((1 2 (3 4 5)) 6 ((7 8 -) 9 (10 11 (- 12 13))))" { t.Fatalf("Got %v", P(n5)) }
expectRemove(n5, 1, "(((- 2 3) 4 5) 6 ((7 8 -) 9 (10 11 (- 12 13))))")
// Case 6:
n6 := N(N(N(nil, 1, nil), 2, N(nil, 3, N(nil, 4, nil))), 5,
N(N(N(nil, 6, nil), 7, nil), 8, N(N(nil, 9, nil), 10, N(nil, 11, N(nil, 12, nil)))))
if P(n6) != "((1 2 (- 3 4)) 5 ((6 7 -) 8 (9 10 (- 11 12))))" { t.Fatalf("Got %v", P(n6)) }
expectRemove(n6, 1, "(((2 3 4) 5 (6 7 -)) 8 (9 10 (- 11 12)))")
}

+ 3
- 0
merkle/types.go View File

@ -29,6 +29,9 @@ type Node interface {
Size() int
Has(key Sortable) bool
Get(key Sortable) (value interface{}, err error)
Put(key Sortable, value interface{}) (_ *IAVLNode, updated bool)
Remove(key Sortable) (_ *IAVLNode, value interface{}, err error)
}
type NodeIterator func() (node Node, next NodeIterator)


merkle/tree.go → merkle/util.go View File


Loading…
Cancel
Save