Browse Source

add db to api for lazy loading

pull/9/head
Jae Kwon 11 years ago
parent
commit
268eaa79f0
4 changed files with 109 additions and 90 deletions
  1. +79
    -65
      merkle/iavl.go
  2. +15
    -15
      merkle/iavl_test.go
  3. +11
    -6
      merkle/types.go
  4. +4
    -4
      merkle/util.go

+ 79
- 65
merkle/iavl.go View File

@ -30,11 +30,11 @@ func (self *IAVLTree) Height() uint8 {
} }
func (self *IAVLTree) Has(key Key) bool { func (self *IAVLTree) Has(key Key) bool {
return self.root.Has(key)
return self.root.Has(nil, key)
} }
func (self *IAVLTree) Put(key Key, value Value) (err error) { func (self *IAVLTree) Put(key Key, value Value) (err error) {
self.root, _ = self.root.Put(key, value)
self.root, _ = self.root.Put(nil, key, value)
return nil return nil
} }
@ -43,11 +43,11 @@ func (self *IAVLTree) Hash() ([]byte, uint64) {
} }
func (self *IAVLTree) Get(key Key) (value Value, err error) { func (self *IAVLTree) Get(key Key) (value Value, err error) {
return self.root.Get(key)
return self.root.Get(nil, key)
} }
func (self *IAVLTree) Remove(key Key) (value Value, err error) { func (self *IAVLTree) Remove(key Key) (value Value, err error) {
new_root, value, err := self.root.Remove(key)
new_root, value, err := self.root.Remove(nil, key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -94,13 +94,23 @@ func (self *IAVLNode) Value() Value {
return self.value return self.value
} }
func (self *IAVLNode) Left() Node {
func (self *IAVLNode) Left(db Db) Node {
if self.left == nil { return nil } if self.left == nil { return nil }
return self.left
return self.left_filled(db)
} }
func (self *IAVLNode) Right() Node {
func (self *IAVLNode) Right(db Db) Node {
if self.right == nil { return nil } if self.right == nil { return nil }
return self.right_filled(db)
}
func (self *IAVLNode) left_filled(db Db) *IAVLNode {
// XXX
return self.left
}
func (self *IAVLNode) right_filled(db Db) *IAVLNode {
// XXX
return self.right return self.right
} }
@ -111,29 +121,29 @@ func (self *IAVLNode) Size() uint64 {
return self.size return self.size
} }
func (self *IAVLNode) Has(key Key) (has bool) {
func (self *IAVLNode) Has(db Db, key Key) (has bool) {
if self == nil { if self == nil {
return false return false
} }
if self.key.Equals(key) { if self.key.Equals(key) {
return true return true
} else if key.Less(self.key) { } else if key.Less(self.key) {
return self.left.Has(key)
return self.left_filled(db).Has(db, key)
} else { } else {
return self.right.Has(key)
return self.right_filled(db).Has(db, key)
} }
} }
func (self *IAVLNode) Get(key Key) (value Value, err error) {
func (self *IAVLNode) Get(db Db, key Key) (value Value, err error) {
if self == nil { if self == nil {
return nil, NotFound(key) return nil, NotFound(key)
} }
if self.key.Equals(key) { if self.key.Equals(key) {
return self.value, nil return self.value, nil
} else if key.Less(self.key) { } else if key.Less(self.key) {
return self.left.Get(key)
return self.left_filled(db).Get(db, key)
} else { } else {
return self.right.Get(key)
return self.right_filled(db).Get(db, key)
} }
} }
@ -176,6 +186,10 @@ func (self *IAVLNode) WriteTo(writer io.Writer) (written int64, hashCount uint64
if self.right != nil { nodeDesc |= 0x04 } if self.right != nil { nodeDesc |= 0x04 }
write([]byte{nodeDesc}) write([]byte{nodeDesc})
// node height & size
write(UInt8(self.height).Bytes())
write(UInt64(self.size).Bytes())
// node key // node key
keyBytes := self.key.Bytes() keyBytes := self.key.Bytes()
if len(keyBytes) > 255 { panic("key is too long") } if len(keyBytes) > 255 { panic("key is too long") }
@ -209,7 +223,7 @@ func (self *IAVLNode) WriteTo(writer io.Writer) (written int64, hashCount uint64
// Returns a new tree (unless node is the root) & a copy of the popped node. // Returns a new tree (unless node is the root) & a copy of the popped node.
// Can only pop nodes that have one or no children. // Can only pop nodes that have one or no children.
func (self *IAVLNode) pop_node(node *IAVLNode) (new_self, new_node *IAVLNode) {
func (self *IAVLNode) pop_node(db Db, node *IAVLNode) (new_self, new_node *IAVLNode) {
if self == nil { if self == nil {
panic("self can't be nil") panic("self can't be nil")
} else if node == nil { } else if node == nil {
@ -222,96 +236,96 @@ func (self *IAVLNode) pop_node(node *IAVLNode) (new_self, new_node *IAVLNode) {
var n *IAVLNode var n *IAVLNode
if node.left != nil { if node.left != nil {
n = node.left
n = node.left_filled(db)
} else if node.right != nil { } else if node.right != nil {
n = node.right
n = node.right_filled(db)
} else { } else {
n = nil n = nil
} }
node = node.Copy() node = node.Copy()
node.left = nil node.left = nil
node.right = nil node.right = nil
node.calc_height_and_size()
node.calc_height_and_size(db)
return n, node return n, node
} else { } else {
self = self.Copy() self = self.Copy()
if node.key.Less(self.key) { if node.key.Less(self.key) {
self.left, node = self.left.pop_node(node)
self.left, node = self.left_filled(db).pop_node(db, node)
} else { } else {
self.right, node = self.right.pop_node(node)
self.right, node = self.right_filled(db).pop_node(db, node)
} }
self.calc_height_and_size()
self.calc_height_and_size(db)
return self, node return self, node
} }
} }
func (self *IAVLNode) rotate_right() *IAVLNode {
func (self *IAVLNode) rotate_right(db Db) *IAVLNode {
self = self.Copy() self = self.Copy()
sl := self.left.Copy()
sl := self.left_filled(db).Copy()
slr := sl.right slr := sl.right
sl.right = self sl.right = self
self.left = slr self.left = slr
self.calc_height_and_size()
sl.calc_height_and_size()
self.calc_height_and_size(db)
sl.calc_height_and_size(db)
return sl return sl
} }
func (self *IAVLNode) rotate_left() *IAVLNode {
func (self *IAVLNode) rotate_left(db Db) *IAVLNode {
self = self.Copy() self = self.Copy()
sr := self.right.Copy()
sr := self.right_filled(db).Copy()
srl := sr.left srl := sr.left
sr.left = self sr.left = self
self.right = srl self.right = srl
self.calc_height_and_size()
sr.calc_height_and_size()
self.calc_height_and_size(db)
sr.calc_height_and_size(db)
return sr return sr
} }
func (self *IAVLNode) calc_height_and_size() {
self.height = maxUint8(self.left.Height(), self.right.Height()) + 1
self.size = self.left.Size() + self.right.Size() + 1
func (self *IAVLNode) calc_height_and_size(db Db) {
self.height = maxUint8(self.left_filled(db).Height(), self.right_filled(db).Height()) + 1
self.size = self.left_filled(db).Size() + self.right_filled(db).Size() + 1
} }
func (self *IAVLNode) calc_balance() int {
func (self *IAVLNode) calc_balance(db Db) int {
if self == nil { if self == nil {
return 0 return 0
} }
return int(self.left.Height()) - int(self.right.Height())
return int(self.left_filled(db).Height()) - int(self.right_filled(db).Height())
} }
func (self *IAVLNode) balance() (new_self *IAVLNode) {
balance := self.calc_balance()
func (self *IAVLNode) balance(db Db) (new_self *IAVLNode) {
balance := self.calc_balance(db)
if (balance > 1) { if (balance > 1) {
if (self.left.calc_balance() >= 0) {
if (self.left_filled(db).calc_balance(db) >= 0) {
// Left Left Case // Left Left Case
return self.rotate_right()
return self.rotate_right(db)
} else { } else {
// Left Right Case // Left Right Case
self = self.Copy() self = self.Copy()
self.left = self.left.rotate_left()
self.left = self.left_filled(db).rotate_left(db)
//self.calc_height_and_size() //self.calc_height_and_size()
return self.rotate_right()
return self.rotate_right(db)
} }
} }
if (balance < -1) { if (balance < -1) {
if (self.right.calc_balance() <= 0) {
if (self.right_filled(db).calc_balance(db) <= 0) {
// Right Right Case // Right Right Case
return self.rotate_left()
return self.rotate_left(db)
} else { } else {
// Right Left Case // Right Left Case
self = self.Copy() self = self.Copy()
self.right = self.right.rotate_right()
self.right = self.right_filled(db).rotate_right(db)
//self.calc_height_and_size() //self.calc_height_and_size()
return self.rotate_left()
return self.rotate_left(db)
} }
} }
// Nothing changed // Nothing changed
@ -319,7 +333,7 @@ func (self *IAVLNode) balance() (new_self *IAVLNode) {
} }
// TODO: don't clear the hash if the value hasn't changed. // TODO: don't clear the hash if the value hasn't changed.
func (self *IAVLNode) Put(key Key, value Value) (_ *IAVLNode, updated bool) {
func (self *IAVLNode) Put(db Db, key Key, value Value) (_ *IAVLNode, updated bool) {
if self == nil { if self == nil {
return &IAVLNode{key: key, value: value, height: 1, size: 1, hash: nil}, false return &IAVLNode{key: key, value: value, height: 1, size: 1, hash: nil}, false
} }
@ -332,38 +346,38 @@ func (self *IAVLNode) Put(key Key, value Value) (_ *IAVLNode, updated bool) {
} }
if key.Less(self.key) { if key.Less(self.key) {
self.left, updated = self.left.Put(key, value)
self.left, updated = self.left_filled(db).Put(db, key, value)
} else { } else {
self.right, updated = self.right.Put(key, value)
self.right, updated = self.right_filled(db).Put(db, key, value)
} }
if updated { if updated {
return self, updated return self, updated
} else { } else {
self.calc_height_and_size()
return self.balance(), updated
self.calc_height_and_size(db)
return self.balance(db), updated
} }
} }
func (self *IAVLNode) Remove(key Key) (new_self *IAVLNode, value Value, err error) {
func (self *IAVLNode) Remove(db Db, key Key) (new_self *IAVLNode, value Value, err error) {
if self == nil { if self == nil {
return nil, nil, NotFound(key) return nil, nil, NotFound(key)
} }
if self.key.Equals(key) { if self.key.Equals(key) {
if self.left != nil && self.right != nil { if self.left != nil && self.right != nil {
if self.left.Size() < self.right.Size() {
self, new_self = self.pop_node(self.right.lmd())
if self.left_filled(db).Size() < self.right_filled(db).Size() {
self, new_self = self.pop_node(db, self.right_filled(db).lmd(db))
} else { } else {
self, new_self = self.pop_node(self.left.rmd())
self, new_self = self.pop_node(db, self.left_filled(db).rmd(db))
} }
new_self.left = self.left new_self.left = self.left
new_self.right = self.right new_self.right = self.right
new_self.calc_height_and_size()
new_self.calc_height_and_size(db)
return new_self, self.value, nil return new_self, self.value, nil
} else if self.left == nil { } else if self.left == nil {
return self.right, self.value, nil
return self.right_filled(db), self.value, nil
} else if self.right == nil { } else if self.right == nil {
return self.left, self.value, nil
return self.left_filled(db), self.value, nil
} else { } else {
return nil, self.value, nil return nil, self.value, nil
} }
@ -374,8 +388,8 @@ func (self *IAVLNode) Remove(key Key) (new_self *IAVLNode, value Value, err erro
return self, nil, NotFound(key) return self, nil, NotFound(key)
} }
var new_left *IAVLNode var new_left *IAVLNode
new_left, value, err = self.left.Remove(key)
if new_left == self.left { // not found
new_left, value, err = self.left_filled(db).Remove(db, key)
if new_left == self.left_filled(db) { // not found
return self, nil, err return self, nil, err
} else if err != nil { // some other error } else if err != nil { // some other error
return self, value, err return self, value, err
@ -387,8 +401,8 @@ func (self *IAVLNode) Remove(key Key) (new_self *IAVLNode, value Value, err erro
return self, nil, NotFound(key) return self, nil, NotFound(key)
} }
var new_right *IAVLNode var new_right *IAVLNode
new_right, value, err = self.right.Remove(key)
if new_right == self.right { // not found
new_right, value, err = self.right_filled(db).Remove(db, key)
if new_right == self.right_filled(db) { // not found
return self, nil, err return self, nil, err
} else if err != nil { // some other error } else if err != nil { // some other error
return self, value, err return self, value, err
@ -396,8 +410,8 @@ func (self *IAVLNode) Remove(key Key) (new_self *IAVLNode, value Value, err erro
self = self.Copy() self = self.Copy()
self.right = new_right self.right = new_right
} }
self.calc_height_and_size()
return self.balance(), value, err
self.calc_height_and_size(db)
return self.balance(db), value, err
} }
func (self *IAVLNode) Height() uint8 { func (self *IAVLNode) Height() uint8 {
@ -419,12 +433,12 @@ func (self *IAVLNode) _md(side func(*IAVLNode)*IAVLNode) (*IAVLNode) {
} }
} }
func (self *IAVLNode) lmd() (*IAVLNode) {
return self._md(func(node *IAVLNode)*IAVLNode { return node.left })
func (self *IAVLNode) lmd(db Db) (*IAVLNode) {
return self._md(func(node *IAVLNode)*IAVLNode { return node.left_filled(db) })
} }
func (self *IAVLNode) rmd() (*IAVLNode) {
return self._md(func(node *IAVLNode)*IAVLNode { return node.right })
func (self *IAVLNode) rmd(db Db) (*IAVLNode) {
return self._md(func(node *IAVLNode)*IAVLNode { return node.right_filled(db) })
} }
func maxUint8(a, b uint8) uint8 { func maxUint8(a, b uint8) uint8 {


+ 15
- 15
merkle/iavl_test.go View File

@ -60,11 +60,11 @@ func TestImmutableAvlPutHasGetRemove(t *testing.T) {
for i := range records { for i := range records {
r := randomRecord() r := randomRecord()
records[i] = r records[i] = r
tree, updated = tree.Put(r.key, String(""))
tree, updated = tree.Put(nil, r.key, String(""))
if updated { if updated {
t.Error("should have not been updated") t.Error("should have not been updated")
} }
tree, updated = tree.Put(r.key, r.value)
tree, updated = tree.Put(nil, r.key, r.value)
if !updated { if !updated {
t.Error("should have been updated") t.Error("should have been updated")
} }
@ -74,13 +74,13 @@ func TestImmutableAvlPutHasGetRemove(t *testing.T) {
} }
for _, r := range records { for _, r := range records {
if has := tree.Has(r.key); !has {
if has := tree.Has(nil, r.key); !has {
t.Error("Missing key") t.Error("Missing key")
} }
if has := tree.Has(randstr(12)); has {
if has := tree.Has(nil, randstr(12)); has {
t.Error("Table has extra key") t.Error("Table has extra key")
} }
if val, err := tree.Get(r.key); err != nil {
if val, err := tree.Get(nil, r.key); err != nil {
t.Error(err, val.(String), r.value) t.Error(err, val.(String), r.value)
} else if !(val.(String)).Equals(r.value) { } else if !(val.(String)).Equals(r.value) {
t.Error("wrong value") t.Error("wrong value")
@ -88,19 +88,19 @@ func TestImmutableAvlPutHasGetRemove(t *testing.T) {
} }
for i, x := range records { for i, x := range records {
if tree, val, err = tree.Remove(x.key); err != nil {
if tree, val, err = tree.Remove(nil, x.key); err != nil {
t.Error(err) t.Error(err)
} else if !(val.(String)).Equals(x.value) { } else if !(val.(String)).Equals(x.value) {
t.Error("wrong value") t.Error("wrong value")
} }
for _, r := range records[i+1:] { for _, r := range records[i+1:] {
if has := tree.Has(r.key); !has {
if has := tree.Has(nil, r.key); !has {
t.Error("Missing key") t.Error("Missing key")
} }
if has := tree.Has(randstr(12)); has {
if has := tree.Has(nil, randstr(12)); has {
t.Error("Table has extra key") t.Error("Table has extra key")
} }
if val, err := tree.Get(r.key); err != nil {
if val, err := tree.Get(nil, r.key); err != nil {
t.Error(err) t.Error(err)
} else if !(val.(String)).Equals(r.value) { } else if !(val.(String)).Equals(r.value) {
t.Error("wrong value") t.Error("wrong value")
@ -178,7 +178,7 @@ func TestGriffin(t *testing.T) {
left: l, left: l,
right: r, right: r,
} }
n.calc_height_and_size()
n.calc_height_and_size(nil)
n.Hash() n.Hash()
return n return n
} }
@ -189,11 +189,11 @@ func TestGriffin(t *testing.T) {
if n.left == nil && n.right == nil { if n.left == nil && n.right == nil {
return fmt.Sprintf("%v", n.key) return fmt.Sprintf("%v", n.key)
} else if n.left == nil { } else if n.left == nil {
return fmt.Sprintf("(- %v %v)", n.key, P(n.right))
return fmt.Sprintf("(- %v %v)", n.key, P(n.right_filled(nil)))
} else if n.right == nil { } else if n.right == nil {
return fmt.Sprintf("(%v %v -)", P(n.left), n.key)
return fmt.Sprintf("(%v %v -)", P(n.left_filled(nil)), n.key)
} else { } else {
return fmt.Sprintf("(%v %v %v)", P(n.left), n.key, P(n.right))
return fmt.Sprintf("(%v %v %v)", P(n.left_filled(nil)), n.key, P(n.right_filled(nil)))
} }
} }
@ -218,7 +218,7 @@ func TestGriffin(t *testing.T) {
} }
expectPut := func(n *IAVLNode, i int, repr string, hashCount uint64) { expectPut := func(n *IAVLNode, i int, repr string, hashCount uint64) {
n2, updated := n.Put(Int32(i), nil)
n2, updated := n.Put(nil, Int32(i), nil)
// ensure node was added & structure is as expected. // ensure node was added & structure is as expected.
if updated == true || P(n2) != repr { if updated == true || P(n2) != repr {
t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v", t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v",
@ -229,7 +229,7 @@ func TestGriffin(t *testing.T) {
} }
expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) { expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) {
n2, value, err := n.Remove(Int32(i))
n2, value, err := n.Remove(nil, Int32(i))
// ensure node was added & structure is as expected. // ensure node was added & structure is as expected.
if value != nil || err != nil || P(n2) != repr { if value != nil || err != nil || P(n2) != repr {
t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v", t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v",


+ 11
- 6
merkle/types.go View File

@ -27,21 +27,26 @@ type Tree interface {
Remove(Key) (Value, error) Remove(Key) (Value, error)
} }
type Db interface {
Get([]byte) ([]byte, error)
Put([]byte, []byte) error
}
type Node interface { type Node interface {
Key() Key Key() Key
Value() Value Value() Value
Left() Node
Right() Node
Left(Db) Node
Right(Db) Node
Size() uint64 Size() uint64
Height() uint8 Height() uint8
Has(Key) bool
Get(Key) (Value, error)
Has(Db, Key) bool
Get(Db, Key) (Value, error)
Hash() ([]byte, uint64) Hash() ([]byte, uint64)
Bytes() []byte Bytes() []byte
Put(Key, Value) (*IAVLNode, bool)
Remove(Key) (*IAVLNode, Value, error)
Put(Db, Key, Value) (*IAVLNode, bool)
Remove(Db, Key) (*IAVLNode, Value, error)
} }
type NodeIterator func() Node type NodeIterator func() Node


+ 4
- 4
merkle/util.go View File

@ -12,11 +12,11 @@ func Iterator(node Node) NodeIterator {
if len(stack) > 0 || cur != nil { if len(stack) > 0 || cur != nil {
for cur != nil { for cur != nil {
stack = append(stack, cur) stack = append(stack, cur)
cur = cur.Left()
cur = cur.Left(nil)
} }
stack, cur = pop(stack) stack, cur = pop(stack)
tn = cur tn = cur
cur = cur.Right()
cur = cur.Right(nil)
return tn return tn
} else { } else {
return nil return nil
@ -47,8 +47,8 @@ func printIAVLNode(node *IAVLNode, indent int) {
if node == nil { if node == nil {
fmt.Printf("%s--\n", indentPrefix) fmt.Printf("%s--\n", indentPrefix)
} else { } else {
printIAVLNode(node.left, indent+1)
printIAVLNode(node.left_filled(nil), indent+1)
fmt.Printf("%s%v:%v\n", indentPrefix, node.key, node.height) fmt.Printf("%s%v:%v\n", indentPrefix, node.key, node.height)
printIAVLNode(node.right, indent+1)
printIAVLNode(node.right_filled(nil), indent+1)
} }
} }

Loading…
Cancel
Save