From 0124593a61c12e8614e04d7eaaf724c51ed724b1 Mon Sep 17 00:00:00 2001 From: Sam Kleinman Date: Fri, 26 Mar 2021 15:15:45 -0400 Subject: [PATCH] fix: avoid race with a deeper copy (#6285) --- consensus/peer_state.go | 2 +- consensus/types/peer_round_state.go | 25 ++++++++++++++++++++ consensus/types/peer_round_state_test.go | 29 ++++++++++++++++++++++++ libs/bits/bit_array.go | 18 ++++++++++----- libs/bits/bit_array_test.go | 2 +- 5 files changed, 68 insertions(+), 8 deletions(-) create mode 100644 consensus/types/peer_round_state_test.go diff --git a/consensus/peer_state.go b/consensus/peer_state.go index 31406a025..1fb16d1be 100644 --- a/consensus/peer_state.go +++ b/consensus/peer_state.go @@ -89,7 +89,7 @@ func (ps *PeerState) GetRoundState() *cstypes.PeerRoundState { ps.mtx.Lock() defer ps.mtx.Unlock() - prs := ps.PRS // copy + prs := ps.PRS.Copy() return &prs } diff --git a/consensus/types/peer_round_state.go b/consensus/types/peer_round_state.go index 07283c5b4..9d294d9af 100644 --- a/consensus/types/peer_round_state.go +++ b/consensus/types/peer_round_state.go @@ -46,6 +46,31 @@ func (prs PeerRoundState) String() string { return prs.StringIndented("") } +// Copy provides a deep copy operation. Because many of the fields in +// the PeerRound struct are pointers, we need an explicit deep copy +// operation to avoid a non-obvious shared data situation. +func (prs PeerRoundState) Copy() PeerRoundState { + // this works because it's not a pointer receiver so it's + // already, effectively a copy. + + headerHash := prs.ProposalBlockPartSetHeader.Hash.Bytes() + + hashCopy := make([]byte, len(headerHash)) + copy(hashCopy, headerHash) + prs.ProposalBlockPartSetHeader = types.PartSetHeader{ + Total: prs.ProposalBlockPartSetHeader.Total, + Hash: hashCopy, + } + prs.ProposalBlockParts = prs.ProposalBlockParts.Copy() + prs.ProposalPOL = prs.ProposalPOL.Copy() + prs.Prevotes = prs.Prevotes.Copy() + prs.Precommits = prs.Precommits.Copy() + prs.LastCommit = prs.LastCommit.Copy() + prs.CatchupCommit = prs.CatchupCommit.Copy() + + return prs +} + // StringIndented returns a string representation of the PeerRoundState func (prs PeerRoundState) StringIndented(indent string) string { return fmt.Sprintf(`PeerRoundState{ diff --git a/consensus/types/peer_round_state_test.go b/consensus/types/peer_round_state_test.go new file mode 100644 index 000000000..393fd2056 --- /dev/null +++ b/consensus/types/peer_round_state_test.go @@ -0,0 +1,29 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/libs/bits" +) + +func TestCopy(t *testing.T) { + t.Run("VerifyShallowCopy", func(t *testing.T) { + prsOne := PeerRoundState{} + prsOne.Prevotes = bits.NewBitArray(12) + prsTwo := prsOne + + prsOne.Prevotes.SetIndex(1, true) + + require.Equal(t, prsOne.Prevotes, prsTwo.Prevotes) + }) + t.Run("DeepCopy", func(t *testing.T) { + prsOne := PeerRoundState{} + prsOne.Prevotes = bits.NewBitArray(12) + prsTwo := prsOne.Copy() + + prsOne.Prevotes.SetIndex(1, true) + + require.NotEqual(t, prsOne.Prevotes, prsTwo.Prevotes) + }) +} diff --git a/libs/bits/bit_array.go b/libs/bits/bit_array.go index 1a41d87f9..3ebad38ce 100644 --- a/libs/bits/bit_array.go +++ b/libs/bits/bit_array.go @@ -422,21 +422,21 @@ func (bA *BitArray) UnmarshalJSON(bz []byte) error { // ToProto converts BitArray to protobuf. It returns nil if BitArray is // nil/empty. -// -// XXX: It does not copy the array. func (bA *BitArray) ToProto() *tmprotobits.BitArray { if bA == nil || (len(bA.Elems) == 0 && bA.Bits == 0) { // empty return nil } - return &tmprotobits.BitArray{Bits: int64(bA.Bits), Elems: bA.Elems} + bA.mtx.Lock() + defer bA.mtx.Unlock() + + bc := bA.copy() + return &tmprotobits.BitArray{Bits: int64(bc.Bits), Elems: bc.Elems} } // FromProto sets BitArray to the given protoBitArray. It returns an error if // protoBitArray is invalid. -// -// XXX: It does not copy the array. func (bA *BitArray) FromProto(protoBitArray *tmprotobits.BitArray) error { if protoBitArray == nil { return nil @@ -454,8 +454,14 @@ func (bA *BitArray) FromProto(protoBitArray *tmprotobits.BitArray) error { return fmt.Errorf("invalid number of Elems: got %d, but exp %d", got, exp) } + bA.mtx.Lock() + defer bA.mtx.Unlock() + + ec := make([]uint64, len(protoBitArray.Elems)) + copy(ec, protoBitArray.Elems) + bA.Bits = int(protoBitArray.Bits) - bA.Elems = protoBitArray.Elems + bA.Elems = ec return nil } diff --git a/libs/bits/bit_array_test.go b/libs/bits/bit_array_test.go index 10d607ef2..96f2e2257 100644 --- a/libs/bits/bit_array_test.go +++ b/libs/bits/bit_array_test.go @@ -299,7 +299,7 @@ func TestBitArrayFromProto(t *testing.T) { expErr bool }{ 0: {nil, &BitArray{}, false}, - 1: {&tmprotobits.BitArray{}, &BitArray{}, false}, + 1: {&tmprotobits.BitArray{}, &BitArray{Elems: []uint64{}}, false}, 2: {&tmprotobits.BitArray{Bits: 1, Elems: make([]uint64, 1)}, &BitArray{Bits: 1, Elems: make([]uint64, 1)}, false},