diff --git a/merkle/iavl.go b/merkle/iavl.go index 000326b8d..6b6ebde46 100644 --- a/merkle/iavl.go +++ b/merkle/iavl.go @@ -30,11 +30,11 @@ func (self *IAVLTree) Height() uint8 { } func (self *IAVLTree) Has(key Key) bool { - return self.root.Has(key) + return self.root.Has(nil, key) } func (self *IAVLTree) Put(key Key, value Value) (err error) { - self.root, _ = self.root.Put(key, value) + self.root, _ = self.root.Put(nil, key, value) return nil } @@ -43,11 +43,11 @@ func (self *IAVLTree) Hash() ([]byte, uint64) { } func (self *IAVLTree) Get(key Key) (value Value, err error) { - return self.root.Get(key) + return self.root.Get(nil, key) } func (self *IAVLTree) Remove(key Key) (value Value, err error) { - new_root, value, err := self.root.Remove(key) + new_root, value, err := self.root.Remove(nil, key) if err != nil { return nil, err } @@ -94,13 +94,23 @@ func (self *IAVLNode) Value() Value { return self.value } -func (self *IAVLNode) Left() Node { +func (self *IAVLNode) Left(db Db) Node { if self.left == nil { return nil } - return self.left + return self.left_filled(db) } -func (self *IAVLNode) Right() Node { +func (self *IAVLNode) Right(db Db) Node { if self.right == nil { return nil } + return self.right_filled(db) +} + +func (self *IAVLNode) left_filled(db Db) *IAVLNode { + // XXX + return self.left +} + +func (self *IAVLNode) right_filled(db Db) *IAVLNode { + // XXX return self.right } @@ -111,29 +121,29 @@ func (self *IAVLNode) Size() uint64 { return self.size } -func (self *IAVLNode) Has(key Key) (has bool) { +func (self *IAVLNode) Has(db Db, key Key) (has bool) { if self == nil { return false } if self.key.Equals(key) { return true } else if key.Less(self.key) { - return self.left.Has(key) + return self.left_filled(db).Has(db, key) } else { - return self.right.Has(key) + return self.right_filled(db).Has(db, key) } } -func (self *IAVLNode) Get(key Key) (value Value, err error) { +func (self *IAVLNode) Get(db Db, key Key) (value Value, err error) { if self == nil { return nil, NotFound(key) } if self.key.Equals(key) { return self.value, nil } else if key.Less(self.key) { - return self.left.Get(key) + return self.left_filled(db).Get(db, key) } else { - return self.right.Get(key) + return self.right_filled(db).Get(db, key) } } @@ -176,6 +186,10 @@ func (self *IAVLNode) WriteTo(writer io.Writer) (written int64, hashCount uint64 if self.right != nil { nodeDesc |= 0x04 } write([]byte{nodeDesc}) + // node height & size + write(UInt8(self.height).Bytes()) + write(UInt64(self.size).Bytes()) + // node key keyBytes := self.key.Bytes() if len(keyBytes) > 255 { panic("key is too long") } @@ -209,7 +223,7 @@ func (self *IAVLNode) WriteTo(writer io.Writer) (written int64, hashCount uint64 // Returns a new tree (unless node is the root) & a copy of the popped node. // Can only pop nodes that have one or no children. -func (self *IAVLNode) pop_node(node *IAVLNode) (new_self, new_node *IAVLNode) { +func (self *IAVLNode) pop_node(db Db, node *IAVLNode) (new_self, new_node *IAVLNode) { if self == nil { panic("self can't be nil") } else if node == nil { @@ -222,96 +236,96 @@ func (self *IAVLNode) pop_node(node *IAVLNode) (new_self, new_node *IAVLNode) { var n *IAVLNode if node.left != nil { - n = node.left + n = node.left_filled(db) } else if node.right != nil { - n = node.right + n = node.right_filled(db) } else { n = nil } node = node.Copy() node.left = nil node.right = nil - node.calc_height_and_size() + node.calc_height_and_size(db) return n, node } else { self = self.Copy() if node.key.Less(self.key) { - self.left, node = self.left.pop_node(node) + self.left, node = self.left_filled(db).pop_node(db, node) } else { - self.right, node = self.right.pop_node(node) + self.right, node = self.right_filled(db).pop_node(db, node) } - self.calc_height_and_size() + self.calc_height_and_size(db) return self, node } } -func (self *IAVLNode) rotate_right() *IAVLNode { +func (self *IAVLNode) rotate_right(db Db) *IAVLNode { self = self.Copy() - sl := self.left.Copy() + sl := self.left_filled(db).Copy() slr := sl.right sl.right = self self.left = slr - self.calc_height_and_size() - sl.calc_height_and_size() + self.calc_height_and_size(db) + sl.calc_height_and_size(db) return sl } -func (self *IAVLNode) rotate_left() *IAVLNode { +func (self *IAVLNode) rotate_left(db Db) *IAVLNode { self = self.Copy() - sr := self.right.Copy() + sr := self.right_filled(db).Copy() srl := sr.left sr.left = self self.right = srl - self.calc_height_and_size() - sr.calc_height_and_size() + self.calc_height_and_size(db) + sr.calc_height_and_size(db) return sr } -func (self *IAVLNode) calc_height_and_size() { - self.height = maxUint8(self.left.Height(), self.right.Height()) + 1 - self.size = self.left.Size() + self.right.Size() + 1 +func (self *IAVLNode) calc_height_and_size(db Db) { + self.height = maxUint8(self.left_filled(db).Height(), self.right_filled(db).Height()) + 1 + self.size = self.left_filled(db).Size() + self.right_filled(db).Size() + 1 } -func (self *IAVLNode) calc_balance() int { +func (self *IAVLNode) calc_balance(db Db) int { if self == nil { return 0 } - return int(self.left.Height()) - int(self.right.Height()) + return int(self.left_filled(db).Height()) - int(self.right_filled(db).Height()) } -func (self *IAVLNode) balance() (new_self *IAVLNode) { - balance := self.calc_balance() +func (self *IAVLNode) balance(db Db) (new_self *IAVLNode) { + balance := self.calc_balance(db) if (balance > 1) { - if (self.left.calc_balance() >= 0) { + if (self.left_filled(db).calc_balance(db) >= 0) { // Left Left Case - return self.rotate_right() + return self.rotate_right(db) } else { // Left Right Case self = self.Copy() - self.left = self.left.rotate_left() + self.left = self.left_filled(db).rotate_left(db) //self.calc_height_and_size() - return self.rotate_right() + return self.rotate_right(db) } } if (balance < -1) { - if (self.right.calc_balance() <= 0) { + if (self.right_filled(db).calc_balance(db) <= 0) { // Right Right Case - return self.rotate_left() + return self.rotate_left(db) } else { // Right Left Case self = self.Copy() - self.right = self.right.rotate_right() + self.right = self.right_filled(db).rotate_right(db) //self.calc_height_and_size() - return self.rotate_left() + return self.rotate_left(db) } } // Nothing changed @@ -319,7 +333,7 @@ func (self *IAVLNode) balance() (new_self *IAVLNode) { } // TODO: don't clear the hash if the value hasn't changed. -func (self *IAVLNode) Put(key Key, value Value) (_ *IAVLNode, updated bool) { +func (self *IAVLNode) Put(db Db, key Key, value Value) (_ *IAVLNode, updated bool) { if self == nil { return &IAVLNode{key: key, value: value, height: 1, size: 1, hash: nil}, false } @@ -332,38 +346,38 @@ func (self *IAVLNode) Put(key Key, value Value) (_ *IAVLNode, updated bool) { } if key.Less(self.key) { - self.left, updated = self.left.Put(key, value) + self.left, updated = self.left_filled(db).Put(db, key, value) } else { - self.right, updated = self.right.Put(key, value) + self.right, updated = self.right_filled(db).Put(db, key, value) } if updated { return self, updated } else { - self.calc_height_and_size() - return self.balance(), updated + self.calc_height_and_size(db) + return self.balance(db), updated } } -func (self *IAVLNode) Remove(key Key) (new_self *IAVLNode, value Value, err error) { +func (self *IAVLNode) Remove(db Db, key Key) (new_self *IAVLNode, value Value, err error) { if self == nil { return nil, nil, NotFound(key) } if self.key.Equals(key) { if self.left != nil && self.right != nil { - if self.left.Size() < self.right.Size() { - self, new_self = self.pop_node(self.right.lmd()) + if self.left_filled(db).Size() < self.right_filled(db).Size() { + self, new_self = self.pop_node(db, self.right_filled(db).lmd(db)) } else { - self, new_self = self.pop_node(self.left.rmd()) + self, new_self = self.pop_node(db, self.left_filled(db).rmd(db)) } new_self.left = self.left new_self.right = self.right - new_self.calc_height_and_size() + new_self.calc_height_and_size(db) return new_self, self.value, nil } else if self.left == nil { - return self.right, self.value, nil + return self.right_filled(db), self.value, nil } else if self.right == nil { - return self.left, self.value, nil + return self.left_filled(db), self.value, nil } else { return nil, self.value, nil } @@ -374,8 +388,8 @@ func (self *IAVLNode) Remove(key Key) (new_self *IAVLNode, value Value, err erro return self, nil, NotFound(key) } var new_left *IAVLNode - new_left, value, err = self.left.Remove(key) - if new_left == self.left { // not found + new_left, value, err = self.left_filled(db).Remove(db, key) + if new_left == self.left_filled(db) { // not found return self, nil, err } else if err != nil { // some other error return self, value, err @@ -387,8 +401,8 @@ func (self *IAVLNode) Remove(key Key) (new_self *IAVLNode, value Value, err erro return self, nil, NotFound(key) } var new_right *IAVLNode - new_right, value, err = self.right.Remove(key) - if new_right == self.right { // not found + new_right, value, err = self.right_filled(db).Remove(db, key) + if new_right == self.right_filled(db) { // not found return self, nil, err } else if err != nil { // some other error return self, value, err @@ -396,8 +410,8 @@ func (self *IAVLNode) Remove(key Key) (new_self *IAVLNode, value Value, err erro self = self.Copy() self.right = new_right } - self.calc_height_and_size() - return self.balance(), value, err + self.calc_height_and_size(db) + return self.balance(db), value, err } func (self *IAVLNode) Height() uint8 { @@ -419,12 +433,12 @@ func (self *IAVLNode) _md(side func(*IAVLNode)*IAVLNode) (*IAVLNode) { } } -func (self *IAVLNode) lmd() (*IAVLNode) { - return self._md(func(node *IAVLNode)*IAVLNode { return node.left }) +func (self *IAVLNode) lmd(db Db) (*IAVLNode) { + return self._md(func(node *IAVLNode)*IAVLNode { return node.left_filled(db) }) } -func (self *IAVLNode) rmd() (*IAVLNode) { - return self._md(func(node *IAVLNode)*IAVLNode { return node.right }) +func (self *IAVLNode) rmd(db Db) (*IAVLNode) { + return self._md(func(node *IAVLNode)*IAVLNode { return node.right_filled(db) }) } func maxUint8(a, b uint8) uint8 { diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index d57d5c139..295c1f17c 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -60,11 +60,11 @@ func TestImmutableAvlPutHasGetRemove(t *testing.T) { for i := range records { r := randomRecord() records[i] = r - tree, updated = tree.Put(r.key, String("")) + tree, updated = tree.Put(nil, r.key, String("")) if updated { t.Error("should have not been updated") } - tree, updated = tree.Put(r.key, r.value) + tree, updated = tree.Put(nil, r.key, r.value) if !updated { t.Error("should have been updated") } @@ -74,13 +74,13 @@ func TestImmutableAvlPutHasGetRemove(t *testing.T) { } for _, r := range records { - if has := tree.Has(r.key); !has { + if has := tree.Has(nil, r.key); !has { t.Error("Missing key") } - if has := tree.Has(randstr(12)); has { + if has := tree.Has(nil, randstr(12)); has { t.Error("Table has extra key") } - if val, err := tree.Get(r.key); err != nil { + if val, err := tree.Get(nil, r.key); err != nil { t.Error(err, val.(String), r.value) } else if !(val.(String)).Equals(r.value) { t.Error("wrong value") @@ -88,19 +88,19 @@ func TestImmutableAvlPutHasGetRemove(t *testing.T) { } for i, x := range records { - if tree, val, err = tree.Remove(x.key); err != nil { + if tree, val, err = tree.Remove(nil, x.key); err != nil { t.Error(err) } else if !(val.(String)).Equals(x.value) { t.Error("wrong value") } for _, r := range records[i+1:] { - if has := tree.Has(r.key); !has { + if has := tree.Has(nil, r.key); !has { t.Error("Missing key") } - if has := tree.Has(randstr(12)); has { + if has := tree.Has(nil, randstr(12)); has { t.Error("Table has extra key") } - if val, err := tree.Get(r.key); err != nil { + if val, err := tree.Get(nil, r.key); err != nil { t.Error(err) } else if !(val.(String)).Equals(r.value) { t.Error("wrong value") @@ -178,7 +178,7 @@ func TestGriffin(t *testing.T) { left: l, right: r, } - n.calc_height_and_size() + n.calc_height_and_size(nil) n.Hash() return n } @@ -189,11 +189,11 @@ func TestGriffin(t *testing.T) { if n.left == nil && n.right == nil { return fmt.Sprintf("%v", n.key) } else if n.left == nil { - return fmt.Sprintf("(- %v %v)", n.key, P(n.right)) + return fmt.Sprintf("(- %v %v)", n.key, P(n.right_filled(nil))) } else if n.right == nil { - return fmt.Sprintf("(%v %v -)", P(n.left), n.key) + return fmt.Sprintf("(%v %v -)", P(n.left_filled(nil)), n.key) } else { - return fmt.Sprintf("(%v %v %v)", P(n.left), n.key, P(n.right)) + return fmt.Sprintf("(%v %v %v)", P(n.left_filled(nil)), n.key, P(n.right_filled(nil))) } } @@ -218,7 +218,7 @@ func TestGriffin(t *testing.T) { } expectPut := func(n *IAVLNode, i int, repr string, hashCount uint64) { - n2, updated := n.Put(Int32(i), nil) + n2, updated := n.Put(nil, Int32(i), nil) // ensure node was added & structure is as expected. if updated == true || P(n2) != repr { t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v", @@ -229,7 +229,7 @@ func TestGriffin(t *testing.T) { } expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) { - n2, value, err := n.Remove(Int32(i)) + n2, value, err := n.Remove(nil, Int32(i)) // ensure node was added & structure is as expected. if value != nil || err != nil || P(n2) != repr { t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v", diff --git a/merkle/types.go b/merkle/types.go index 4f99383ac..ecd1b0ca0 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -27,21 +27,26 @@ type Tree interface { Remove(Key) (Value, error) } +type Db interface { + Get([]byte) ([]byte, error) + Put([]byte, []byte) error +} + type Node interface { Key() Key Value() Value - Left() Node - Right() Node + Left(Db) Node + Right(Db) Node Size() uint64 Height() uint8 - Has(Key) bool - Get(Key) (Value, error) + Has(Db, Key) bool + Get(Db, Key) (Value, error) Hash() ([]byte, uint64) Bytes() []byte - Put(Key, Value) (*IAVLNode, bool) - Remove(Key) (*IAVLNode, Value, error) + Put(Db, Key, Value) (*IAVLNode, bool) + Remove(Db, Key) (*IAVLNode, Value, error) } type NodeIterator func() Node diff --git a/merkle/util.go b/merkle/util.go index c2076bbfa..41bbfa7c8 100644 --- a/merkle/util.go +++ b/merkle/util.go @@ -12,11 +12,11 @@ func Iterator(node Node) NodeIterator { if len(stack) > 0 || cur != nil { for cur != nil { stack = append(stack, cur) - cur = cur.Left() + cur = cur.Left(nil) } stack, cur = pop(stack) tn = cur - cur = cur.Right() + cur = cur.Right(nil) return tn } else { return nil @@ -47,8 +47,8 @@ func printIAVLNode(node *IAVLNode, indent int) { if node == nil { fmt.Printf("%s--\n", indentPrefix) } else { - printIAVLNode(node.left, indent+1) + printIAVLNode(node.left_filled(nil), indent+1) fmt.Printf("%s%v:%v\n", indentPrefix, node.key, node.height) - printIAVLNode(node.right, indent+1) + printIAVLNode(node.right_filled(nil), indent+1) } }