From fdf0e720bcfac92920a8cd228fa7eb655e3249d5 Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Sun, 24 May 2015 14:19:46 -0700 Subject: [PATCH] Merkle proofs! --- binary/int.go | 21 +++--- binary/int_test.go | 2 +- common/random.go | 4 ++ common/test/mutate.go | 28 ++++++++ merkle/iavl_node.go | 162 ++++++++++++++++++++++++++++-------------- merkle/iavl_proof.go | 137 +++++++++++++++++++++++++++++++++++ merkle/iavl_test.go | 76 ++++++++++++++++++-- merkle/iavl_tree.go | 11 +-- merkle/types.go | 8 +-- 9 files changed, 371 insertions(+), 78 deletions(-) create mode 100644 common/test/mutate.go create mode 100644 merkle/iavl_proof.go diff --git a/binary/int.go b/binary/int.go index 8b4a36d24..90f3f4323 100644 --- a/binary/int.go +++ b/binary/int.go @@ -160,6 +160,9 @@ func ReadUint64(r io.Reader, n *int64, err *error) uint64 { func uvarintSize(i_ uint) int { i := uint64(i_) + if i == 0 { + return 0 + } if i < 1<<8 { return 1 } @@ -197,9 +200,11 @@ func WriteVarint(i int, w io.Writer, n *int64, err *error) { } else { WriteUint8(uint8(size), w, n, err) } - buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf, uint64(i)) - WriteTo(buf[(8-size):], w, n, err) + if size > 0 { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(i)) + WriteTo(buf[(8-size):], w, n, err) + } *n += int64(1 + size) } @@ -215,7 +220,6 @@ func ReadVarint(r io.Reader, n *int64, err *error) int { return 0 } if size == 0 { - setFirstErr(err, errors.New("Varint underflow")) return 0 } buf := make([]byte, 8) @@ -234,9 +238,11 @@ func ReadVarint(r io.Reader, n *int64, err *error) int { func WriteUvarint(i uint, w io.Writer, n *int64, err *error) { var size = uvarintSize(i) WriteUint8(uint8(size), w, n, err) - buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf, uint64(i)) - WriteTo(buf[(8-size):], w, n, err) + if size > 0 { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(i)) + WriteTo(buf[(8-size):], w, n, err) + } *n += int64(1 + size) } @@ -247,7 +253,6 @@ func ReadUvarint(r io.Reader, n *int64, err *error) uint { return 0 } if size == 0 { - setFirstErr(err, errors.New("Uvarint underflow")) return 0 } buf := make([]byte, 8) diff --git a/binary/int_test.go b/binary/int_test.go index 9d472b3f9..49348fae7 100644 --- a/binary/int_test.go +++ b/binary/int_test.go @@ -32,7 +32,7 @@ func TestVarint(t *testing.T) { // Near zero check(-1, "F101") - check(0, "0100") + check(0, "00") check(1, "0101") // Positives check(1<<32-1, "04FFFFFFFF") diff --git a/common/random.go b/common/random.go index 12ce8a24d..62a509220 100644 --- a/common/random.go +++ b/common/random.go @@ -62,6 +62,10 @@ func RandUint() uint { return uint(rand.Int()) } +func RandInt() int { + return rand.Int() +} + // Distributed pseudo-exponentially to test for various cases func RandUint16Exp() uint16 { bits := rand.Uint32() % 16 diff --git a/common/test/mutate.go b/common/test/mutate.go new file mode 100644 index 000000000..39bf90557 --- /dev/null +++ b/common/test/mutate.go @@ -0,0 +1,28 @@ +package test + +import ( + . "github.com/tendermint/tendermint/common" +) + +// Contract: !bytes.Equal(input, output) && len(input) >= len(output) +func MutateByteSlice(bytez []byte) []byte { + // If bytez is empty, panic + if len(bytez) == 0 { + panic("Cannot mutate an empty bytez") + } + + // Copy bytez + mBytez := make([]byte, len(bytez)) + copy(mBytez, bytez) + bytez = mBytez + + // Try a random mutation + switch RandInt() % 2 { + case 0: // Mutate a single byte + bytez[RandInt()%len(bytez)] += byte(RandInt()%255 + 1) + case 1: // Remove an arbitrary byte + pos := RandInt() % len(bytez) + bytez = append(bytez[:pos], bytez[pos+1:]...) + } + return bytez +} diff --git a/merkle/iavl_node.go b/merkle/iavl_node.go index 6d7d5d634..086283a53 100644 --- a/merkle/iavl_node.go +++ b/merkle/iavl_node.go @@ -1,6 +1,7 @@ package merkle import ( + "bytes" "crypto/sha256" "io" @@ -12,8 +13,8 @@ import ( type IAVLNode struct { key interface{} value interface{} - size uint64 height uint8 + size uint hash []byte leftHash []byte leftNode *IAVLNode @@ -24,27 +25,28 @@ type IAVLNode struct { func NewIAVLNode(key interface{}, value interface{}) *IAVLNode { return &IAVLNode{ - key: key, - value: value, - size: 1, + key: key, + value: value, + height: 0, + size: 1, } } +// NOTE: The hash is not saved or set. The caller should set the hash afterwards. +// (Presumably the caller already has the hash) func ReadIAVLNode(t *IAVLTree, r io.Reader, n *int64, err *error) *IAVLNode { node := &IAVLNode{} - // node header & key + // node header node.height = binary.ReadUint8(r, n, err) - node.size = binary.ReadUint64(r, n, err) - node.key = t.keyCodec.Decode(r, n, err) - if *err != nil { - panic(*err) - } + node.size = binary.ReadUvarint(r, n, err) + node.key = decodeByteSlice(t.keyCodec, r, n, err) - // node value or children. if node.height == 0 { - node.value = t.valueCodec.Decode(r, n, err) + // value + node.value = decodeByteSlice(t.valueCodec, r, n, err) } else { + // children node.leftHash = binary.ReadByteSlice(r, n, err) node.rightHash = binary.ReadByteSlice(r, n, err) } @@ -60,8 +62,8 @@ func (node *IAVLNode) _copy() *IAVLNode { } return &IAVLNode{ key: node.key, - size: node.size, height: node.height, + size: node.size, hash: nil, // Going to be mutated anyways. leftHash: node.leftHash, leftNode: node.leftNode, @@ -86,7 +88,7 @@ func (node *IAVLNode) has(t *IAVLTree, key interface{}) (has bool) { } } -func (node *IAVLNode) get(t *IAVLTree, key interface{}) (index uint64, value interface{}) { +func (node *IAVLNode) get(t *IAVLTree, key interface{}) (index uint, value interface{}) { if node.height == 0 { if t.keyCodec.Compare(node.key, key) == 0 { return 0, node.value @@ -105,7 +107,7 @@ func (node *IAVLNode) get(t *IAVLTree, key interface{}) (index uint64, value int } } -func (node *IAVLNode) getByIndex(t *IAVLTree, index uint64) (key interface{}, value interface{}) { +func (node *IAVLNode) getByIndex(t *IAVLTree, index uint) (key interface{}, value interface{}) { if node.height == 0 { if index == 0 { return node.key, node.value @@ -125,21 +127,61 @@ func (node *IAVLNode) getByIndex(t *IAVLTree, index uint64) (key interface{}, va } // NOTE: sets hashes recursively -func (node *IAVLNode) hashWithCount(t *IAVLTree) ([]byte, uint64) { +func (node *IAVLNode) hashWithCount(t *IAVLTree) ([]byte, uint) { if node.hash != nil { return node.hash, 0 } hasher := sha256.New() - _, hashCount, err := node.writeToCountHashes(t, hasher) + buf := new(bytes.Buffer) + _, hashCount, err := node.writeHashBytes(t, buf) if err != nil { panic(err) } + // fmt.Printf("Wrote IAVL hash bytes: %X\n", buf.Bytes()) + hasher.Write(buf.Bytes()) node.hash = hasher.Sum(nil) + // fmt.Printf("Write IAVL hash: %X\n", node.hash) return node.hash, hashCount + 1 } +// NOTE: sets hashes recursively +func (node *IAVLNode) writeHashBytes(t *IAVLTree, w io.Writer) (n int64, hashCount uint, err error) { + // height & size + binary.WriteUint8(node.height, w, &n, &err) + binary.WriteUvarint(node.size, w, &n, &err) + // key is not written for inner nodes, unlike writePersistBytes + + if node.height == 0 { + // key & value + encodeByteSlice(node.key, t.keyCodec, w, &n, &err) + encodeByteSlice(node.value, t.valueCodec, w, &n, &err) + } else { + // left + if node.leftNode != nil { + leftHash, leftCount := node.leftNode.hashWithCount(t) + node.leftHash = leftHash + hashCount += leftCount + } + if node.leftHash == nil { + panic("node.leftHash was nil in writeHashBytes") + } + binary.WriteByteSlice(node.leftHash, w, &n, &err) + // right + if node.rightNode != nil { + rightHash, rightCount := node.rightNode.hashWithCount(t) + node.rightHash = rightHash + hashCount += rightCount + } + if node.rightHash == nil { + panic("node.rightHash was nil in writeHashBytes") + } + binary.WriteByteSlice(node.rightHash, w, &n, &err) + } + return +} + // NOTE: sets hashes recursively // NOTE: clears leftNode/rightNode recursively func (node *IAVLNode) save(t *IAVLTree) []byte { @@ -165,6 +207,32 @@ func (node *IAVLNode) save(t *IAVLTree) []byte { return node.hash } +// NOTE: sets hashes recursively +func (node *IAVLNode) writePersistBytes(t *IAVLTree, w io.Writer) (n int64, err error) { + // node header + binary.WriteUint8(node.height, w, &n, &err) + binary.WriteUvarint(node.size, w, &n, &err) + // key (unlike writeHashBytes, key is written for inner nodes) + encodeByteSlice(node.key, t.keyCodec, w, &n, &err) + + if node.height == 0 { + // value + encodeByteSlice(node.value, t.valueCodec, w, &n, &err) + } else { + // left + if node.leftHash == nil { + panic("node.leftHash was nil in writePersistBytes") + } + binary.WriteByteSlice(node.leftHash, w, &n, &err) + // right + if node.rightHash == nil { + panic("node.rightHash was nil in writePersistBytes") + } + binary.WriteByteSlice(node.rightHash, w, &n, &err) + } + return +} + func (node *IAVLNode) set(t *IAVLTree, key interface{}, value interface{}) (newSelf *IAVLNode, updated bool) { if node.height == 0 { cmp := t.keyCodec.Compare(key, node.key) @@ -251,44 +319,6 @@ func (node *IAVLNode) remove(t *IAVLTree, key interface{}) ( } } -// NOTE: sets hashes recursively -func (node *IAVLNode) writeToCountHashes(t *IAVLTree, w io.Writer) (n int64, hashCount uint64, err error) { - // height & size & key - binary.WriteUint8(node.height, w, &n, &err) - binary.WriteUint64(node.size, w, &n, &err) - t.keyCodec.Encode(node.key, w, &n, &err) - if err != nil { - return - } - - if node.height == 0 { - // value - t.valueCodec.Encode(node.value, w, &n, &err) - } else { - // left - if node.leftNode != nil { - leftHash, leftCount := node.leftNode.hashWithCount(t) - node.leftHash = leftHash - hashCount += leftCount - } - if node.leftHash == nil { - panic("node.leftHash was nil in save") - } - binary.WriteByteSlice(node.leftHash, w, &n, &err) - // right - if node.rightNode != nil { - rightHash, rightCount := node.rightNode.hashWithCount(t) - node.rightHash = rightHash - hashCount += rightCount - } - if node.rightHash == nil { - panic("node.rightHash was nil in save") - } - binary.WriteByteSlice(node.rightHash, w, &n, &err) - } - return -} - func (node *IAVLNode) getLeftNode(t *IAVLTree) *IAVLNode { if node.leftNode != nil { return node.leftNode @@ -406,3 +436,25 @@ func (node *IAVLNode) rmd(t *IAVLTree) *IAVLNode { } return node.getRightNode(t).rmd(t) } + +//-------------------------------------------------------------------------------- + +// Read a (length prefixed) byteslice then decode the object using the codec +func decodeByteSlice(codec binary.Codec, r io.Reader, n *int64, err *error) interface{} { + bytez := binary.ReadByteSlice(r, n, err) + if *err != nil { + return nil + } + n_ := new(int64) + return codec.Decode(bytes.NewBuffer(bytez), n_, err) +} + +// Encode object using codec, then write a (length prefixed) byteslice. +func encodeByteSlice(o interface{}, codec binary.Codec, w io.Writer, n *int64, err *error) { + buf, n_ := new(bytes.Buffer), new(int64) + codec.Encode(o, buf, n_, err) + if *err != nil { + return + } + binary.WriteByteSlice(buf.Bytes(), w, n, err) +} diff --git a/merkle/iavl_proof.go b/merkle/iavl_proof.go new file mode 100644 index 000000000..4d671cf66 --- /dev/null +++ b/merkle/iavl_proof.go @@ -0,0 +1,137 @@ +package merkle + +import ( + "bytes" + "crypto/sha256" + + "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/common" +) + +type IAVLProof struct { + Root []byte + Branches []IAVLProofBranch + Leaf IAVLProofLeaf +} + +func (proof *IAVLProof) Verify() bool { + hash := proof.Leaf.Hash() + // fmt.Printf("leaf hash: %X\n", hash) + for i := len(proof.Branches) - 1; 0 <= i; i-- { + hash = proof.Branches[i].Hash(hash) + // fmt.Printf("branch hash: %X\n", hash) + } + // fmt.Printf("root: %X, computed: %X\n", proof.Root, hash) + return bytes.Equal(proof.Root, hash) +} + +type IAVLProofBranch struct { + Height uint8 + Size uint + Left []byte + Right []byte +} + +func (branch IAVLProofBranch) Hash(childHash []byte) []byte { + hasher := sha256.New() + buf := new(bytes.Buffer) + n, err := int64(0), error(nil) + binary.WriteUint8(branch.Height, buf, &n, &err) + binary.WriteUvarint(branch.Size, buf, &n, &err) + if branch.Left == nil { + binary.WriteByteSlice(childHash, buf, &n, &err) + binary.WriteByteSlice(branch.Right, buf, &n, &err) + } else { + binary.WriteByteSlice(branch.Left, buf, &n, &err) + binary.WriteByteSlice(childHash, buf, &n, &err) + } + if err != nil { + panic(Fmt("Failed to hash IAVLProofBranch: %v", err)) + } + // fmt.Printf("Branch hash bytes: %X\n", buf.Bytes()) + hasher.Write(buf.Bytes()) + return hasher.Sum(nil) +} + +type IAVLProofLeaf struct { + KeyBytes []byte + ValueBytes []byte +} + +func (leaf IAVLProofLeaf) Hash() []byte { + hasher := sha256.New() + buf := new(bytes.Buffer) + n, err := int64(0), error(nil) + binary.WriteUint8(0, buf, &n, &err) + binary.WriteUvarint(1, buf, &n, &err) + binary.WriteByteSlice(leaf.KeyBytes, buf, &n, &err) + binary.WriteByteSlice(leaf.ValueBytes, buf, &n, &err) + if err != nil { + panic(Fmt("Failed to hash IAVLProofLeaf: %v", err)) + } + // fmt.Printf("Leaf hash bytes: %X\n", buf.Bytes()) + hasher.Write(buf.Bytes()) + return hasher.Sum(nil) +} + +func (node *IAVLNode) constructProof(t *IAVLTree, key interface{}, proof *IAVLProof) (exists bool) { + if node.height == 0 { + if t.keyCodec.Compare(node.key, key) == 0 { + keyBuf, valueBuf := new(bytes.Buffer), new(bytes.Buffer) + n, err := int64(0), error(nil) + t.keyCodec.Encode(node.key, keyBuf, &n, &err) + if err != nil { + panic(Fmt("Failed to encode node.key: %v", err)) + } + t.valueCodec.Encode(node.value, valueBuf, &n, &err) + if err != nil { + panic(Fmt("Failed to encode node.value: %v", err)) + } + leaf := IAVLProofLeaf{ + KeyBytes: keyBuf.Bytes(), + ValueBytes: valueBuf.Bytes(), + } + proof.Leaf = leaf + return true + } else { + return false + } + } else { + if t.keyCodec.Compare(key, node.key) < 0 { + branch := IAVLProofBranch{ + Height: node.height, + Size: node.size, + Left: nil, + Right: node.getRightNode(t).hash, + } + if node.getRightNode(t).hash == nil { + // fmt.Println(node.getRightNode(t)) + panic("WTF") + } + proof.Branches = append(proof.Branches, branch) + return node.getLeftNode(t).constructProof(t, key, proof) + } else { + branch := IAVLProofBranch{ + Height: node.height, + Size: node.size, + Left: node.getLeftNode(t).hash, + Right: nil, + } + proof.Branches = append(proof.Branches, branch) + return node.getRightNode(t).constructProof(t, key, proof) + } + } +} + +// Returns nil if key is not in tree. +func (t *IAVLTree) ConstructProof(key interface{}) *IAVLProof { + if t.root == nil { + return nil + } + t.root.hashWithCount(t) // Ensure that all hashes are calculated. + proof := &IAVLProof{ + Root: t.root.hash, + } + t.root.constructProof(t, key, proof) + return proof +} diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index e6926c9a3..4e87505a6 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -6,6 +6,7 @@ import ( "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/common" + . "github.com/tendermint/tendermint/common/test" "github.com/tendermint/tendermint/db" "runtime" @@ -59,7 +60,7 @@ func P(n *IAVLNode) string { func TestUnit(t *testing.T) { - expectHash := func(tree *IAVLTree, hashCount uint64) { + expectHash := func(tree *IAVLTree, hashCount uint) { // ensure number of new hash calculations is as expected. hash, count := tree.HashWithCount() if count != hashCount { @@ -77,7 +78,7 @@ func TestUnit(t *testing.T) { } } - expectSet := func(tree *IAVLTree, i int, repr string, hashCount uint64) { + expectSet := func(tree *IAVLTree, i int, repr string, hashCount uint) { origNode := tree.root updated := tree.Set(i, "") // ensure node was added & structure is as expected. @@ -90,7 +91,7 @@ func TestUnit(t *testing.T) { tree.root = origNode } - expectRemove := func(tree *IAVLTree, i int, repr string, hashCount uint64) { + expectRemove := func(tree *IAVLTree, i int, repr string, hashCount uint) { origNode := tree.root value, removed := tree.Remove(i) // ensure node was added & structure is as expected. @@ -167,7 +168,7 @@ func TestIntegration(t *testing.T) { if !updated { t.Error("should have been updated") } - if tree.Size() != uint64(i+1) { + if tree.Size() != uint(i+1) { t.Error("size was wrong", tree.Size(), i+1) } } @@ -202,7 +203,7 @@ func TestIntegration(t *testing.T) { t.Error("wrong value") } } - if tree.Size() != uint64(len(records)-(i+1)) { + if tree.Size() != uint(len(records)-(i+1)) { t.Error("size was wrong", tree.Size(), (len(records) - (i + 1))) } } @@ -237,6 +238,71 @@ func TestPersistence(t *testing.T) { } } +func testProof(t *testing.T, proof *IAVLProof) { + // Proof must verify. + if !proof.Verify() { + t.Errorf("Invalid proof. Verification failed.") + return + } + // Write/Read then verify. + proofBytes := binary.BinaryBytes(proof) + n, err := int64(0), error(nil) + proof2 := binary.ReadBinary(&IAVLProof{}, bytes.NewBuffer(proofBytes), &n, &err).(*IAVLProof) + if err != nil { + t.Errorf("Failed to read IAVLProof from bytes: %v", err) + return + } + if !proof2.Verify() { + t.Errorf("Invalid proof after write/read. Verification failed.") + return + } + // Random mutations must not verify + for i := 0; i < 3; i++ { + badProofBytes := MutateByteSlice(proofBytes) + n, err := int64(0), error(nil) + badProof := binary.ReadBinary(&IAVLProof{}, bytes.NewBuffer(badProofBytes), &n, &err).(*IAVLProof) + if err != nil { + continue // This is fine. + } + if badProof.Verify() { + t.Errorf("Proof was still valid after a random mutation:\n%X\n%X", proofBytes, badProofBytes) + } + } +} + +func TestConstructProof(t *testing.T) { + // Construct some random tree + db := db.NewMemDB() + var tree *IAVLTree = NewIAVLTree(binary.BasicCodec, binary.BasicCodec, 100, db) + for i := 0; i < 1000; i++ { + key, value := randstr(20), randstr(20) + tree.Set(key, value) + } + + // Persist the items so far + tree.Save() + + // Add more items so it's not all persisted + for i := 0; i < 100; i++ { + key, value := randstr(20), randstr(20) + tree.Set(key, value) + } + + // Now for each item, construct a proof and verify + tree.Iterate(func(key interface{}, value interface{}) bool { + proof := tree.ConstructProof(key) + if !bytes.Equal(proof.Root, tree.Hash()) { + t.Errorf("Invalid proof. Expected root %X, got %X", tree.Hash(), proof.Root) + } + if !proof.Verify() { + t.Errorf("Invalid proof. Verification failed.") + } + testProof(t, proof) + return false + }) + +} + func BenchmarkImmutableAvlTree(b *testing.B) { b.StopTimer() diff --git a/merkle/iavl_tree.go b/merkle/iavl_tree.go index 725d9a6f6..66c2ea919 100644 --- a/merkle/iavl_tree.go +++ b/merkle/iavl_tree.go @@ -68,7 +68,7 @@ func (t *IAVLTree) Copy() Tree { } } -func (t *IAVLTree) Size() uint64 { +func (t *IAVLTree) Size() uint { if t.root == nil { return 0 } @@ -106,7 +106,7 @@ func (t *IAVLTree) Hash() []byte { return hash } -func (t *IAVLTree) HashWithCount() ([]byte, uint64) { +func (t *IAVLTree) HashWithCount() ([]byte, uint) { if t.root == nil { return nil, 0 } @@ -130,14 +130,14 @@ func (t *IAVLTree) Load(hash []byte) { } } -func (t *IAVLTree) Get(key interface{}) (index uint64, value interface{}) { +func (t *IAVLTree) Get(key interface{}) (index uint, value interface{}) { if t.root == nil { return 0, nil } return t.root.get(t, key) } -func (t *IAVLTree) GetByIndex(index uint64) (key interface{}, value interface{}) { +func (t *IAVLTree) GetByIndex(index uint) (key interface{}, value interface{}) { if t.root == nil { return nil, nil } @@ -220,6 +220,7 @@ func (ndb *nodeDB) GetNode(t *IAVLTree, hash []byte) *IAVLNode { if err != nil { panic(Fmt("Error reading IAVLNode. bytes: %X error: %v", buf, err)) } + node.hash = hash node.persisted = true ndb.cacheNode(node) return node @@ -240,7 +241,7 @@ func (ndb *nodeDB) SaveNode(t *IAVLTree, node *IAVLNode) { }*/ // Save node bytes to db buf := bytes.NewBuffer(nil) - _, _, err := node.writeToCountHashes(t, buf) + _, err := node.writePersistBytes(t, buf) if err != nil { panic(err) } diff --git a/merkle/types.go b/merkle/types.go index bba0eee11..68a461310 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -1,14 +1,14 @@ package merkle type Tree interface { - Size() (size uint64) + Size() (size uint) Height() (height uint8) Has(key interface{}) (has bool) - Get(key interface{}) (index uint64, value interface{}) - GetByIndex(index uint64) (key interface{}, value interface{}) + Get(key interface{}) (index uint, value interface{}) + GetByIndex(index uint) (key interface{}, value interface{}) Set(key interface{}, value interface{}) (updated bool) Remove(key interface{}) (value interface{}, removed bool) - HashWithCount() (hash []byte, count uint64) + HashWithCount() (hash []byte, count uint) Hash() (hash []byte) Save() (hash []byte) Load(hash []byte)