From a823d167bc615b3a482420efde98c239d14985d1 Mon Sep 17 00:00:00 2001 From: Sam Kleinman Date: Wed, 1 Dec 2021 09:28:06 -0500 Subject: [PATCH] service: cleanup base implementation and some caller implementations (#7301) --- abci/client/local_client.go | 5 +- abci/client/socket_client.go | 2 - internal/blocksync/pool.go | 37 +++++++++++---- internal/blocksync/reactor.go | 10 ++-- internal/consensus/common_test.go | 2 + internal/consensus/reactor.go | 6 ++- internal/consensus/replay.go | 2 +- internal/consensus/state.go | 41 +++++++++++------ internal/consensus/ticker.go | 14 +++--- internal/consensus/wal.go | 13 ++++-- internal/eventbus/event_bus.go | 9 +--- internal/eventbus/event_bus_test.go | 6 +-- internal/libs/autofile/cmd/logjack.go | 6 +-- internal/libs/autofile/group.go | 9 ++-- internal/mempool/mempool_test.go | 24 +++++----- internal/mempool/reactor_test.go | 10 ++-- internal/p2p/conn/connection.go | 8 ++-- internal/p2p/conn/connection_test.go | 58 ++++++++++++++---------- libs/service/service.go | 43 +++++++++--------- libs/service/service_test.go | 5 +- light/rpc/client.go | 17 +++++-- node/node_test.go | 41 ++++++----------- privval/signer_client_test.go | 6 +-- privval/signer_dialer_endpoint.go | 4 ++ privval/signer_listener_endpoint.go | 14 +++--- privval/signer_listener_endpoint_test.go | 6 +-- privval/signer_server.go | 2 - rpc/client/local/local.go | 2 +- 28 files changed, 220 insertions(+), 182 deletions(-) diff --git a/abci/client/local_client.go b/abci/client/local_client.go index 701108a3c..f534a1716 100644 --- a/abci/client/local_client.go +++ b/abci/client/local_client.go @@ -38,10 +38,13 @@ func NewLocalClient(mtx *tmsync.Mutex, app types.Application) Client { return cli } +func (*localClient) OnStart(context.Context) error { return nil } +func (*localClient) OnStop() {} + func (app *localClient) SetResponseCallback(cb Callback) { app.mtx.Lock() + defer app.mtx.Unlock() app.Callback = cb - app.mtx.Unlock() } // TODO: change types.Application to include Error()? diff --git a/abci/client/socket_client.go b/abci/client/socket_client.go index 8dfee0c8d..9173475c6 100644 --- a/abci/client/socket_client.go +++ b/abci/client/socket_client.go @@ -126,8 +126,6 @@ func (cli *socketClient) sendRequestsRoutine(ctx context.Context, conn io.Writer select { case <-ctx.Done(): return - case <-cli.Quit(): - return case reqres := <-cli.reqQueue: if ctx.Err() != nil { return diff --git a/internal/blocksync/pool.go b/internal/blocksync/pool.go index 6f06c9883..a83a119a6 100644 --- a/internal/blocksync/pool.go +++ b/internal/blocksync/pool.go @@ -84,6 +84,7 @@ type BlockPool struct { requestsCh chan<- BlockRequest errorsCh chan<- peerError + exitedCh chan struct{} startHeight int64 lastHundredBlockTimeStamp time.Time @@ -102,11 +103,11 @@ func NewBlockPool( bp := &BlockPool{ peers: make(map[types.NodeID]*bpPeer), - requesters: make(map[int64]*bpRequester), - height: start, - startHeight: start, - numPending: 0, - + requesters: make(map[int64]*bpRequester), + height: start, + startHeight: start, + numPending: 0, + exitedCh: make(chan struct{}), requestsCh: requestsCh, errorsCh: errorsCh, lastSyncRate: 0, @@ -121,9 +122,17 @@ func (pool *BlockPool) OnStart(ctx context.Context) error { pool.lastAdvance = time.Now() pool.lastHundredBlockTimeStamp = pool.lastAdvance go pool.makeRequestersRoutine(ctx) + + go func() { + defer close(pool.exitedCh) + pool.Wait() + }() + return nil } +func (*BlockPool) OnStop() {} + // spawns requesters as needed func (pool *BlockPool) makeRequestersRoutine(ctx context.Context) { for { @@ -572,10 +581,12 @@ func newBPRequester(pool *BlockPool, height int64) *bpRequester { } func (bpr *bpRequester) OnStart(ctx context.Context) error { - go bpr.requestRoutine() + go bpr.requestRoutine(ctx) return nil } +func (*bpRequester) OnStop() {} + // Returns true if the peer matches and block doesn't already exist. func (bpr *bpRequester) setBlock(block *types.Block, peerID types.NodeID) bool { bpr.mtx.Lock() @@ -630,7 +641,13 @@ func (bpr *bpRequester) redo(peerID types.NodeID) { // Responsible for making more requests as necessary // Returns only when a block is found (e.g. AddBlock() is called) -func (bpr *bpRequester) requestRoutine() { +func (bpr *bpRequester) requestRoutine(ctx context.Context) { + bprPoolDone := make(chan struct{}) + go func() { + defer close(bprPoolDone) + bpr.pool.Wait() + }() + OUTER_LOOP: for { // Pick a peer to send request to. @@ -656,13 +673,13 @@ OUTER_LOOP: WAIT_LOOP: for { select { - case <-bpr.pool.Quit(): + case <-ctx.Done(): + return + case <-bpr.pool.exitedCh: if err := bpr.Stop(); err != nil { bpr.Logger.Error("Error stopped requester", "err", err) } return - case <-bpr.Quit(): - return case peerID := <-bpr.redoCh: if peerID == bpr.peerID { bpr.reset() diff --git a/internal/blocksync/reactor.go b/internal/blocksync/reactor.go index ac5d45fb7..5fe8b2123 100644 --- a/internal/blocksync/reactor.go +++ b/internal/blocksync/reactor.go @@ -158,7 +158,7 @@ func (r *Reactor) OnStart(ctx context.Context) error { return err } r.poolWG.Add(1) - go r.requestRoutine() + go r.requestRoutine(ctx) r.poolWG.Add(1) go r.poolRoutine(false) @@ -375,7 +375,7 @@ func (r *Reactor) SwitchToBlockSync(ctx context.Context, state sm.State) error { r.syncStartTime = time.Now() r.poolWG.Add(1) - go r.requestRoutine() + go r.requestRoutine(ctx) r.poolWG.Add(1) go r.poolRoutine(true) @@ -383,7 +383,7 @@ func (r *Reactor) SwitchToBlockSync(ctx context.Context, state sm.State) error { return nil } -func (r *Reactor) requestRoutine() { +func (r *Reactor) requestRoutine(ctx context.Context) { statusUpdateTicker := time.NewTicker(statusUpdateIntervalSeconds * time.Second) defer statusUpdateTicker.Stop() @@ -394,7 +394,7 @@ func (r *Reactor) requestRoutine() { case <-r.closeCh: return - case <-r.pool.Quit(): + case <-ctx.Done(): return case request := <-r.requestsCh: @@ -607,7 +607,7 @@ FOR_LOOP: case <-r.closeCh: break FOR_LOOP - case <-r.pool.Quit(): + case <-r.pool.exitedCh: break FOR_LOOP } } diff --git a/internal/consensus/common_test.go b/internal/consensus/common_test.go index 30729c038..27f9628d1 100644 --- a/internal/consensus/common_test.go +++ b/internal/consensus/common_test.go @@ -915,6 +915,8 @@ func (m *mockTicker) Stop() error { return nil } +func (m *mockTicker) IsRunning() bool { return false } + func (m *mockTicker) ScheduleTimeout(ti timeoutInfo) { m.mtx.Lock() defer m.mtx.Unlock() diff --git a/internal/consensus/reactor.go b/internal/consensus/reactor.go index 7e46444b9..4e2b1ad3a 100644 --- a/internal/consensus/reactor.go +++ b/internal/consensus/reactor.go @@ -2,6 +2,7 @@ package consensus import ( "context" + "errors" "fmt" "runtime/debug" "time" @@ -205,11 +206,12 @@ func (r *Reactor) OnStart(ctx context.Context) error { // blocking until they all exit, as well as unsubscribing from events and stopping // state. func (r *Reactor) OnStop() { - r.unsubscribeFromBroadcastEvents() if err := r.state.Stop(); err != nil { - r.Logger.Error("failed to stop consensus state", "err", err) + if !errors.Is(err, service.ErrAlreadyStopped) { + r.Logger.Error("failed to stop consensus state", "err", err) + } } if !r.WaitSync() { diff --git a/internal/consensus/replay.go b/internal/consensus/replay.go index 2408b03f1..b797eabae 100644 --- a/internal/consensus/replay.go +++ b/internal/consensus/replay.go @@ -50,7 +50,7 @@ func (cs *State) readReplayMessage(ctx context.Context, msg *TimedWALMessage, ne cs.Logger.Info("Replay: New Step", "height", m.Height, "round", m.Round, "step", m.Step) // these are playback checks if newStepSub != nil { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() stepMsg, err := newStepSub.Next(ctx) if errors.Is(err, context.DeadlineExceeded) { diff --git a/internal/consensus/state.go b/internal/consensus/state.go index 2e9f97503..f71d08649 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -363,6 +363,7 @@ func (cs *State) OnStart(ctx context.Context) error { // 1) prep work if err := cs.wal.Stop(); err != nil { + return err } @@ -421,6 +422,8 @@ func (cs *State) OnStart(ctx context.Context) error { // timeoutRoutine: receive requests for timeouts on tickChan and fire timeouts on tockChan // receiveRoutine: serializes processing of proposoals, block parts, votes; coordinates state transitions +// +// this is only used in tests. func (cs *State) startRoutines(ctx context.Context, maxSteps int) { err := cs.timeoutTicker.Start(ctx) if err != nil { @@ -445,7 +448,6 @@ func (cs *State) loadWalFile(ctx context.Context) error { // OnStop implements service.Service. func (cs *State) OnStop() { - // If the node is committing a new block, wait until it is finished! if cs.GetRoundState().Step == cstypes.RoundStepCommit { select { @@ -457,15 +459,19 @@ func (cs *State) OnStop() { close(cs.onStopCh) - if err := cs.evsw.Stop(); err != nil { - if !errors.Is(err, service.ErrAlreadyStopped) { - cs.Logger.Error("failed trying to stop eventSwitch", "error", err) + if cs.evsw.IsRunning() { + if err := cs.evsw.Stop(); err != nil { + if !errors.Is(err, service.ErrAlreadyStopped) { + cs.Logger.Error("failed trying to stop eventSwitch", "error", err) + } } } - if err := cs.timeoutTicker.Stop(); err != nil { - if !errors.Is(err, service.ErrAlreadyStopped) { - cs.Logger.Error("failed trying to stop timeoutTicket", "error", err) + if cs.timeoutTicker.IsRunning() { + if err := cs.timeoutTicker.Stop(); err != nil { + if !errors.Is(err, service.ErrAlreadyStopped) { + cs.Logger.Error("failed trying to stop timeoutTicket", "error", err) + } } } // WAL is stopped in receiveRoutine. @@ -845,9 +851,10 @@ func (cs *State) receiveRoutine(ctx context.Context, maxSteps int) { // go to the next step cs.handleTimeout(ctx, ti, rs) - case <-cs.Quit(): + case <-ctx.Done(): onExit(cs) return + } // TODO should we handle context cancels here? } @@ -875,7 +882,11 @@ func (cs *State) handleMsg(ctx context.Context, mi msgInfo) { // if the proposal is complete, we'll enterPrevote or tryFinalizeCommit added, err = cs.addProposalBlockPart(ctx, msg, peerID) if added { - cs.statsMsgQueue <- mi + select { + case cs.statsMsgQueue <- mi: + case <-ctx.Done(): + return + } } if err != nil && msg.Round != cs.Round { @@ -893,7 +904,11 @@ func (cs *State) handleMsg(ctx context.Context, mi msgInfo) { // if the vote gives us a 2/3-any or 2/3-one, we transition added, err = cs.tryAddVote(ctx, msg.Vote, peerID) if added { - cs.statsMsgQueue <- mi + select { + case cs.statsMsgQueue <- mi: + case <-ctx.Done(): + return + } } // if err == ErrAddingVote { @@ -1012,7 +1027,7 @@ func (cs *State) handleTxsAvailable(ctx context.Context) { // Used internally by handleTimeout and handleMsg to make state transitions // Enter: `timeoutNewHeight` by startTime (commitTime+timeoutCommit), -// or, if SkipTimeoutCommit==true, after receiving all precommits from (height,round-1) +// or, if SkipTimeoutCommit==true, after receiving all precommits from (height,round-1) // Enter: `timeoutPrecommits` after any +2/3 precommits from (height,round-1) // Enter: +2/3 precommits for nil at (height,round-1) // Enter: +2/3 prevotes any or +2/3 precommits for block or any from (height, round) @@ -1097,7 +1112,7 @@ func (cs *State) needProofBlock(height int64) bool { // Enter (CreateEmptyBlocks): from enterNewRound(height,round) // Enter (CreateEmptyBlocks, CreateEmptyBlocksInterval > 0 ): -// after enterNewRound(height,round), after timeout of CreateEmptyBlocksInterval +// after enterNewRound(height,round), after timeout of CreateEmptyBlocksInterval // Enter (!CreateEmptyBlocks) : after enterNewRound(height,round), once txs are in the mempool func (cs *State) enterPropose(ctx context.Context, height int64, round int32) { logger := cs.Logger.With("height", height, "round", round) @@ -2011,7 +2026,7 @@ func (cs *State) tryAddVote(ctx context.Context, vote *types.Vote, peerID types. // 1) bad peer OR // 2) not a bad peer? this can also err sometimes with "Unexpected step" OR // 3) tmkms use with multiple validators connecting to a single tmkms instance - // (https://github.com/tendermint/tendermint/issues/3839). + // (https://github.com/tendermint/tendermint/issues/3839). cs.Logger.Info("failed attempting to add vote", "err", err) return added, ErrAddingVote } diff --git a/internal/consensus/ticker.go b/internal/consensus/ticker.go index e8583932d..6e323b2d0 100644 --- a/internal/consensus/ticker.go +++ b/internal/consensus/ticker.go @@ -18,6 +18,7 @@ var ( type TimeoutTicker interface { Start(context.Context) error Stop() error + IsRunning() bool Chan() <-chan timeoutInfo // on which to receive a timeout ScheduleTimeout(ti timeoutInfo) // reset the timer } @@ -48,17 +49,14 @@ func NewTimeoutTicker(logger log.Logger) TimeoutTicker { } // OnStart implements service.Service. It starts the timeout routine. -func (t *timeoutTicker) OnStart(gctx context.Context) error { - go t.timeoutRoutine() +func (t *timeoutTicker) OnStart(ctx context.Context) error { + go t.timeoutRoutine(ctx) return nil } // OnStop implements service.Service. It stops the timeout routine. -func (t *timeoutTicker) OnStop() { - t.BaseService.OnStop() - t.stopTimer() -} +func (t *timeoutTicker) OnStop() { t.stopTimer() } // Chan returns a channel on which timeouts are sent. func (t *timeoutTicker) Chan() <-chan timeoutInfo { @@ -89,7 +87,7 @@ func (t *timeoutTicker) stopTimer() { // send on tickChan to start a new timer. // timers are interupted and replaced by new ticks from later steps // timeouts of 0 on the tickChan will be immediately relayed to the tockChan -func (t *timeoutTicker) timeoutRoutine() { +func (t *timeoutTicker) timeoutRoutine(ctx context.Context) { t.Logger.Debug("Starting timeout routine") var ti timeoutInfo for { @@ -125,7 +123,7 @@ func (t *timeoutTicker) timeoutRoutine() { // We can eliminate it by merging the timeoutRoutine into receiveRoutine // and managing the timeouts ourselves with a millisecond ticker go func(toi timeoutInfo) { t.tockChan <- toi }(ti) - case <-t.Quit(): + case <-ctx.Done(): return } } diff --git a/internal/consensus/wal.go b/internal/consensus/wal.go index 13f29a202..e89cf0992 100644 --- a/internal/consensus/wal.go +++ b/internal/consensus/wal.go @@ -131,18 +131,18 @@ func (wal *BaseWAL) OnStart(ctx context.Context) error { return err } wal.flushTicker = time.NewTicker(wal.flushInterval) - go wal.processFlushTicks() + go wal.processFlushTicks(ctx) return nil } -func (wal *BaseWAL) processFlushTicks() { +func (wal *BaseWAL) processFlushTicks(ctx context.Context) { for { select { case <-wal.flushTicker.C: if err := wal.FlushAndSync(); err != nil { wal.Logger.Error("Periodic WAL flush failed", "err", err) } - case <-wal.Quit(): + case <-ctx.Done(): return } } @@ -175,7 +175,12 @@ func (wal *BaseWAL) OnStop() { // Wait for the underlying autofile group to finish shutting down // so it's safe to cleanup files. func (wal *BaseWAL) Wait() { - wal.group.Wait() + if wal.IsRunning() { + wal.BaseService.Wait() + } + if wal.group.IsRunning() { + wal.group.Wait() + } } // Write is called in newStep and for each receive on the diff --git a/internal/eventbus/event_bus.go b/internal/eventbus/event_bus.go index 61473c713..0eb44c37b 100644 --- a/internal/eventbus/event_bus.go +++ b/internal/eventbus/event_bus.go @@ -2,7 +2,6 @@ package eventbus import ( "context" - "errors" "fmt" "strings" @@ -43,13 +42,7 @@ func (b *EventBus) OnStart(ctx context.Context) error { return b.pubsub.Start(ctx) } -func (b *EventBus) OnStop() { - if err := b.pubsub.Stop(); err != nil { - if !errors.Is(err, service.ErrAlreadyStopped) { - b.pubsub.Logger.Error("error trying to stop eventBus", "error", err) - } - } -} +func (b *EventBus) OnStop() {} func (b *EventBus) NumClients() int { return b.pubsub.NumClients() diff --git a/internal/eventbus/event_bus_test.go b/internal/eventbus/event_bus_test.go index 06f2bfa64..72b1094fb 100644 --- a/internal/eventbus/event_bus_test.go +++ b/internal/eventbus/event_bus_test.go @@ -436,11 +436,7 @@ func benchmarkEventBus(numClients int, randQueries bool, randEvents bool, b *tes if err != nil { b.Error(err) } - b.Cleanup(func() { - if err := eventBus.Stop(); err != nil { - b.Error(err) - } - }) + b.Cleanup(eventBus.Wait) q := types.EventQueryNewBlock diff --git a/internal/libs/autofile/cmd/logjack.go b/internal/libs/autofile/cmd/logjack.go index 0f412a366..72d816fcf 100644 --- a/internal/libs/autofile/cmd/logjack.go +++ b/internal/libs/autofile/cmd/logjack.go @@ -63,14 +63,10 @@ func main() { for { n, err := os.Stdin.Read(buf) if err != nil { - if err := group.Stop(); err != nil { - fmt.Fprintf(os.Stderr, "logjack stopped with error %v\n", headPath) - os.Exit(1) - } if err == io.EOF { os.Exit(0) } else { - fmt.Println("logjack errored") + fmt.Println("logjack errored:", err.Error()) os.Exit(1) } } diff --git a/internal/libs/autofile/group.go b/internal/libs/autofile/group.go index 0e208d8e9..969d101c3 100644 --- a/internal/libs/autofile/group.go +++ b/internal/libs/autofile/group.go @@ -138,7 +138,7 @@ func GroupTotalSizeLimit(limit int64) func(*Group) { // and group limits. func (g *Group) OnStart(ctx context.Context) error { g.ticker = time.NewTicker(g.groupCheckDuration) - go g.processTicks() + go g.processTicks(ctx) return nil } @@ -237,15 +237,16 @@ func (g *Group) FlushAndSync() error { return err } -func (g *Group) processTicks() { +func (g *Group) processTicks(ctx context.Context) { defer close(g.doneProcessTicks) + for { select { + case <-ctx.Done(): + return case <-g.ticker.C: g.checkHeadSizeLimit() g.checkTotalSizeLimit() - case <-g.Quit(): - return } } } diff --git a/internal/mempool/mempool_test.go b/internal/mempool/mempool_test.go index f06ee18d9..1613dce98 100644 --- a/internal/mempool/mempool_test.go +++ b/internal/mempool/mempool_test.go @@ -98,7 +98,9 @@ func setup(ctx context.Context, t testing.TB, cacheSize int, options ...TxMempoo return NewTxMempool(logger.With("test", t.Name()), cfg.Mempool, appConnMem, 0, options...) } -func checkTxs(t *testing.T, txmp *TxMempool, numTxs int, peerID uint16) []testTx { +func checkTxs(ctx context.Context, t *testing.T, txmp *TxMempool, numTxs int, peerID uint16) []testTx { + t.Helper() + txs := make([]testTx, numTxs) txInfo := TxInfo{SenderID: peerID} @@ -115,7 +117,7 @@ func checkTxs(t *testing.T, txmp *TxMempool, numTxs int, peerID uint16) []testTx tx: []byte(fmt.Sprintf("sender-%d-%d=%X=%d", i, peerID, prefix, priority)), priority: priority, } - require.NoError(t, txmp.CheckTx(context.Background(), txs[i].tx, nil, txInfo)) + require.NoError(t, txmp.CheckTx(ctx, txs[i].tx, nil, txInfo)) } return txs @@ -161,7 +163,7 @@ func TestTxMempool_TxsAvailable(t *testing.T) { // Execute CheckTx for some transactions and ensure TxsAvailable only fires // once. - txs := checkTxs(t, txmp, 100, 0) + txs := checkTxs(ctx, t, txmp, 100, 0) ensureTxFire() ensureNoTxFire() @@ -184,7 +186,7 @@ func TestTxMempool_TxsAvailable(t *testing.T) { // Execute CheckTx for more transactions and ensure we do not fire another // event as we're still on the same height (1). - _ = checkTxs(t, txmp, 100, 0) + _ = checkTxs(ctx, t, txmp, 100, 0) ensureNoTxFire() } @@ -193,7 +195,7 @@ func TestTxMempool_Size(t *testing.T) { defer cancel() txmp := setup(ctx, t, 0) - txs := checkTxs(t, txmp, 100, 0) + txs := checkTxs(ctx, t, txmp, 100, 0) require.Equal(t, len(txs), txmp.Size()) require.Equal(t, int64(5690), txmp.SizeBytes()) @@ -220,7 +222,7 @@ func TestTxMempool_Flush(t *testing.T) { defer cancel() txmp := setup(ctx, t, 0) - txs := checkTxs(t, txmp, 100, 0) + txs := checkTxs(ctx, t, txmp, 100, 0) require.Equal(t, len(txs), txmp.Size()) require.Equal(t, int64(5690), txmp.SizeBytes()) @@ -248,7 +250,7 @@ func TestTxMempool_ReapMaxBytesMaxGas(t *testing.T) { defer cancel() txmp := setup(ctx, t, 0) - tTxs := checkTxs(t, txmp, 100, 0) // all txs request 1 gas unit + tTxs := checkTxs(ctx, t, txmp, 100, 0) // all txs request 1 gas unit require.Equal(t, len(tTxs), txmp.Size()) require.Equal(t, int64(5690), txmp.SizeBytes()) @@ -301,7 +303,7 @@ func TestTxMempool_ReapMaxTxs(t *testing.T) { defer cancel() txmp := setup(ctx, t, 0) - tTxs := checkTxs(t, txmp, 100, 0) + tTxs := checkTxs(ctx, t, txmp, 100, 0) require.Equal(t, len(tTxs), txmp.Size()) require.Equal(t, int64(5690), txmp.SizeBytes()) @@ -424,7 +426,7 @@ func TestTxMempool_ConcurrentTxs(t *testing.T) { wg.Add(1) go func() { for i := 0; i < 20; i++ { - _ = checkTxs(t, txmp, 100, 0) + _ = checkTxs(ctx, t, txmp, 100, 0) dur := rng.Intn(1000-500) + 500 time.Sleep(time.Duration(dur) * time.Millisecond) } @@ -486,7 +488,7 @@ func TestTxMempool_ExpiredTxs_NumBlocks(t *testing.T) { txmp.height = 100 txmp.config.TTLNumBlocks = 10 - tTxs := checkTxs(t, txmp, 100, 0) + tTxs := checkTxs(ctx, t, txmp, 100, 0) require.Equal(t, len(tTxs), txmp.Size()) require.Equal(t, 100, txmp.heightIndex.Size()) @@ -505,7 +507,7 @@ func TestTxMempool_ExpiredTxs_NumBlocks(t *testing.T) { require.Equal(t, 95, txmp.heightIndex.Size()) // check more txs at height 101 - _ = checkTxs(t, txmp, 50, 1) + _ = checkTxs(ctx, t, txmp, 50, 1) require.Equal(t, 145, txmp.Size()) require.Equal(t, 145, txmp.heightIndex.Size()) diff --git a/internal/mempool/reactor_test.go b/internal/mempool/reactor_test.go index 4456424b5..62cacfd10 100644 --- a/internal/mempool/reactor_test.go +++ b/internal/mempool/reactor_test.go @@ -203,7 +203,7 @@ func TestReactorBroadcastTxs(t *testing.T) { primary := rts.nodes[0] secondaries := rts.nodes[1:] - txs := checkTxs(t, rts.reactors[primary].mempool, numTxs, UnknownPeerID) + txs := checkTxs(ctx, t, rts.reactors[primary].mempool, numTxs, UnknownPeerID) // run the router rts.start(t) @@ -238,7 +238,7 @@ func TestReactorConcurrency(t *testing.T) { // 1. submit a bunch of txs // 2. update the whole mempool - txs := checkTxs(t, rts.reactors[primary].mempool, numTxs, UnknownPeerID) + txs := checkTxs(ctx, t, rts.reactors[primary].mempool, numTxs, UnknownPeerID) go func() { defer wg.Done() @@ -257,7 +257,7 @@ func TestReactorConcurrency(t *testing.T) { // 1. submit a bunch of txs // 2. update none - _ = checkTxs(t, rts.reactors[secondary].mempool, numTxs, UnknownPeerID) + _ = checkTxs(ctx, t, rts.reactors[secondary].mempool, numTxs, UnknownPeerID) go func() { defer wg.Done() @@ -290,7 +290,7 @@ func TestReactorNoBroadcastToSender(t *testing.T) { secondary := rts.nodes[1] peerID := uint16(1) - _ = checkTxs(t, rts.mempools[primary], numTxs, peerID) + _ = checkTxs(ctx, t, rts.mempools[primary], numTxs, peerID) rts.start(t) @@ -430,7 +430,7 @@ func TestBroadcastTxForPeerStopsWhenPeerStops(t *testing.T) { } time.Sleep(500 * time.Millisecond) - txs := checkTxs(t, rts.reactors[primary].mempool, 4, UnknownPeerID) + txs := checkTxs(ctx, t, rts.reactors[primary].mempool, 4, UnknownPeerID) require.Equal(t, 4, len(txs)) require.Equal(t, 4, rts.mempools[primary].Size()) require.Equal(t, 0, rts.mempools[secondary].Size()) diff --git a/internal/p2p/conn/connection.go b/internal/p2p/conn/connection.go index 9fb330286..94f248a8c 100644 --- a/internal/p2p/conn/connection.go +++ b/internal/p2p/conn/connection.go @@ -101,6 +101,8 @@ type MConnection struct { // are safe to call concurrently. stopMtx tmsync.Mutex + cancel context.CancelFunc + flushTimer *timer.ThrottleTimer // flush writes as necessary but throttled. pingTimer *time.Ticker // send pings periodically @@ -187,6 +189,7 @@ func NewMConnectionWithConfig( onError: onError, config: config, created: time.Now(), + cancel: func() {}, } mconn.BaseService = *service.NewBaseService(logger, "MConnection", mconn) @@ -211,9 +214,6 @@ func NewMConnectionWithConfig( // OnStart implements BaseService func (c *MConnection) OnStart(ctx context.Context) error { - if err := c.BaseService.OnStart(ctx); err != nil { - return err - } c.flushTimer = timer.NewThrottleTimer("flush", c.config.FlushThrottle) c.pingTimer = time.NewTicker(c.config.PingInterval) c.pongTimeoutCh = make(chan bool, 1) @@ -247,7 +247,6 @@ func (c *MConnection) stopServices() (alreadyStopped bool) { default: } - c.BaseService.OnStop() c.flushTimer.Stop() c.pingTimer.Stop() c.chStatsTimer.Stop() @@ -296,6 +295,7 @@ func (c *MConnection) stopForError(r interface{}) { if err := c.Stop(); err != nil { c.Logger.Error("Error stopping connection", "err", err) } + if atomic.CompareAndSwapUint32(&c.errored, 0, 1) { if c.onError != nil { c.onError(r) diff --git a/internal/p2p/conn/connection_test.go b/internal/p2p/conn/connection_test.go index f8b34bad6..f1b2ae24c 100644 --- a/internal/p2p/conn/connection_test.go +++ b/internal/p2p/conn/connection_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/hex" "net" + "sync" "testing" "time" @@ -14,6 +15,7 @@ import ( "github.com/tendermint/tendermint/internal/libs/protoio" "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/libs/service" tmp2p "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/proto/tendermint/types" ) @@ -54,7 +56,7 @@ func TestMConnectionSendFlushStop(t *testing.T) { clientConn := createTestMConnection(log.TestingLogger(), client) err := clientConn.Start(ctx) require.Nil(t, err) - t.Cleanup(stopAll(t, clientConn)) + t.Cleanup(waitAll(clientConn)) msg := []byte("abc") assert.True(t, clientConn.Send(0x01, msg)) @@ -91,7 +93,7 @@ func TestMConnectionSend(t *testing.T) { mconn := createTestMConnection(log.TestingLogger(), client) err := mconn.Start(ctx) require.Nil(t, err) - t.Cleanup(stopAll(t, mconn)) + t.Cleanup(waitAll(mconn)) msg := []byte("Ant-Man") assert.True(t, mconn.Send(0x01, msg)) @@ -132,12 +134,12 @@ func TestMConnectionReceive(t *testing.T) { mconn1 := createMConnectionWithCallbacks(logger, client, onReceive, onError) err := mconn1.Start(ctx) require.Nil(t, err) - t.Cleanup(stopAll(t, mconn1)) + t.Cleanup(waitAll(mconn1)) mconn2 := createTestMConnection(logger, server) err = mconn2.Start(ctx) require.Nil(t, err) - t.Cleanup(stopAll(t, mconn2)) + t.Cleanup(waitAll(mconn2)) msg := []byte("Cyclops") assert.True(t, mconn2.Send(0x01, msg)) @@ -171,7 +173,7 @@ func TestMConnectionPongTimeoutResultsInError(t *testing.T) { mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) err := mconn.Start(ctx) require.Nil(t, err) - t.Cleanup(stopAll(t, mconn)) + t.Cleanup(waitAll(mconn)) serverGotPing := make(chan struct{}) go func() { @@ -212,7 +214,7 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) err := mconn.Start(ctx) require.Nil(t, err) - t.Cleanup(stopAll(t, mconn)) + t.Cleanup(waitAll(mconn)) // sending 3 pongs in a row (abuse) protoWriter := protoio.NewDelimitedWriter(server) @@ -269,7 +271,7 @@ func TestMConnectionMultiplePings(t *testing.T) { mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) err := mconn.Start(ctx) require.Nil(t, err) - t.Cleanup(stopAll(t, mconn)) + t.Cleanup(waitAll(mconn)) // sending 3 pings in a row (abuse) // see https://github.com/tendermint/tendermint/issues/1190 @@ -320,7 +322,7 @@ func TestMConnectionPingPongs(t *testing.T) { mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) err := mconn.Start(ctx) require.Nil(t, err) - t.Cleanup(stopAll(t, mconn)) + t.Cleanup(waitAll(mconn)) serverGotPing := make(chan struct{}) go func() { @@ -380,7 +382,7 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) { mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) err := mconn.Start(ctx) require.Nil(t, err) - t.Cleanup(stopAll(t, mconn)) + t.Cleanup(waitAll(mconn)) if err := client.Close(); err != nil { t.Error(err) @@ -454,7 +456,7 @@ func TestMConnectionReadErrorBadEncoding(t *testing.T) { _, err := client.Write([]byte{1, 2, 3, 4, 5}) require.NoError(t, err) assert.True(t, expectSend(chOnErr), "badly encoded msgPacket") - t.Cleanup(stopAll(t, mconnClient, mconnServer)) + t.Cleanup(waitAll(mconnClient, mconnServer)) } func TestMConnectionReadErrorUnknownChannel(t *testing.T) { @@ -473,7 +475,7 @@ func TestMConnectionReadErrorUnknownChannel(t *testing.T) { // should cause an error assert.True(t, mconnClient.Send(0x02, msg)) assert.True(t, expectSend(chOnErr), "unknown channel") - t.Cleanup(stopAll(t, mconnClient, mconnServer)) + t.Cleanup(waitAll(mconnClient, mconnServer)) } func TestMConnectionReadErrorLongMessage(t *testing.T) { @@ -484,7 +486,7 @@ func TestMConnectionReadErrorLongMessage(t *testing.T) { defer cancel() mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) - t.Cleanup(stopAll(t, mconnClient, mconnServer)) + t.Cleanup(waitAll(mconnClient, mconnServer)) mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) { chOnRcv <- struct{}{} @@ -522,7 +524,7 @@ func TestMConnectionReadErrorUnknownMsgType(t *testing.T) { chOnErr := make(chan struct{}) mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) - t.Cleanup(stopAll(t, mconnClient, mconnServer)) + t.Cleanup(waitAll(mconnClient, mconnServer)) // send msg with unknown msg type _, err := protoio.NewDelimitedWriter(mconnClient.conn).WriteMsg(&types.Header{ChainID: "x"}) @@ -539,7 +541,7 @@ func TestMConnectionTrySend(t *testing.T) { mconn := createTestMConnection(log.TestingLogger(), client) err := mconn.Start(ctx) require.Nil(t, err) - t.Cleanup(stopAll(t, mconn)) + t.Cleanup(waitAll(mconn)) msg := []byte("Semicolon-Woman") resultCh := make(chan string, 2) @@ -586,7 +588,7 @@ func TestMConnectionChannelOverflow(t *testing.T) { defer cancel() mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) - t.Cleanup(stopAll(t, mconnClient, mconnServer)) + t.Cleanup(waitAll(mconnClient, mconnServer)) mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) { chOnRcv <- struct{}{} @@ -611,16 +613,26 @@ func TestMConnectionChannelOverflow(t *testing.T) { } -type stopper interface { - Stop() error -} - -func stopAll(t *testing.T, stoppers ...stopper) func() { +func waitAll(waiters ...service.Service) func() { return func() { - for _, s := range stoppers { - if err := s.Stop(); err != nil { - t.Log(err) + switch len(waiters) { + case 0: + return + case 1: + waiters[0].Wait() + return + default: + wg := &sync.WaitGroup{} + + for _, w := range waiters { + wg.Add(1) + go func(s service.Service) { + defer wg.Done() + s.Wait() + }(w) } + + wg.Wait() } } } diff --git a/libs/service/service.go b/libs/service/service.go index f2b440e94..ad88ccf16 100644 --- a/libs/service/service.go +++ b/libs/service/service.go @@ -136,17 +136,31 @@ func (bs *BaseService) Start(ctx context.Context) error { } go func(ctx context.Context) { - <-ctx.Done() - if err := bs.Stop(); err != nil { - bs.Logger.Error("stopped service", - "err", err.Error(), + select { + case <-bs.quit: + // someone else explicitly called stop + // and then we shouldn't. + return + case <-ctx.Done(): + // if nothing is running, no need to + // shut down again. + if !bs.impl.IsRunning() { + return + } + + // the context was cancel and we + // should stop. + if err := bs.Stop(); err != nil { + bs.Logger.Error("stopped service", + "err", err.Error(), + "service", bs.name, + "impl", bs.impl.String()) + } + + bs.Logger.Info("stopped service", "service", bs.name, "impl", bs.impl.String()) } - - bs.Logger.Info("stopped service", - "service", bs.name, - "impl", bs.impl.String()) }(ctx) return nil @@ -156,11 +170,6 @@ func (bs *BaseService) Start(ctx context.Context) error { return ErrAlreadyStarted } -// OnStart implements Service by doing nothing. -// NOTE: Do not put anything in here, -// that way users don't need to call BaseService.OnStart() -func (bs *BaseService) OnStart(ctx context.Context) error { return nil } - // Stop implements Service by calling OnStop (if defined) and closing quit // channel. An error will be returned if the service is already stopped. func (bs *BaseService) Stop() error { @@ -182,11 +191,6 @@ func (bs *BaseService) Stop() error { return ErrAlreadyStopped } -// OnStop implements Service by doing nothing. -// NOTE: Do not put anything in here, -// that way users don't need to call BaseService.OnStop() -func (bs *BaseService) OnStop() {} - // IsRunning implements Service by returning true or false depending on the // service's state. func (bs *BaseService) IsRunning() bool { @@ -198,6 +202,3 @@ func (bs *BaseService) Wait() { <-bs.quit } // String implements Service by returning a string representation of the service. func (bs *BaseService) String() string { return bs.name } - -// Quit Implements Service by returning a quit channel. -func (bs *BaseService) Quit() <-chan struct{} { return bs.quit } diff --git a/libs/service/service_test.go b/libs/service/service_test.go index dc5d0ccb1..9630d358b 100644 --- a/libs/service/service_test.go +++ b/libs/service/service_test.go @@ -12,7 +12,8 @@ type testService struct { BaseService } -func (testService) OnReset() error { +func (testService) OnStop() {} +func (testService) OnStart(context.Context) error { return nil } @@ -31,7 +32,7 @@ func TestBaseServiceWait(t *testing.T) { waitFinished <- struct{}{} }() - go ts.Stop() //nolint:errcheck // ignore for tests + go cancel() select { case <-waitFinished: diff --git a/light/rpc/client.go b/light/rpc/client.go index 51676dd3e..a5317ca0b 100644 --- a/light/rpc/client.go +++ b/light/rpc/client.go @@ -47,6 +47,8 @@ type Client struct { // proof runtime used to verify values returned by ABCIQuery prt *merkle.ProofRuntime keyPathFn KeyPathFunc + + quitCh chan struct{} } var _ rpcclient.Client = (*Client)(nil) @@ -87,9 +89,10 @@ func DefaultMerkleKeyPathFn() KeyPathFunc { // NewClient returns a new client. func NewClient(next rpcclient.Client, lc LightClient, opts ...Option) *Client { c := &Client{ - next: next, - lc: lc, - prt: merkle.DefaultProofRuntime(), + next: next, + lc: lc, + prt: merkle.DefaultProofRuntime(), + quitCh: make(chan struct{}), } c.BaseService = *service.NewBaseService(nil, "Client", c) for _, o := range opts { @@ -102,6 +105,12 @@ func (c *Client) OnStart(ctx context.Context) error { if !c.next.IsRunning() { return c.next.Start(ctx) } + + go func() { + defer close(c.quitCh) + c.Wait() + }() + return nil } @@ -586,7 +595,7 @@ func (c *Client) SubscribeWS(ctx *rpctypes.Context, query string) (*coretypes.Re rpctypes.JSONRPCStringID(fmt.Sprintf("%v#event", ctx.JSONReq.ID)), resultEvent, )) - case <-c.Quit(): + case <-c.quitCh: return } } diff --git a/node/node_test.go b/node/node_test.go index 3e4e83492..16f8c44aa 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -48,45 +48,34 @@ func TestNodeStartStop(t *testing.T) { // create & start node ns, err := newDefaultNode(ctx, cfg, log.TestingLogger()) require.NoError(t, err) - require.NoError(t, ns.Start(ctx)) + n, ok := ns.(*nodeImpl) + require.True(t, ok) t.Cleanup(func() { - if ns.IsRunning() { + if n.IsRunning() { bcancel() - ns.Wait() + n.Wait() } }) - n, ok := ns.(*nodeImpl) - require.True(t, ok) - + require.NoError(t, n.Start(ctx)) // wait for the node to produce a block - blocksSub, err := n.EventBus().SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ + tctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + blocksSub, err := n.EventBus().SubscribeWithArgs(tctx, pubsub.SubscribeArgs{ ClientID: "node_test", Query: types.EventQueryNewBlock, }) require.NoError(t, err) - tctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - if _, err := blocksSub.Next(tctx); err != nil { - t.Fatalf("Waiting for event: %v", err) - } - - // stop the node - go func() { - bcancel() - n.Wait() - }() + _, err = blocksSub.Next(tctx) + require.NoError(t, err, "waiting for event") - select { - case <-n.Quit(): - return - case <-time.After(10 * time.Second): - if n.IsRunning() { - t.Fatal("timed out waiting for shutdown") - } + cancel() // stop the subscription context + bcancel() // stop the base context + n.Wait() - } + require.False(t, n.IsRunning(), "node must shut down") } func getTestNode(ctx context.Context, t *testing.T, conf *config.Config, logger log.Logger) *nodeImpl { diff --git a/privval/signer_client_test.go b/privval/signer_client_test.go index f9272b004..7ff353dcd 100644 --- a/privval/signer_client_test.go +++ b/privval/signer_client_test.go @@ -60,10 +60,10 @@ func getSignerTestCases(ctx context.Context, t *testing.T) []signerTestCase { } func TestSignerClose(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + bctx, bcancel := context.WithCancel(context.Background()) + defer bcancel() - for _, tc := range getSignerTestCases(ctx, t) { + for _, tc := range getSignerTestCases(bctx, t) { t.Run(tc.name, func(t *testing.T) { defer tc.closer() diff --git a/privval/signer_dialer_endpoint.go b/privval/signer_dialer_endpoint.go index 93d26b043..cc605617f 100644 --- a/privval/signer_dialer_endpoint.go +++ b/privval/signer_dialer_endpoint.go @@ -1,6 +1,7 @@ package privval import ( + "context" "time" "github.com/tendermint/tendermint/libs/log" @@ -69,6 +70,9 @@ func NewSignerDialerEndpoint( return sd } +func (sd *SignerDialerEndpoint) OnStart(context.Context) error { return nil } +func (sd *SignerDialerEndpoint) OnStop() {} + func (sd *SignerDialerEndpoint) ensureConnection() error { if sd.IsConnected() { return nil diff --git a/privval/signer_listener_endpoint.go b/privval/signer_listener_endpoint.go index e2287c630..a825f635d 100644 --- a/privval/signer_listener_endpoint.go +++ b/privval/signer_listener_endpoint.go @@ -72,8 +72,8 @@ func (sl *SignerListenerEndpoint) OnStart(ctx context.Context) error { sl.pingInterval = time.Duration(sl.signerEndpoint.timeoutReadWrite.Milliseconds()*2/3) * time.Millisecond sl.pingTimer = time.NewTicker(sl.pingInterval) - go sl.serviceLoop() - go sl.pingLoop() + go sl.serviceLoop(ctx) + go sl.pingLoop(ctx) sl.connectRequestCh <- struct{}{} @@ -173,7 +173,7 @@ func (sl *SignerListenerEndpoint) triggerReconnect() { sl.triggerConnect() } -func (sl *SignerListenerEndpoint) serviceLoop() { +func (sl *SignerListenerEndpoint) serviceLoop(ctx context.Context) { for { select { case <-sl.connectRequestCh: @@ -185,7 +185,7 @@ func (sl *SignerListenerEndpoint) serviceLoop() { // We have a good connection, wait for someone that needs one otherwise cancellation select { case sl.connectionAvailableCh <- conn: - case <-sl.Quit(): + case <-ctx.Done(): return } } @@ -195,13 +195,13 @@ func (sl *SignerListenerEndpoint) serviceLoop() { default: } } - case <-sl.Quit(): + case <-ctx.Done(): return } } } -func (sl *SignerListenerEndpoint) pingLoop() { +func (sl *SignerListenerEndpoint) pingLoop(ctx context.Context) { for { select { case <-sl.pingTimer.C: @@ -212,7 +212,7 @@ func (sl *SignerListenerEndpoint) pingLoop() { sl.triggerReconnect() } } - case <-sl.Quit(): + case <-ctx.Done(): return } } diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index b92e0abe5..969590bc8 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -77,11 +77,7 @@ func TestSignerRemoteRetryTCPOnly(t *testing.T) { err = signerServer.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := signerServer.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(signerServer.Wait) select { case attempts := <-attemptCh: diff --git a/privval/signer_server.go b/privval/signer_server.go index e31d3bdb4..dd924f0eb 100644 --- a/privval/signer_server.go +++ b/privval/signer_server.go @@ -94,8 +94,6 @@ func (ss *SignerServer) servicePendingRequest() { func (ss *SignerServer) serviceLoop(ctx context.Context) { for { select { - case <-ss.Quit(): - return case <-ctx.Done(): return default: diff --git a/rpc/client/local/local.go b/rpc/client/local/local.go index 10b2b4be7..9f5ba0072 100644 --- a/rpc/client/local/local.go +++ b/rpc/client/local/local.go @@ -220,7 +220,7 @@ func (c *Local) Subscribe( } ctx, cancel := context.WithCancel(ctx) - go func() { <-c.Quit(); cancel() }() + go func() { c.Wait(); cancel() }() subArgs := pubsub.SubscribeArgs{ ClientID: subscriber,