From 19971bd181a81f1985712f5d142e26e300346454 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Wed, 16 Oct 2019 18:00:39 -0500 Subject: [PATCH] types: validate Part#Proof add ValidateBasic to crypto/merkle/SimpleProof --- consensus/reactor_test.go | 261 +++++++++++++++++++++++++++++ crypto/merkle/simple_proof.go | 30 ++++ crypto/merkle/simple_proof_test.go | 38 +++++ types/part_set.go | 7 +- types/part_set_test.go | 8 + 5 files changed, 342 insertions(+), 2 deletions(-) create mode 100644 crypto/merkle/simple_proof_test.go diff --git a/consensus/reactor_test.go b/consensus/reactor_test.go index b237da6b5..bae6b507d 100644 --- a/consensus/reactor_test.go +++ b/consensus/reactor_test.go @@ -19,6 +19,9 @@ import ( abci "github.com/tendermint/tendermint/abci/types" bc "github.com/tendermint/tendermint/blockchain" cfg "github.com/tendermint/tendermint/config" + cstypes "github.com/tendermint/tendermint/consensus/types" + "github.com/tendermint/tendermint/crypto/tmhash" + cmn "github.com/tendermint/tendermint/libs/common" dbm "github.com/tendermint/tendermint/libs/db" "github.com/tendermint/tendermint/libs/log" mempl "github.com/tendermint/tendermint/mempool" @@ -632,3 +635,261 @@ func capture() { count := runtime.Stack(trace, true) fmt.Printf("Stack of %d bytes: %s\n", count, trace) } + +//------------------------------------------------------------- +// Ensure basic validation of structs is functioning + +func TestNewRoundStepMessageValidateBasic(t *testing.T) { + testCases := []struct { // nolint: maligned + expectErr bool + messageRound int + messageLastCommitRound int + messageHeight int64 + testName string + messageStep cstypes.RoundStepType + }{ + {false, 0, 0, 0, "Valid Message", 0x01}, + {true, -1, 0, 0, "Invalid Message", 0x01}, + {true, 0, 0, -1, "Invalid Message", 0x01}, + {true, 0, 0, 1, "Invalid Message", 0x00}, + {true, 0, 0, 1, "Invalid Message", 0x00}, + {true, 0, -2, 2, "Invalid Message", 0x01}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + message := NewRoundStepMessage{ + Height: tc.messageHeight, + Round: tc.messageRound, + Step: tc.messageStep, + LastCommitRound: tc.messageLastCommitRound, + } + + assert.Equal(t, tc.expectErr, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") + }) + } +} + +func TestNewValidBlockMessageValidateBasic(t *testing.T) { + testBitArray := cmn.NewBitArray(1) + testCases := []struct { + testName string + messageHeight int64 + messageRound int + messageBlockParts *cmn.BitArray + expectErr bool + }{ + {"Valid Message", 0, 0, testBitArray, false}, + {"Invalid Message", -1, 0, testBitArray, true}, + {"Invalid Message", 0, -1, testBitArray, true}, + {"Invalid Message", 0, 0, cmn.NewBitArray(0), true}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + message := NewValidBlockMessage{ + Height: tc.messageHeight, + Round: tc.messageRound, + BlockParts: tc.messageBlockParts, + } + + message.BlockPartsHeader.Total = 1 + + assert.Equal(t, tc.expectErr, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") + }) + } +} + +func TestProposalPOLMessageValidateBasic(t *testing.T) { + testBitArray := cmn.NewBitArray(1) + testCases := []struct { + testName string + messageHeight int64 + messageProposalPOLRound int + messageProposalPOL *cmn.BitArray + expectErr bool + }{ + {"Valid Message", 0, 0, testBitArray, false}, + {"Invalid Message", -1, 0, testBitArray, true}, + {"Invalid Message", 0, -1, testBitArray, true}, + {"Invalid Message", 0, 0, cmn.NewBitArray(0), true}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + message := ProposalPOLMessage{ + Height: tc.messageHeight, + ProposalPOLRound: tc.messageProposalPOLRound, + ProposalPOL: tc.messageProposalPOL, + } + + assert.Equal(t, tc.expectErr, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") + }) + } +} + +func TestBlockPartMessageValidateBasic(t *testing.T) { + testPart := new(types.Part) + testPart.Proof.LeafHash = tmhash.Sum([]byte("leaf")) + testCases := []struct { + testName string + messageHeight int64 + messageRound int + messagePart *types.Part + expectErr bool + }{ + {"Valid Message", 0, 0, testPart, false}, + {"Invalid Message", -1, 0, testPart, true}, + {"Invalid Message", 0, -1, testPart, true}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + message := BlockPartMessage{ + Height: tc.messageHeight, + Round: tc.messageRound, + Part: tc.messagePart, + } + + assert.Equal(t, tc.expectErr, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") + }) + } + + message := BlockPartMessage{Height: 0, Round: 0, Part: new(types.Part)} + message.Part.Index = -1 + + assert.Equal(t, true, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") +} + +func TestHasVoteMessageValidateBasic(t *testing.T) { + const ( + validSignedMsgType types.SignedMsgType = 0x01 + invalidSignedMsgType types.SignedMsgType = 0x03 + ) + + testCases := []struct { // nolint: maligned + expectErr bool + messageRound int + messageIndex int + messageHeight int64 + testName string + messageType types.SignedMsgType + }{ + {false, 0, 0, 0, "Valid Message", validSignedMsgType}, + {true, -1, 0, 0, "Invalid Message", validSignedMsgType}, + {true, 0, -1, 0, "Invalid Message", validSignedMsgType}, + {true, 0, 0, 0, "Invalid Message", invalidSignedMsgType}, + {true, 0, 0, -1, "Invalid Message", validSignedMsgType}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + message := HasVoteMessage{ + Height: tc.messageHeight, + Round: tc.messageRound, + Type: tc.messageType, + Index: tc.messageIndex, + } + + assert.Equal(t, tc.expectErr, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") + }) + } +} + +func TestVoteSetMaj23MessageValidateBasic(t *testing.T) { + const ( + validSignedMsgType types.SignedMsgType = 0x01 + invalidSignedMsgType types.SignedMsgType = 0x03 + ) + + validBlockID := types.BlockID{} + invalidBlockID := types.BlockID{ + Hash: cmn.HexBytes{}, + PartsHeader: types.PartSetHeader{ + Total: -1, + Hash: cmn.HexBytes{}, + }, + } + + testCases := []struct { // nolint: maligned + expectErr bool + messageRound int + messageHeight int64 + testName string + messageType types.SignedMsgType + messageBlockID types.BlockID + }{ + {false, 0, 0, "Valid Message", validSignedMsgType, validBlockID}, + {true, -1, 0, "Invalid Message", validSignedMsgType, validBlockID}, + {true, 0, -1, "Invalid Message", validSignedMsgType, validBlockID}, + {true, 0, 0, "Invalid Message", invalidSignedMsgType, validBlockID}, + {true, 0, 0, "Invalid Message", validSignedMsgType, invalidBlockID}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + message := VoteSetMaj23Message{ + Height: tc.messageHeight, + Round: tc.messageRound, + Type: tc.messageType, + BlockID: tc.messageBlockID, + } + + assert.Equal(t, tc.expectErr, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") + }) + } +} + +func TestVoteSetBitsMessageValidateBasic(t *testing.T) { + const ( + validSignedMsgType types.SignedMsgType = 0x01 + invalidSignedMsgType types.SignedMsgType = 0x03 + ) + + validBlockID := types.BlockID{} + invalidBlockID := types.BlockID{ + Hash: cmn.HexBytes{}, + PartsHeader: types.PartSetHeader{ + Total: -1, + Hash: cmn.HexBytes{}, + }, + } + testBitArray := cmn.NewBitArray(1) + + testCases := []struct { // nolint: maligned + expectErr bool + messageRound int + messageHeight int64 + testName string + messageType types.SignedMsgType + messageBlockID types.BlockID + messageVotes *cmn.BitArray + }{ + {false, 0, 0, "Valid Message", validSignedMsgType, validBlockID, testBitArray}, + {true, -1, 0, "Invalid Message", validSignedMsgType, validBlockID, testBitArray}, + {true, 0, -1, "Invalid Message", validSignedMsgType, validBlockID, testBitArray}, + {true, 0, 0, "Invalid Message", invalidSignedMsgType, validBlockID, testBitArray}, + {true, 0, 0, "Invalid Message", validSignedMsgType, invalidBlockID, testBitArray}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + message := VoteSetBitsMessage{ + Height: tc.messageHeight, + Round: tc.messageRound, + Type: tc.messageType, + // Votes: tc.messageVotes, + BlockID: tc.messageBlockID, + } + + assert.Equal(t, tc.expectErr, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") + }) + } +} diff --git a/crypto/merkle/simple_proof.go b/crypto/merkle/simple_proof.go index f01dcdca1..fa4fefb63 100644 --- a/crypto/merkle/simple_proof.go +++ b/crypto/merkle/simple_proof.go @@ -5,9 +5,16 @@ import ( "errors" "fmt" + "github.com/pkg/errors" + "github.com/tendermint/tendermint/crypto/tmhash" cmn "github.com/tendermint/tendermint/libs/common" ) +const ( + // given maxMsgSizeBytes in consensus wal is 1MB + maxAunts = 30000 +) + // SimpleProof represents a simple Merkle proof. // NOTE: The convention for proofs is to include leaf hashes but to // exclude the root hash. @@ -109,6 +116,29 @@ func (sp *SimpleProof) StringIndented(indent string) string { indent) } +// ValidateBasic performs basic validation. +// NOTE: it expects LeafHash and Aunts of tmhash.Size size. +func (sp *SimpleProof) ValidateBasic() error { + if sp.Total < 0 { + return errors.New("negative Total") + } + if sp.Index < 0 { + return errors.New("negative Index") + } + if len(sp.LeafHash) != tmhash.Size { + return errors.Errorf("expected LeafHash size to be %d, got %d", tmhash.Size, len(sp.LeafHash)) + } + if len(sp.Aunts) > maxAunts { + return errors.Errorf("expected no more than %d aunts, got %d", maxAunts, len(sp.Aunts)) + } + for i, auntHash := range sp.Aunts { + if len(auntHash) != tmhash.Size { + return errors.Errorf("expected Aunts#%d size to be %d, got %d", i, tmhash.Size, len(auntHash)) + } + } + return nil +} + // Use the leafHash and innerHashes to get the root merkle hash. // If the length of the innerHashes slice isn't exactly correct, the result is nil. // Recursive impl. diff --git a/crypto/merkle/simple_proof_test.go b/crypto/merkle/simple_proof_test.go new file mode 100644 index 000000000..521bf4a35 --- /dev/null +++ b/crypto/merkle/simple_proof_test.go @@ -0,0 +1,38 @@ +package merkle + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSimpleProofValidateBasic(t *testing.T) { + testCases := []struct { + testName string + malleateProof func(*SimpleProof) + errStr string + }{ + {"Good", func(sp *SimpleProof) {}, ""}, + {"Negative Total", func(sp *SimpleProof) { sp.Total = -1 }, "negative Total"}, + {"Negative Index", func(sp *SimpleProof) { sp.Index = -1 }, "negative Index"}, + {"Invalid LeafHash", func(sp *SimpleProof) { sp.LeafHash = make([]byte, 10) }, "expected LeafHash size to be 32, got 10"}, + {"Too many Aunts", func(sp *SimpleProof) { sp.Aunts = make([][]byte, maxAunts+1) }, "expected no more than 30000 aunts, got 30001"}, + {"Invalid Aunt", func(sp *SimpleProof) { sp.Aunts[0] = make([]byte, 10) }, "expected Aunts#0 size to be 32, got 10"}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + _, proofs := SimpleProofsFromByteSlices([][]byte{ + []byte("apple"), + []byte("watermelon"), + []byte("kiwi"), + }) + tc.malleateProof(proofs[0]) + err := proofs[0].ValidateBasic() + if tc.errStr != "" { + assert.Contains(t, err.Error(), tc.errStr) + } + }) + } +} diff --git a/types/part_set.go b/types/part_set.go index 389db7a0b..ecac027f9 100644 --- a/types/part_set.go +++ b/types/part_set.go @@ -26,10 +26,13 @@ type Part struct { // ValidateBasic performs basic validation. func (part *Part) ValidateBasic() error { if part.Index < 0 { - return errors.New("Negative Index") + return errors.New("negative Index") } if len(part.Bytes) > BlockPartSizeBytes { - return fmt.Errorf("Too big (max: %d)", BlockPartSizeBytes) + return errors.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes) + } + if err := part.Proof.ValidateBasic(); err != nil { + return errors.Wrap(err, "wrong Proof") } return nil } diff --git a/types/part_set_test.go b/types/part_set_test.go index daa2fa5c5..5c0edaffd 100644 --- a/types/part_set_test.go +++ b/types/part_set_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/crypto/merkle" cmn "github.com/tendermint/tendermint/libs/common" ) @@ -114,6 +115,13 @@ func TestPartValidateBasic(t *testing.T) { {"Good Part", func(pt *Part) {}, false}, {"Negative index", func(pt *Part) { pt.Index = -1 }, true}, {"Too big part", func(pt *Part) { pt.Bytes = make([]byte, BlockPartSizeBytes+1) }, true}, + {"Too big proof", func(pt *Part) { + pt.Proof = merkle.SimpleProof{ + Total: 1, + Index: 1, + LeafHash: make([]byte, 1024*1024), + } + }, true}, } for _, tc := range testCases {