diff --git a/internal/consensus/peer_state.go b/internal/consensus/peer_state.go index 8bf1280f4..f081eada1 100644 --- a/internal/consensus/peer_state.go +++ b/internal/consensus/peer_state.go @@ -19,6 +19,8 @@ import ( var ( ErrPeerStateHeightRegression = errors.New("peer state height regression") ErrPeerStateInvalidStartTime = errors.New("peer state invalid startTime") + ErrPeerStateSetNilVote = errors.New("peer state set a nil vote") + ErrPeerStateInvalidVoteIndex = errors.New("peer sent a vote with an invalid vote index") ) // peerStateStats holds internal statistics for a peer. @@ -356,17 +358,19 @@ func (ps *PeerState) BlockPartsSent() int { } // SetHasVote sets the given vote as known by the peer -func (ps *PeerState) SetHasVote(vote *types.Vote) { +func (ps *PeerState) SetHasVote(vote *types.Vote) error { + // sanity check if vote == nil { - return + return ErrPeerStateSetNilVote } ps.mtx.Lock() defer ps.mtx.Unlock() - ps.setHasVote(vote.Height, vote.Round, vote.Type, vote.ValidatorIndex) + return ps.setHasVote(vote.Height, vote.Round, vote.Type, vote.ValidatorIndex) } -func (ps *PeerState) setHasVote(height int64, round int32, voteType tmproto.SignedMsgType, index int32) { +// setHasVote will return an error when the index exceeds the bitArray length +func (ps *PeerState) setHasVote(height int64, round int32, voteType tmproto.SignedMsgType, index int32) error { logger := ps.logger.With( "peerH/R", fmt.Sprintf("%d/%d", ps.PRS.Height, ps.PRS.Round), "H/R", fmt.Sprintf("%d/%d", height, round), @@ -377,8 +381,12 @@ func (ps *PeerState) setHasVote(height int64, round int32, voteType tmproto.Sign // NOTE: some may be nil BitArrays -> no side effects psVotes := ps.getVoteBitArray(height, round, voteType) if psVotes != nil { - psVotes.SetIndex(int(index), true) + if ok := psVotes.SetIndex(int(index), true); !ok { + // https://github.com/tendermint/tendermint/issues/2871 + return ErrPeerStateInvalidVoteIndex + } } + return nil } // ApplyNewRoundStepMessage updates the peer state for the new round. @@ -475,15 +483,15 @@ func (ps *PeerState) ApplyProposalPOLMessage(msg *ProposalPOLMessage) { } // ApplyHasVoteMessage updates the peer state for the new vote. -func (ps *PeerState) ApplyHasVoteMessage(msg *HasVoteMessage) { +func (ps *PeerState) ApplyHasVoteMessage(msg *HasVoteMessage) error { ps.mtx.Lock() defer ps.mtx.Unlock() if ps.PRS.Height != msg.Height { - return + return nil } - ps.setHasVote(msg.Height, msg.Round, msg.Type, msg.Index) + return ps.setHasVote(msg.Height, msg.Round, msg.Type, msg.Index) } // ApplyVoteSetBitsMessage updates the peer state for the bit-array of votes diff --git a/internal/consensus/peer_state_test.go b/internal/consensus/peer_state_test.go new file mode 100644 index 000000000..06f49508a --- /dev/null +++ b/internal/consensus/peer_state_test.go @@ -0,0 +1,100 @@ +package consensus + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/libs/log" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + "github.com/tendermint/tendermint/types" +) + +func peerStateSetup(h, r, v int) *PeerState { + ps := NewPeerState(log.TestingLogger(), "testPeerState") + ps.PRS.Height = int64(h) + ps.PRS.Round = int32(r) + ps.ensureVoteBitArrays(int64(h), v) + return ps +} + +func TestSetHasVote(t *testing.T) { + ps := peerStateSetup(1, 1, 1) + pva := ps.PRS.Prevotes.Copy() + + // nil vote should return ErrPeerStateNilVote + err := ps.SetHasVote(nil) + require.Equal(t, ErrPeerStateSetNilVote, err) + + // the peer giving an invalid index should returns ErrPeerStateInvalidVoteIndex + v0 := &types.Vote{ + Height: 1, + ValidatorIndex: -1, + Round: 1, + Type: tmproto.PrevoteType, + } + + err = ps.SetHasVote(v0) + require.Equal(t, ErrPeerStateInvalidVoteIndex, err) + + // the peer giving an invalid index should returns ErrPeerStateInvalidVoteIndex + v1 := &types.Vote{ + Height: 1, + ValidatorIndex: 1, + Round: 1, + Type: tmproto.PrevoteType, + } + + err = ps.SetHasVote(v1) + require.Equal(t, ErrPeerStateInvalidVoteIndex, err) + + // the peer giving a correct index should return nil (vote has been set) + v2 := &types.Vote{ + Height: 1, + ValidatorIndex: 0, + Round: 1, + Type: tmproto.PrevoteType, + } + require.Nil(t, ps.SetHasVote(v2)) + + // verify vote + pva.SetIndex(0, true) + require.Equal(t, pva, ps.getVoteBitArray(1, 1, tmproto.PrevoteType)) + + // the vote is not in the correct height/round/voteType should return nil (ignore the vote) + v3 := &types.Vote{ + Height: 2, + ValidatorIndex: 0, + Round: 1, + Type: tmproto.PrevoteType, + } + require.Nil(t, ps.SetHasVote(v3)) + // prevote bitarray has no update + require.Equal(t, pva, ps.getVoteBitArray(1, 1, tmproto.PrevoteType)) +} + +func TestApplyHasVoteMessage(t *testing.T) { + ps := peerStateSetup(1, 1, 1) + pva := ps.PRS.Prevotes.Copy() + + // ignore the message with an invalid height + msg := &HasVoteMessage{ + Height: 2, + } + require.Nil(t, ps.ApplyHasVoteMessage(msg)) + + // apply a message like v2 in TestSetHasVote + msg2 := &HasVoteMessage{ + Height: 1, + Index: 0, + Round: 1, + Type: tmproto.PrevoteType, + } + + require.Nil(t, ps.ApplyHasVoteMessage(msg2)) + + // verify vote + pva.SetIndex(0, true) + require.Equal(t, pva, ps.getVoteBitArray(1, 1, tmproto.PrevoteType)) + + // skip test cases like v & v3 in TestSetHasVote due to the same path +} diff --git a/internal/consensus/reactor.go b/internal/consensus/reactor.go index 1673c0e21..eb038d9f5 100644 --- a/internal/consensus/reactor.go +++ b/internal/consensus/reactor.go @@ -656,7 +656,10 @@ func (r *Reactor) pickSendVote(ctx context.Context, ps *PeerState, votes types.V return false, err } - ps.SetHasVote(vote) + if err := ps.SetHasVote(vote); err != nil { + return false, err + } + return true, nil } @@ -1060,8 +1063,10 @@ func (r *Reactor) handleStateMessage(ctx context.Context, envelope *p2p.Envelope ps.ApplyNewValidBlockMessage(msgI.(*NewValidBlockMessage)) case *tmcons.HasVote: - ps.ApplyHasVoteMessage(msgI.(*HasVoteMessage)) - + if err := ps.ApplyHasVoteMessage(msgI.(*HasVoteMessage)); err != nil { + r.logger.Error("applying HasVote message", "msg", msg, "err", err) + return err + } case *tmcons.VoteSetMaj23: r.state.mtx.RLock() height, votes := r.state.Height, r.state.Votes @@ -1195,7 +1200,9 @@ func (r *Reactor) handleVoteMessage(ctx context.Context, envelope *p2p.Envelope, ps.EnsureVoteBitArrays(height, valSize) ps.EnsureVoteBitArrays(height-1, lastCommitSize) - ps.SetHasVote(vMsg.Vote) + if err := ps.SetHasVote(vMsg.Vote); err != nil { + return err + } select { case r.state.peerMsgQueue <- msgInfo{vMsg, envelope.From, tmtime.Now()}: diff --git a/libs/bits/bit_array.go b/libs/bits/bit_array.go index d0f79ce14..ff824af21 100644 --- a/libs/bits/bit_array.go +++ b/libs/bits/bit_array.go @@ -72,7 +72,7 @@ func (bA *BitArray) getIndex(i int) bool { } // SetIndex sets the bit at index i within the bit array. -// The behavior is undefined if i >= bA.Bits +// This method returns false if i is out of range of the BitArray. func (bA *BitArray) SetIndex(i int, v bool) bool { if bA == nil { return false @@ -83,7 +83,7 @@ func (bA *BitArray) SetIndex(i int, v bool) bool { } func (bA *BitArray) setIndex(i int, v bool) bool { - if i >= bA.Bits { + if i < 0 || i >= bA.Bits { return false } if v { diff --git a/libs/bits/bit_array_test.go b/libs/bits/bit_array_test.go index a12cc80a2..b76085bee 100644 --- a/libs/bits/bit_array_test.go +++ b/libs/bits/bit_array_test.go @@ -170,6 +170,8 @@ func TestBytes(t *testing.T) { check(bA, []byte{0x80, 0x01}) bA.SetIndex(9, true) check(bA, []byte{0x80, 0x03}) + + require.False(t, bA.SetIndex(-1, true)) } func TestEmptyFull(t *testing.T) {