From 7fb735fe8721e5e41413f3b249edb6046de5071c Mon Sep 17 00:00:00 2001 From: tycho garen Date: Wed, 9 Mar 2022 16:22:12 -0500 Subject: [PATCH] retry --- internal/consensus/byzantine_test.go | 6 +- internal/consensus/common_test.go | 4 +- internal/consensus/invalid_test.go | 19 +----- internal/consensus/pbts_test.go | 5 +- internal/consensus/reactor_test.go | 10 ++- internal/consensus/state.go | 97 ++++++++++++++++------------ internal/mempool/mempool.go | 3 + 7 files changed, 79 insertions(+), 65 deletions(-) diff --git a/internal/consensus/byzantine_test.go b/internal/consensus/byzantine_test.go index dfeb556fe..56e5d0e9b 100644 --- a/internal/consensus/byzantine_test.go +++ b/internal/consensus/byzantine_test.go @@ -3,7 +3,6 @@ package consensus import ( "context" "fmt" - "os" "path" "sync" "testing" @@ -63,8 +62,6 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { thisConfig, err := ResetConfig(t.TempDir(), fmt.Sprintf("%s_%d", testName, i)) require.NoError(t, err) - defer os.RemoveAll(thisConfig.RootDir) - ensureDir(t, path.Dir(thisConfig.Consensus.WalFile()), 0700) // dir for wal app := kvstore.NewApplication() vals := types.TM2PB.ValidatorUpdates(state.Validators) @@ -103,6 +100,7 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { cs.SetPrivValidator(ctx, pv) cs.SetTimeoutTicker(tickerFunc()) + cs.Start(ctx) states[i] = cs }() @@ -233,7 +231,7 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { } for _, reactor := range rts.reactors { - reactor.SwitchToConsensus(ctx, reactor.state.GetState(), false) + reactor.SwitchToConsensus(ctx, reactor.state.state, false) } // Evidence should be submitted and committed at the third height but diff --git a/internal/consensus/common_test.go b/internal/consensus/common_test.go index 161594021..565063a6b 100644 --- a/internal/consensus/common_test.go +++ b/internal/consensus/common_test.go @@ -504,8 +504,8 @@ func newStateWithConfigAndBlockStore( if err != nil { t.Fatal(err) } - cs.SetPrivValidator(ctx, pv) + cs.Start(ctx) return cs } @@ -826,6 +826,7 @@ func makeConsensusState( l := logger.With("validator", i, "module", "consensus") css[i] = newStateWithConfigAndBlockStore(ctx, t, l, thisConfig, state, privVals[i], app, blockStore) css[i].SetTimeoutTicker(tickerFunc()) + css[i].Start(ctx) } return css, func() { @@ -897,6 +898,7 @@ func randConsensusNetWithPeers( css[i] = newStateWithConfig(ctx, t, logger.With("validator", i, "module", "consensus"), thisConfig, state, privVal, app) css[i].SetTimeoutTicker(tickerFunc()) + css[i].Start(ctx) } return css, genDoc, peer0Config, func() { for _, dir := range configRootDirs { diff --git a/internal/consensus/invalid_test.go b/internal/consensus/invalid_test.go index 033b096ba..bb89c3492 100644 --- a/internal/consensus/invalid_test.go +++ b/internal/consensus/invalid_test.go @@ -32,16 +32,10 @@ func TestReactorInvalidPrecommit(t *testing.T) { newMockTickerFunc(true)) t.Cleanup(cleanup) - for i := 0; i < 4; i++ { - ticker := NewTimeoutTicker(states[i].logger) - states[i].SetTimeoutTicker(ticker) - } - rts := setup(ctx, t, n, states, 100) // buffer must be large enough to not deadlock for _, reactor := range rts.reactors { - state := reactor.state.GetState() - reactor.SwitchToConsensus(ctx, state, false) + reactor.SwitchToConsensus(ctx, reactor.state.state, false) } // this val sends a random precommit at each height @@ -53,13 +47,11 @@ func TestReactorInvalidPrecommit(t *testing.T) { signal := make(chan struct{}) // Update the doPrevote function to just send a valid precommit for a random // block and otherwise disable the priv validator. - byzState.mtx.Lock() privVal := byzState.privValidator byzState.doPrevote = func(ctx context.Context, height int64, round int32) { defer close(signal) invalidDoPrevoteFunc(ctx, t, height, round, byzState, byzReactor, privVal) } - byzState.mtx.Unlock() // wait for a bunch of blocks // @@ -87,14 +79,9 @@ func TestReactorInvalidPrecommit(t *testing.T) { select { case <-wait: - if _, ok := <-signal; !ok { - t.Fatal("test condition did not fire") - } + t.Fatal("test condition did not fire") case <-ctx.Done(): - if _, ok := <-signal; !ok { - t.Fatal("test condition did not fire after timeout") - return - } + t.Fatal("test condition did not fire after timeout") case <-signal: // test passed } diff --git a/internal/consensus/pbts_test.go b/internal/consensus/pbts_test.go index c5a5ac535..ba55f7fd9 100644 --- a/internal/consensus/pbts_test.go +++ b/internal/consensus/pbts_test.go @@ -115,6 +115,9 @@ func newPBTSTestHarness(ctx context.Context, t *testing.T, tc pbtsTestConfigurat Validators: validators, }) cs := newState(ctx, t, log.NewNopLogger(), state, privVals[0], kvstore.NewApplication()) + if err := cs.Start(ctx); err != nil { + t.Fatal(err) + } vss := make([]*validatorStub, validators) for i := 0; i < validators; i++ { vss[i] = newValidatorStub(privVals[i], int32(i)) @@ -153,7 +156,7 @@ func (p *pbtsTestHarness) observedValidatorProposerHeight(ctx context.Context, t timeout := time.Until(previousBlockTime.Add(ensureTimeout)) ensureProposalWithTimeout(t, p.ensureProposalCh, p.currentHeight, p.currentRound, nil, timeout) - rs := p.observedState.GetRoundState() + rs := p.observedState.RoundState bid := types.BlockID{Hash: rs.ProposalBlock.Hash(), PartSetHeader: rs.ProposalBlockParts.Header()} ensurePrevote(t, p.ensureVoteCh, p.currentHeight, p.currentRound) signAddVotes(ctx, t, p.observedState, tmproto.PrevoteType, p.chainID, bid, p.otherValidators...) diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index b8d638976..2e74342d8 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -364,8 +364,7 @@ func TestReactorBasic(t *testing.T) { rts := setup(ctx, t, n, states, 100) // buffer must be large enough to not deadlock for _, reactor := range rts.reactors { - state := reactor.state.GetState() - reactor.SwitchToConsensus(ctx, state, false) + reactor.SwitchToConsensus(ctx, reactor.state.state, false) } var wg sync.WaitGroup @@ -384,7 +383,11 @@ func TestReactorBasic(t *testing.T) { case errors.Is(err, context.Canceled): return case err != nil: - errCh <- err + select { + case errCh <- err: + case <-ctx.Done(): + return + } cancel() // terminate other workers return } @@ -512,6 +515,7 @@ func TestReactorWithEvidence(t *testing.T) { cs.SetPrivValidator(ctx, pv) cs.SetTimeoutTicker(tickerFunc()) + cs.Start(ctx) states[i] = cs } diff --git a/internal/consensus/state.go b/internal/consensus/state.go index bd79f4f83..8a8da2972 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -223,10 +223,6 @@ func NewState( cs.doPrevote = cs.defaultDoPrevote cs.setProposal = cs.defaultSetProposal - if err := cs.updateStateFromStore(ctx); err != nil { - return nil, err - } - // NOTE: we do not call scheduleRound0 yet, we do that upon Start() cs.BaseService = *service.NewBaseService(logger, "State", cs) for _, option := range options { @@ -351,12 +347,19 @@ func (cs *State) SetPrivValidator(ctx context.Context, priv types.PrivValidator) } } +func (cs *State) getTimeoutTicker() <-chan timeoutInfo { + cs.mtx.Lock() + defer cs.mtx.Unlock() + + return cs.timeoutTicker.Chan() +} + // SetTimeoutTicker sets the local timer. It may be useful to overwrite for // testing. func (cs *State) SetTimeoutTicker(timeoutTicker TimeoutTicker) { cs.mtx.Lock() + defer cs.mtx.Unlock() cs.timeoutTicker = timeoutTicker - cs.mtx.Unlock() } // LoadCommit loads the commit for a given height. @@ -472,6 +475,7 @@ func (cs *State) OnStart(ctx context.Context) error { // // this is only used in tests. func (cs *State) startRoutines(ctx context.Context, maxSteps int) { + return err := cs.timeoutTicker.Start(ctx) if err != nil { cs.logger.Error("failed to start timeout ticker", "err", err) @@ -645,6 +649,9 @@ func (cs *State) updateRoundStep(round int32, step cstypes.RoundStepType) { cs.metrics.MarkStep(cs.Step) } } + cs.mtx.Lock() + defer cs.mtx.Unlock() + cs.Round = round cs.Step = step } @@ -912,7 +919,7 @@ func (cs *State) receiveRoutine(ctx context.Context, maxSteps int) { // handles proposals, block parts, votes cs.handleMsg(ctx, mi) - case ti := <-cs.timeoutTicker.Chan(): // tockChan: + case ti := <-cs.getTimeoutTicker(): // tockChan: if err := cs.wal.Write(ti); err != nil { cs.logger.Error("failed writing to WAL", "err", err) } @@ -977,12 +984,6 @@ func (cs *State) handleMsg(ctx context.Context, mi msgInfo) { } if err != nil && msg.Round != cs.Round { - cs.logger.Debug( - "received block part from wrong round", - "height", cs.Height, - "cs_round", cs.Round, - "block_round", msg.Round, - ) err = nil } @@ -1349,6 +1350,9 @@ func (cs *State) defaultDecideProposal(ctx context.Context, height int64, round // Returns true if the proposal block is complete && // (if POLRound was proposed, we have +2/3 prevotes from there). func (cs *State) isProposalComplete() bool { + cs.mtx.RLock() + defer cs.mtx.RUnlock() + if cs.Proposal == nil || cs.ProposalBlock == nil { return false } @@ -2058,40 +2062,51 @@ func (cs *State) RecordMetrics(height int64, block *types.Block) { func (cs *State) defaultSetProposal(proposal *types.Proposal, recvTime time.Time) error { // Already have one // TODO: possibly catch double proposals - if cs.Proposal != nil || proposal == nil { - return nil - } - // Does not apply - if proposal.Height != cs.Height || proposal.Round != cs.Round { - return nil - } + var p *tmproto.Proposal + if err := func() error { + cs.mtx.RLock() + defer cs.mtx.RUnlock() + if cs.Proposal != nil || proposal == nil { + return nil + } - // Verify POLRound, which must be -1 or in range [0, proposal.Round). - if proposal.POLRound < -1 || - (proposal.POLRound >= 0 && proposal.POLRound >= proposal.Round) { - return ErrInvalidProposalPOLRound - } + // Does not apply + if proposal.Height != cs.Height || proposal.Round != cs.Round { + return nil + } - p := proposal.ToProto() - // Verify signature - if !cs.Validators.GetProposer().PubKey.VerifySignature( - types.ProposalSignBytes(cs.state.ChainID, p), proposal.Signature, - ) { - return ErrInvalidProposalSignature + // Verify POLRound, which must be -1 or in range [0, proposal.Round). + if proposal.POLRound < -1 || (proposal.POLRound >= 0 && proposal.POLRound >= proposal.Round) { + return ErrInvalidProposalPOLRound + } + + p = proposal.ToProto() + // Verify signature + if !cs.Validators.GetProposer().PubKey.VerifySignature(types.ProposalSignBytes(cs.state.ChainID, p), proposal.Signature) { + return ErrInvalidProposalSignature + } + return nil + }(); err != nil { + return err } proposal.Signature = p.Signature - cs.Proposal = proposal - cs.ProposalReceiveTime = recvTime - cs.calculateProposalTimestampDifferenceMetric() - // We don't update cs.ProposalBlockParts if it is already set. - // This happens if we're already in cstypes.RoundStepCommit or if there is a valid block in the current round. - // TODO: We can check if Proposal is for a different block as this is a sign of misbehavior! - if cs.ProposalBlockParts == nil { - cs.metrics.MarkBlockGossipStarted() - cs.ProposalBlockParts = types.NewPartSetFromHeader(proposal.BlockID.PartSetHeader) - } + func() { + cs.mtx.Lock() + defer cs.mtx.Unlock() + + cs.Proposal = proposal + cs.ProposalReceiveTime = recvTime + cs.calculateProposalTimestampDifferenceMetric() + // We don't update cs.ProposalBlockParts if it is already set. + // This happens if we're already in cstypes.RoundStepCommit or if there is a valid block in the current round. + // TODO: We can check if Proposal is for a different block as this is a sign of misbehavior! + if cs.ProposalBlockParts == nil { + cs.metrics.MarkBlockGossipStarted() + cs.ProposalBlockParts = types.NewPartSetFromHeader(proposal.BlockID.PartSetHeader) + } + }() cs.logger.Info("received proposal", "proposal", proposal) return nil @@ -2106,6 +2121,8 @@ func (cs *State) addProposalBlockPart( peerID types.NodeID, ) (added bool, err error) { height, round, part := msg.Height, msg.Round, msg.Part + cs.mtx.Lock() + defer cs.mtx.Unlock() // Blocks might be reused, so round mismatch is OK if cs.Height != height { diff --git a/internal/mempool/mempool.go b/internal/mempool/mempool.go index 6fcfe86c1..0278b191d 100644 --- a/internal/mempool/mempool.go +++ b/internal/mempool/mempool.go @@ -200,6 +200,9 @@ func (txmp *TxMempool) EnableTxsAvailable() { // TxsAvailable returns a channel which fires once for every height, and only // when transactions are available in the mempool. It is thread-safe. func (txmp *TxMempool) TxsAvailable() <-chan struct{} { + txmp.mtx.Lock() + defer txmp.mtx.Unlock() + return txmp.txsAvailable }