diff --git a/merkle/iavl.go b/merkle/iavl.go index 02043462e..2d428c360 100644 --- a/merkle/iavl.go +++ b/merkle/iavl.go @@ -2,7 +2,8 @@ package merkle import ( //"fmt" - "hash" + "math" + //"hash" "crypto/sha256" ) @@ -24,20 +25,24 @@ func (self *IAVLTree) Size() int { return self.root.Size() } -func (self *IAVLTree) Has(key Sortable) bool { +func (self *IAVLTree) Has(key Key) bool { return self.root.Has(key) } -func (self *IAVLTree) Put(key Sortable, value interface{}) (err error) { +func (self *IAVLTree) Put(key Key, value Value) (err error) { self.root, _ = self.root.Put(key, value) return nil } -func (self *IAVLTree) Get(key Sortable) (value interface{}, err error) { +func (self *IAVLTree) Hash() []byte { + return self.root.Hash() +} + +func (self *IAVLTree) Get(key Key) (value Value, err error) { return self.root.Get(key) } -func (self *IAVLTree) Remove(key Sortable) (value interface{}, err error) { +func (self *IAVLTree) Remove(key Key) (value Value, err error) { new_root, value, err := self.root.Remove(key) if err != nil { return nil, err @@ -49,8 +54,8 @@ func (self *IAVLTree) Remove(key Sortable) (value interface{}, err error) { // Node type IAVLNode struct { - key Sortable - value interface{} + key Key + value Value height int hash []byte left *IAVLNode @@ -75,7 +80,32 @@ func (self *IAVLNode) Copy(copyHash bool) *IAVLNode { } } -func (self *IAVLNode) Has(key Sortable) (has bool) { +func (self *IAVLNode) Key() Key { + return self.key +} + +func (self *IAVLNode) Value() Value { + return self.value +} + +func (self *IAVLNode) Left() Node { + if self.left == nil { return nil } + return self.left +} + +func (self *IAVLNode) Right() Node { + if self.right == nil { return nil } + return self.right +} + +func (self *IAVLNode) Size() int { + if self == nil { + return 0 + } + return 1 + self.left.Size() + self.right.Size() +} + +func (self *IAVLNode) Has(key Key) (has bool) { if self == nil { return false } @@ -88,7 +118,7 @@ func (self *IAVLNode) Has(key Sortable) (has bool) { } } -func (self *IAVLNode) Get(key Sortable) (value interface{}, err error) { +func (self *IAVLNode) Get(key Key) (value Value, err error) { if self == nil { return nil, NotFound(key) } @@ -101,6 +131,47 @@ func (self *IAVLNode) Get(key Sortable) (value interface{}, err error) { } } +func (self *IAVLNode) Hash() []byte { + if self == nil { + return nil + } + if self.hash != nil { return self.hash } + hasher := sha256.New() + + // node descriptor + nodeDesc := byte(0) + if self.value != nil { nodeDesc |= 0x01 } + if self.left != nil { nodeDesc |= 0x02 } + if self.right != nil { nodeDesc |= 0x04 } + hasher.Write([]byte{nodeDesc}) + + // node key + keyBytes := self.key.Bytes() + if len(keyBytes) > 255 { panic("key is too long") } + hasher.Write([]byte{byte(len(keyBytes))}) + hasher.Write(keyBytes) + + // node value + if self.value != nil { + valueBytes := self.value.Bytes() + if len(valueBytes) > math.MaxUint32 { panic("value is too long") } + hasher.Write([]byte{byte(len(valueBytes))}) + hasher.Write(valueBytes) + } + + // left child + if self.left != nil { + hasher.Write(self.left.Hash()) + } + + // right child + if self.right != nil { + hasher.Write(self.right.Hash()) + } + + return hasher.Sum(nil) +} + // Returns a new tree (unless node is the root) & a copy of the popped node. // Can only pop nodes that have one or no children. func (self *IAVLNode) pop_node(node *IAVLNode) (new_self, new_node *IAVLNode) { @@ -231,7 +302,7 @@ func (self *IAVLNode) balance() (new_self *IAVLNode) { } // TODO: don't clear the hash if the value hasn't changed. -func (self *IAVLNode) Put(key Sortable, value interface{}) (_ *IAVLNode, updated bool) { +func (self *IAVLNode) Put(key Key, value Value) (_ *IAVLNode, updated bool) { if self == nil { return &IAVLNode{key: key, value: value, height: 1, hash: nil}, false } @@ -256,7 +327,7 @@ func (self *IAVLNode) Put(key Sortable, value interface{}) (_ *IAVLNode, updated } } -func (self *IAVLNode) Remove(key Sortable) (new_self *IAVLNode, value interface{}, err error) { +func (self *IAVLNode) Remove(key Key) (new_self *IAVLNode, value Value, err error) { if self == nil { return nil, nil, NotFound(key) } @@ -318,32 +389,6 @@ func (self *IAVLNode) Height() int { return self.height } -func (self *IAVLNode) Size() int { - if self == nil { - return 0 - } - return 1 + self.left.Size() + self.right.Size() -} - - -func (self *IAVLNode) Key() Sortable { - return self.key -} - -func (self *IAVLNode) Value() interface{} { - return self.value -} - -func (self *IAVLNode) Left() Node { - if self.left == nil { return nil } - return self.left -} - -func (self *IAVLNode) Right() Node { - if self.right == nil { return nil } - return self.right -} - // ... func (self *IAVLNode) _md(side func(*IAVLNode)*IAVLNode) (*IAVLNode) { @@ -377,14 +422,3 @@ func max(a, b int) int { } return b } - -// Calculate the hash of hasher over buf. -func CalcHash(buf []byte, hasher hash.Hash) []byte { - hasher.Write(buf) - return hasher.Sum(nil) -} - -// calculate hash256 which is sha256(sha256(data)) -func CalcSha256(buf []byte) []byte { - return CalcHash(buf, sha256.New()) -} diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index 94fce5434..ed87e2a60 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -50,7 +50,7 @@ func TestImmutableAvlPutHasGetRemove(t *testing.T) { records := make([]*record, 400) var tree *IAVLNode var err error - var val interface{} + var val Value var updated bool ranrec := func() *record { @@ -155,7 +155,7 @@ func TestTraversals(t *testing.T) { test := func(T Tree) { t.Logf("%T", T) for j := range order { - if err := T.Put(Int(data[order[j]]), order[j]); err != nil { + if err := T.Put(Int(data[order[j]]), Int(order[j])); err != nil { t.Error(err) } } @@ -174,28 +174,17 @@ func TestTraversals(t *testing.T) { test(NewIAVLTree()) } -func BenchmarkSha256(b *testing.B) { - b.StopTimer() - - str := []byte(randstr(32)) - - ransha256 := func() []byte { - return CalcSha256(str) - } - - b.StartTimer() - for i := 0; i < b.N; i++ { - ransha256() - } -} - // from http://stackoverflow.com/questions/3955680/how-to-check-if-my-avl-tree-implementation-is-correct func TestGriffin(t *testing.T) { + + // Convenience for a new node N := func(l *IAVLNode, i int, r *IAVLNode) *IAVLNode { n := &IAVLNode{Int32(i), nil, -1, nil, l, r} n.calc_height() return n } + + // Convenience for simple printing of keys & tree structure var P func(*IAVLNode) string P = func(n *IAVLNode) string { if n.left == nil && n.right == nil { @@ -273,3 +262,17 @@ func TestGriffin(t *testing.T) { expectRemove(n6, 1, "(((2 3 4) 5 (6 7 -)) 8 (9 10 (- 11 12)))") } + +func TestHash(t *testing.T) { + + // Maybe, construct some tree & determine the number of new hashes calculated. + // Make sure the number of new hashings is expected, as well as the hash value. + // Then, nuke the hash values and reconstruct, ensure that the resulting hash is the same. + + tree := NewIAVLTree() + tree.Put(String("foo"), String("bar")) + fmt.Println(tree.Hash()) + + tree.Put(String("foo2"), String("bar")) + fmt.Println(tree.Hash()) +} diff --git a/merkle/int.go b/merkle/int.go index 26115a96c..f16022e87 100644 --- a/merkle/int.go +++ b/merkle/int.go @@ -1,5 +1,8 @@ package merkle +import ( + "encoding/binary" +) type Int8 int8 type UInt8 uint8 @@ -13,7 +16,7 @@ type Int int type UInt uint -func (self Int8) Equals(other Sortable) bool { +func (self Int8) Equals(other Key) bool { if o, ok := other.(Int8); ok { return self == o } else { @@ -21,7 +24,7 @@ func (self Int8) Equals(other Sortable) bool { } } -func (self Int8) Less(other Sortable) bool { +func (self Int8) Less(other Key) bool { if o, ok := other.(Int8); ok { return self < o } else { @@ -29,12 +32,12 @@ func (self Int8) Less(other Sortable) bool { } } -func (self Int8) Hash() int { - return int(self) +func (self Int8) Bytes() []byte { + return []byte{byte(self)} } -func (self UInt8) Equals(other Sortable) bool { +func (self UInt8) Equals(other Key) bool { if o, ok := other.(UInt8); ok { return self == o } else { @@ -42,7 +45,7 @@ func (self UInt8) Equals(other Sortable) bool { } } -func (self UInt8) Less(other Sortable) bool { +func (self UInt8) Less(other Key) bool { if o, ok := other.(UInt8); ok { return self < o } else { @@ -50,12 +53,12 @@ func (self UInt8) Less(other Sortable) bool { } } -func (self UInt8) Hash() int { - return int(self) +func (self UInt8) Bytes() []byte { + return []byte{byte(self)} } -func (self Int16) Equals(other Sortable) bool { +func (self Int16) Equals(other Key) bool { if o, ok := other.(Int16); ok { return self == o } else { @@ -63,7 +66,7 @@ func (self Int16) Equals(other Sortable) bool { } } -func (self Int16) Less(other Sortable) bool { +func (self Int16) Less(other Key) bool { if o, ok := other.(Int16); ok { return self < o } else { @@ -71,12 +74,14 @@ func (self Int16) Less(other Sortable) bool { } } -func (self Int16) Hash() int { - return int(self) +func (self Int16) Bytes() []byte { + b := [2]byte{} + binary.LittleEndian.PutUint16(b[:], uint16(self)) + return b[:] } -func (self UInt16) Equals(other Sortable) bool { +func (self UInt16) Equals(other Key) bool { if o, ok := other.(UInt16); ok { return self == o } else { @@ -84,7 +89,7 @@ func (self UInt16) Equals(other Sortable) bool { } } -func (self UInt16) Less(other Sortable) bool { +func (self UInt16) Less(other Key) bool { if o, ok := other.(UInt16); ok { return self < o } else { @@ -92,12 +97,14 @@ func (self UInt16) Less(other Sortable) bool { } } -func (self UInt16) Hash() int { - return int(self) +func (self UInt16) Bytes() []byte { + b := [2]byte{} + binary.LittleEndian.PutUint16(b[:], uint16(self)) + return b[:] } -func (self Int32) Equals(other Sortable) bool { +func (self Int32) Equals(other Key) bool { if o, ok := other.(Int32); ok { return self == o } else { @@ -105,7 +112,7 @@ func (self Int32) Equals(other Sortable) bool { } } -func (self Int32) Less(other Sortable) bool { +func (self Int32) Less(other Key) bool { if o, ok := other.(Int32); ok { return self < o } else { @@ -113,12 +120,14 @@ func (self Int32) Less(other Sortable) bool { } } -func (self Int32) Hash() int { - return int(self) +func (self Int32) Bytes() []byte { + b := [4]byte{} + binary.LittleEndian.PutUint32(b[:], uint32(self)) + return b[:] } -func (self UInt32) Equals(other Sortable) bool { +func (self UInt32) Equals(other Key) bool { if o, ok := other.(UInt32); ok { return self == o } else { @@ -126,7 +135,7 @@ func (self UInt32) Equals(other Sortable) bool { } } -func (self UInt32) Less(other Sortable) bool { +func (self UInt32) Less(other Key) bool { if o, ok := other.(UInt32); ok { return self < o } else { @@ -134,12 +143,14 @@ func (self UInt32) Less(other Sortable) bool { } } -func (self UInt32) Hash() int { - return int(self) +func (self UInt32) Bytes() []byte { + b := [4]byte{} + binary.LittleEndian.PutUint32(b[:], uint32(self)) + return b[:] } -func (self Int64) Equals(other Sortable) bool { +func (self Int64) Equals(other Key) bool { if o, ok := other.(Int64); ok { return self == o } else { @@ -147,7 +158,7 @@ func (self Int64) Equals(other Sortable) bool { } } -func (self Int64) Less(other Sortable) bool { +func (self Int64) Less(other Key) bool { if o, ok := other.(Int64); ok { return self < o } else { @@ -155,12 +166,14 @@ func (self Int64) Less(other Sortable) bool { } } -func (self Int64) Hash() int { - return int(self>>32) ^ int(self) +func (self Int64) Bytes() []byte { + b := [8]byte{} + binary.LittleEndian.PutUint64(b[:], uint64(self)) + return b[:] } -func (self UInt64) Equals(other Sortable) bool { +func (self UInt64) Equals(other Key) bool { if o, ok := other.(UInt64); ok { return self == o } else { @@ -168,7 +181,7 @@ func (self UInt64) Equals(other Sortable) bool { } } -func (self UInt64) Less(other Sortable) bool { +func (self UInt64) Less(other Key) bool { if o, ok := other.(UInt64); ok { return self < o } else { @@ -176,12 +189,14 @@ func (self UInt64) Less(other Sortable) bool { } } -func (self UInt64) Hash() int { - return int(self>>32) ^ int(self) +func (self UInt64) Bytes() []byte { + b := [8]byte{} + binary.LittleEndian.PutUint64(b[:], uint64(self)) + return b[:] } -func (self Int) Equals(other Sortable) bool { +func (self Int) Equals(other Key) bool { if o, ok := other.(Int); ok { return self == o } else { @@ -189,7 +204,7 @@ func (self Int) Equals(other Sortable) bool { } } -func (self Int) Less(other Sortable) bool { +func (self Int) Less(other Key) bool { if o, ok := other.(Int); ok { return self < o } else { @@ -197,12 +212,14 @@ func (self Int) Less(other Sortable) bool { } } -func (self Int) Hash() int { - return int(self) +func (self Int) Bytes() []byte { + b := [8]byte{} + binary.LittleEndian.PutUint64(b[:], uint64(self)) + return b[:] } -func (self UInt) Equals(other Sortable) bool { +func (self UInt) Equals(other Key) bool { if o, ok := other.(UInt); ok { return self == o } else { @@ -210,7 +227,7 @@ func (self UInt) Equals(other Sortable) bool { } } -func (self UInt) Less(other Sortable) bool { +func (self UInt) Less(other Key) bool { if o, ok := other.(UInt); ok { return self < o } else { @@ -218,8 +235,8 @@ func (self UInt) Less(other Sortable) bool { } } -func (self UInt) Hash() int { - return int(self) +func (self UInt) Bytes() []byte { + b := [8]byte{} + binary.LittleEndian.PutUint64(b[:], uint64(self)) + return b[:] } - - diff --git a/merkle/string.go b/merkle/string.go index e497f0282..b09ab4a19 100644 --- a/merkle/string.go +++ b/merkle/string.go @@ -5,7 +5,7 @@ import "bytes" type String string type ByteSlice []byte -func (self String) Equals(other Sortable) bool { +func (self String) Equals(other Key) bool { if o, ok := other.(String); ok { return self == o } else { @@ -13,7 +13,7 @@ func (self String) Equals(other Sortable) bool { } } -func (self String) Less(other Sortable) bool { +func (self String) Less(other Key) bool { if o, ok := other.(String); ok { return self < o } else { @@ -21,16 +21,11 @@ func (self String) Less(other Sortable) bool { } } -func (self String) Hash() int { - bytes := []byte(self) - hash := 0 - for i, c := range bytes { - hash += (i+1)*int(c) - } - return hash +func (self String) Bytes() []byte { + return []byte(self) } -func (self ByteSlice) Equals(other Sortable) bool { +func (self ByteSlice) Equals(other Key) bool { if o, ok := other.(ByteSlice); ok { return bytes.Equal(self, o) } else { @@ -38,7 +33,7 @@ func (self ByteSlice) Equals(other Sortable) bool { } } -func (self ByteSlice) Less(other Sortable) bool { +func (self ByteSlice) Less(other Key) bool { if o, ok := other.(ByteSlice); ok { return bytes.Compare(self, o) < 0 // -1 if a < b } else { @@ -46,12 +41,6 @@ func (self ByteSlice) Less(other Sortable) bool { } } -func (self ByteSlice) Hash() int { - hash := 0 - for i, c := range self { - hash += (i+1)*int(c) - } - return hash +func (self ByteSlice) Bytes() []byte { + return []byte(self) } - - diff --git a/merkle/types.go b/merkle/types.go index a7bb33ee9..625af68ee 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -4,38 +4,45 @@ import ( "fmt" ) -type Sortable interface { - Equals(b Sortable) bool - Less(b Sortable) bool +type Value interface { + Bytes() []byte +} + +type Key interface { + Equals(b Key) bool + Less(b Key) bool + Bytes() []byte } type Tree interface { Root() Node Size() int - Has(key Sortable) bool - Get(key Sortable) (value interface{}, err error) + Has(key Key) bool + Get(key Key) (value Value, err error) + Hash() []byte - Put(key Sortable, value interface{}) (err error) - Remove(key Sortable) (value interface{}, err error) + Put(key Key, value Value) (err error) + Remove(key Key) (value Value, err error) } type Node interface { - Key() Sortable - Value() interface{} - Left() Node + Key() Key + Value() Value + Left() Node Right() Node Size() int - Has(key Sortable) bool - Get(key Sortable) (value interface{}, err error) + Has(key Key) bool + Get(key Key) (value Value, err error) + Hash() []byte - Put(key Sortable, value interface{}) (_ *IAVLNode, updated bool) - Remove(key Sortable) (_ *IAVLNode, value interface{}, err error) + Put(key Key, value Value) (_ *IAVLNode, updated bool) + Remove(key Key) (_ *IAVLNode, value Value, err error) } type NodeIterator func() (node Node, next NodeIterator) -func NotFound(key Sortable) error { +func NotFound(key Key) error { return fmt.Errorf("Key was not found.") }