diff --git a/merkle/db.go b/merkle/db.go index 89f3c05cb..74c180a3a 100644 --- a/merkle/db.go +++ b/merkle/db.go @@ -22,13 +22,13 @@ func NewLDBDatabase(name string) (*LDBDatabase, error) { func (db *LDBDatabase) Put(key []byte, value []byte) { err := db.db.Put(key, value, nil) - if err != nil { - fmt.Println("Error put", err) - } + if err != nil { panic(err) } } -func (db *LDBDatabase) Get(key []byte) ([]byte, error) { - return db.db.Get(key, nil) +func (db *LDBDatabase) Get(key []byte) ([]byte) { + res, err := db.db.Get(key, nil) + if err != nil { panic(err) } + return res } func (db *LDBDatabase) Delete(key []byte) error { diff --git a/merkle/iavl.go b/merkle/iavl.go index 6b6ebde46..46ab71b63 100644 --- a/merkle/iavl.go +++ b/merkle/iavl.go @@ -1,12 +1,11 @@ package merkle import ( - "bytes" - "math" - "io" "crypto/sha256" ) +const HASH_BYTE_SIZE int = 4+32 + // Immutable AVL Tree (wraps the Node root) type IAVLTree struct { @@ -33,16 +32,15 @@ func (self *IAVLTree) Has(key Key) bool { return self.root.Has(nil, key) } -func (self *IAVLTree) Put(key Key, value Value) (err error) { +func (self *IAVLTree) Put(key Key, value Value) { self.root, _ = self.root.Put(nil, key, value) - return nil } -func (self *IAVLTree) Hash() ([]byte, uint64) { +func (self *IAVLTree) Hash() (ByteSlice, uint64) { return self.root.Hash() } -func (self *IAVLTree) Get(key Key) (value Value, err error) { +func (self *IAVLTree) Get(key Key) (value Value) { return self.root.Get(nil, key) } @@ -104,27 +102,18 @@ func (self *IAVLNode) Right(db Db) Node { return self.right_filled(db) } -func (self *IAVLNode) left_filled(db Db) *IAVLNode { - // XXX - return self.left -} - -func (self *IAVLNode) right_filled(db Db) *IAVLNode { - // XXX - return self.right -} - func (self *IAVLNode) Size() uint64 { - if self == nil { - return 0 - } + if self == nil { return 0 } return self.size } +func (self *IAVLNode) Height() uint8 { + if self == nil { return 0 } + return self.height +} + func (self *IAVLNode) Has(db Db, key Key) (has bool) { - if self == nil { - return false - } + if self == nil { return false } if self.key.Equals(key) { return true } else if key.Less(self.key) { @@ -134,12 +123,10 @@ func (self *IAVLNode) Has(db Db, key Key) (has bool) { } } -func (self *IAVLNode) Get(db Db, key Key) (value Value, err error) { - if self == nil { - return nil, NotFound(key) - } +func (self *IAVLNode) Get(db Db, key Key) (value Value) { + if self == nil { return nil } if self.key.Equals(key) { - return self.value, nil + return self.value } else if key.Less(self.key) { return self.left_filled(db).Get(db, key) } else { @@ -147,78 +134,177 @@ func (self *IAVLNode) Get(db Db, key Key) (value Value, err error) { } } -func (self *IAVLNode) Bytes() []byte { - b := new(bytes.Buffer) - self.WriteTo(b) - return b.Bytes() -} - -func (self *IAVLNode) Hash() ([]byte, uint64) { - if self == nil { - return nil, 0 - } +func (self *IAVLNode) Hash() (ByteSlice, uint64) { + if self == nil { return nil, 0 } if self.hash != nil { return self.hash, 0 } + size := self.ByteSize() + buf := make([]byte, size, size) hasher := sha256.New() - _, hashCount, err := self.WriteTo(hasher) - if err != nil { panic(err) } + _, hashCount := self.saveToCountHashes(buf) + hasher.Write(buf) self.hash = hasher.Sum(nil) - return self.hash, hashCount + return self.hash, hashCount+1 +} + +// TODO: don't clear the hash if the value hasn't changed. +func (self *IAVLNode) Put(db Db, key Key, value Value) (_ *IAVLNode, updated bool) { + if self == nil { + return &IAVLNode{key: key, value: value, height: 1, size: 1, hash: nil}, false + } + + self = self.Copy() + + if self.key.Equals(key) { + self.value = value + return self, true + } + + if key.Less(self.key) { + self.left, updated = self.left_filled(db).Put(db, key, value) + } else { + self.right, updated = self.right_filled(db).Put(db, key, value) + } + if updated { + return self, updated + } else { + self.calc_height_and_size(db) + return self.balance(db), updated + } } -func (self *IAVLNode) WriteTo(writer io.Writer) (written int64, hashCount uint64, err error) { +func (self *IAVLNode) Remove(db Db, key Key) (new_self *IAVLNode, value Value, err error) { + if self == nil { return nil, nil, NotFound(key) } - write := func(bytes []byte) { - if err == nil { - var n int - n, err = writer.Write(bytes) - written += int64(n) + if self.key.Equals(key) { + if self.left != nil && self.right != nil { + if self.left_filled(db).Size() < self.right_filled(db).Size() { + self, new_self = self.pop_node(db, self.right_filled(db).lmd(db)) + } else { + self, new_self = self.pop_node(db, self.left_filled(db).rmd(db)) + } + new_self.left = self.left + new_self.right = self.right + new_self.calc_height_and_size(db) + return new_self, self.value, nil + } else if self.left == nil { + return self.right_filled(db), self.value, nil + } else if self.right == nil { + return self.left_filled(db), self.value, nil + } else { + return nil, self.value, nil } } + if key.Less(self.key) { + if self.left == nil { + return self, nil, NotFound(key) + } + var new_left *IAVLNode + new_left, value, err = self.left_filled(db).Remove(db, key) + if new_left == self.left_filled(db) { // not found + return self, nil, err + } else if err != nil { // some other error + return self, value, err + } + self = self.Copy() + self.left = new_left + } else { + if self.right == nil { + return self, nil, NotFound(key) + } + var new_right *IAVLNode + new_right, value, err = self.right_filled(db).Remove(db, key) + if new_right == self.right_filled(db) { // not found + return self, nil, err + } else if err != nil { // some other error + return self, value, err + } + self = self.Copy() + self.right = new_right + } + self.calc_height_and_size(db) + return self.balance(db), value, err +} + +func (self *IAVLNode) ByteSize() int { + // 1 byte node descriptor + // 1 byte node neight + // 8 bytes node size + size := 10 + size += self.key.ByteSize() + if self.value != nil { + size += self.value.ByteSize() + } else { + size += 1 + } + if self.left != nil { + size += HASH_BYTE_SIZE + } + if self.right != nil { + size += HASH_BYTE_SIZE + } + return size +} + +func (self *IAVLNode) SaveTo(buf []byte) int { + written, _ := self.saveToCountHashes(buf) + return written +} + +func (self *IAVLNode) saveToCountHashes(buf []byte) (int, uint64) { + cur := 0 + hashCount := uint64(0) + // node descriptor nodeDesc := byte(0) if self.value != nil { nodeDesc |= 0x01 } if self.left != nil { nodeDesc |= 0x02 } if self.right != nil { nodeDesc |= 0x04 } - write([]byte{nodeDesc}) + cur += UInt8(nodeDesc).SaveTo(buf[cur:]) // node height & size - write(UInt8(self.height).Bytes()) - write(UInt64(self.size).Bytes()) + cur += UInt8(self.height).SaveTo(buf[cur:]) + cur += UInt64(self.size).SaveTo(buf[cur:]) // node key - keyBytes := self.key.Bytes() - if len(keyBytes) > 255 { panic("key is too long") } - write([]byte{byte(len(keyBytes))}) - write(keyBytes) + cur += self.key.SaveTo(buf[cur:]) // node value if self.value != nil { - valueBytes := self.value.Bytes() - if len(valueBytes) > math.MaxUint32 { panic("value is too long") } - write([]byte{byte(len(valueBytes))}) - write(valueBytes) + cur += self.value.SaveTo(buf[cur:]) + } else { + cur += UInt8(0).SaveTo(buf[cur:]) } // left child if self.left != nil { leftHash, leftCount := self.left.Hash() hashCount += leftCount - write(leftHash) + cur += leftHash.SaveTo(buf[cur:]) } // right child if self.right != nil { rightHash, rightCount := self.right.Hash() hashCount += rightCount - write(rightHash) + cur += rightHash.SaveTo(buf[cur:]) } - return written, hashCount+1, err + return cur, hashCount +} + +func (self *IAVLNode) left_filled(db Db) *IAVLNode { + // XXX + return self.left +} + +func (self *IAVLNode) right_filled(db Db) *IAVLNode { + // XXX + return self.right } // Returns a new tree (unless node is the root) & a copy of the popped node. @@ -332,97 +418,6 @@ func (self *IAVLNode) balance(db Db) (new_self *IAVLNode) { return self } -// TODO: don't clear the hash if the value hasn't changed. -func (self *IAVLNode) Put(db Db, key Key, value Value) (_ *IAVLNode, updated bool) { - if self == nil { - return &IAVLNode{key: key, value: value, height: 1, size: 1, hash: nil}, false - } - - self = self.Copy() - - if self.key.Equals(key) { - self.value = value - return self, true - } - - if key.Less(self.key) { - self.left, updated = self.left_filled(db).Put(db, key, value) - } else { - self.right, updated = self.right_filled(db).Put(db, key, value) - } - if updated { - return self, updated - } else { - self.calc_height_and_size(db) - return self.balance(db), updated - } -} - -func (self *IAVLNode) Remove(db Db, key Key) (new_self *IAVLNode, value Value, err error) { - if self == nil { - return nil, nil, NotFound(key) - } - - if self.key.Equals(key) { - if self.left != nil && self.right != nil { - if self.left_filled(db).Size() < self.right_filled(db).Size() { - self, new_self = self.pop_node(db, self.right_filled(db).lmd(db)) - } else { - self, new_self = self.pop_node(db, self.left_filled(db).rmd(db)) - } - new_self.left = self.left - new_self.right = self.right - new_self.calc_height_and_size(db) - return new_self, self.value, nil - } else if self.left == nil { - return self.right_filled(db), self.value, nil - } else if self.right == nil { - return self.left_filled(db), self.value, nil - } else { - return nil, self.value, nil - } - } - - if key.Less(self.key) { - if self.left == nil { - return self, nil, NotFound(key) - } - var new_left *IAVLNode - new_left, value, err = self.left_filled(db).Remove(db, key) - if new_left == self.left_filled(db) { // not found - return self, nil, err - } else if err != nil { // some other error - return self, value, err - } - self = self.Copy() - self.left = new_left - } else { - if self.right == nil { - return self, nil, NotFound(key) - } - var new_right *IAVLNode - new_right, value, err = self.right_filled(db).Remove(db, key) - if new_right == self.right_filled(db) { // not found - return self, nil, err - } else if err != nil { // some other error - return self, value, err - } - self = self.Copy() - self.right = new_right - } - self.calc_height_and_size(db) - return self.balance(db), value, err -} - -func (self *IAVLNode) Height() uint8 { - if self == nil { - return 0 - } - return self.height -} - -// ... - func (self *IAVLNode) _md(side func(*IAVLNode)*IAVLNode) (*IAVLNode) { if self == nil { return nil diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index 295c1f17c..ec0d2c368 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -80,9 +80,7 @@ func TestImmutableAvlPutHasGetRemove(t *testing.T) { if has := tree.Has(nil, randstr(12)); has { t.Error("Table has extra key") } - if val, err := tree.Get(nil, r.key); err != nil { - t.Error(err, val.(String), r.value) - } else if !(val.(String)).Equals(r.value) { + if val := tree.Get(nil, r.key); !(val.(String)).Equals(r.value) { t.Error("wrong value") } } @@ -100,9 +98,7 @@ func TestImmutableAvlPutHasGetRemove(t *testing.T) { if has := tree.Has(nil, randstr(12)); has { t.Error("Table has extra key") } - if val, err := tree.Get(nil, r.key); err != nil { - t.Error(err) - } else if !(val.(String)).Equals(r.value) { + if val := tree.Get(nil, r.key); !(val.(String)).Equals(r.value) { t.Error("wrong value") } } @@ -151,9 +147,7 @@ func TestTraversals(t *testing.T) { test := func(T Tree) { t.Logf("%T", T) for j := range order { - if err := T.Put(Int(data[order[j]]), Int(order[j])); err != nil { - t.Error(err) - } + T.Put(Int(data[order[j]]), Int(order[j])) } j := 0 diff --git a/merkle/int.go b/merkle/int.go index f16022e87..d6d6a09a7 100644 --- a/merkle/int.go +++ b/merkle/int.go @@ -16,6 +16,8 @@ type Int int type UInt uint +// Int8 + func (self Int8) Equals(other Key) bool { if o, ok := other.(Int8); ok { return self == o @@ -32,10 +34,22 @@ func (self Int8) Less(other Key) bool { } } -func (self Int8) Bytes() []byte { - return []byte{byte(self)} +func (self Int8) ByteSize() int { + return 1 +} + +func (self Int8) SaveTo(b []byte) int { + if cap(b) < 1 { panic("buf too small") } + b[0] = byte(self) + return 1 } +func LoadInt8(bytes []byte) Int8 { + return Int8(bytes[0]) +} + + +// UInt8 func (self UInt8) Equals(other Key) bool { if o, ok := other.(UInt8); ok { @@ -53,10 +67,22 @@ func (self UInt8) Less(other Key) bool { } } -func (self UInt8) Bytes() []byte { - return []byte{byte(self)} +func (self UInt8) ByteSize() int { + return 1 } +func (self UInt8) SaveTo(b []byte) int { + if cap(b) < 1 { panic("buf too small") } + b[0] = byte(self) + return 1 +} + +func LoadUInt8(bytes []byte) UInt8 { + return UInt8(bytes[0]) +} + + +// Int16 func (self Int16) Equals(other Key) bool { if o, ok := other.(Int16); ok { @@ -74,12 +100,22 @@ func (self Int16) Less(other Key) bool { } } -func (self Int16) Bytes() []byte { - b := [2]byte{} - binary.LittleEndian.PutUint16(b[:], uint16(self)) - return b[:] +func (self Int16) ByteSize() int { + return 2 } +func (self Int16) SaveTo(b []byte) int { + if cap(b) < 2 { panic("buf too small") } + binary.LittleEndian.PutUint16(b, uint16(self)) + return 2 +} + +func LoadInt16(bytes []byte) Int16 { + return Int16(binary.LittleEndian.Uint16(bytes)) +} + + +// UInt16 func (self UInt16) Equals(other Key) bool { if o, ok := other.(UInt16); ok { @@ -97,13 +133,23 @@ func (self UInt16) Less(other Key) bool { } } -func (self UInt16) Bytes() []byte { - b := [2]byte{} - binary.LittleEndian.PutUint16(b[:], uint16(self)) - return b[:] +func (self UInt16) ByteSize() int { + return 2 +} + +func (self UInt16) SaveTo(b []byte) int { + if cap(b) < 2 { panic("buf too small") } + binary.LittleEndian.PutUint16(b, uint16(self)) + return 2 +} + +func LoadUInt16(bytes []byte) UInt16 { + return UInt16(binary.LittleEndian.Uint16(bytes)) } +// Int32 + func (self Int32) Equals(other Key) bool { if o, ok := other.(Int32); ok { return self == o @@ -120,13 +166,23 @@ func (self Int32) Less(other Key) bool { } } -func (self Int32) Bytes() []byte { - b := [4]byte{} - binary.LittleEndian.PutUint32(b[:], uint32(self)) - return b[:] +func (self Int32) ByteSize() int { + return 4 +} + +func (self Int32) SaveTo(b []byte) int { + if cap(b) < 4 { panic("buf too small") } + binary.LittleEndian.PutUint32(b, uint32(self)) + return 4 +} + +func LoadInt32(bytes []byte) Int32 { + return Int32(binary.LittleEndian.Uint32(bytes)) } +// UInt32 + func (self UInt32) Equals(other Key) bool { if o, ok := other.(UInt32); ok { return self == o @@ -143,12 +199,22 @@ func (self UInt32) Less(other Key) bool { } } -func (self UInt32) Bytes() []byte { - b := [4]byte{} - binary.LittleEndian.PutUint32(b[:], uint32(self)) - return b[:] +func (self UInt32) ByteSize() int { + return 4 +} + +func (self UInt32) SaveTo(b []byte) int { + if cap(b) < 4 { panic("buf too small") } + binary.LittleEndian.PutUint32(b, uint32(self)) + return 4 } +func LoadUInt32(bytes []byte) UInt32 { + return UInt32(binary.LittleEndian.Uint32(bytes)) +} + + +// Int64 func (self Int64) Equals(other Key) bool { if o, ok := other.(Int64); ok { @@ -166,12 +232,22 @@ func (self Int64) Less(other Key) bool { } } -func (self Int64) Bytes() []byte { - b := [8]byte{} - binary.LittleEndian.PutUint64(b[:], uint64(self)) - return b[:] +func (self Int64) ByteSize() int { + return 8 } +func (self Int64) SaveTo(b []byte) int { + if cap(b) < 8 { panic("buf too small") } + binary.LittleEndian.PutUint64(b, uint64(self)) + return 8 +} + +func LoadInt64(bytes []byte) Int64 { + return Int64(binary.LittleEndian.Uint64(bytes)) +} + + +// UInt64 func (self UInt64) Equals(other Key) bool { if o, ok := other.(UInt64); ok { @@ -189,13 +265,23 @@ func (self UInt64) Less(other Key) bool { } } -func (self UInt64) Bytes() []byte { - b := [8]byte{} - binary.LittleEndian.PutUint64(b[:], uint64(self)) - return b[:] +func (self UInt64) ByteSize() int { + return 8 +} + +func (self UInt64) SaveTo(b []byte) int { + if cap(b) < 8 { panic("buf too small") } + binary.LittleEndian.PutUint64(b, uint64(self)) + return 8 +} + +func LoadUInt64(bytes []byte) UInt64 { + return UInt64(binary.LittleEndian.Uint64(bytes)) } +// Int + func (self Int) Equals(other Key) bool { if o, ok := other.(Int); ok { return self == o @@ -212,12 +298,21 @@ func (self Int) Less(other Key) bool { } } -func (self Int) Bytes() []byte { - b := [8]byte{} - binary.LittleEndian.PutUint64(b[:], uint64(self)) - return b[:] +func (self Int) ByteSize() int { + return 8 +} + +func (self Int) SaveTo(b []byte) int { + if cap(b) < 8 { panic("buf too small") } + binary.LittleEndian.PutUint64(b, uint64(self)) + return 8 } +func LoadInt(bytes []byte) Int { + return Int(binary.LittleEndian.Uint64(bytes)) +} + +// UInt func (self UInt) Equals(other Key) bool { if o, ok := other.(UInt); ok { @@ -235,8 +330,16 @@ func (self UInt) Less(other Key) bool { } } -func (self UInt) Bytes() []byte { - b := [8]byte{} - binary.LittleEndian.PutUint64(b[:], uint64(self)) - return b[:] +func (self UInt) ByteSize() int { + return 8 +} + +func (self UInt) SaveTo(b []byte) int { + if cap(b) < 8 { panic("buf too small") } + binary.LittleEndian.PutUint64(b, uint64(self)) + return 8 +} + +func LoadUInt(bytes []byte) UInt { + return UInt(binary.LittleEndian.Uint64(bytes)) } diff --git a/merkle/string.go b/merkle/string.go index b09ab4a19..e52cc59f9 100644 --- a/merkle/string.go +++ b/merkle/string.go @@ -5,6 +5,8 @@ import "bytes" type String string type ByteSlice []byte +// String + func (self String) Equals(other Key) bool { if o, ok := other.(String); ok { return self == o @@ -21,10 +23,25 @@ func (self String) Less(other Key) bool { } } -func (self String) Bytes() []byte { - return []byte(self) +func (self String) ByteSize() int { + return len(self)+4 +} + +func (self String) SaveTo(buf []byte) int { + if len(buf) < self.ByteSize() { panic("buf too small") } + UInt32(len(self)).SaveTo(buf) + copy(buf[4:], []byte(self)) + return len(self)+4 } +func LoadString(bytes []byte) String { + length := LoadUInt32(bytes) + return String(bytes[4:4+length]) +} + + +// ByteSlice + func (self ByteSlice) Equals(other Key) bool { if o, ok := other.(ByteSlice); ok { return bytes.Equal(self, o) @@ -41,6 +58,18 @@ func (self ByteSlice) Less(other Key) bool { } } -func (self ByteSlice) Bytes() []byte { - return []byte(self) +func (self ByteSlice) ByteSize() int { + return len(self)+4 +} + +func (self ByteSlice) SaveTo(buf []byte) int { + if len(buf) < self.ByteSize() { panic("buf too small") } + UInt32(len(self)).SaveTo(buf) + copy(buf[4:], self) + return len(self)+4 +} + +func LoadByteSlice(bytes []byte) ByteSlice { + length := LoadUInt32(bytes) + return ByteSlice(bytes[4:4+length]) } diff --git a/merkle/types.go b/merkle/types.go index ecd1b0ca0..00ffdff18 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -4,14 +4,20 @@ import ( "fmt" ) +type Binary interface { + ByteSize() int + SaveTo([]byte) int +} + type Value interface { - Bytes() []byte + Binary } type Key interface { + Binary + Equals(b Key) bool Less(b Key) bool - Bytes() []byte } type Tree interface { @@ -20,19 +26,21 @@ type Tree interface { Size() uint64 Height() uint8 Has(key Key) bool - Get(key Key) (Value, error) - Hash() ([]byte, uint64) + Get(key Key) Value + Hash() (ByteSlice, uint64) - Put(Key, Value) (err error) + Put(Key, Value) Remove(Key) (Value, error) } type Db interface { - Get([]byte) ([]byte, error) - Put([]byte, []byte) error + Get([]byte) []byte + Put([]byte, []byte) } type Node interface { + Binary + Key() Key Value() Value Left(Db) Node @@ -41,9 +49,8 @@ type Node interface { Size() uint64 Height() uint8 Has(Db, Key) bool - Get(Db, Key) (Value, error) - Hash() ([]byte, uint64) - Bytes() []byte + Get(Db, Key) Value + Hash() (ByteSlice, uint64) Put(Db, Key, Value) (*IAVLNode, bool) Remove(Db, Key) (*IAVLNode, Value, error)