From b164c05b2701863f0b97ab87f763b68a8633a70e Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Wed, 17 Jun 2015 19:04:31 -0700 Subject: [PATCH] Refactor IAVLProof and SimpleProof to use similar Verify() --- merkle/iavl_proof.go | 73 +++++++------ merkle/iavl_test.go | 32 +++--- merkle/simple_tree.go | 205 ++++++++++++++++++++++--------------- merkle/simple_tree_test.go | 61 ++++++----- 4 files changed, 222 insertions(+), 149 deletions(-) diff --git a/merkle/iavl_proof.go b/merkle/iavl_proof.go index 4d671cf66..b3e3b2a04 100644 --- a/merkle/iavl_proof.go +++ b/merkle/iavl_proof.go @@ -9,30 +9,39 @@ import ( ) type IAVLProof struct { - Root []byte - Branches []IAVLProofBranch - Leaf IAVLProofLeaf + LeafNode IAVLProofLeafNode + InnerNodes []IAVLProofInnerNode + RootHash []byte } -func (proof *IAVLProof) Verify() bool { - hash := proof.Leaf.Hash() +func (proof *IAVLProof) Verify(keyBytes, valueBytes, rootHash []byte) bool { + if !bytes.Equal(keyBytes, proof.LeafNode.KeyBytes) { + return false + } + if !bytes.Equal(valueBytes, proof.LeafNode.ValueBytes) { + return false + } + if !bytes.Equal(rootHash, proof.RootHash) { + return false + } + hash := proof.LeafNode.Hash() // fmt.Printf("leaf hash: %X\n", hash) - for i := len(proof.Branches) - 1; 0 <= i; i-- { - hash = proof.Branches[i].Hash(hash) + for _, branch := range proof.InnerNodes { + hash = branch.Hash(hash) // fmt.Printf("branch hash: %X\n", hash) } - // fmt.Printf("root: %X, computed: %X\n", proof.Root, hash) - return bytes.Equal(proof.Root, hash) + // fmt.Printf("root: %X, computed: %X\n", proof.RootHash, hash) + return bytes.Equal(proof.RootHash, hash) } -type IAVLProofBranch struct { +type IAVLProofInnerNode struct { Height uint8 Size uint Left []byte Right []byte } -func (branch IAVLProofBranch) Hash(childHash []byte) []byte { +func (branch IAVLProofInnerNode) Hash(childHash []byte) []byte { hasher := sha256.New() buf := new(bytes.Buffer) n, err := int64(0), error(nil) @@ -46,19 +55,19 @@ func (branch IAVLProofBranch) Hash(childHash []byte) []byte { binary.WriteByteSlice(childHash, buf, &n, &err) } if err != nil { - panic(Fmt("Failed to hash IAVLProofBranch: %v", err)) + panic(Fmt("Failed to hash IAVLProofInnerNode: %v", err)) } - // fmt.Printf("Branch hash bytes: %X\n", buf.Bytes()) + // fmt.Printf("InnerNode hash bytes: %X\n", buf.Bytes()) hasher.Write(buf.Bytes()) return hasher.Sum(nil) } -type IAVLProofLeaf struct { +type IAVLProofLeafNode struct { KeyBytes []byte ValueBytes []byte } -func (leaf IAVLProofLeaf) Hash() []byte { +func (leaf IAVLProofLeafNode) Hash() []byte { hasher := sha256.New() buf := new(bytes.Buffer) n, err := int64(0), error(nil) @@ -67,9 +76,9 @@ func (leaf IAVLProofLeaf) Hash() []byte { binary.WriteByteSlice(leaf.KeyBytes, buf, &n, &err) binary.WriteByteSlice(leaf.ValueBytes, buf, &n, &err) if err != nil { - panic(Fmt("Failed to hash IAVLProofLeaf: %v", err)) + panic(Fmt("Failed to hash IAVLProofLeafNode: %v", err)) } - // fmt.Printf("Leaf hash bytes: %X\n", buf.Bytes()) + // fmt.Printf("LeafNode hash bytes: %X\n", buf.Bytes()) hasher.Write(buf.Bytes()) return hasher.Sum(nil) } @@ -87,38 +96,42 @@ func (node *IAVLNode) constructProof(t *IAVLTree, key interface{}, proof *IAVLPr if err != nil { panic(Fmt("Failed to encode node.value: %v", err)) } - leaf := IAVLProofLeaf{ + leaf := IAVLProofLeafNode{ KeyBytes: keyBuf.Bytes(), ValueBytes: valueBuf.Bytes(), } - proof.Leaf = leaf + proof.LeafNode = leaf return true } else { return false } } else { if t.keyCodec.Compare(key, node.key) < 0 { - branch := IAVLProofBranch{ + exists := node.getLeftNode(t).constructProof(t, key, proof) + if !exists { + return false + } + branch := IAVLProofInnerNode{ Height: node.height, Size: node.size, Left: nil, Right: node.getRightNode(t).hash, } - if node.getRightNode(t).hash == nil { - // fmt.Println(node.getRightNode(t)) - panic("WTF") - } - proof.Branches = append(proof.Branches, branch) - return node.getLeftNode(t).constructProof(t, key, proof) + proof.InnerNodes = append(proof.InnerNodes, branch) + return true } else { - branch := IAVLProofBranch{ + exists := node.getRightNode(t).constructProof(t, key, proof) + if !exists { + return false + } + branch := IAVLProofInnerNode{ Height: node.height, Size: node.size, Left: node.getLeftNode(t).hash, Right: nil, } - proof.Branches = append(proof.Branches, branch) - return node.getRightNode(t).constructProof(t, key, proof) + proof.InnerNodes = append(proof.InnerNodes, branch) + return true } } } @@ -130,7 +143,7 @@ func (t *IAVLTree) ConstructProof(key interface{}) *IAVLProof { } t.root.hashWithCount(t) // Ensure that all hashes are calculated. proof := &IAVLProof{ - Root: t.root.hash, + RootHash: t.root.hash, } t.root.constructProof(t, key, proof) return proof diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index 4e87505a6..e92bf4dd9 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -238,9 +238,9 @@ func TestPersistence(t *testing.T) { } } -func testProof(t *testing.T, proof *IAVLProof) { +func testProof(t *testing.T, proof *IAVLProof, keyBytes, valueBytes, rootHash []byte) { // Proof must verify. - if !proof.Verify() { + if !proof.Verify(keyBytes, valueBytes, rootHash) { t.Errorf("Invalid proof. Verification failed.") return } @@ -252,25 +252,36 @@ func testProof(t *testing.T, proof *IAVLProof) { t.Errorf("Failed to read IAVLProof from bytes: %v", err) return } - if !proof2.Verify() { + if !proof2.Verify(keyBytes, valueBytes, rootHash) { t.Errorf("Invalid proof after write/read. Verification failed.") return } // Random mutations must not verify - for i := 0; i < 3; i++ { + for i := 0; i < 5; i++ { badProofBytes := MutateByteSlice(proofBytes) n, err := int64(0), error(nil) badProof := binary.ReadBinary(&IAVLProof{}, bytes.NewBuffer(badProofBytes), &n, &err).(*IAVLProof) if err != nil { continue // This is fine. } - if badProof.Verify() { + if badProof.Verify(keyBytes, valueBytes, rootHash) { t.Errorf("Proof was still valid after a random mutation:\n%X\n%X", proofBytes, badProofBytes) } } } -func TestConstructProof(t *testing.T) { +func TestIAVLProof(t *testing.T) { + + // Convenient wrapper around binary.BasicCodec. + toBytes := func(o interface{}) []byte { + buf, n, err := new(bytes.Buffer), int64(0), error(nil) + binary.BasicCodec.Encode(o, buf, &n, &err) + if err != nil { + panic(Fmt("Failed to encode thing: %v", err)) + } + return buf.Bytes() + } + // Construct some random tree db := db.NewMemDB() var tree *IAVLTree = NewIAVLTree(binary.BasicCodec, binary.BasicCodec, 100, db) @@ -291,13 +302,10 @@ func TestConstructProof(t *testing.T) { // Now for each item, construct a proof and verify tree.Iterate(func(key interface{}, value interface{}) bool { proof := tree.ConstructProof(key) - if !bytes.Equal(proof.Root, tree.Hash()) { - t.Errorf("Invalid proof. Expected root %X, got %X", tree.Hash(), proof.Root) - } - if !proof.Verify() { - t.Errorf("Invalid proof. Verification failed.") + if !bytes.Equal(proof.RootHash, tree.Hash()) { + t.Errorf("Invalid proof. Expected root %X, got %X", tree.Hash(), proof.RootHash) } - testProof(t, proof) + testProof(t, proof, toBytes(key), toBytes(value), tree.Hash()) return false }) diff --git a/merkle/simple_tree.go b/merkle/simple_tree.go index bdc9164d8..2f05f8bb5 100644 --- a/merkle/simple_tree.go +++ b/merkle/simple_tree.go @@ -7,18 +7,18 @@ the tree the same size, but the left may be one greater. Use this for short deterministic trees, such as the validator list. For larger datasets, use IAVLTree. - * - / \ - / \ - / \ - / \ - * * - / \ / \ - / \ / \ - / \ / \ - * * * h6 - / \ / \ / \ - h0 h1 h2 h3 h4 h5 + * + / \ + / \ + / \ + / \ + * * + / \ / \ + / \ / \ + / \ / \ + * * * h6 + / \ / \ / \ + h0 h1 h2 h3 h4 h5 */ @@ -31,7 +31,7 @@ import ( "github.com/tendermint/tendermint/binary" ) -func HashFromTwoHashes(left []byte, right []byte) []byte { +func SimpleHashFromTwoHashes(left []byte, right []byte) []byte { var n int64 var err error var hasher = sha256.New() @@ -43,7 +43,7 @@ func HashFromTwoHashes(left []byte, right []byte) []byte { return hasher.Sum(nil) } -func HashFromHashes(hashes [][]byte) []byte { +func SimpleHashFromHashes(hashes [][]byte) []byte { // Recursive impl. switch len(hashes) { case 0: @@ -51,23 +51,23 @@ func HashFromHashes(hashes [][]byte) []byte { case 1: return hashes[0] default: - left := HashFromHashes(hashes[:(len(hashes)+1)/2]) - right := HashFromHashes(hashes[(len(hashes)+1)/2:]) - return HashFromTwoHashes(left, right) + left := SimpleHashFromHashes(hashes[:(len(hashes)+1)/2]) + right := SimpleHashFromHashes(hashes[(len(hashes)+1)/2:]) + return SimpleHashFromTwoHashes(left, right) } } -// Convenience for HashFromHashes. -func HashFromBinaries(items []interface{}) []byte { +// Convenience for SimpleHashFromHashes. +func SimpleHashFromBinaries(items []interface{}) []byte { hashes := [][]byte{} for _, item := range items { - hashes = append(hashes, HashFromBinary(item)) + hashes = append(hashes, SimpleHashFromBinary(item)) } - return HashFromHashes(hashes) + return SimpleHashFromHashes(hashes) } // General Convenience -func HashFromBinary(item interface{}) []byte { +func SimpleHashFromBinary(item interface{}) []byte { hasher, n, err := sha256.New(), new(int64), new(error) binary.WriteBinary(item, hasher, n, err) if *err != nil { @@ -76,103 +76,146 @@ func HashFromBinary(item interface{}) []byte { return hasher.Sum(nil) } -// Convenience for HashFromHashes. -func HashFromHashables(items []Hashable) []byte { +// Convenience for SimpleHashFromHashes. +func SimpleHashFromHashables(items []Hashable) []byte { hashes := [][]byte{} for _, item := range items { hash := item.Hash() hashes = append(hashes, hash) } - return HashFromHashes(hashes) + return SimpleHashFromHashes(hashes) } -type HashTrail struct { - Hash []byte - Parent *HashTrail - Left *HashTrail - Right *HashTrail +//-------------------------------------------------------------------------------- + +type SimpleProof struct { + Index uint + Total uint + LeafHash []byte + InnerHashes [][]byte // Hashes from leaf's sibling to a root's child. + RootHash []byte } -func (ht *HashTrail) Flatten() [][]byte { - // Nonrecursive impl. - trail := [][]byte{} - for ht != nil { - if ht.Left != nil { - trail = append(trail, ht.Left.Hash) - } else if ht.Right != nil { - trail = append(trail, ht.Right.Hash) - } else { - break +// proofs[0] is the proof for items[0]. +func SimpleProofsFromHashables(items []Hashable) (proofs []*SimpleProof) { + trails, root := trailsFromHashables(items) + proofs = make([]*SimpleProof, len(items)) + for i, trail := range trails { + proofs[i] = &SimpleProof{ + Index: uint(i), + Total: uint(len(items)), + LeafHash: trail.Hash, + InnerHashes: trail.FlattenInnerHashes(), + RootHash: root.Hash, } - ht = ht.Parent } - return trail + return } -// returned trails[0].Hash is the leaf hash. -// trails[0].Parent.Hash is the hash above that, etc. -func HashTrailsFromHashables(items []Hashable) (trails []*HashTrail, root *HashTrail) { - // Recursive impl. - switch len(items) { - case 0: - return nil, nil - case 1: - trail := &HashTrail{items[0].Hash(), nil, nil, nil} - return []*HashTrail{trail}, trail - default: - lefts, leftRoot := HashTrailsFromHashables(items[:(len(items)+1)/2]) - rights, rightRoot := HashTrailsFromHashables(items[(len(items)+1)/2:]) - rootHash := HashFromTwoHashes(leftRoot.Hash, rightRoot.Hash) - root := &HashTrail{rootHash, nil, nil, nil} - leftRoot.Parent = root - leftRoot.Right = rightRoot - rightRoot.Parent = root - rightRoot.Left = leftRoot - return append(lefts, rights...), root +// Verify that leafHash is a leaf hash of the simple-merkle-tree +// which hashes to rootHash. +func (sp *SimpleProof) Verify(leafHash []byte, rootHash []byte) bool { + if !bytes.Equal(leafHash, sp.LeafHash) { + return false } -} - -// Ensures that leafHash is part of rootHash. -func VerifyHashTrail(index uint, total uint, leafHash []byte, trail [][]byte, rootHash []byte) bool { - computedRoot := ComputeRootFromTrail(index, total, leafHash, trail) - if computedRoot == nil { + if !bytes.Equal(rootHash, sp.RootHash) { + return false + } + computedHash := computeHashFromInnerHashes(sp.Index, sp.Total, sp.LeafHash, sp.InnerHashes) + if computedHash == nil { return false } - return bytes.Equal(computedRoot, rootHash) + if !bytes.Equal(computedHash, rootHash) { + return false + } + return true } -// Use the leafHash and trail to get the root merkle hash. -// If the length of the trail slice isn't exactly correct, the result is nil. -func ComputeRootFromTrail(index uint, total uint, leafHash []byte, trail [][]byte) []byte { +// Use the leafHash and innerHashes to get the root merkle hash. +// If the length of the innerHashes slice isn't exactly correct, the result is nil. +func computeHashFromInnerHashes(index uint, total uint, leafHash []byte, innerHashes [][]byte) []byte { // Recursive impl. if index >= total { return nil } switch total { case 0: - panic("Cannot call ComputeRootFromTrail() with 0 total") + panic("Cannot call computeHashFromInnerHashes() with 0 total") case 1: - if len(trail) != 0 { + if len(innerHashes) != 0 { return nil } return leafHash default: - if len(trail) == 0 { + if len(innerHashes) == 0 { return nil } numLeft := (total + 1) / 2 if index < numLeft { - leftRoot := ComputeRootFromTrail(index, numLeft, leafHash, trail[:len(trail)-1]) - if leftRoot == nil { + leftHash := computeHashFromInnerHashes(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + if leftHash == nil { return nil } - return HashFromTwoHashes(leftRoot, trail[len(trail)-1]) + return SimpleHashFromTwoHashes(leftHash, innerHashes[len(innerHashes)-1]) } else { - rightRoot := ComputeRootFromTrail(index-numLeft, total-numLeft, leafHash, trail[:len(trail)-1]) - if rightRoot == nil { + rightHash := computeHashFromInnerHashes(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + if rightHash == nil { return nil } - return HashFromTwoHashes(trail[len(trail)-1], rightRoot) + return SimpleHashFromTwoHashes(innerHashes[len(innerHashes)-1], rightHash) } } } + +// Helper structure to construct merkle proof. +// The node and the tree is thrown away afterwards. +// Exactly one of node.Left and node.Right is nil, unless node is the root, in which case both are nil. +// node.Parent.Hash = hash(node.Hash, node.Right.Hash) or +// hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child. +type SimpleProofNode struct { + Hash []byte + Parent *SimpleProofNode + Left *SimpleProofNode // Left sibling (only one of Left,Right is set) + Right *SimpleProofNode // Right sibling (only one of Left,Right is set) +} + +// Starting from a leaf SimpleProofNode, FlattenInnerHashes() will return +// the inner hashes for the item corresponding to the leaf. +func (spn *SimpleProofNode) FlattenInnerHashes() [][]byte { + // Nonrecursive impl. + innerHashes := [][]byte{} + for spn != nil { + if spn.Left != nil { + innerHashes = append(innerHashes, spn.Left.Hash) + } else if spn.Right != nil { + innerHashes = append(innerHashes, spn.Right.Hash) + } else { + break + } + spn = spn.Parent + } + return innerHashes +} + +// trails[0].Hash is the leaf hash for items[0]. +// trails[i].Parent.Parent....Parent == root for all i. +func trailsFromHashables(items []Hashable) (trails []*SimpleProofNode, root *SimpleProofNode) { + // Recursive impl. + switch len(items) { + case 0: + return nil, nil + case 1: + trail := &SimpleProofNode{items[0].Hash(), nil, nil, nil} + return []*SimpleProofNode{trail}, trail + default: + lefts, leftRoot := trailsFromHashables(items[:(len(items)+1)/2]) + rights, rightRoot := trailsFromHashables(items[(len(items)+1)/2:]) + rootHash := SimpleHashFromTwoHashes(leftRoot.Hash, rightRoot.Hash) + root := &SimpleProofNode{rootHash, nil, nil, nil} + leftRoot.Parent = root + leftRoot.Right = rightRoot + rightRoot.Parent = root + rightRoot.Left = leftRoot + return append(lefts, rights...), root + } +} diff --git a/merkle/simple_tree_test.go b/merkle/simple_tree_test.go index 893a30ce9..eabf8f320 100644 --- a/merkle/simple_tree_test.go +++ b/merkle/simple_tree_test.go @@ -2,8 +2,8 @@ package merkle import ( . "github.com/tendermint/tendermint/common" + . "github.com/tendermint/tendermint/common/test" - "bytes" "testing" ) @@ -13,7 +13,7 @@ func (tI testItem) Hash() []byte { return []byte(tI) } -func TestMerkleTrails(t *testing.T) { +func TestSimpleProof(t *testing.T) { numItems := uint(100) @@ -22,53 +22,62 @@ func TestMerkleTrails(t *testing.T) { items[i] = testItem(RandBytes(32)) } - root := HashFromHashables(items) + rootHash := SimpleHashFromHashables(items) - trails, rootTrail := HashTrailsFromHashables(items) - - // Assert that HashFromHashables and HashTrailsFromHashables are compatible. - if !bytes.Equal(root, rootTrail.Hash) { - t.Errorf("Root mismatch:\n%X vs\n%X", root, rootTrail.Hash) - } + proofs := SimpleProofsFromHashables(items) // For each item, check the trail. for i, item := range items { itemHash := item.Hash() - flatTrail := trails[i].Flatten() + proof := proofs[i] // Verify success - ok := VerifyHashTrail(uint(i), numItems, itemHash, flatTrail, root) + ok := proof.Verify(itemHash, rootHash) if !ok { t.Errorf("Verification failed for index %v.", i) } // Wrong item index should make it fail - ok = VerifyHashTrail(uint(i)+1, numItems, itemHash, flatTrail, root) - if ok { - t.Errorf("Expected verification to fail for wrong index %v.", i) + proof.Index += 1 + { + ok = proof.Verify(itemHash, rootHash) + if ok { + t.Errorf("Expected verification to fail for wrong index %v.", i) + } } + proof.Index -= 1 // Trail too long should make it fail - trail2 := append(flatTrail, RandBytes(32)) - ok = VerifyHashTrail(uint(i), numItems, itemHash, trail2, root) - if ok { - t.Errorf("Expected verification to fail for wrong trail length.") + origInnerHashes := proof.InnerHashes + proof.InnerHashes = append(proof.InnerHashes, RandBytes(32)) + { + ok = proof.Verify(itemHash, rootHash) + if ok { + t.Errorf("Expected verification to fail for wrong trail length.") + } } + proof.InnerHashes = origInnerHashes // Trail too short should make it fail - trail2 = flatTrail[:len(flatTrail)-1] - ok = VerifyHashTrail(uint(i), numItems, itemHash, trail2, root) - if ok { - t.Errorf("Expected verification to fail for wrong trail length.") + proof.InnerHashes = proof.InnerHashes[0 : len(proof.InnerHashes)-1] + { + ok = proof.Verify(itemHash, rootHash) + if ok { + t.Errorf("Expected verification to fail for wrong trail length.") + } } + proof.InnerHashes = origInnerHashes // Mutating the itemHash should make it fail. - itemHash2 := make([]byte, len(itemHash)) - copy(itemHash2, itemHash) - itemHash2[0] += byte(0x01) - ok = VerifyHashTrail(uint(i), numItems, itemHash2, flatTrail, root) + ok = proof.Verify(MutateByteSlice(itemHash), rootHash) if ok { t.Errorf("Expected verification to fail for mutated leaf hash") } + + // Mutating the rootHash should make it fail. + ok = proof.Verify(itemHash, MutateByteSlice(rootHash)) + if ok { + t.Errorf("Expected verification to fail for mutated root hash") + } } }