Browse Source

Merkle proofs!

pull/79/head
Jae Kwon 9 years ago
parent
commit
fdf0e720bc
9 changed files with 371 additions and 78 deletions
  1. +13
    -8
      binary/int.go
  2. +1
    -1
      binary/int_test.go
  3. +4
    -0
      common/random.go
  4. +28
    -0
      common/test/mutate.go
  5. +107
    -55
      merkle/iavl_node.go
  6. +137
    -0
      merkle/iavl_proof.go
  7. +71
    -5
      merkle/iavl_test.go
  8. +6
    -5
      merkle/iavl_tree.go
  9. +4
    -4
      merkle/types.go

+ 13
- 8
binary/int.go View File

@ -160,6 +160,9 @@ func ReadUint64(r io.Reader, n *int64, err *error) uint64 {
func uvarintSize(i_ uint) int {
i := uint64(i_)
if i == 0 {
return 0
}
if i < 1<<8 {
return 1
}
@ -197,9 +200,11 @@ func WriteVarint(i int, w io.Writer, n *int64, err *error) {
} else {
WriteUint8(uint8(size), w, n, err)
}
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(i))
WriteTo(buf[(8-size):], w, n, err)
if size > 0 {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(i))
WriteTo(buf[(8-size):], w, n, err)
}
*n += int64(1 + size)
}
@ -215,7 +220,6 @@ func ReadVarint(r io.Reader, n *int64, err *error) int {
return 0
}
if size == 0 {
setFirstErr(err, errors.New("Varint underflow"))
return 0
}
buf := make([]byte, 8)
@ -234,9 +238,11 @@ func ReadVarint(r io.Reader, n *int64, err *error) int {
func WriteUvarint(i uint, w io.Writer, n *int64, err *error) {
var size = uvarintSize(i)
WriteUint8(uint8(size), w, n, err)
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(i))
WriteTo(buf[(8-size):], w, n, err)
if size > 0 {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(i))
WriteTo(buf[(8-size):], w, n, err)
}
*n += int64(1 + size)
}
@ -247,7 +253,6 @@ func ReadUvarint(r io.Reader, n *int64, err *error) uint {
return 0
}
if size == 0 {
setFirstErr(err, errors.New("Uvarint underflow"))
return 0
}
buf := make([]byte, 8)


+ 1
- 1
binary/int_test.go View File

