diff --git a/binary/codec.go b/binary/codec.go index cd18ec628..6dce295a4 100644 --- a/binary/codec.go +++ b/binary/codec.go @@ -2,16 +2,17 @@ package binary import ( "bytes" + "io" "time" ) type Codec interface { - Write(interface{}) ([]byte, error) - Read([]byte) (interface{}, error) + Encode(o interface{}, w io.Writer, n *int64, err *error) + Decode(r io.Reader, n *int64, err *error) interface{} + Compare(o1 interface{}, o2 interface{}) int } const ( - typeNil = byte(0x00) typeByte = byte(0x01) typeInt8 = byte(0x02) // typeUInt8 = byte(0x03) @@ -21,6 +22,8 @@ const ( typeUInt32 = byte(0x07) typeInt64 = byte(0x08) typeUInt64 = byte(0x09) + typeVarInt = byte(0x0A) + typeUVarInt = byte(0x0B) typeString = byte(0x10) typeByteSlice = byte(0x11) typeTime = byte(0x20) @@ -30,11 +33,10 @@ var BasicCodec = basicCodec{} type basicCodec struct{} -func (bc basicCodec) Write(o interface{}) ([]byte, error) { - n, err, w := new(int64), new(error), new(bytes.Buffer) +func (bc basicCodec) Encode(o interface{}, w io.Writer, n *int64, err *error) { switch o.(type) { case nil: - WriteByte(w, typeNil, n, err) + panic("nil type unsupported") case byte: WriteByte(w, typeByte, n, err) WriteByte(w, o.(byte), n, err) @@ -62,6 +64,12 @@ func (bc basicCodec) Write(o interface{}) ([]byte, error) { case uint64: WriteByte(w, typeUInt64, n, err) WriteUInt64(w, o.(uint64), n, err) + case int: + WriteByte(w, typeVarInt, n, err) + WriteVarInt(w, o.(int), n, err) + case uint: + WriteByte(w, typeUVarInt, n, err) + WriteUVarInt(w, o.(uint), n, err) case string: WriteByte(w, typeString, n, err) WriteString(w, o.(string), n, err) @@ -74,15 +82,11 @@ func (bc basicCodec) Write(o interface{}) ([]byte, error) { default: panic("Unsupported type") } - return w.Bytes(), *err } -func (bc basicCodec) Read(bz []byte) (interface{}, error) { - n, err, r, o := new(int64), new(error), bytes.NewBuffer(bz), interface{}(nil) +func (bc basicCodec) Decode(r io.Reader, n *int64, err *error) (o interface{}) { type_ := ReadByte(r, n, err) switch type_ { - case typeNil: - o = nil case typeByte: o = ReadByte(r, n, err) case typeInt8: @@ -101,6 +105,10 @@ func (bc basicCodec) Read(bz []byte) (interface{}, error) { o = ReadInt64(r, n, err) case typeUInt64: o = ReadUInt64(r, n, err) + case typeVarInt: + o = ReadVarInt(r, n, err) + case typeUVarInt: + o = ReadUVarInt(r, n, err) case typeString: o = ReadString(r, n, err) case typeByteSlice: @@ -110,5 +118,39 @@ func (bc basicCodec) Read(bz []byte) (interface{}, error) { default: panic("Unsupported type") } - return o, *err + return o +} + +func (bc basicCodec) Compare(o1 interface{}, o2 interface{}) int { + switch o1.(type) { + case byte: + return int(o1.(byte) - o2.(byte)) + case int8: + return int(o1.(int8) - o2.(int8)) + //case uint8: + case int16: + return int(o1.(int16) - o2.(int16)) + case uint16: + return int(o1.(uint16) - o2.(uint16)) + case int32: + return int(o1.(int32) - o2.(int32)) + case uint32: + return int(o1.(uint32) - o2.(uint32)) + case int64: + return int(o1.(int64) - o2.(int64)) + case uint64: + return int(o1.(uint64) - o2.(uint64)) + case int: + return o1.(int) - o2.(int) + case uint: + return int(o1.(uint)) - int(o2.(uint)) + case string: + return bytes.Compare([]byte(o1.(string)), []byte(o2.(string))) + case []byte: + return bytes.Compare(o1.([]byte), o2.([]byte)) + case time.Time: + return int(o1.(time.Time).UnixNano() - o2.(time.Time).UnixNano()) + default: + panic("Unsupported type") + } } diff --git a/binary/int.go b/binary/int.go index ae03adc7c..42445e185 100644 --- a/binary/int.go +++ b/binary/int.go @@ -158,34 +158,34 @@ func ReadUInt64(r io.Reader, n *int64, err *error) uint64 { // VarInt -func WriteVarInt(w io.Writer, i int64, n *int64, err *error) { +func WriteVarInt(w io.Writer, i int, n *int64, err *error) { buf := make([]byte, 9) n_ := int64(binary.PutVarint(buf, int64(i))) *n += n_ WriteTo(w, buf[:n_], n, err) } -func ReadVarInt(r io.Reader, n *int64, err *error) int64 { +func ReadVarInt(r io.Reader, n *int64, err *error) int { res, n_, err_ := readVarint(r) *n += n_ *err = err_ - return res + return int(res) } // UVarInt -func WriteUVarInt(w io.Writer, i uint64, n *int64, err *error) { +func WriteUVarInt(w io.Writer, i uint, n *int64, err *error) { buf := make([]byte, 9) n_ := int64(binary.PutUvarint(buf, uint64(i))) *n += n_ WriteTo(w, buf[:n_], n, err) } -func ReadUVarInt(r io.Reader, n *int64, err *error) uint64 { +func ReadUVarInt(r io.Reader, n *int64, err *error) uint { res, n_, err_ := readUvarint(r) *n += n_ *err = err_ - return res + return uint(res) } //----------------------------------------------------------------------------- diff --git a/merkle/iavl_node.go b/merkle/iavl_node.go index bb746e9b9..c4de4a1f1 100644 --- a/merkle/iavl_node.go +++ b/merkle/iavl_node.go @@ -1,7 +1,6 @@ package merkle import ( - "bytes" "crypto/sha256" . "github.com/tendermint/tendermint/binary" "io" @@ -10,8 +9,8 @@ import ( // Node type IAVLNode struct { - key []byte - value []byte + key interface{} + value interface{} size uint64 height uint8 hash []byte @@ -22,7 +21,7 @@ type IAVLNode struct { persisted bool } -func NewIAVLNode(key []byte, value []byte) *IAVLNode { +func NewIAVLNode(key interface{}, value interface{}) *IAVLNode { return &IAVLNode{ key: key, value: value, @@ -30,20 +29,20 @@ func NewIAVLNode(key []byte, value []byte) *IAVLNode { } } -func ReadIAVLNode(r io.Reader, n *int64, err *error) *IAVLNode { +func ReadIAVLNode(t *IAVLTree, r io.Reader, n *int64, err *error) *IAVLNode { node := &IAVLNode{} // node header & key node.height = ReadUInt8(r, n, err) node.size = ReadUInt64(r, n, err) - node.key = ReadByteSlice(r, n, err) + node.key = t.keyCodec.Decode(r, n, err) if *err != nil { panic(*err) } // node value or children. if node.height == 0 { - node.value = ReadByteSlice(r, n, err) + node.value = t.valueCodec.Decode(r, n, err) } else { node.leftHash = ReadByteSlice(r, n, err) node.rightHash = ReadByteSlice(r, n, err) @@ -54,319 +53,331 @@ func ReadIAVLNode(r io.Reader, n *int64, err *error) *IAVLNode { return node } -func (self *IAVLNode) Copy() *IAVLNode { - if self.height == 0 { +func (node *IAVLNode) _copy() *IAVLNode { + if node.height == 0 { panic("Why are you copying a value node?") } return &IAVLNode{ - key: self.key, - size: self.size, - height: self.height, + key: node.key, + size: node.size, + height: node.height, hash: nil, // Going to be mutated anyways. - leftHash: self.leftHash, - leftNode: self.leftNode, - rightHash: self.rightHash, - rightNode: self.rightNode, - persisted: self.persisted, + leftHash: node.leftHash, + leftNode: node.leftNode, + rightHash: node.rightHash, + rightNode: node.rightNode, + persisted: node.persisted, } } -func (self *IAVLNode) Size() uint64 { - return self.size -} - -func (self *IAVLNode) Height() uint8 { - return self.height -} - -func (self *IAVLNode) has(ndb *IAVLNodeDB, key []byte) (has bool) { - if bytes.Equal(self.key, key) { +func (node *IAVLNode) has(t *IAVLTree, key interface{}) (has bool) { + if t.keyCodec.Compare(node.key, key) == 0 { return true } - if self.height == 0 { + if node.height == 0 { return false } else { - if bytes.Compare(key, self.key) == -1 { - return self.getLeftNode(ndb).has(ndb, key) + if t.keyCodec.Compare(key, node.key) < 0 { + return node.getLeftNode(t).has(t, key) } else { - return self.getRightNode(ndb).has(ndb, key) + return node.getRightNode(t).has(t, key) } } } -func (self *IAVLNode) get(ndb *IAVLNodeDB, key []byte) (value []byte) { - if self.height == 0 { - if bytes.Equal(self.key, key) { - return self.value +func (node *IAVLNode) get(t *IAVLTree, key interface{}) (index uint64, value interface{}) { + if node.height == 0 { + if t.keyCodec.Compare(node.key, key) == 0 { + return 0, node.value } else { - return nil + return 0, nil } } else { - if bytes.Compare(key, self.key) == -1 { - return self.getLeftNode(ndb).get(ndb, key) + if t.keyCodec.Compare(key, node.key) < 0 { + return node.getLeftNode(t).get(t, key) } else { - return self.getRightNode(ndb).get(ndb, key) + rightNode := node.getRightNode(t) + index, value = rightNode.get(t, key) + index += node.size - rightNode.size + return index, value } } } -func (self *IAVLNode) HashWithCount() ([]byte, uint64) { - if self.hash != nil { - return self.hash, 0 +func (node *IAVLNode) getByIndex(t *IAVLTree, index uint64) (key interface{}, value interface{}) { + if node.height == 0 { + if index == 0 { + return node.key, node.value + } else { + panic("getByIndex asked for invalid index") + } + } else { + // TODO: could improve this by storing the + // sizes as well as left/right hash. + leftNode := node.getLeftNode(t) + if index < leftNode.size { + return leftNode.getByIndex(t, index) + } else { + return node.getRightNode(t).getByIndex(t, index-leftNode.size) + } + } +} + +func (node *IAVLNode) hashWithCount(t *IAVLTree) ([]byte, uint64) { + if node.hash != nil { + return node.hash, 0 } hasher := sha256.New() - _, hashCount, err := self.writeToCountHashes(hasher) + _, hashCount, err := node.writeToCountHashes(t, hasher) if err != nil { panic(err) } - self.hash = hasher.Sum(nil) + node.hash = hasher.Sum(nil) - return self.hash, hashCount + 1 + return node.hash, hashCount + 1 } -func (self *IAVLNode) Save(ndb *IAVLNodeDB) []byte { - if self.hash == nil { - self.hash, _ = self.HashWithCount() +func (node *IAVLNode) save(t *IAVLTree) []byte { + if node.hash == nil { + node.hash, _ = node.hashWithCount(t) } - if self.persisted { - return self.hash + if node.persisted { + return node.hash } // save children - if self.leftNode != nil { - self.leftHash = self.leftNode.Save(ndb) - self.leftNode = nil + if node.leftNode != nil { + node.leftHash = node.leftNode.save(t) + node.leftNode = nil } - if self.rightNode != nil { - self.rightHash = self.rightNode.Save(ndb) - self.rightNode = nil + if node.rightNode != nil { + node.rightHash = node.rightNode.save(t) + node.rightNode = nil } - // save self - ndb.Save(self) - return self.hash + // save node + t.saveNode(node) + return node.hash } -func (self *IAVLNode) set(ndb *IAVLNodeDB, key []byte, value []byte) (_ *IAVLNode, updated bool) { - if self.height == 0 { - if bytes.Compare(key, self.key) == -1 { +func (node *IAVLNode) set(t *IAVLTree, key interface{}, value interface{}) (newSelf *IAVLNode, updated bool) { + if node.height == 0 { + cmp := t.keyCodec.Compare(key, node.key) + if cmp < 0 { return &IAVLNode{ - key: self.key, + key: node.key, height: 1, size: 2, leftNode: NewIAVLNode(key, value), - rightNode: self, + rightNode: node, }, false - } else if bytes.Equal(self.key, key) { + } else if cmp == 0 { return NewIAVLNode(key, value), true } else { return &IAVLNode{ key: key, height: 1, size: 2, - leftNode: self, + leftNode: node, rightNode: NewIAVLNode(key, value), }, false } } else { - self = self.Copy() - if bytes.Compare(key, self.key) == -1 { - self.leftNode, updated = self.getLeftNode(ndb).set(ndb, key, value) - self.leftHash = nil + node = node._copy() + if t.keyCodec.Compare(key, node.key) < 0 { + node.leftNode, updated = node.getLeftNode(t).set(t, key, value) + node.leftHash = nil } else { - self.rightNode, updated = self.getRightNode(ndb).set(ndb, key, value) - self.rightHash = nil + node.rightNode, updated = node.getRightNode(t).set(t, key, value) + node.rightHash = nil } if updated { - return self, updated + return node, updated } else { - self.calcHeightAndSize(ndb) - return self.balance(ndb), updated + node.calcHeightAndSize(t) + return node.balance(t), updated } } } -// newHash/newNode: The new hash or node to replace self after remove. +// newHash/newNode: The new hash or node to replace node after remove. // newKey: new leftmost leaf key for tree after successfully removing 'key' if changed. // value: removed value. -func (self *IAVLNode) remove(ndb *IAVLNodeDB, key []byte) ( - newHash []byte, newNode *IAVLNode, newKey []byte, value []byte, err error) { - if self.height == 0 { - if bytes.Equal(self.key, key) { - return nil, nil, nil, self.value, nil +func (node *IAVLNode) remove(t *IAVLTree, key interface{}) ( + newHash []byte, newNode *IAVLNode, newKey interface{}, value interface{}, removed bool) { + if node.height == 0 { + if t.keyCodec.Compare(key, node.key) == 0 { + return nil, nil, nil, node.value, true } else { - return nil, self, nil, nil, NotFound(key) + return nil, node, nil, nil, false } } else { - if bytes.Compare(key, self.key) == -1 { + if t.keyCodec.Compare(key, node.key) < 0 { var newLeftHash []byte var newLeftNode *IAVLNode - newLeftHash, newLeftNode, newKey, value, err = self.getLeftNode(ndb).remove(ndb, key) - if err != nil { - return nil, self, nil, value, err + newLeftHash, newLeftNode, newKey, value, removed = node.getLeftNode(t).remove(t, key) + if !removed { + return nil, node, nil, value, false } else if newLeftHash == nil && newLeftNode == nil { // left node held value, was removed - return self.rightHash, self.rightNode, self.key, value, nil + return node.rightHash, node.rightNode, node.key, value, true } - self = self.Copy() - self.leftHash, self.leftNode = newLeftHash, newLeftNode + node = node._copy() + node.leftHash, node.leftNode = newLeftHash, newLeftNode + node.calcHeightAndSize(t) + return nil, node.balance(t), newKey, value, true } else { var newRightHash []byte var newRightNode *IAVLNode - newRightHash, newRightNode, newKey, value, err = self.getRightNode(ndb).remove(ndb, key) - if err != nil { - return nil, self, nil, value, err + newRightHash, newRightNode, newKey, value, removed = node.getRightNode(t).remove(t, key) + if !removed { + return nil, node, nil, value, false } else if newRightHash == nil && newRightNode == nil { // right node held value, was removed - return self.leftHash, self.leftNode, nil, value, nil + return node.leftHash, node.leftNode, nil, value, true } - self = self.Copy() - self.rightHash, self.rightNode = newRightHash, newRightNode + node = node._copy() + node.rightHash, node.rightNode = newRightHash, newRightNode if newKey != nil { - self.key = newKey + node.key = newKey newKey = nil } + node.calcHeightAndSize(t) + return nil, node.balance(t), newKey, value, true } - self.calcHeightAndSize(ndb) - return nil, self.balance(ndb), newKey, value, err } } -func (self *IAVLNode) WriteTo(w io.Writer) (n int64, err error) { - n, _, err = self.writeToCountHashes(w) - return -} - -func (self *IAVLNode) writeToCountHashes(w io.Writer) (n int64, hashCount uint64, err error) { +func (node *IAVLNode) writeToCountHashes(t *IAVLTree, w io.Writer) (n int64, hashCount uint64, err error) { // height & size & key - WriteUInt8(w, self.height, &n, &err) - WriteUInt64(w, self.size, &n, &err) - WriteByteSlice(w, self.key, &n, &err) + WriteUInt8(w, node.height, &n, &err) + WriteUInt64(w, node.size, &n, &err) + t.keyCodec.Encode(node.key, w, &n, &err) if err != nil { return } - if self.height == 0 { + if node.height == 0 { // value - WriteByteSlice(w, self.value, &n, &err) + t.valueCodec.Encode(node.value, w, &n, &err) } else { // left - if self.leftNode != nil { - leftHash, leftCount := self.leftNode.HashWithCount() - self.leftHash = leftHash + if node.leftNode != nil { + leftHash, leftCount := node.leftNode.hashWithCount(t) + node.leftHash = leftHash hashCount += leftCount } - if self.leftHash == nil { - panic("self.leftHash was nil in save") + if node.leftHash == nil { + panic("node.leftHash was nil in save") } - WriteByteSlice(w, self.leftHash, &n, &err) + WriteByteSlice(w, node.leftHash, &n, &err) // right - if self.rightNode != nil { - rightHash, rightCount := self.rightNode.HashWithCount() - self.rightHash = rightHash + if node.rightNode != nil { + rightHash, rightCount := node.rightNode.hashWithCount(t) + node.rightHash = rightHash hashCount += rightCount } - if self.rightHash == nil { - panic("self.rightHash was nil in save") + if node.rightHash == nil { + panic("node.rightHash was nil in save") } - WriteByteSlice(w, self.rightHash, &n, &err) + WriteByteSlice(w, node.rightHash, &n, &err) } return } -func (self *IAVLNode) getLeftNode(ndb *IAVLNodeDB) *IAVLNode { - if self.leftNode != nil { - return self.leftNode +func (node *IAVLNode) getLeftNode(t *IAVLTree) *IAVLNode { + if node.leftNode != nil { + return node.leftNode } else { - return ndb.Get(self.leftHash) + return t.getNode(node.leftHash) } } -func (self *IAVLNode) getRightNode(ndb *IAVLNodeDB) *IAVLNode { - if self.rightNode != nil { - return self.rightNode +func (node *IAVLNode) getRightNode(t *IAVLTree) *IAVLNode { + if node.rightNode != nil { + return node.rightNode } else { - return ndb.Get(self.rightHash) + return t.getNode(node.rightHash) } } -func (self *IAVLNode) rotateRight(ndb *IAVLNodeDB) *IAVLNode { - self = self.Copy() - sl := self.getLeftNode(ndb).Copy() +func (node *IAVLNode) rotateRight(t *IAVLTree) *IAVLNode { + node = node._copy() + sl := node.getLeftNode(t)._copy() slrHash, slrCached := sl.rightHash, sl.rightNode - sl.rightHash, sl.rightNode = nil, self - self.leftHash, self.leftNode = slrHash, slrCached + sl.rightHash, sl.rightNode = nil, node + node.leftHash, node.leftNode = slrHash, slrCached - self.calcHeightAndSize(ndb) - sl.calcHeightAndSize(ndb) + node.calcHeightAndSize(t) + sl.calcHeightAndSize(t) return sl } -func (self *IAVLNode) rotateLeft(ndb *IAVLNodeDB) *IAVLNode { - self = self.Copy() - sr := self.getRightNode(ndb).Copy() +func (node *IAVLNode) rotateLeft(t *IAVLTree) *IAVLNode { + node = node._copy() + sr := node.getRightNode(t)._copy() srlHash, srlCached := sr.leftHash, sr.leftNode - sr.leftHash, sr.leftNode = nil, self - self.rightHash, self.rightNode = srlHash, srlCached + sr.leftHash, sr.leftNode = nil, node + node.rightHash, node.rightNode = srlHash, srlCached - self.calcHeightAndSize(ndb) - sr.calcHeightAndSize(ndb) + node.calcHeightAndSize(t) + sr.calcHeightAndSize(t) return sr } -func (self *IAVLNode) calcHeightAndSize(ndb *IAVLNodeDB) { - self.height = maxUint8(self.getLeftNode(ndb).Height(), self.getRightNode(ndb).Height()) + 1 - self.size = self.getLeftNode(ndb).Size() + self.getRightNode(ndb).Size() +func (node *IAVLNode) calcHeightAndSize(t *IAVLTree) { + node.height = maxUint8(node.getLeftNode(t).height, node.getRightNode(t).height) + 1 + node.size = node.getLeftNode(t).size + node.getRightNode(t).size } -func (self *IAVLNode) calcBalance(ndb *IAVLNodeDB) int { - return int(self.getLeftNode(ndb).Height()) - int(self.getRightNode(ndb).Height()) +func (node *IAVLNode) calcBalance(t *IAVLTree) int { + return int(node.getLeftNode(t).height) - int(node.getRightNode(t).height) } -func (self *IAVLNode) balance(ndb *IAVLNodeDB) (newSelf *IAVLNode) { - balance := self.calcBalance(ndb) +func (node *IAVLNode) balance(t *IAVLTree) (newSelf *IAVLNode) { + balance := node.calcBalance(t) if balance > 1 { - if self.getLeftNode(ndb).calcBalance(ndb) >= 0 { + if node.getLeftNode(t).calcBalance(t) >= 0 { // Left Left Case - return self.rotateRight(ndb) + return node.rotateRight(t) } else { // Left Right Case - self = self.Copy() - self.leftHash, self.leftNode = nil, self.getLeftNode(ndb).rotateLeft(ndb) - //self.calcHeightAndSize() - return self.rotateRight(ndb) + node = node._copy() + node.leftHash, node.leftNode = nil, node.getLeftNode(t).rotateLeft(t) + //node.calcHeightAndSize() + return node.rotateRight(t) } } if balance < -1 { - if self.getRightNode(ndb).calcBalance(ndb) <= 0 { + if node.getRightNode(t).calcBalance(t) <= 0 { // Right Right Case - return self.rotateLeft(ndb) + return node.rotateLeft(t) } else { // Right Left Case - self = self.Copy() - self.rightHash, self.rightNode = nil, self.getRightNode(ndb).rotateRight(ndb) - //self.calcHeightAndSize() - return self.rotateLeft(ndb) + node = node._copy() + node.rightHash, node.rightNode = nil, node.getRightNode(t).rotateRight(t) + //node.calcHeightAndSize() + return node.rotateLeft(t) } } // Nothing changed - return self + return node } -func (self *IAVLNode) traverse(ndb *IAVLNodeDB, cb func(*IAVLNode) bool) bool { - stop := cb(self) +func (node *IAVLNode) traverse(t *IAVLTree, cb func(*IAVLNode) bool) bool { + stop := cb(node) if stop { return stop } - if self.height > 0 { - stop = self.getLeftNode(ndb).traverse(ndb, cb) + if node.height > 0 { + stop = node.getLeftNode(t).traverse(t, cb) if stop { return stop } - stop = self.getRightNode(ndb).traverse(ndb, cb) + stop = node.getRightNode(t).traverse(t, cb) if stop { return stop } @@ -375,17 +386,17 @@ func (self *IAVLNode) traverse(ndb *IAVLNodeDB, cb func(*IAVLNode) bool) bool { } // Only used in testing... -func (self *IAVLNode) lmd(ndb *IAVLNodeDB) *IAVLNode { - if self.height == 0 { - return self +func (node *IAVLNode) lmd(t *IAVLTree) *IAVLNode { + if node.height == 0 { + return node } - return self.getLeftNode(ndb).lmd(ndb) + return node.getLeftNode(t).lmd(t) } // Only used in testing... -func (self *IAVLNode) rmd(ndb *IAVLNodeDB) *IAVLNode { - if self.height == 0 { - return self +func (node *IAVLNode) rmd(t *IAVLTree) *IAVLNode { + if node.height == 0 { + return node } - return self.getRightNode(ndb).rmd(ndb) + return node.getRightNode(t).rmd(t) } diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index f759771c5..ba1519444 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -2,9 +2,7 @@ package merkle import ( "bytes" - "crypto/sha256" "fmt" - "time" . "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/common" @@ -22,116 +20,127 @@ func randstr(length int) string { return RandStr(length) } -func TestUnit(t *testing.T) { - - // Convenience for a new node - N := func(l, r interface{}) *IAVLNode { - var left, right *IAVLNode - if _, ok := l.(*IAVLNode); ok { - left = l.(*IAVLNode) - } else { - left = NewIAVLNode([]byte{byte(l.(int))}, nil) - } - if _, ok := r.(*IAVLNode); ok { - right = r.(*IAVLNode) - } else { - right = NewIAVLNode([]byte{byte(r.(int))}, nil) - } +// Convenience for a new node +func N(l, r interface{}) *IAVLNode { + var left, right *IAVLNode + if _, ok := l.(*IAVLNode); ok { + left = l.(*IAVLNode) + } else { + left = NewIAVLNode(l, "") + } + if _, ok := r.(*IAVLNode); ok { + right = r.(*IAVLNode) + } else { + right = NewIAVLNode(r, "") + } - n := &IAVLNode{ - key: right.lmd(nil).key, - leftNode: left, - rightNode: right, - } - n.calcHeightAndSize(nil) - n.HashWithCount() - return n + n := &IAVLNode{ + key: right.lmd(nil).key, + value: "", + leftNode: left, + rightNode: right, } + n.calcHeightAndSize(nil) + return n +} - // Convenience for simple printing of keys & tree structure - var P func(*IAVLNode) string - P = func(n *IAVLNode) string { - if n.height == 0 { - return fmt.Sprintf("%v", n.key[0]) - } else { - return fmt.Sprintf("(%v %v)", P(n.leftNode), P(n.rightNode)) - } +// Setup a deep node +func T(n *IAVLNode) *IAVLTree { + t := NewIAVLTree(BasicCodec, BasicCodec, 0, nil) + n.hashWithCount(t) + t.root = n + return t +} + +// Convenience for simple printing of keys & tree structure +func P(n *IAVLNode) string { + if n.height == 0 { + return fmt.Sprintf("%v", n.key) + } else { + return fmt.Sprintf("(%v %v)", P(n.leftNode), P(n.rightNode)) } +} - expectHash := func(n2 *IAVLNode, hashCount uint64) { +func TestUnit(t *testing.T) { + + expectHash := func(tree *IAVLTree, hashCount uint64) { // ensure number of new hash calculations is as expected. - hash, count := n2.HashWithCount() + hash, count := tree.HashWithCount() if count != hashCount { t.Fatalf("Expected %v new hashes, got %v", hashCount, count) } // nuke hashes and reconstruct hash, ensure it's the same. - n2.traverse(nil, func(node *IAVLNode) bool { + tree.root.traverse(tree, func(node *IAVLNode) bool { node.hash = nil return false }) // ensure that the new hash after nuking is the same as the old. - newHash, _ := n2.HashWithCount() + newHash, _ := tree.HashWithCount() if bytes.Compare(hash, newHash) != 0 { t.Fatalf("Expected hash %v but got %v after nuking", hash, newHash) } } - expectSet := func(n *IAVLNode, i int, repr string, hashCount uint64) { - n2, updated := n.set(nil, []byte{byte(i)}, nil) + expectSet := func(tree *IAVLTree, i int, repr string, hashCount uint64) { + origNode := tree.root + updated := tree.Set(i, "") // ensure node was added & structure is as expected. - if updated == true || P(n2) != repr { + if updated == true || P(tree.root) != repr { t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v", - i, P(n), repr, P(n2), updated) + i, P(origNode), repr, P(tree.root), updated) } // ensure hash calculation requirements - expectHash(n2, hashCount) + expectHash(tree, hashCount) + tree.root = origNode } - expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) { - _, n2, _, value, err := n.remove(nil, []byte{byte(i)}) + expectRemove := func(tree *IAVLTree, i int, repr string, hashCount uint64) { + origNode := tree.root + value, removed := tree.Remove(i) // ensure node was added & structure is as expected. - if value != nil || err != nil || P(n2) != repr { - t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v", - i, P(n), repr, P(n2), value, err) + if value != "" || !removed || P(tree.root) != repr { + t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v removed:%v", + i, P(origNode), repr, P(tree.root), value, removed) } // ensure hash calculation requirements - expectHash(n2, hashCount) + expectHash(tree, hashCount) + tree.root = origNode } //////// Test Set cases: // Case 1: - n1 := N(4, 20) + t1 := T(N(4, 20)) - expectSet(n1, 8, "((4 8) 20)", 3) - expectSet(n1, 25, "(4 (20 25))", 3) + expectSet(t1, 8, "((4 8) 20)", 3) + expectSet(t1, 25, "(4 (20 25))", 3) - n2 := N(4, N(20, 25)) + t2 := T(N(4, N(20, 25))) - expectSet(n2, 8, "((4 8) (20 25))", 3) - expectSet(n2, 30, "((4 20) (25 30))", 4) + expectSet(t2, 8, "((4 8) (20 25))", 3) + expectSet(t2, 30, "((4 20) (25 30))", 4) - n3 := N(N(1, 2), 6) + t3 := T(N(N(1, 2), 6)) - expectSet(n3, 4, "((1 2) (4 6))", 4) - expectSet(n3, 8, "((1 2) (6 8))", 3) + expectSet(t3, 4, "((1 2) (4 6))", 4) + expectSet(t3, 8, "((1 2) (6 8))", 3) - n4 := N(N(1, 2), N(N(5, 6), N(7, 9))) + t4 := T(N(N(1, 2), N(N(5, 6), N(7, 9)))) - expectSet(n4, 8, "(((1 2) (5 6)) ((7 8) 9))", 5) - expectSet(n4, 10, "(((1 2) (5 6)) (7 (9 10)))", 5) + expectSet(t4, 8, "(((1 2) (5 6)) ((7 8) 9))", 5) + expectSet(t4, 10, "(((1 2) (5 6)) (7 (9 10)))", 5) //////// Test Remove cases: - n10 := N(N(1, 2), 3) + t10 := T(N(N(1, 2), 3)) - expectRemove(n10, 2, "(1 3)", 1) - expectRemove(n10, 3, "(1 2)", 0) + expectRemove(t10, 2, "(1 3)", 1) + expectRemove(t10, 3, "(1 2)", 0) - n11 := N(N(N(1, 2), 3), N(4, 5)) + t11 := T(N(N(N(1, 2), 3), N(4, 5))) - expectRemove(n11, 4, "((1 2) (3 5))", 2) - expectRemove(n11, 3, "((1 2) (4 5))", 1) + expectRemove(t11, 4, "((1 2) (3 5))", 2) + expectRemove(t11, 3, "((1 2) (4 5))", 1) } @@ -143,10 +152,7 @@ func TestIntegration(t *testing.T) { } records := make([]*record, 400) - var tree *IAVLTree = NewIAVLTree(nil) - var err error - var val []byte - var updated bool + var tree *IAVLTree = NewIAVLTree(BasicCodec, BasicCodec, 0, nil) randomRecord := func() *record { return &record{randstr(20), randstr(20)} @@ -157,11 +163,11 @@ func TestIntegration(t *testing.T) { records[i] = r //t.Log("New record", r) //PrintIAVLNode(tree.root) - updated = tree.Set([]byte(r.key), []byte("")) + updated := tree.Set(r.key, "") if updated { t.Error("should have not been updated") } - updated = tree.Set([]byte(r.key), []byte(r.value)) + updated = tree.Set(r.key, r.value) if !updated { t.Error("should have been updated") } @@ -171,32 +177,32 @@ func TestIntegration(t *testing.T) { } for _, r := range records { - if has := tree.Has([]byte(r.key)); !has { + if has := tree.Has(r.key); !has { t.Error("Missing key", r.key) } - if has := tree.Has([]byte(randstr(12))); has { + if has := tree.Has(randstr(12)); has { t.Error("Table has extra key") } - if val := tree.Get([]byte(r.key)); string(val) != r.value { + if _, val := tree.Get(r.key); val.(string) != r.value { t.Error("wrong value") } } for i, x := range records { - if val, err = tree.Remove([]byte(x.key)); err != nil { - t.Error(err) - } else if string(val) != x.value { - t.Error("wrong value") + if val, removed := tree.Remove(x.key); !removed { + t.Error("Wasn't removed") + } else if val != x.value { + t.Error("Wrong value") } for _, r := range records[i+1:] { - if has := tree.Has([]byte(r.key)); !has { + if has := tree.Has(r.key); !has { t.Error("Missing key", r.key) } - if has := tree.Has([]byte(randstr(12))); has { + if has := tree.Has(randstr(12)); has { t.Error("Table has extra key") } - val := tree.Get([]byte(r.key)) - if string(val) != r.value { + _, val := tree.Get(r.key) + if val != r.value { t.Error("wrong value") } } @@ -216,96 +222,30 @@ func TestPersistence(t *testing.T) { } // Construct some tree and save it - t1 := NewIAVLTree(db) + t1 := NewIAVLTree(BasicCodec, BasicCodec, 0, db) for key, value := range records { - t1.Set([]byte(key), []byte(value)) + t1.Set(key, value) } t1.Save() hash, _ := t1.HashWithCount() // Load a tree - t2 := LoadIAVLTreeFromHash(db, hash) + t2 := LoadIAVLTreeFromHash(BasicCodec, BasicCodec, 0, db, hash) for key, value := range records { - t2value := t2.Get([]byte(key)) - if string(t2value) != value { + _, t2value := t2.Get(key) + if t2value != value { t.Fatalf("Invalid value. Expected %v, got %v", value, t2value) } } } -func TestTypedTree(t *testing.T) { - db := db.NewMemDB() - - // Construct some tree and save it - t1 := NewTypedTree(NewIAVLTree(db), BasicCodec, BasicCodec) - t1.Set(uint8(1), "uint8(1)") - t1.Set(uint16(1), "uint16(1)") - t1.Set(uint32(1), "uint32(1)") - t1.Set(uint64(1), "uint64(1)") - t1.Set("byteslice01", []byte{byte(0x00), byte(0x01)}) - t1.Set("byteslice23", []byte{byte(0x02), byte(0x03)}) - t1.Set("time", time.Unix(123, 0)) - t1.Set("nil", nil) - t1Hash := t1.Tree.Save() - - // Reconstruct tree - t2 := NewTypedTree(LoadIAVLTreeFromHash(db, t1Hash), BasicCodec, BasicCodec) - if t2.Get(uint8(1)).(string) != "uint8(1)" { - t.Errorf("Expected string uint8(1)") - } - if t2.Get(uint16(1)).(string) != "uint16(1)" { - t.Errorf("Expected string uint16(1)") - } - if t2.Get(uint32(1)).(string) != "uint32(1)" { - t.Errorf("Expected string uint32(1)") - } - if t2.Get(uint64(1)).(string) != "uint64(1)" { - t.Errorf("Expected string uint64(1)") - } - if !bytes.Equal(t2.Get("byteslice01").([]byte), []byte{byte(0x00), byte(0x01)}) { - t.Errorf("Expected byteslice 0x00 0x01") - } - if !bytes.Equal(t2.Get("byteslice23").([]byte), []byte{byte(0x02), byte(0x03)}) { - t.Errorf("Expected byteslice 0x02 0x03") - } - if t2.Get("time").(time.Time).Unix() != 123 { - t.Errorf("Expected time 123") - } - if t2.Get("nil") != nil { - t.Errorf("Expected nil") - } -} - -func BenchmarkHash(b *testing.B) { - b.StopTimer() - - s := randstr(128) - - b.StartTimer() - for i := 0; i < b.N; i++ { - hasher := sha256.New() - hasher.Write([]byte(s)) - hasher.Sum(nil) - } -} - func BenchmarkImmutableAvlTree(b *testing.B) { b.StopTimer() - type record struct { - key string - value string - } - - randomRecord := func() *record { - return &record{randstr(32), randstr(32)} - } - - t := NewIAVLTree(nil) + t := NewIAVLTree(BasicCodec, BasicCodec, 0, nil) for i := 0; i < 1000000; i++ { - r := randomRecord() - t.Set([]byte(r.key), []byte(r.value)) + t.Set(RandUInt64(), "") } fmt.Println("ok, starting") @@ -314,8 +254,8 @@ func BenchmarkImmutableAvlTree(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - r := randomRecord() - t.Set([]byte(r.key), []byte(r.value)) - t.Remove([]byte(r.key)) + ri := RandUInt64() + t.Set(ri, "") + t.Remove(ri) } } diff --git a/merkle/iavl_tree.go b/merkle/iavl_tree.go index 5fed69677..84424e79a 100644 --- a/merkle/iavl_tree.go +++ b/merkle/iavl_tree.go @@ -3,6 +3,9 @@ package merkle import ( "bytes" "container/list" + + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/db" ) const defaultCacheCapacity = 1000 // TODO make configurable. @@ -14,53 +17,64 @@ This tree is not concurrency safe. You must wrap your calls with your own mutex. */ type IAVLTree struct { - ndb *IAVLNodeDB - root *IAVLNode + keyCodec Codec + valueCodec Codec + root *IAVLNode + + // Cache + cache map[string]nodeElement + cacheSize int + queue *list.List + + // Persistence + db DB } -func NewIAVLTree(db DB) *IAVLTree { +func NewIAVLTree(keyCodec, valueCodec Codec, cacheSize int, db DB) *IAVLTree { return &IAVLTree{ - ndb: NewIAVLNodeDB(defaultCacheCapacity, db), - root: nil, + keyCodec: keyCodec, + valueCodec: valueCodec, + root: nil, + cache: make(map[string]nodeElement), + cacheSize: cacheSize, + queue: list.New(), + db: db, } } -func LoadIAVLTreeFromHash(db DB, hash []byte) *IAVLTree { - ndb := NewIAVLNodeDB(defaultCacheCapacity, db) - root := ndb.Get(hash) - if root == nil { - return nil - } - return &IAVLTree{ndb: ndb, root: root} +func LoadIAVLTreeFromHash(keyCodec, valueCodec Codec, cacheSize int, db DB, hash []byte) *IAVLTree { + t := NewIAVLTree(keyCodec, valueCodec, cacheSize, db) + t.root = t.getNode(hash) + return t } func (t *IAVLTree) Size() uint64 { if t.root == nil { return 0 } - return t.root.Size() + return t.root.size } func (t *IAVLTree) Height() uint8 { if t.root == nil { return 0 } - return t.root.Height() + return t.root.height } -func (t *IAVLTree) Has(key []byte) bool { +func (t *IAVLTree) Has(key interface{}) bool { if t.root == nil { return false } - return t.root.has(t.ndb, key) + return t.root.has(t, key) } -func (t *IAVLTree) Set(key []byte, value []byte) (updated bool) { +func (t *IAVLTree) Set(key interface{}, value interface{}) (updated bool) { if t.root == nil { t.root = NewIAVLNode(key, value) return false } - t.root, updated = t.root.set(t.ndb, key, value) + t.root, updated = t.root.set(t, key, value) return updated } @@ -68,7 +82,7 @@ func (t *IAVLTree) Hash() []byte { if t.root == nil { return nil } - hash, _ := t.root.HashWithCount() + hash, _ := t.root.hashWithCount(t) return hash } @@ -76,117 +90,110 @@ func (t *IAVLTree) HashWithCount() ([]byte, uint64) { if t.root == nil { return nil, 0 } - return t.root.HashWithCount() + return t.root.hashWithCount(t) } func (t *IAVLTree) Save() []byte { if t.root == nil { return nil } - return t.root.Save(t.ndb) + return t.root.save(t) } -func (t *IAVLTree) Get(key []byte) (value []byte) { +func (t *IAVLTree) Get(key interface{}) (index uint64, value interface{}) { if t.root == nil { - return nil + return 0, nil } - return t.root.get(t.ndb, key) + return t.root.get(t, key) } -func (t *IAVLTree) Remove(key []byte) (value []byte, err error) { +func (t *IAVLTree) GetByIndex(index uint64) (key interface{}, value interface{}) { if t.root == nil { - return nil, NotFound(key) + return nil, nil } - newRootHash, newRoot, _, value, err := t.root.remove(t.ndb, key) - if err != nil { - return nil, err + return t.root.getByIndex(t, index) +} + +func (t *IAVLTree) Remove(key interface{}) (value interface{}, removed bool) { + if t.root == nil { + return nil, false + } + newRootHash, newRoot, _, value, removed := t.root.remove(t, key) + if !removed { + return nil, false } if newRoot == nil && newRootHash != nil { - t.root = t.ndb.Get(newRootHash) + t.root = t.getNode(newRootHash) } else { t.root = newRoot } - return value, nil + return value, true } -func (t *IAVLTree) Copy() Tree { - return &IAVLTree{ndb: t.ndb, root: t.root} +func (t *IAVLTree) Checkpoint() interface{} { + return t.root } -//----------------------------------------------------------------------------- +func (t *IAVLTree) Restore(checkpoint interface{}) { + t.root = checkpoint.(*IAVLNode) +} type nodeElement struct { node *IAVLNode elem *list.Element } -type IAVLNodeDB struct { - capacity int - db DB - cache map[string]nodeElement - queue *list.List -} - -func NewIAVLNodeDB(capacity int, db DB) *IAVLNodeDB { - return &IAVLNodeDB{ - capacity: capacity, - db: db, - cache: make(map[string]nodeElement), - queue: list.New(), - } -} - -func (ndb *IAVLNodeDB) Get(hash []byte) *IAVLNode { +func (t *IAVLTree) getNode(hash []byte) *IAVLNode { // Check the cache. - nodeElem, ok := ndb.cache[string(hash)] + nodeElem, ok := t.cache[string(hash)] if ok { // Already exists. Move to back of queue. - ndb.queue.MoveToBack(nodeElem.elem) + t.queue.MoveToBack(nodeElem.elem) return nodeElem.node } else { // Doesn't exist, load. - buf := ndb.db.Get(hash) + buf := t.db.Get(hash) r := bytes.NewReader(buf) var n int64 var err error - node := ReadIAVLNode(r, &n, &err) + node := ReadIAVLNode(t, r, &n, &err) if err != nil { panic(err) } node.persisted = true - ndb.cacheNode(node) + t.cacheNode(node) return node } } -func (ndb *IAVLNodeDB) Save(node *IAVLNode) { +func (t *IAVLTree) cacheNode(node *IAVLNode) { + // Create entry in cache and append to queue. + elem := t.queue.PushBack(node.hash) + t.cache[string(node.hash)] = nodeElement{node, elem} + // Maybe expire an item. + if t.queue.Len() > t.cacheSize { + hash := t.queue.Remove(t.queue.Front()).([]byte) + delete(t.cache, string(hash)) + } +} + +func (t *IAVLTree) saveNode(node *IAVLNode) { if node.hash == nil { panic("Expected to find node.hash, but none found.") } if node.persisted { panic("Shouldn't be calling save on an already persisted node.") } - if _, ok := ndb.cache[string(node.hash)]; ok { + if _, ok := t.cache[string(node.hash)]; ok { panic("Shouldn't be calling save on an already cached node.") } // Save node bytes to db buf := bytes.NewBuffer(nil) - _, err := node.WriteTo(buf) + _, _, err := node.writeToCountHashes(t, buf) if err != nil { panic(err) } - ndb.db.Set(node.hash, buf.Bytes()) + t.db.Set(node.hash, buf.Bytes()) node.persisted = true - ndb.cacheNode(node) -} - -func (ndb *IAVLNodeDB) cacheNode(node *IAVLNode) { - // Create entry in cache and append to queue. - elem := ndb.queue.PushBack(node.hash) - ndb.cache[string(node.hash)] = nodeElement{node, elem} - // Maybe expire an item. - if ndb.queue.Len() > ndb.capacity { - hash := ndb.queue.Remove(ndb.queue.Front()).([]byte) - delete(ndb.cache, string(hash)) - } + t.cacheNode(node) } diff --git a/merkle/typed_tree.go b/merkle/typed_tree.go deleted file mode 100644 index efafc6b88..000000000 --- a/merkle/typed_tree.go +++ /dev/null @@ -1,81 +0,0 @@ -package merkle - -import ( - . "github.com/tendermint/tendermint/binary" - . "github.com/tendermint/tendermint/common" -) - -// TODO: make TypedTree work with the underlying tree to cache the decoded value. -type TypedTree struct { - Tree Tree - keyCodec Codec - valueCodec Codec -} - -func NewTypedTree(tree Tree, keyCodec, valueCodec Codec) *TypedTree { - return &TypedTree{ - Tree: tree, - keyCodec: keyCodec, - valueCodec: valueCodec, - } -} - -func (t *TypedTree) Has(key interface{}) bool { - bytes, err := t.keyCodec.Write(key) - if err != nil { - Panicf("Error from keyCodec: %v", err) - } - return t.Tree.Has(bytes) -} - -func (t *TypedTree) Get(key interface{}) interface{} { - keyBytes, err := t.keyCodec.Write(key) - if err != nil { - Panicf("Error from keyCodec: %v", err) - } - valueBytes := t.Tree.Get(keyBytes) - if valueBytes == nil { - return nil - } - value, err := t.valueCodec.Read(valueBytes) - if err != nil { - Panicf("Error from valueCodec: %v", err) - } - return value -} - -func (t *TypedTree) Set(key interface{}, value interface{}) bool { - keyBytes, err := t.keyCodec.Write(key) - if err != nil { - Panicf("Error from keyCodec: %v", err) - } - valueBytes, err := t.valueCodec.Write(value) - if err != nil { - Panicf("Error from valueCodec: %v", err) - } - return t.Tree.Set(keyBytes, valueBytes) -} - -func (t *TypedTree) Remove(key interface{}) (interface{}, error) { - keyBytes, err := t.keyCodec.Write(key) - if err != nil { - Panicf("Error from keyCodec: %v", err) - } - valueBytes, err := t.Tree.Remove(keyBytes) - if valueBytes == nil { - return nil, err - } - value, err_ := t.valueCodec.Read(valueBytes) - if err_ != nil { - Panicf("Error from valueCodec: %v", err) - } - return value, err -} - -func (t *TypedTree) Copy() *TypedTree { - return &TypedTree{ - Tree: t.Tree.Copy(), - keyCodec: t.keyCodec, - valueCodec: t.valueCodec, - } -} diff --git a/merkle/types.go b/merkle/types.go index 5538197a6..3631725d3 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -1,29 +1,18 @@ package merkle -import ( - "fmt" -) - -type DB interface { - Get([]byte) []byte - Set([]byte, []byte) -} - type Tree interface { - Size() uint64 - Height() uint8 - Has(key []byte) bool - Get(key []byte) []byte - Set(key []byte, value []byte) bool - Remove(key []byte) ([]byte, error) - HashWithCount() ([]byte, uint64) - Hash() []byte - Save() []byte - Copy() Tree -} - -func NotFound(key []byte) error { - return fmt.Errorf("Key was not found.") + Size() (size uint64) + Height() (height uint8) + Has(key interface{}) (has bool) + Get(key interface{}) (index uint64, value interface{}) + GetByIndex(index uint64) (key interface{}, value interface{}) + Set(key interface{}, value interface{}) (updated bool) + Remove(key interface{}) (value interface{}, removed bool) + HashWithCount() (hash []byte, count uint64) + Hash() (hash []byte) + Save() (hash []byte) + Checkpoint() (checkpoint interface{}) + Restore(checkpoint interface{}) } type Hashable interface {