@ -32,7 +32,7 @@ func TestVarint(t *testing.T) {
// Near zero
check(-1, "F101")
check(0, "0100")
check(0, "00")
check(1, "0101")
// Positives
check(1<<32-1, "04FFFFFFFF")


+ 4
- 0
common/random.go View File

@ -62,6 +62,10 @@ func RandUint() uint {
return uint(rand.Int())
}
func RandInt() int {
return rand.Int()
}
// Distributed pseudo-exponentially to test for various cases
func RandUint16Exp() uint16 {
bits := rand.Uint32() % 16


+ 28
- 0
common/test/mutate.go View File

@ -0,0 +1,28 @@
package test
import (
. "github.com/tendermint/tendermint/common"
)
// Contract: !bytes.Equal(input, output) && len(input) >= len(output)
func MutateByteSlice(bytez []byte) []byte {
// If bytez is empty, panic
if len(bytez) == 0 {
panic("Cannot mutate an empty bytez")
}
// Copy bytez
mBytez := make([]byte, len(bytez))
copy(mBytez, bytez)
bytez = mBytez
// Try a random mutation
switch RandInt() % 2 {
case 0: // Mutate a single byte
bytez[RandInt()%len(bytez)] += byte(RandInt()%255 + 1)
case 1: // Remove an arbitrary byte
pos := RandInt() % len(bytez)
bytez = append(bytez[:pos], bytez[pos+1:]...)
}
return bytez
}

+ 107
- 55
merkle/iavl_node.go View File

@ -1,6 +1,7 @@
package merkle
import (
"bytes"
"crypto/sha256"
"io"
@ -12,8 +13,8 @@ import (
type IAVLNode struct {
key interface{}
value interface{}
size uint64
height uint8
size uint
hash []byte
leftHash []byte
leftNode *IAVLNode
@ -24,27 +25,28 @@ type IAVLNode struct {
func NewIAVLNode(key interface{}, value interface{}) *IAVLNode {
return &IAVLNode{
key: key,
value: value,
size: 1,
key: key,
value: value,
height: 0,
size: 1,
}
}
// NOTE: The hash is not saved or set. The caller should set the hash afterwards.
// (Presumably the caller already has the hash)
func ReadIAVLNode(t *IAVLTree, r io.Reader, n *int64, err *error) *IAVLNode {
node := &IAVLNode{}
// node header & key
// node header
node.height = binary.ReadUint8(r, n, err)
node.size = binary.ReadUint64(r, n, err)
node.key = t.keyCodec.Decode(r, n, err)
if *err != nil {
panic(*err)
}
node.size = binary.ReadUvarint(r, n, err)
node.key = decodeByteSlice(t.keyCodec, r, n, err)
// node value or children.
if node.height == 0 {
node.value = t.valueCodec.Decode(r, n, err)
// value
node.value = decodeByteSlice(t.valueCodec, r, n, err)
} else {
// children
node.leftHash = binary.ReadByteSlice(r, n, err)
node.rightHash = binary.ReadByteSlice(r, n, err)
}
@ -60,8 +62,8 @@ func (node *IAVLNode) _copy() *IAVLNode {
}
return &IAVLNode{
key: node.key,
size: node.size,
height: node.height,
size: node.size,
hash: nil, // Going to be mutated anyways.
leftHash: node.leftHash,
leftNode: node.leftNode,
@ -86,7 +88,7 @@ func (node *IAVLNode) has(t *IAVLTree, key interface{}) (has bool) {
}
}
func (node *IAVLNode) get(t *IAVLTree, key interface{}) (index uint64, value interface{}) {
func (node *IAVLNode) get(t *IAVLTree, key interface{}) (index uint, value interface{}) {
if node.height == 0 {
if t.keyCodec.Compare(node.key, key) == 0 {
return 0, node.value
@ -105,7 +107,7 @@ func (node *IAVLNode) get(t *IAVLTree, key interface{}) (index uint64, value int
}
}
func (node *IAVLNode) getByIndex(t *IAVLTree, index uint64) (key interface{}, value interface{}) {
func (node *IAVLNode) getByIndex(t *IAVLTree, index uint) (key interface{}, value interface{}) {
if node.height == 0 {
if index == 0 {
return node.key, node.value
@ -125,21 +127,61 @@ func (node *IAVLNode) getByIndex(t *IAVLTree, index uint64) (key interface{}, va
}
// NOTE: sets hashes recursively
func (node *IAVLNode) hashWithCount(t *IAVLTree) ([]byte, uint64) {
func (node *IAVLNode) hashWithCount(t *IAVLTree) ([]byte, uint) {
if node.hash != nil {
return node.hash, 0
}
hasher := sha256.New()
_, hashCount, err := node.writeToCountHashes(t, hasher)
buf := new(bytes.Buffer)
_, hashCount, err := node.writeHashBytes(t, buf)
if err != nil {
panic(err)
}
// fmt.Printf("Wrote IAVL hash bytes: %X\n", buf.Bytes())
hasher.Write(buf.Bytes())
node.hash = hasher.Sum(nil)
// fmt.Printf("Write IAVL hash: %X\n", node.hash)
return node.hash, hashCount + 1
}
// NOTE: sets hashes recursively
func (node *IAVLNode) writeHashBytes(t *IAVLTree, w io.Writer) (n int64, hashCount uint, err error) {
// height & size
binary.WriteUint8(node.height, w, &n, &err)
binary.WriteUvarint(node.size, w, &n, &err)
// key is not written for inner nodes, unlike writePersistBytes
if node.height == 0 {
// key & value
encodeByteSlice(node.key, t.keyCodec, w, &n, &err)
encodeByteSlice(node.value, t.valueCodec, w, &n, &err)
} else {
// left
if node.leftNode != nil {
leftHash, leftCount := node.leftNode.hashWithCount(t)
node.leftHash = leftHash
hashCount += leftCount
}
if node.leftHash == nil {
panic("node.leftHash was nil in writeHashBytes")
}
binary.WriteByteSlice(node.leftHash, w, &n, &err)
// right
if node.rightNode != nil {
rightHash, rightCount := node.rightNode.hashWithCount(t)
node.rightHash = rightHash
hashCount += rightCount
}
if node.rightHash == nil {
panic("node.rightHash was nil in writeHashBytes")
}
binary.WriteByteSlice(node.rightHash, w, &n, &err)
}
return
}
// NOTE: sets hashes recursively
// NOTE: clears leftNode/rightNode recursively
func (node *IAVLNode) save(t *IAVLTree) []byte {
@ -165,6 +207,32 @@ func (node *IAVLNode) save(t *IAVLTree) []byte {
return node.hash
}
// NOTE: sets hashes recursively
func (node *IAVLNode) writePersistBytes(t *IAVLTree, w io.Writer) (n int64, err error) {
// node header
binary.WriteUint8(node.height, w, &n, &err)
binary.WriteUvarint(node.size, w, &n, &err)
// key (unlike writeHashBytes, key is written for inner nodes)
encodeByteSlice(node.key, t.keyCodec, w, &n, &err)
if node.height == 0 {
// value
encodeByteSlice(node.value, t.valueCodec, w, &n, &err)
} else {
// left
if node.leftHash == nil {
panic("node.leftHash was nil in writePersistBytes")
}
binary.WriteByteSlice(node.leftHash, w, &n, &err)
// right
if node.rightHash == nil {
panic("node.rightHash was nil in writePersistBytes")
}
binary.WriteByteSlice(node.rightHash, w, &n, &err)
}
return
}
func (node *IAVLNode) set(t *IAVLTree, key interface{}, value interface{}) (newSelf *IAVLNode, updated bool) {
if node.height == 0 {
cmp := t.keyCodec.Compare(key, node.key)
@ -251,44 +319,6 @@ func (node *IAVLNode) remove(t *IAVLTree, key interface{}) (
}
}
// NOTE: sets hashes recursively
func (node *IAVLNode) writeToCountHashes(t *IAVLTree, w io.Writer) (n int64, hashCount uint64, err error) {
// height & size & key
binary.WriteUint8(node.height, w, &n, &err)
binary.WriteUint64(node.size, w, &n, &err)
t.keyCodec.Encode(node.key, w, &n, &err)
if err != nil {
return
}
if node.height == 0 {
// value
t.valueCodec.Encode(node.value, w, &n, &err)
} else {
// left
if node.leftNode != nil {
leftHash, leftCount := node.leftNode.hashWithCount(t)
node.leftHash = leftHash
hashCount += leftCount
}
if node.leftHash == nil {
panic("node.leftHash was nil in save")
}
binary.WriteByteSlice(node.leftHash, w, &n, &err)
// right
if node.rightNode != nil {
rightHash, rightCount := node.rightNode.hashWithCount(t)
node.rightHash = rightHash
hashCount += rightCount
}
if node.rightHash == nil {
panic("node.rightHash was nil in save")
}
binary.WriteByteSlice(node.rightHash, w, &n, &err)
}
return
}
func (node *IAVLNode) getLeftNode(t *IAVLTree) *IAVLNode {
if node.leftNode != nil {
return node.leftNode
@ -406,3 +436,25 @@ func (node *IAVLNode) rmd(t *IAVLTree) *IAVLNode {
}
return node.getRightNode(t).rmd(t)
}
//--------------------------------------------------------------------------------
// Read a (length prefixed) byteslice then decode the object using the codec
func decodeByteSlice(codec binary.Codec, r io.Reader, n *int64, err *error) interface{} {
bytez := binary.ReadByteSlice(r, n, err)
if *err != nil {
return nil
}
n_ := new(int64)
return codec.Decode(bytes.NewBuffer(bytez), n_, err)
}
// Encode object using codec, then write a (length prefixed) byteslice.
func encodeByteSlice(o interface{}, codec binary.Codec, w io.Writer, n *int64, err *error) {
buf, n_ := new(bytes.Buffer), new(int64)
codec.Encode(o, buf, n_, err)
if *err != nil {
return
}
binary.WriteByteSlice(buf.Bytes(), w, n, err)
}

+ 137
- 0
merkle/iavl_proof.go View File

@ -0,0 +1,137 @@
package merkle
import (
"bytes"
"crypto/sha256"
"github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/common"
)
type IAVLProof struct {
Root []byte
Branches []IAVLProofBranch
Leaf IAVLProofLeaf
}
func (proof *IAVLProof) Verify() bool {
hash := proof.Leaf.Hash()
// fmt.Printf("leaf hash: %X\n", hash)
for i := len(proof.Branches) - 1; 0 <= i; i-- {
hash = proof.Branches[i].Hash(hash)
// fmt.Printf("branch hash: %X\n", hash)
}
// fmt.Printf("root: %X, computed: %X\n", proof.Root, hash)
return bytes.Equal(proof.Root, hash)
}
type IAVLProofBranch struct {
Height uint8
Size uint
Left []byte
Right []byte
}
func (branch IAVLProofBranch) Hash(childHash []byte) []byte {
hasher := sha256.New()
buf := new(bytes.Buffer)
n, err := int64(0), error(nil)
binary.WriteUint8(branch.Height, buf, &n, &err)
binary.WriteUvarint(branch.Size, buf, &n, &err)
if branch.Left == nil {
binary.WriteByteSlice(childHash, buf, &n, &err)
binary.WriteByteSlice(branch.Right, buf, &n, &err)
} else {
binary.WriteByteSlice(branch.Left, buf, &n, &err)
binary.WriteByteSlice(childHash, buf, &n, &err)
}
if err != nil {
panic(Fmt("Failed to hash IAVLProofBranch: %v", err))
}
// fmt.Printf("Branch hash bytes: %X\n", buf.Bytes())
hasher.Write(buf.Bytes())
return hasher.Sum(nil)
}
type IAVLProofLeaf struct {
KeyBytes []byte
ValueBytes []byte
}
func (leaf IAVLProofLeaf) Hash() []byte {
hasher := sha256.New()
buf := new(bytes.Buffer)
n, err := int64(0), error(nil)
binary.WriteUint8(0, buf, &n, &err)
binary.WriteUvarint(1, buf, &n, &err)
binary.WriteByteSlice(leaf.KeyBytes, buf, &n, &err)
binary.WriteByteSlice(leaf.ValueBytes, buf, &n, &err)
if err != nil {
panic(Fmt("Failed to hash IAVLProofLeaf: %v", err))
}
// fmt.Printf("Leaf hash bytes: %X\n", buf.Bytes())
hasher.Write(buf.Bytes())
return hasher.Sum(nil)
}
func (node *IAVLNode) constructProof(t *IAVLTree, key interface{}, proof *IAVLProof) (exists bool) {
if node.height == 0 {
if t.keyCodec.Compare(node.key, key) == 0 {
keyBuf, valueBuf := new(bytes.Buffer), new(bytes.Buffer)
n, err := int64(0), error(nil)
t.keyCodec.Encode(node.key, keyBuf, &n, &err)
if err != nil {
panic(Fmt("Failed to encode node.key: %v", err))
}
t.valueCodec.Encode(node.value, valueBuf, &n, &err)
if err != nil {
panic(Fmt("Failed to encode node.value: %v", err))
}
leaf := IAVLProofLeaf{
KeyBytes: keyBuf.Bytes(),
ValueBytes: valueBuf.Bytes(),
}
proof.Leaf = leaf
return true
} else {
return false
}
} else {
if t.keyCodec.Compare(key, node.key) < 0 {
branch := IAVLProofBranch{
Height: node.height,
Size: node.size,
Left: nil,
Right: node.getRightNode(t).hash,
}
if node.getRightNode(t).hash == nil {
// fmt.Println(node.getRightNode(t))
panic("WTF")
}
proof.Branches = append(proof.Branches, branch)
return node.getLeftNode(t).constructProof(t, key, proof)
} else {
branch := IAVLProofBranch{
Height: node.height,
Size: node.size,
Left: node.getLeftNode(t).hash,
Right: nil,
}
proof.Branches = append(proof.Branches, branch)
return node.getRightNode(t).constructProof(t, key, proof)
}
}
}
// Returns nil if key is not in tree.
func (t *IAVLTree) ConstructProof(key interface{}) *IAVLProof {
if t.root == nil {
return nil
}
t.root.hashWithCount(t) // Ensure that all hashes are calculated.
proof := &IAVLProof{
Root: t.root.hash,
}
t.root.constructProof(t, key, proof)
return proof
}

+ 71
- 5
merkle/iavl_test.go View File

@ -6,6 +6,7 @@ import (
"github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/common"
. "github.com/tendermint/tendermint/common/test"
"github.com/tendermint/tendermint/db"
"runtime"
@ -59,7 +60,7 @@ func P(n *IAVLNode) string {
func TestUnit(t *testing.T) {
expectHash := func(tree *IAVLTree, hashCount uint64) {
expectHash := func(tree *IAVLTree, hashCount uint) {
// ensure number of new hash calculations is as expected.
hash, count := tree.HashWithCount()
if count != hashCount {
@ -77,7 +78,7 @@ func TestUnit(t *testing.T) {
}
}
expectSet := func(tree *IAVLTree, i int, repr string, hashCount uint64) {
expectSet := func(tree *IAVLTree, i int, repr string, hashCount uint) {
origNode := tree.root
updated := tree.Set(i, "")
// ensure node was added & structure is as expected.
@ -90,7 +91,7 @@ func TestUnit(t *testing.T) {
tree.root = origNode
}
expectRemove := func(tree *IAVLTree, i int, repr string, hashCount uint64) {
expectRemove := func(tree *IAVLTree, i int, repr string, hashCount uint) {
origNode := tree.root
value, removed := tree.Remove(i)
// ensure node was added & structure is as expected.
@ -167,7 +168,7 @@ func TestIntegration(t *testing.T) {
if !updated {
t.Error("should have been updated")
}
if tree.Size() != uint64(i+1) {
if tree.Size() != uint(i+1) {
t.Error("size was wrong", tree.Size(), i+1)
}
}
@ -202,7 +203,7 @@ func TestIntegration(t *testing.T) {
t.Error("wrong value")
}
}
if tree.Size() != uint64(len(records)-(i+1)) {
if tree.Size() != uint(len(records)-(i+1)) {
t.Error("size was wrong", tree.Size(), (len(records) - (i + 1)))
}
}
@ -237,6 +238,71 @@ func TestPersistence(t *testing.T) {
}
}
func testProof(t *testing.T, proof *IAVLProof) {
// Proof must verify.
if !proof.Verify() {
t.Errorf("Invalid proof. Verification failed.")
return
}
// Write/Read then verify.
proofBytes := binary.BinaryBytes(proof)
n, err := int64(0), error(nil)
proof2 := binary.ReadBinary(&IAVLProof{}, bytes.NewBuffer(proofBytes), &n, &err).(*IAVLProof)
if err != nil {
t.Errorf("Failed to read IAVLProof from bytes: %v", err)
return
}
if !proof2.Verify() {
t.Errorf("Invalid proof after write/read. Verification failed.")
return
}
// Random mutations must not verify
for i := 0; i < 3; i++ {
badProofBytes := MutateByteSlice(proofBytes)
n, err := int64(0), error(nil)
badProof := binary.ReadBinary(&IAVLProof{}, bytes.NewBuffer(badProofBytes), &n, &err).(*IAVLProof)
if err != nil {
continue // This is fine.
}
if badProof.Verify() {
t.Errorf("Proof was still valid after a random mutation:\n%X\n%X", proofBytes, badProofBytes)
}
}
}
func TestConstructProof(t *testing.T) {
// Construct some random tree
db := db.NewMemDB()
var tree *IAVLTree = NewIAVLTree(binary.BasicCodec, binary.BasicCodec, 100, db)
for i := 0; i < 1000; i++ {
key, value := randstr(20), randstr(20)
tree.Set(key, value)
}
// Persist the items so far
tree.Save()
// Add more items so it's not all persisted
for i := 0; i < 100; i++ {
key, value := randstr(20), randstr(20)
tree.Set(key, value)
}
// Now for each item, construct a proof and verify
tree.Iterate(func(key interface{}, value interface{}) bool {
proof := tree.ConstructProof(key)
if !bytes.Equal(proof.Root, tree.Hash()) {
t.Errorf("Invalid proof. Expected root %X, got %X", tree.Hash(), proof.Root)
}
if !proof.Verify() {
t.Errorf("Invalid proof. Verification failed.")
}
testProof(t, proof)
return false
})
}
func BenchmarkImmutableAvlTree(b *testing.B) {
b.StopTimer()


+ 6
- 5
merkle/iavl_tree.go View File

@ -68,7 +68,7 @@ func (t *IAVLTree) Copy() Tree {
}
}
func (t *IAVLTree) Size() uint64 {
func (t *IAVLTree) Size() uint {
if t.root == nil {
return 0
}
@ -106,7 +106,7 @@ func (t *IAVLTree) Hash() []byte {
return hash
}
func (t *IAVLTree) HashWithCount() ([]byte, uint64) {
func (t *IAVLTree) HashWithCount() ([]byte, uint) {
if t.root == nil {
return nil, 0
}
@ -130,14 +130,14 @@ func (t *IAVLTree) Load(hash []byte) {
}
}
func (t *IAVLTree) Get(key interface{}) (index uint64, value interface{}) {
func (t *IAVLTree) Get(key interface{}) (index uint, value interface{}) {
if t.root == nil {
return 0, nil
}
return t.root.get(t, key)
}
func (t *IAVLTree) GetByIndex(index uint64) (key interface{}, value interface{}) {
func (t *IAVLTree) GetByIndex(index uint) (key interface{}, value interface{}) {
if t.root == nil {
return nil, nil
}
@ -220,6 +220,7 @@ func (ndb *nodeDB) GetNode(t *IAVLTree, hash []byte) *IAVLNode {
if err != nil {
panic(Fmt("Error reading IAVLNode. bytes: %X error: %v", buf, err))
}
node.hash = hash
node.persisted = true
ndb.cacheNode(node)
return node
@ -240,7 +241,7 @@ func (ndb *nodeDB) SaveNode(t *IAVLTree, node *IAVLNode) {
}*/
// Save node bytes to db
buf := bytes.NewBuffer(nil)
_, _, err := node.writeToCountHashes(t, buf)
_, err := node.writePersistBytes(t, buf)
if err != nil {
panic(err)
}


+ 4
- 4
merkle/types.go View File

@ -1,14 +1,14 @@
package merkle
type Tree interface {
Size() (size uint64)
Size() (size uint)
Height() (height uint8)
Has(key interface{}) (has bool)
Get(key interface{}) (index uint64, value interface{})
GetByIndex(index uint64) (key interface{}, value interface{})
Get(key interface{}) (index uint, value interface{})
GetByIndex(index uint) (key interface{}, value interface{})
Set(key interface{}, value interface{}) (updated bool)
Remove(key interface{}) (value interface{}, removed bool)
HashWithCount() (hash []byte, count uint64)
HashWithCount() (hash []byte, count uint)
Hash() (hash []byte)
Save() (hash []byte)
Load(hash []byte)


Loading…
Cancel
Save