From bd6dc3ca8858446745240aadab32f1bce5c15f84 Mon Sep 17 00:00:00 2001 From: Sam Kleinman Date: Thu, 9 Dec 2021 14:03:41 -0500 Subject: [PATCH] p2p: refactor channel Send/out (#7414) --- internal/blocksync/reactor.go | 40 +++-- internal/consensus/byzantine_test.go | 26 +-- internal/consensus/invalid_test.go | 5 +- internal/consensus/reactor.go | 229 +++++++++++++++---------- internal/consensus/replay_test.go | 8 +- internal/consensus/state.go | 67 +++++--- internal/consensus/state_test.go | 26 +-- internal/evidence/reactor.go | 8 +- internal/evidence/reactor_test.go | 21 +-- internal/mempool/reactor.go | 8 +- internal/mempool/reactor_test.go | 33 +--- internal/p2p/channel.go | 6 +- internal/p2p/channel_test.go | 2 +- internal/p2p/p2ptest/require.go | 20 ++- internal/p2p/pex/reactor.go | 20 ++- internal/p2p/pex/reactor_test.go | 27 +-- internal/p2p/router_test.go | 43 ++--- internal/statesync/dispatcher.go | 11 +- internal/statesync/dispatcher_test.go | 63 +++++-- internal/statesync/reactor.go | 66 ++++---- internal/statesync/reactor_test.go | 4 +- internal/statesync/stateprovider.go | 12 +- internal/statesync/syncer.go | 54 +++--- internal/statesync/syncer_test.go | 14 +- libs/events/event_cache.go | 6 +- libs/events/event_cache_test.go | 12 +- libs/events/events.go | 15 +- libs/events/events_test.go | 231 ++++++++++++++++++-------- 28 files changed, 625 insertions(+), 452 deletions(-) diff --git a/internal/blocksync/reactor.go b/internal/blocksync/reactor.go index 2f93a3cf3..53a63fb84 100644 --- a/internal/blocksync/reactor.go +++ b/internal/blocksync/reactor.go @@ -2,6 +2,7 @@ package blocksync import ( "context" + "errors" "fmt" "runtime/debug" "sync" @@ -185,40 +186,38 @@ func (r *Reactor) OnStop() { // respondToPeer loads a block and sends it to the requesting peer, if we have it. // Otherwise, we'll respond saying we do not have it. -func (r *Reactor) respondToPeer(msg *bcproto.BlockRequest, peerID types.NodeID) { +func (r *Reactor) respondToPeer(ctx context.Context, msg *bcproto.BlockRequest, peerID types.NodeID) error { block := r.store.LoadBlock(msg.Height) if block != nil { blockProto, err := block.ToProto() if err != nil { r.logger.Error("failed to convert msg to protobuf", "err", err) - return + return err } - r.blockSyncCh.Out <- p2p.Envelope{ + return r.blockSyncCh.Send(ctx, p2p.Envelope{ To: peerID, Message: &bcproto.BlockResponse{Block: blockProto}, - } - - return + }) } r.logger.Info("peer requesting a block we do not have", "peer", peerID, "height", msg.Height) - r.blockSyncCh.Out <- p2p.Envelope{ + + return r.blockSyncCh.Send(ctx, p2p.Envelope{ To: peerID, Message: &bcproto.NoBlockResponse{Height: msg.Height}, - } + }) } // handleBlockSyncMessage handles envelopes sent from peers on the // BlockSyncChannel. It returns an error only if the Envelope.Message is unknown // for this channel. This should never be called outside of handleMessage. -func (r *Reactor) handleBlockSyncMessage(envelope p2p.Envelope) error { +func (r *Reactor) handleBlockSyncMessage(ctx context.Context, envelope p2p.Envelope) error { logger := r.logger.With("peer", envelope.From) switch msg := envelope.Message.(type) { case *bcproto.BlockRequest: - r.respondToPeer(msg, envelope.From) - + return r.respondToPeer(ctx, msg, envelope.From) case *bcproto.BlockResponse: block, err := types.BlockFromProto(msg.Block) if err != nil { @@ -229,14 +228,13 @@ func (r *Reactor) handleBlockSyncMessage(envelope p2p.Envelope) error { r.pool.AddBlock(envelope.From, block, block.Size()) case *bcproto.StatusRequest: - r.blockSyncCh.Out <- p2p.Envelope{ + return r.blockSyncCh.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &bcproto.StatusResponse{ Height: r.store.Height(), Base: r.store.Base(), }, - } - + }) case *bcproto.StatusResponse: r.pool.SetPeerRange(envelope.From, msg.Base, msg.Height) @@ -253,7 +251,7 @@ func (r *Reactor) handleBlockSyncMessage(envelope p2p.Envelope) error { // handleMessage handles an Envelope sent from a peer on a specific p2p Channel. // It will handle errors and any possible panics gracefully. A caller can handle // any error returned by sending a PeerError on the respective channel. -func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err error) { +func (r *Reactor) handleMessage(ctx context.Context, chID p2p.ChannelID, envelope p2p.Envelope) (err error) { defer func() { if e := recover(); e != nil { err = fmt.Errorf("panic in processing message: %v", e) @@ -269,7 +267,7 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err switch chID { case BlockSyncChannel: - err = r.handleBlockSyncMessage(envelope) + err = r.handleBlockSyncMessage(ctx, envelope) default: err = fmt.Errorf("unknown channel ID (%d) for envelope (%v)", chID, envelope) @@ -290,7 +288,11 @@ func (r *Reactor) processBlockSyncCh(ctx context.Context) { r.logger.Debug("stopped listening on block sync channel; closing...") return case envelope := <-r.blockSyncCh.In: - if err := r.handleMessage(r.blockSyncCh.ID, envelope); err != nil { + if err := r.handleMessage(ctx, r.blockSyncCh.ID, envelope); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + r.logger.Error("failed to process message", "ch_id", r.blockSyncCh.ID, "envelope", envelope, "err", err) if serr := r.blockSyncCh.SendError(ctx, p2p.PeerError{ NodeID: envelope.From, @@ -300,7 +302,9 @@ func (r *Reactor) processBlockSyncCh(ctx context.Context) { } } case envelope := <-r.blockSyncOutBridgeCh: - r.blockSyncCh.Out <- envelope + if err := r.blockSyncCh.Send(ctx, envelope); err != nil { + return + } } } } diff --git a/internal/consensus/byzantine_test.go b/internal/consensus/byzantine_test.go index 9526f4ae1..3133e3659 100644 --- a/internal/consensus/byzantine_test.go +++ b/internal/consensus/byzantine_test.go @@ -141,20 +141,22 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { for _, ps := range bzReactor.peers { if i < len(bzReactor.peers)/2 { bzNodeState.logger.Info("signed and pushed vote", "vote", prevote1, "peer", ps.peerID) - bzReactor.voteCh.Out <- p2p.Envelope{ - To: ps.peerID, - Message: &tmcons.Vote{ - Vote: prevote1.ToProto(), - }, - } + require.NoError(t, bzReactor.voteCh.Send(ctx, + p2p.Envelope{ + To: ps.peerID, + Message: &tmcons.Vote{ + Vote: prevote1.ToProto(), + }, + })) } else { bzNodeState.logger.Info("signed and pushed vote", "vote", prevote2, "peer", ps.peerID) - bzReactor.voteCh.Out <- p2p.Envelope{ - To: ps.peerID, - Message: &tmcons.Vote{ - Vote: prevote2.ToProto(), - }, - } + require.NoError(t, bzReactor.voteCh.Send(ctx, + p2p.Envelope{ + To: ps.peerID, + Message: &tmcons.Vote{ + Vote: prevote2.ToProto(), + }, + })) } i++ diff --git a/internal/consensus/invalid_test.go b/internal/consensus/invalid_test.go index 0c0528d6f..1b3636f02 100644 --- a/internal/consensus/invalid_test.go +++ b/internal/consensus/invalid_test.go @@ -124,13 +124,12 @@ func invalidDoPrevoteFunc( for _, ps := range r.peers { cs.logger.Info("sending bad vote", "block", blockHash, "peer", ps.peerID) - - r.voteCh.Out <- p2p.Envelope{ + require.NoError(t, r.voteCh.Send(ctx, p2p.Envelope{ To: ps.peerID, Message: &tmcons.Vote{ Vote: precommit.ToProto(), }, - } + })) } }() } diff --git a/internal/consensus/reactor.go b/internal/consensus/reactor.go index 5e2a6b535..88a831ede 100644 --- a/internal/consensus/reactor.go +++ b/internal/consensus/reactor.go @@ -317,16 +317,16 @@ func (r *Reactor) GetPeerState(peerID types.NodeID) (*PeerState, bool) { return ps, ok } -func (r *Reactor) broadcastNewRoundStepMessage(rs *cstypes.RoundState) { - r.stateCh.Out <- p2p.Envelope{ +func (r *Reactor) broadcastNewRoundStepMessage(ctx context.Context, rs *cstypes.RoundState) error { + return r.stateCh.Send(ctx, p2p.Envelope{ Broadcast: true, Message: makeRoundStepMessage(rs), - } + }) } -func (r *Reactor) broadcastNewValidBlockMessage(rs *cstypes.RoundState) { +func (r *Reactor) broadcastNewValidBlockMessage(ctx context.Context, rs *cstypes.RoundState) error { psHeader := rs.ProposalBlockParts.Header() - r.stateCh.Out <- p2p.Envelope{ + return r.stateCh.Send(ctx, p2p.Envelope{ Broadcast: true, Message: &tmcons.NewValidBlock{ Height: rs.Height, @@ -335,11 +335,11 @@ func (r *Reactor) broadcastNewValidBlockMessage(rs *cstypes.RoundState) { BlockParts: rs.ProposalBlockParts.BitArray().ToProto(), IsCommit: rs.Step == cstypes.RoundStepCommit, }, - } + }) } -func (r *Reactor) broadcastHasVoteMessage(vote *types.Vote) { - r.stateCh.Out <- p2p.Envelope{ +func (r *Reactor) broadcastHasVoteMessage(ctx context.Context, vote *types.Vote) error { + return r.stateCh.Send(ctx, p2p.Envelope{ Broadcast: true, Message: &tmcons.HasVote{ Height: vote.Height, @@ -347,7 +347,7 @@ func (r *Reactor) broadcastHasVoteMessage(vote *types.Vote) { Type: vote.Type, Index: vote.ValidatorIndex, }, - } + }) } // subscribeToBroadcastEvents subscribes for new round steps and votes using the @@ -357,11 +357,17 @@ func (r *Reactor) subscribeToBroadcastEvents() { err := r.state.evsw.AddListenerForEvent( listenerIDConsensus, types.EventNewRoundStepValue, - func(data tmevents.EventData) { - r.broadcastNewRoundStepMessage(data.(*cstypes.RoundState)) + func(ctx context.Context, data tmevents.EventData) error { + if err := r.broadcastNewRoundStepMessage(ctx, data.(*cstypes.RoundState)); err != nil { + return err + } select { case r.state.onStopCh <- data.(*cstypes.RoundState): + return nil + case <-ctx.Done(): + return ctx.Err() default: + return nil } }, ) @@ -372,8 +378,8 @@ func (r *Reactor) subscribeToBroadcastEvents() { err = r.state.evsw.AddListenerForEvent( listenerIDConsensus, types.EventValidBlockValue, - func(data tmevents.EventData) { - r.broadcastNewValidBlockMessage(data.(*cstypes.RoundState)) + func(ctx context.Context, data tmevents.EventData) error { + return r.broadcastNewValidBlockMessage(ctx, data.(*cstypes.RoundState)) }, ) if err != nil { @@ -383,8 +389,8 @@ func (r *Reactor) subscribeToBroadcastEvents() { err = r.state.evsw.AddListenerForEvent( listenerIDConsensus, types.EventVoteValue, - func(data tmevents.EventData) { - r.broadcastHasVoteMessage(data.(*types.Vote)) + func(ctx context.Context, data tmevents.EventData) error { + return r.broadcastHasVoteMessage(ctx, data.(*types.Vote)) }, ) if err != nil { @@ -406,19 +412,14 @@ func makeRoundStepMessage(rs *cstypes.RoundState) *tmcons.NewRoundStep { } } -func (r *Reactor) sendNewRoundStepMessage(ctx context.Context, peerID types.NodeID) { - rs := r.state.GetRoundState() - msg := makeRoundStepMessage(rs) - select { - case <-ctx.Done(): - case r.stateCh.Out <- p2p.Envelope{ +func (r *Reactor) sendNewRoundStepMessage(ctx context.Context, peerID types.NodeID) error { + return r.stateCh.Send(ctx, p2p.Envelope{ To: peerID, - Message: msg, - }: - } + Message: makeRoundStepMessage(r.state.GetRoundState()), + }) } -func (r *Reactor) gossipDataForCatchup(rs *cstypes.RoundState, prs *cstypes.PeerRoundState, ps *PeerState) { +func (r *Reactor) gossipDataForCatchup(ctx context.Context, rs *cstypes.RoundState, prs *cstypes.PeerRoundState, ps *PeerState) { logger := r.logger.With("height", prs.Height).With("peer", ps.peerID) if index, ok := prs.ProposalBlockParts.Not().PickRandom(); ok { @@ -467,14 +468,14 @@ func (r *Reactor) gossipDataForCatchup(rs *cstypes.RoundState, prs *cstypes.Peer } logger.Debug("sending block part for catchup", "round", prs.Round, "index", index) - r.dataCh.Out <- p2p.Envelope{ + _ = r.dataCh.Send(ctx, p2p.Envelope{ To: ps.peerID, Message: &tmcons.BlockPart{ Height: prs.Height, // not our height, so it does not matter. Round: prs.Round, // not our height, so it does not matter Part: *partProto, }, - } + }) return } @@ -521,13 +522,15 @@ OUTER_LOOP: } logger.Debug("sending block part", "height", prs.Height, "round", prs.Round) - r.dataCh.Out <- p2p.Envelope{ + if err := r.dataCh.Send(ctx, p2p.Envelope{ To: ps.peerID, Message: &tmcons.BlockPart{ Height: rs.Height, // this tells peer that this part applies to us Round: rs.Round, // this tells peer that this part applies to us Part: *partProto, }, + }); err != nil { + return } ps.SetHasProposalBlockPart(prs.Height, prs.Round, index) @@ -566,7 +569,7 @@ OUTER_LOOP: continue OUTER_LOOP } - r.gossipDataForCatchup(rs, prs, ps) + r.gossipDataForCatchup(ctx, rs, prs, ps) continue OUTER_LOOP } @@ -593,11 +596,13 @@ OUTER_LOOP: propProto := rs.Proposal.ToProto() logger.Debug("sending proposal", "height", prs.Height, "round", prs.Round) - r.dataCh.Out <- p2p.Envelope{ + if err := r.dataCh.Send(ctx, p2p.Envelope{ To: ps.peerID, Message: &tmcons.Proposal{ Proposal: *propProto, }, + }); err != nil { + return } // NOTE: A peer might have received a different proposal message, so @@ -614,13 +619,15 @@ OUTER_LOOP: pPolProto := pPol.ToProto() logger.Debug("sending POL", "height", prs.Height, "round", prs.Round) - r.dataCh.Out <- p2p.Envelope{ + if err := r.dataCh.Send(ctx, p2p.Envelope{ To: ps.peerID, Message: &tmcons.ProposalPOL{ Height: rs.Height, ProposalPolRound: rs.Proposal.POLRound, ProposalPol: *pPolProto, }, + }); err != nil { + return } } @@ -640,24 +647,24 @@ OUTER_LOOP: // pickSendVote picks a vote and sends it to the peer. It will return true if // there is a vote to send and false otherwise. -func (r *Reactor) pickSendVote(ctx context.Context, ps *PeerState, votes types.VoteSetReader) bool { - if vote, ok := ps.PickVoteToSend(votes); ok { - r.logger.Debug("sending vote message", "ps", ps, "vote", vote) - select { - case <-ctx.Done(): - case r.voteCh.Out <- p2p.Envelope{ - To: ps.peerID, - Message: &tmcons.Vote{ - Vote: vote.ToProto(), - }, - }: - } +func (r *Reactor) pickSendVote(ctx context.Context, ps *PeerState, votes types.VoteSetReader) (bool, error) { + vote, ok := ps.PickVoteToSend(votes) + if !ok { + return false, nil + } - ps.SetHasVote(vote) - return true + r.logger.Debug("sending vote message", "ps", ps, "vote", vote) + if err := r.voteCh.Send(ctx, p2p.Envelope{ + To: ps.peerID, + Message: &tmcons.Vote{ + Vote: vote.ToProto(), + }, + }); err != nil { + return false, err } - return false + ps.SetHasVote(vote) + return true, nil } func (r *Reactor) gossipVotesForHeight( @@ -665,62 +672,75 @@ func (r *Reactor) gossipVotesForHeight( rs *cstypes.RoundState, prs *cstypes.PeerRoundState, ps *PeerState, -) bool { +) (bool, error) { logger := r.logger.With("height", prs.Height).With("peer", ps.peerID) // if there are lastCommits to send... if prs.Step == cstypes.RoundStepNewHeight { - if r.pickSendVote(ctx, ps, rs.LastCommit) { + if ok, err := r.pickSendVote(ctx, ps, rs.LastCommit); err != nil { + return false, err + } else if ok { logger.Debug("picked rs.LastCommit to send") - return true + return true, nil + } } // if there are POL prevotes to send... if prs.Step <= cstypes.RoundStepPropose && prs.Round != -1 && prs.Round <= rs.Round && prs.ProposalPOLRound != -1 { if polPrevotes := rs.Votes.Prevotes(prs.ProposalPOLRound); polPrevotes != nil { - if r.pickSendVote(ctx, ps, polPrevotes) { + if ok, err := r.pickSendVote(ctx, ps, polPrevotes); err != nil { + return false, err + } else if ok { logger.Debug("picked rs.Prevotes(prs.ProposalPOLRound) to send", "round", prs.ProposalPOLRound) - return true + return true, nil } } } // if there are prevotes to send... if prs.Step <= cstypes.RoundStepPrevoteWait && prs.Round != -1 && prs.Round <= rs.Round { - if r.pickSendVote(ctx, ps, rs.Votes.Prevotes(prs.Round)) { + if ok, err := r.pickSendVote(ctx, ps, rs.Votes.Prevotes(prs.Round)); err != nil { + return false, err + } else if ok { logger.Debug("picked rs.Prevotes(prs.Round) to send", "round", prs.Round) - return true + return true, nil } } // if there are precommits to send... if prs.Step <= cstypes.RoundStepPrecommitWait && prs.Round != -1 && prs.Round <= rs.Round { - if r.pickSendVote(ctx, ps, rs.Votes.Precommits(prs.Round)) { + if ok, err := r.pickSendVote(ctx, ps, rs.Votes.Precommits(prs.Round)); err != nil { + return false, err + } else if ok { logger.Debug("picked rs.Precommits(prs.Round) to send", "round", prs.Round) - return true + return true, nil } } // if there are prevotes to send...(which are needed because of validBlock mechanism) if prs.Round != -1 && prs.Round <= rs.Round { - if r.pickSendVote(ctx, ps, rs.Votes.Prevotes(prs.Round)) { + if ok, err := r.pickSendVote(ctx, ps, rs.Votes.Prevotes(prs.Round)); err != nil { + return false, err + } else if ok { logger.Debug("picked rs.Prevotes(prs.Round) to send", "round", prs.Round) - return true + return true, nil } } // if there are POLPrevotes to send... if prs.ProposalPOLRound != -1 { if polPrevotes := rs.Votes.Prevotes(prs.ProposalPOLRound); polPrevotes != nil { - if r.pickSendVote(ctx, ps, polPrevotes) { + if ok, err := r.pickSendVote(ctx, ps, polPrevotes); err != nil { + return false, err + } else if ok { logger.Debug("picked rs.Prevotes(prs.ProposalPOLRound) to send", "round", prs.ProposalPOLRound) - return true + return true, nil } } } - return false + return false, nil } func (r *Reactor) gossipVotesRoutine(ctx context.Context, ps *PeerState) { @@ -763,14 +783,18 @@ OUTER_LOOP: // if height matches, then send LastCommit, Prevotes, and Precommits if rs.Height == prs.Height { - if r.gossipVotesForHeight(ctx, rs, prs, ps) { + if ok, err := r.gossipVotesForHeight(ctx, rs, prs, ps); err != nil { + return + } else if ok { continue OUTER_LOOP } } // special catchup logic -- if peer is lagging by height 1, send LastCommit if prs.Height != 0 && rs.Height == prs.Height+1 { - if r.pickSendVote(ctx, ps, rs.LastCommit) { + if ok, err := r.pickSendVote(ctx, ps, rs.LastCommit); err != nil { + return + } else if ok { logger.Debug("picked rs.LastCommit to send", "height", prs.Height) continue OUTER_LOOP } @@ -782,7 +806,9 @@ OUTER_LOOP: // Load the block commit for prs.Height, which contains precommit // signatures for prs.Height. if commit := r.state.blockStore.LoadBlockCommit(prs.Height); commit != nil { - if r.pickSendVote(ctx, ps, commit) { + if ok, err := r.pickSendVote(ctx, ps, commit); err != nil { + return + } else if ok { logger.Debug("picked Catchup commit to send", "height", prs.Height) continue OUTER_LOOP } @@ -844,7 +870,7 @@ OUTER_LOOP: if rs.Height == prs.Height { if maj23, ok := rs.Votes.Prevotes(prs.Round).TwoThirdsMajority(); ok { - r.stateCh.Out <- p2p.Envelope{ + if err := r.stateCh.Send(ctx, p2p.Envelope{ To: ps.peerID, Message: &tmcons.VoteSetMaj23{ Height: prs.Height, @@ -852,6 +878,8 @@ OUTER_LOOP: Type: tmproto.PrevoteType, BlockID: maj23.ToProto(), }, + }); err != nil { + return } timer.Reset(r.state.config.PeerQueryMaj23SleepDuration) @@ -871,7 +899,7 @@ OUTER_LOOP: if rs.Height == prs.Height { if maj23, ok := rs.Votes.Precommits(prs.Round).TwoThirdsMajority(); ok { - r.stateCh.Out <- p2p.Envelope{ + if err := r.stateCh.Send(ctx, p2p.Envelope{ To: ps.peerID, Message: &tmcons.VoteSetMaj23{ Height: prs.Height, @@ -879,6 +907,8 @@ OUTER_LOOP: Type: tmproto.PrecommitType, BlockID: maj23.ToProto(), }, + }); err != nil { + return } select { @@ -898,7 +928,7 @@ OUTER_LOOP: if rs.Height == prs.Height && prs.ProposalPOLRound >= 0 { if maj23, ok := rs.Votes.Prevotes(prs.ProposalPOLRound).TwoThirdsMajority(); ok { - r.stateCh.Out <- p2p.Envelope{ + if err := r.stateCh.Send(ctx, p2p.Envelope{ To: ps.peerID, Message: &tmcons.VoteSetMaj23{ Height: prs.Height, @@ -906,6 +936,8 @@ OUTER_LOOP: Type: tmproto.PrevoteType, BlockID: maj23.ToProto(), }, + }); err != nil { + return } timer.Reset(r.state.config.PeerQueryMaj23SleepDuration) @@ -928,7 +960,7 @@ OUTER_LOOP: if prs.CatchupCommitRound != -1 && prs.Height > 0 && prs.Height <= r.state.blockStore.Height() && prs.Height >= r.state.blockStore.Base() { if commit := r.state.LoadCommit(prs.Height); commit != nil { - r.stateCh.Out <- p2p.Envelope{ + if err := r.stateCh.Send(ctx, p2p.Envelope{ To: ps.peerID, Message: &tmcons.VoteSetMaj23{ Height: prs.Height, @@ -936,6 +968,8 @@ OUTER_LOOP: Type: tmproto.PrecommitType, BlockID: commit.BlockID.ToProto(), }, + }); err != nil { + return } timer.Reset(r.state.config.PeerQueryMaj23SleepDuration) @@ -1006,7 +1040,7 @@ func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpda // Send our state to the peer. If we're block-syncing, broadcast a // RoundStepMessage later upon SwitchToConsensus(). if !r.waitSync { - go r.sendNewRoundStepMessage(ctx, ps.peerID) + go func() { _ = r.sendNewRoundStepMessage(ctx, ps.peerID) }() } } @@ -1036,7 +1070,7 @@ func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpda // If we fail to find the peer state for the envelope sender, we perform a no-op // and return. This can happen when we process the envelope after the peer is // removed. -func (r *Reactor) handleStateMessage(envelope p2p.Envelope, msgI Message) error { +func (r *Reactor) handleStateMessage(ctx context.Context, envelope p2p.Envelope, msgI Message) error { ps, ok := r.GetPeerState(envelope.From) if !ok || ps == nil { r.logger.Debug("failed to find peer state", "peer", envelope.From, "ch_id", "StateChannel") @@ -1104,9 +1138,11 @@ func (r *Reactor) handleStateMessage(envelope p2p.Envelope, msgI Message) error eMsg.Votes = *votesProto } - r.voteSetBitsCh.Out <- p2p.Envelope{ + if err := r.voteSetBitsCh.Send(ctx, p2p.Envelope{ To: envelope.From, Message: eMsg, + }); err != nil { + return err } default: @@ -1120,7 +1156,7 @@ func (r *Reactor) handleStateMessage(envelope p2p.Envelope, msgI Message) error // fail to find the peer state for the envelope sender, we perform a no-op and // return. This can happen when we process the envelope after the peer is // removed. -func (r *Reactor) handleDataMessage(envelope p2p.Envelope, msgI Message) error { +func (r *Reactor) handleDataMessage(ctx context.Context, envelope p2p.Envelope, msgI Message) error { logger := r.logger.With("peer", envelope.From, "ch_id", "DataChannel") ps, ok := r.GetPeerState(envelope.From) @@ -1139,17 +1175,24 @@ func (r *Reactor) handleDataMessage(envelope p2p.Envelope, msgI Message) error { pMsg := msgI.(*ProposalMessage) ps.SetHasProposal(pMsg.Proposal) - r.state.peerMsgQueue <- msgInfo{pMsg, envelope.From} - + select { + case <-ctx.Done(): + return ctx.Err() + case r.state.peerMsgQueue <- msgInfo{pMsg, envelope.From}: + } case *tmcons.ProposalPOL: ps.ApplyProposalPOLMessage(msgI.(*ProposalPOLMessage)) - case *tmcons.BlockPart: bpMsg := msgI.(*BlockPartMessage) ps.SetHasProposalBlockPart(bpMsg.Height, bpMsg.Round, int(bpMsg.Part.Index)) r.Metrics.BlockParts.With("peer_id", string(envelope.From)).Add(1) - r.state.peerMsgQueue <- msgInfo{bpMsg, envelope.From} + select { + case r.state.peerMsgQueue <- msgInfo{bpMsg, envelope.From}: + return nil + case <-ctx.Done(): + return ctx.Err() + } default: return fmt.Errorf("received unknown message on DataChannel: %T", msg) @@ -1162,7 +1205,7 @@ func (r *Reactor) handleDataMessage(envelope p2p.Envelope, msgI Message) error { // fail to find the peer state for the envelope sender, we perform a no-op and // return. This can happen when we process the envelope after the peer is // removed. -func (r *Reactor) handleVoteMessage(envelope p2p.Envelope, msgI Message) error { +func (r *Reactor) handleVoteMessage(ctx context.Context, envelope p2p.Envelope, msgI Message) error { logger := r.logger.With("peer", envelope.From, "ch_id", "VoteChannel") ps, ok := r.GetPeerState(envelope.From) @@ -1188,20 +1231,22 @@ func (r *Reactor) handleVoteMessage(envelope p2p.Envelope, msgI Message) error { ps.EnsureVoteBitArrays(height-1, lastCommitSize) ps.SetHasVote(vMsg.Vote) - r.state.peerMsgQueue <- msgInfo{vMsg, envelope.From} - + select { + case r.state.peerMsgQueue <- msgInfo{vMsg, envelope.From}: + return nil + case <-ctx.Done(): + return ctx.Err() + } default: return fmt.Errorf("received unknown message on VoteChannel: %T", msg) } - - return nil } // handleVoteSetBitsMessage handles envelopes sent from peers on the // VoteSetBitsChannel. If we fail to find the peer state for the envelope sender, // we perform a no-op and return. This can happen when we process the envelope // after the peer is removed. -func (r *Reactor) handleVoteSetBitsMessage(envelope p2p.Envelope, msgI Message) error { +func (r *Reactor) handleVoteSetBitsMessage(ctx context.Context, envelope p2p.Envelope, msgI Message) error { logger := r.logger.With("peer", envelope.From, "ch_id", "VoteSetBitsChannel") ps, ok := r.GetPeerState(envelope.From) @@ -1259,7 +1304,7 @@ func (r *Reactor) handleVoteSetBitsMessage(envelope p2p.Envelope, msgI Message) // the p2p channel. // // NOTE: We block on consensus state for proposals, block parts, and votes. -func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err error) { +func (r *Reactor) handleMessage(ctx context.Context, chID p2p.ChannelID, envelope p2p.Envelope) (err error) { defer func() { if e := recover(); e != nil { err = fmt.Errorf("panic in processing message: %v", e) @@ -1290,16 +1335,16 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err switch chID { case StateChannel: - err = r.handleStateMessage(envelope, msgI) + err = r.handleStateMessage(ctx, envelope, msgI) case DataChannel: - err = r.handleDataMessage(envelope, msgI) + err = r.handleDataMessage(ctx, envelope, msgI) case VoteChannel: - err = r.handleVoteMessage(envelope, msgI) + err = r.handleVoteMessage(ctx, envelope, msgI) case VoteSetBitsChannel: - err = r.handleVoteSetBitsMessage(envelope, msgI) + err = r.handleVoteSetBitsMessage(ctx, envelope, msgI) default: err = fmt.Errorf("unknown channel ID (%d) for envelope (%v)", chID, envelope) @@ -1320,7 +1365,7 @@ func (r *Reactor) processStateCh(ctx context.Context) { r.logger.Debug("stopped listening on StateChannel; closing...") return case envelope := <-r.stateCh.In: - if err := r.handleMessage(r.stateCh.ID, envelope); err != nil { + if err := r.handleMessage(ctx, r.stateCh.ID, envelope); err != nil { r.logger.Error("failed to process message", "ch_id", r.stateCh.ID, "envelope", envelope, "err", err) if serr := r.stateCh.SendError(ctx, p2p.PeerError{ NodeID: envelope.From, @@ -1345,7 +1390,7 @@ func (r *Reactor) processDataCh(ctx context.Context) { r.logger.Debug("stopped listening on DataChannel; closing...") return case envelope := <-r.dataCh.In: - if err := r.handleMessage(r.dataCh.ID, envelope); err != nil { + if err := r.handleMessage(ctx, r.dataCh.ID, envelope); err != nil { r.logger.Error("failed to process message", "ch_id", r.dataCh.ID, "envelope", envelope, "err", err) if serr := r.dataCh.SendError(ctx, p2p.PeerError{ NodeID: envelope.From, @@ -1370,7 +1415,7 @@ func (r *Reactor) processVoteCh(ctx context.Context) { r.logger.Debug("stopped listening on VoteChannel; closing...") return case envelope := <-r.voteCh.In: - if err := r.handleMessage(r.voteCh.ID, envelope); err != nil { + if err := r.handleMessage(ctx, r.voteCh.ID, envelope); err != nil { r.logger.Error("failed to process message", "ch_id", r.voteCh.ID, "envelope", envelope, "err", err) if serr := r.voteCh.SendError(ctx, p2p.PeerError{ NodeID: envelope.From, @@ -1395,7 +1440,11 @@ func (r *Reactor) processVoteSetBitsCh(ctx context.Context) { r.logger.Debug("stopped listening on VoteSetBitsChannel; closing...") return case envelope := <-r.voteSetBitsCh.In: - if err := r.handleMessage(r.voteSetBitsCh.ID, envelope); err != nil { + if err := r.handleMessage(ctx, r.voteSetBitsCh.ID, envelope); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + r.logger.Error("failed to process message", "ch_id", r.voteSetBitsCh.ID, "envelope", envelope, "err", err) if serr := r.voteSetBitsCh.SendError(ctx, p2p.PeerError{ NodeID: envelope.From, diff --git a/internal/consensus/replay_test.go b/internal/consensus/replay_test.go index 036614b71..56a4924cd 100644 --- a/internal/consensus/replay_test.go +++ b/internal/consensus/replay_test.go @@ -391,7 +391,7 @@ func setupSimulator(ctx context.Context, t *testing.T) *simulatorTestSuite { proposal.Signature = p.Signature // set the proposal block - if err := css[0].SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil { + if err := css[0].SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } ensureNewProposal(proposalCh, height, round) @@ -423,7 +423,7 @@ func setupSimulator(ctx context.Context, t *testing.T) *simulatorTestSuite { proposal.Signature = p.Signature // set the proposal block - if err := css[0].SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil { + if err := css[0].SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } ensureNewProposal(proposalCh, height, round) @@ -482,7 +482,7 @@ func setupSimulator(ctx context.Context, t *testing.T) *simulatorTestSuite { proposal.Signature = p.Signature // set the proposal block - if err := css[0].SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil { + if err := css[0].SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } ensureNewProposal(proposalCh, height, round) @@ -545,7 +545,7 @@ func setupSimulator(ctx context.Context, t *testing.T) *simulatorTestSuite { proposal.Signature = p.Signature // set the proposal block - if err := css[0].SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil { + if err := css[0].SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } ensureNewProposal(proposalCh, height, round) diff --git a/internal/consensus/state.go b/internal/consensus/state.go index 02ab2ae54..051b7afba 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -511,58 +511,85 @@ func (cs *State) OpenWAL(ctx context.Context, walFile string) (WAL, error) { // TODO: should these return anything or let callers just use events? // AddVote inputs a vote. -func (cs *State) AddVote(vote *types.Vote, peerID types.NodeID) (added bool, err error) { +func (cs *State) AddVote(ctx context.Context, vote *types.Vote, peerID types.NodeID) error { if peerID == "" { - cs.internalMsgQueue <- msgInfo{&VoteMessage{vote}, ""} + select { + case <-ctx.Done(): + return ctx.Err() + case cs.internalMsgQueue <- msgInfo{&VoteMessage{vote}, ""}: + return nil + } } else { - cs.peerMsgQueue <- msgInfo{&VoteMessage{vote}, peerID} + select { + case <-ctx.Done(): + return ctx.Err() + case cs.peerMsgQueue <- msgInfo{&VoteMessage{vote}, peerID}: + return nil + } } // TODO: wait for event?! - return false, nil } // SetProposal inputs a proposal. -func (cs *State) SetProposal(proposal *types.Proposal, peerID types.NodeID) error { +func (cs *State) SetProposal(ctx context.Context, proposal *types.Proposal, peerID types.NodeID) error { if peerID == "" { - cs.internalMsgQueue <- msgInfo{&ProposalMessage{proposal}, ""} + select { + case <-ctx.Done(): + return ctx.Err() + case cs.internalMsgQueue <- msgInfo{&ProposalMessage{proposal}, ""}: + return nil + } } else { - cs.peerMsgQueue <- msgInfo{&ProposalMessage{proposal}, peerID} + select { + case <-ctx.Done(): + return ctx.Err() + case cs.peerMsgQueue <- msgInfo{&ProposalMessage{proposal}, peerID}: + return nil + } } // TODO: wait for event?! - return nil } // AddProposalBlockPart inputs a part of the proposal block. -func (cs *State) AddProposalBlockPart(height int64, round int32, part *types.Part, peerID types.NodeID) error { - +func (cs *State) AddProposalBlockPart(ctx context.Context, height int64, round int32, part *types.Part, peerID types.NodeID) error { if peerID == "" { - cs.internalMsgQueue <- msgInfo{&BlockPartMessage{height, round, part}, ""} + select { + case <-ctx.Done(): + return ctx.Err() + case cs.internalMsgQueue <- msgInfo{&BlockPartMessage{height, round, part}, ""}: + return nil + } } else { - cs.peerMsgQueue <- msgInfo{&BlockPartMessage{height, round, part}, peerID} + select { + case <-ctx.Done(): + return ctx.Err() + case cs.peerMsgQueue <- msgInfo{&BlockPartMessage{height, round, part}, peerID}: + return nil + } } // TODO: wait for event?! - return nil } // SetProposalAndBlock inputs the proposal and all block parts. func (cs *State) SetProposalAndBlock( + ctx context.Context, proposal *types.Proposal, block *types.Block, parts *types.PartSet, peerID types.NodeID, ) error { - if err := cs.SetProposal(proposal, peerID); err != nil { + if err := cs.SetProposal(ctx, proposal, peerID); err != nil { return err } for i := 0; i < int(parts.Total()); i++ { part := parts.GetPart(i) - if err := cs.AddProposalBlockPart(proposal.Height, proposal.Round, part, peerID); err != nil { + if err := cs.AddProposalBlockPart(ctx, proposal.Height, proposal.Round, part, peerID); err != nil { return err } } @@ -761,7 +788,7 @@ func (cs *State) newStep(ctx context.Context) { cs.logger.Error("failed publishing new round step", "err", err) } - cs.evsw.FireEvent(types.EventNewRoundStepValue, &cs.RoundState) + cs.evsw.FireEvent(ctx, types.EventNewRoundStepValue, &cs.RoundState) } } @@ -1607,7 +1634,7 @@ func (cs *State) enterCommit(ctx context.Context, height int64, commitRound int3 logger.Error("failed publishing valid block", "err", err) } - cs.evsw.FireEvent(types.EventValidBlockValue, &cs.RoundState) + cs.evsw.FireEvent(ctx, types.EventValidBlockValue, &cs.RoundState) } } } @@ -2075,7 +2102,7 @@ func (cs *State) addVote( return added, err } - cs.evsw.FireEvent(types.EventVoteValue, vote) + cs.evsw.FireEvent(ctx, types.EventVoteValue, vote) // if we can skip timeoutCommit and have all the votes now, if cs.config.SkipTimeoutCommit && cs.LastCommit.HasAll() { @@ -2104,7 +2131,7 @@ func (cs *State) addVote( if err := cs.eventBus.PublishEventVote(ctx, types.EventDataVote{Vote: vote}); err != nil { return added, err } - cs.evsw.FireEvent(types.EventVoteValue, vote) + cs.evsw.FireEvent(ctx, types.EventVoteValue, vote) switch vote.Type { case tmproto.PrevoteType: @@ -2158,7 +2185,7 @@ func (cs *State) addVote( cs.ProposalBlockParts = types.NewPartSetFromHeader(blockID.PartSetHeader) } - cs.evsw.FireEvent(types.EventValidBlockValue, &cs.RoundState) + cs.evsw.FireEvent(ctx, types.EventValidBlockValue, &cs.RoundState) if err := cs.eventBus.PublishEventValidBlock(ctx, cs.RoundStateEvent()); err != nil { return added, err } diff --git a/internal/consensus/state_test.go b/internal/consensus/state_test.go index 5d09908aa..387650704 100644 --- a/internal/consensus/state_test.go +++ b/internal/consensus/state_test.go @@ -251,7 +251,7 @@ func TestStateBadProposal(t *testing.T) { proposal.Signature = p.Signature // set the proposal block - if err := cs1.SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } @@ -314,7 +314,7 @@ func TestStateOversizedBlock(t *testing.T) { totalBytes += len(part.Bytes) } - if err := cs1.SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } @@ -621,7 +621,7 @@ func TestStateLockNoPOL(t *testing.T) { // now we're on a new round and not the proposer // so set the proposal block - if err := cs1.SetProposalAndBlock(prop, propBlock, propBlock.MakePartSet(partSize), ""); err != nil { + if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlock.MakePartSet(partSize), ""); err != nil { t.Fatal(err) } @@ -723,7 +723,7 @@ func TestStateLockPOLRelock(t *testing.T) { round++ // moving to the next round //XXX: this isnt guaranteed to get there before the timeoutPropose ... - if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } @@ -828,7 +828,7 @@ func TestStateLockPOLUnlock(t *testing.T) { cs1 unlocks! */ //XXX: this isnt guaranteed to get there before the timeoutPropose ... - if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } @@ -940,7 +940,7 @@ func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) { // we should have unlocked and locked on the new block, sending a precommit for this new block validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil) - if err := cs1.SetProposalAndBlock(prop, propBlock, secondBlockParts, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, secondBlockParts, "some peer"); err != nil { t.Fatal(err) } @@ -971,7 +971,7 @@ func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) { Round2 (vs3, C) // C C C C // C nil nil nil) */ - if err := cs1.SetProposalAndBlock(prop, propBlock, thirdPropBlockParts, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, thirdPropBlockParts, "some peer"); err != nil { t.Fatal(err) } @@ -1048,7 +1048,7 @@ func TestStateLockPOLSafety1(t *testing.T) { ensureNewRound(newRoundCh, height, round) //XXX: this isnt guaranteed to get there before the timeoutPropose ... - if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } /*Round2 @@ -1160,7 +1160,7 @@ func TestStateLockPOLSafety2(t *testing.T) { startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) - if err := cs1.SetProposalAndBlock(prop1, propBlock1, propBlockParts1, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, prop1, propBlock1, propBlockParts1, "some peer"); err != nil { t.Fatal(err) } ensureNewProposal(proposalCh, height, round) @@ -1193,7 +1193,7 @@ func TestStateLockPOLSafety2(t *testing.T) { newProp.Signature = p.Signature - if err := cs1.SetProposalAndBlock(newProp, propBlock0, propBlockParts0, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, newProp, propBlock0, propBlockParts0, "some peer"); err != nil { t.Fatal(err) } @@ -1428,7 +1428,7 @@ func TestSetValidBlockOnDelayedProposal(t *testing.T) { ensurePrecommit(voteCh, height, round) validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil) - if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } @@ -1658,7 +1658,7 @@ func TestCommitFromPreviousRound(t *testing.T) { assert.True(t, rs.ProposalBlock == nil) assert.True(t, rs.ProposalBlockParts.Header().Equals(propBlockParts.Header())) - if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } @@ -1797,7 +1797,7 @@ func TestResetTimeoutPrecommitUponNewHeight(t *testing.T) { prop, propBlock := decideProposal(ctx, cs1, vs2, height+1, 0) propBlockParts := propBlock.MakePartSet(partSize) - if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil { + if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) } ensureNewProposal(proposalCh, height+1, 0) diff --git a/internal/evidence/reactor.go b/internal/evidence/reactor.go index 29712581c..7302773ae 100644 --- a/internal/evidence/reactor.go +++ b/internal/evidence/reactor.go @@ -319,15 +319,13 @@ func (r *Reactor) broadcastEvidenceLoop(ctx context.Context, peerID types.NodeID // and thus would not be able to process the evidence correctly. Also, the // peer may receive this piece of evidence multiple times if it added and // removed frequently from the broadcasting peer. - select { - case <-ctx.Done(): - return - case r.evidenceCh.Out <- p2p.Envelope{ + if err := r.evidenceCh.Send(ctx, p2p.Envelope{ To: peerID, Message: &tmproto.EvidenceList{ Evidence: []tmproto.Evidence{*evProto}, }, - }: + }); err != nil { + return } r.logger.Debug("gossiped evidence to peer", "evidence", ev, "peer", peerID) diff --git a/internal/evidence/reactor_test.go b/internal/evidence/reactor_test.go index df636ba66..156d47c6f 100644 --- a/internal/evidence/reactor_test.go +++ b/internal/evidence/reactor_test.go @@ -108,8 +108,8 @@ func setup(ctx context.Context, t *testing.T, stateStores []sm.Store, chBuf uint } } - leaktest.Check(t) }) + t.Cleanup(leaktest.Check(t)) return rts } @@ -191,21 +191,6 @@ func (rts *reactorTestSuite) waitForEvidence(t *testing.T, evList types.Evidence wg.Wait() } -func (rts *reactorTestSuite) assertEvidenceChannelsEmpty(t *testing.T) { - t.Helper() - - for id, r := range rts.reactors { - require.NoError(t, r.Stop(), "stopping reactor #%s", id) - r.Wait() - require.False(t, r.IsRunning(), "reactor #%d did not stop", id) - - } - - for id, ech := range rts.evidenceChannels { - require.Empty(t, ech.Out, "checking channel #%q", id) - } -} - func createEvidenceList( t *testing.T, pool *evidence.Pool, @@ -325,8 +310,6 @@ func TestReactorBroadcastEvidence(t *testing.T) { for _, pool := range rts.pools { require.Equal(t, numEvidence, int(pool.Size())) } - - rts.assertEvidenceChannelsEmpty(t) } // TestReactorSelectiveBroadcast tests a context where we have two reactors @@ -367,8 +350,6 @@ func TestReactorBroadcastEvidence_Lagging(t *testing.T) { require.Equal(t, numEvidence, int(rts.pools[primary.NodeID].Size())) require.Equal(t, int(height2), int(rts.pools[secondary.NodeID].Size())) - - rts.assertEvidenceChannelsEmpty(t) } func TestReactorBroadcastEvidence_Pending(t *testing.T) { diff --git a/internal/mempool/reactor.go b/internal/mempool/reactor.go index 19d857614..2e1a94f01 100644 --- a/internal/mempool/reactor.go +++ b/internal/mempool/reactor.go @@ -349,15 +349,15 @@ func (r *Reactor) broadcastTxRoutine(ctx context.Context, peerID types.NodeID, c if ok := r.mempool.txStore.TxHasPeer(memTx.hash, peerMempoolID); !ok { // Send the mempool tx to the corresponding peer. Note, the peer may be // behind and thus would not be able to process the mempool tx correctly. - select { - case r.mempoolCh.Out <- p2p.Envelope{ + if err := r.mempoolCh.Send(ctx, p2p.Envelope{ To: peerID, Message: &protomem.Txs{ Txs: [][]byte{memTx.tx}, }, - }: - case <-ctx.Done(): + }); err != nil { + return } + r.logger.Debug( "gossiped tx to peer", "tx", fmt.Sprintf("%X", memTx.tx.Hash()), diff --git a/internal/mempool/reactor_test.go b/internal/mempool/reactor_test.go index f75809744..e3f0b5718 100644 --- a/internal/mempool/reactor_test.go +++ b/internal/mempool/reactor_test.go @@ -109,14 +109,6 @@ func (rts *reactorTestSuite) start(ctx context.Context, t *testing.T) { "network does not have expected number of nodes") } -func (rts *reactorTestSuite) assertMempoolChannelsDrained(t *testing.T) { - t.Helper() - - for _, mch := range rts.mempoolChannels { - require.Empty(t, mch.Out, "checking channel %q (len=%d)", mch.ID, len(mch.Out)) - } -} - func (rts *reactorTestSuite) waitForTxns(t *testing.T, txs []types.Tx, ids ...types.NodeID) { t.Helper() @@ -296,8 +288,6 @@ func TestReactorNoBroadcastToSender(t *testing.T) { require.Eventually(t, func() bool { return rts.mempools[secondary].Size() == 0 }, time.Minute, 100*time.Millisecond) - - rts.assertMempoolChannelsDrained(t) } func TestReactor_MaxTxBytes(t *testing.T) { @@ -334,8 +324,6 @@ func TestReactor_MaxTxBytes(t *testing.T) { tx2 := tmrand.Bytes(cfg.Mempool.MaxTxBytes + 1) err = rts.mempools[primary].CheckTx(ctx, tx2, nil, TxInfo{SenderID: UnknownPeerID}) require.Error(t, err) - - rts.assertMempoolChannelsDrained(t) } func TestDontExhaustMaxActiveIDs(t *testing.T) { @@ -359,30 +347,13 @@ func TestDontExhaustMaxActiveIDs(t *testing.T) { NodeID: peerID, } - rts.mempoolChannels[nodeID].Out <- p2p.Envelope{ + require.NoError(t, rts.mempoolChannels[nodeID].Send(ctx, p2p.Envelope{ To: peerID, Message: &protomem.Txs{ Txs: [][]byte{}, }, - } + })) } - - require.Eventually( - t, - func() bool { - for _, mch := range rts.mempoolChannels { - if len(mch.Out) > 0 { - return false - } - } - - return true - }, - time.Minute, - 10*time.Millisecond, - ) - - rts.assertMempoolChannelsDrained(t) } func TestMempoolIDsPanicsIfNodeRequestsOvermaxActiveIDs(t *testing.T) { diff --git a/internal/p2p/channel.go b/internal/p2p/channel.go index 9296ca15e..da6955596 100644 --- a/internal/p2p/channel.go +++ b/internal/p2p/channel.go @@ -63,7 +63,7 @@ func (pe PeerError) Unwrap() error { return pe.Err } type Channel struct { ID ChannelID In <-chan Envelope // inbound messages (peers to reactors) - Out chan<- Envelope // outbound messages (reactors to peers) + outCh chan<- Envelope // outbound messages (reactors to peers) errCh chan<- PeerError // peer error reporting messageType proto.Message // the channel's message type, used for unmarshaling @@ -82,7 +82,7 @@ func NewChannel( ID: id, messageType: messageType, In: inCh, - Out: outCh, + outCh: outCh, errCh: errCh, } } @@ -93,7 +93,7 @@ func (ch *Channel) Send(ctx context.Context, envelope Envelope) error { select { case <-ctx.Done(): return ctx.Err() - case ch.Out <- envelope: + case ch.outCh <- envelope: return nil } } diff --git a/internal/p2p/channel_test.go b/internal/p2p/channel_test.go index 4b2ce5937..525eb18fb 100644 --- a/internal/p2p/channel_test.go +++ b/internal/p2p/channel_test.go @@ -24,7 +24,7 @@ func testChannel(size int) (*channelInternal, *Channel) { } ch := &Channel{ In: in.In, - Out: in.Out, + outCh: in.Out, errCh: in.Error, } return in, ch diff --git a/internal/p2p/p2ptest/require.go b/internal/p2p/p2ptest/require.go index b55d6a51f..22a1d2a81 100644 --- a/internal/p2p/p2ptest/require.go +++ b/internal/p2p/p2ptest/require.go @@ -64,26 +64,30 @@ func RequireReceiveUnordered(t *testing.T, channel *p2p.Channel, expect []p2p.En } // RequireSend requires that the given envelope is sent on the channel. -func RequireSend(t *testing.T, channel *p2p.Channel, envelope p2p.Envelope) { - timer := time.NewTimer(time.Second) // not time.After due to goroutine leaks - defer timer.Stop() - select { - case channel.Out <- envelope: - case <-timer.C: - require.Fail(t, "timed out sending message", "%v on channel %v", envelope, channel.ID) +func RequireSend(ctx context.Context, t *testing.T, channel *p2p.Channel, envelope p2p.Envelope) { + tctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + err := channel.Send(tctx, envelope) + switch { + case errors.Is(err, context.DeadlineExceeded): + require.Fail(t, "timed out sending message to %q", envelope.To) + default: + require.NoError(t, err, "unexpected error") } } // RequireSendReceive requires that a given Protobuf message is sent to the // given peer, and then that the given response is received back. func RequireSendReceive( + ctx context.Context, t *testing.T, channel *p2p.Channel, peerID types.NodeID, send proto.Message, receive proto.Message, ) { - RequireSend(t, channel, p2p.Envelope{To: peerID, Message: send}) + RequireSend(ctx, t, channel, p2p.Envelope{To: peerID, Message: send}) RequireReceive(t, channel, p2p.Envelope{From: peerID, Message: send}) } diff --git a/internal/p2p/pex/reactor.go b/internal/p2p/pex/reactor.go index 24aeec05f..f5eb2ab7f 100644 --- a/internal/p2p/pex/reactor.go +++ b/internal/p2p/pex/reactor.go @@ -165,12 +165,12 @@ func (r *Reactor) processPexCh(ctx context.Context) { // outbound requests for new peers case <-timer.C: - r.sendRequestForPeers() + r.sendRequestForPeers(ctx) // inbound requests for new peers or responses to requests sent by this // reactor case envelope := <-r.pexCh.In: - if err := r.handleMessage(r.pexCh.ID, envelope); err != nil { + if err := r.handleMessage(ctx, r.pexCh.ID, envelope); err != nil { r.logger.Error("failed to process message", "ch_id", r.pexCh.ID, "envelope", envelope, "err", err) if serr := r.pexCh.SendError(ctx, p2p.PeerError{ NodeID: envelope.From, @@ -199,7 +199,7 @@ func (r *Reactor) processPeerUpdates(ctx context.Context) { } // handlePexMessage handles envelopes sent from peers on the PexChannel. -func (r *Reactor) handlePexMessage(envelope p2p.Envelope) error { +func (r *Reactor) handlePexMessage(ctx context.Context, envelope p2p.Envelope) error { logger := r.logger.With("peer", envelope.From) switch msg := envelope.Message.(type) { @@ -219,9 +219,11 @@ func (r *Reactor) handlePexMessage(envelope p2p.Envelope) error { URL: addr.String(), } } - r.pexCh.Out <- p2p.Envelope{ + if err := r.pexCh.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &protop2p.PexResponse{Addresses: pexAddresses}, + }); err != nil { + return err } case *protop2p.PexResponse: @@ -264,7 +266,7 @@ func (r *Reactor) handlePexMessage(envelope p2p.Envelope) error { // handleMessage handles an Envelope sent from a peer on a specific p2p Channel. // It will handle errors and any possible panics gracefully. A caller can handle // any error returned by sending a PeerError on the respective channel. -func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err error) { +func (r *Reactor) handleMessage(ctx context.Context, chID p2p.ChannelID, envelope p2p.Envelope) (err error) { defer func() { if e := recover(); e != nil { err = fmt.Errorf("panic in processing message: %v", e) @@ -280,7 +282,7 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err switch chID { case p2p.ChannelID(PexChannel): - err = r.handlePexMessage(envelope) + err = r.handlePexMessage(ctx, envelope) default: err = fmt.Errorf("unknown channel ID (%d) for envelope (%v)", chID, envelope) @@ -312,7 +314,7 @@ func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { // peer a request for more peer addresses. The function then moves the // peer into the requestsSent bucket and calculates when the next request // time should be -func (r *Reactor) sendRequestForPeers() { +func (r *Reactor) sendRequestForPeers(ctx context.Context) { r.mtx.Lock() defer r.mtx.Unlock() if len(r.availablePeers) == 0 { @@ -330,9 +332,11 @@ func (r *Reactor) sendRequestForPeers() { } // send out the pex request - r.pexCh.Out <- p2p.Envelope{ + if err := r.pexCh.Send(ctx, p2p.Envelope{ To: peerID, Message: &protop2p.PexRequest{}, + }); err != nil { + return } // remove the peer from the abvailable peers list and mark it in the requestsSent map diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index 28da5c72c..3f0adcf89 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -45,7 +45,7 @@ func TestReactorBasic(t *testing.T) { // assert that when a mock node sends a request it receives a response (and // the correct one) - testNet.sendRequest(t, firstNode, secondNode) + testNet.sendRequest(ctx, t, firstNode, secondNode) testNet.listenForResponse(t, secondNode, firstNode, shortWait, []p2pproto.PexAddress(nil)) } @@ -112,8 +112,8 @@ func TestReactorSendsResponseWithoutRequest(t *testing.T) { // firstNode sends the secondNode an unrequested response // NOTE: secondNode will send a request by default during startup so we send // two responses to counter that. - testNet.sendResponse(t, firstNode, secondNode, []int{thirdNode}) - testNet.sendResponse(t, firstNode, secondNode, []int{thirdNode}) + testNet.sendResponse(ctx, t, firstNode, secondNode, []int{thirdNode}) + testNet.sendResponse(ctx, t, firstNode, secondNode, []int{thirdNode}) // secondNode should evict the firstNode testNet.listenForPeerUpdate(ctx, t, secondNode, firstNode, p2p.PeerStatusDown, shortWait) @@ -139,7 +139,7 @@ func TestReactorNeverSendsTooManyPeers(t *testing.T) { // first we check that even although we have 110 peers, honest pex reactors // only send 100 (test if secondNode sends firstNode 100 addresses) - testNet.pingAndlistenForNAddresses(t, secondNode, firstNode, shortWait, 100) + testNet.pingAndlistenForNAddresses(ctx, t, secondNode, firstNode, shortWait, 100) } func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { @@ -475,11 +475,13 @@ func (r *reactorTestSuite) listenForRequest(t *testing.T, fromNode, toNode int, } func (r *reactorTestSuite) pingAndlistenForNAddresses( + ctx context.Context, t *testing.T, fromNode, toNode int, waitPeriod time.Duration, addresses int, ) { + t.Helper() r.logger.Info("Listening for addresses", "from", fromNode, "to", toNode) to, from := r.checkNodePair(t, toNode, fromNode) conditional := func(msg p2p.Envelope) bool { @@ -499,10 +501,10 @@ func (r *reactorTestSuite) pingAndlistenForNAddresses( // if we didn't get the right length, we wait and send the // request again time.Sleep(300 * time.Millisecond) - r.sendRequest(t, toNode, fromNode) + r.sendRequest(ctx, t, toNode, fromNode) return false } - r.sendRequest(t, toNode, fromNode) + r.sendRequest(ctx, t, toNode, fromNode) r.listenFor(t, to, conditional, assertion, waitPeriod) } @@ -566,27 +568,30 @@ func (r *reactorTestSuite) getAddressesFor(nodes []int) []p2pproto.PexAddress { return addresses } -func (r *reactorTestSuite) sendRequest(t *testing.T, fromNode, toNode int) { +func (r *reactorTestSuite) sendRequest(ctx context.Context, t *testing.T, fromNode, toNode int) { + t.Helper() to, from := r.checkNodePair(t, toNode, fromNode) - r.pexChannels[from].Out <- p2p.Envelope{ + require.NoError(t, r.pexChannels[from].Send(ctx, p2p.Envelope{ To: to, Message: &p2pproto.PexRequest{}, - } + })) } func (r *reactorTestSuite) sendResponse( + ctx context.Context, t *testing.T, fromNode, toNode int, withNodes []int, ) { + t.Helper() from, to := r.checkNodePair(t, fromNode, toNode) addrs := r.getAddressesFor(withNodes) - r.pexChannels[from].Out <- p2p.Envelope{ + require.NoError(t, r.pexChannels[from].Send(ctx, p2p.Envelope{ To: to, Message: &p2pproto.PexResponse{ Addresses: addrs, }, - } + })) } func (r *reactorTestSuite) requireNumberOfPeers( diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 2974c1e88..a6d5fdc03 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -32,11 +32,12 @@ func echoReactor(ctx context.Context, channel *p2p.Channel) { select { case envelope := <-channel.In: value := envelope.Message.(*p2ptest.Message).Value - channel.Out <- p2p.Envelope{ + if err := channel.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &p2ptest.Message{Value: value}, + }); err != nil { + return } - case <-ctx.Done(): return } @@ -64,14 +65,14 @@ func TestRouter_Network(t *testing.T) { // Sending a message to each peer should work. for _, peer := range peers { - p2ptest.RequireSendReceive(t, channel, peer.NodeID, + p2ptest.RequireSendReceive(ctx, t, channel, peer.NodeID, &p2ptest.Message{Value: "foo"}, &p2ptest.Message{Value: "foo"}, ) } // Sending a broadcast should return back a message from all peers. - p2ptest.RequireSend(t, channel, p2p.Envelope{ + p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{ Broadcast: true, Message: &p2ptest.Message{Value: "bar"}, }) @@ -151,13 +152,13 @@ func TestRouter_Channel_Basic(t *testing.T) { require.NoError(t, err) // We should be able to send on the channel, even though there are no peers. - p2ptest.RequireSend(t, channel, p2p.Envelope{ + p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{ To: types.NodeID(strings.Repeat("a", 40)), Message: &p2ptest.Message{Value: "foo"}, }) // A message to ourselves should be dropped. - p2ptest.RequireSend(t, channel, p2p.Envelope{ + p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{ To: selfID, Message: &p2ptest.Message{Value: "self"}, }) @@ -184,40 +185,40 @@ func TestRouter_Channel_SendReceive(t *testing.T) { // Sending a message a->b should work, and not send anything // further to a, b, or c. - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) + p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "foo"}}) p2ptest.RequireEmpty(t, a, b, c) // Sending a nil message a->b should be dropped. - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: nil}) + p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: nil}) p2ptest.RequireEmpty(t, a, b, c) // Sending a different message type should be dropped. - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) + p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) p2ptest.RequireEmpty(t, a, b, c) // Sending to an unknown peer should be dropped. - p2ptest.RequireSend(t, a, p2p.Envelope{ + p2ptest.RequireSend(ctx, t, a, p2p.Envelope{ To: types.NodeID(strings.Repeat("a", 40)), Message: &p2ptest.Message{Value: "a"}, }) p2ptest.RequireEmpty(t, a, b, c) // Sending without a recipient should be dropped. - p2ptest.RequireSend(t, a, p2p.Envelope{Message: &p2ptest.Message{Value: "noto"}}) + p2ptest.RequireSend(ctx, t, a, p2p.Envelope{Message: &p2ptest.Message{Value: "noto"}}) p2ptest.RequireEmpty(t, a, b, c) // Sending to self should be dropped. - p2ptest.RequireSend(t, a, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "self"}}) + p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "self"}}) p2ptest.RequireEmpty(t, a, b, c) // Removing b and sending to it should be dropped. network.Remove(ctx, t, bID) - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "nob"}}) + p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "nob"}}) p2ptest.RequireEmpty(t, a, b, c) // After all this, sending a message c->a should work. - p2ptest.RequireSend(t, c, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "bar"}}) + p2ptest.RequireSend(ctx, t, c, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "bar"}}) p2ptest.RequireReceive(t, a, p2p.Envelope{From: cID, Message: &p2ptest.Message{Value: "bar"}}) p2ptest.RequireEmpty(t, a, b, c) @@ -244,7 +245,7 @@ func TestRouter_Channel_Broadcast(t *testing.T) { network.Start(ctx, t) // Sending a broadcast from b should work. - p2ptest.RequireSend(t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}}) + p2ptest.RequireSend(ctx, t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}}) p2ptest.RequireReceive(t, a, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) p2ptest.RequireReceive(t, c, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) p2ptest.RequireReceive(t, d, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) @@ -252,7 +253,7 @@ func TestRouter_Channel_Broadcast(t *testing.T) { // Removing one node from the network shouldn't prevent broadcasts from working. network.Remove(ctx, t, dID) - p2ptest.RequireSend(t, a, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "bar"}}) + p2ptest.RequireSend(ctx, t, a, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "bar"}}) p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) p2ptest.RequireReceive(t, c, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) p2ptest.RequireEmpty(t, a, b, c, d) @@ -285,16 +286,16 @@ func TestRouter_Channel_Wrapper(t *testing.T) { // Since wrapperMessage implements p2p.Wrapper and handles Message, it // should automatically wrap and unwrap sent messages -- we prepend the // wrapper actions to the message value to signal this. - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) + p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "unwrap:wrap:foo"}}) // If we send a different message that can't be wrapped, it should be dropped. - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) + p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) p2ptest.RequireEmpty(t, b) // If we send the wrapper message itself, it should also be passed through // since WrapperMessage supports it, and should only be unwrapped at the receiver. - p2ptest.RequireSend(t, a, p2p.Envelope{ + p2ptest.RequireSend(ctx, t, a, p2p.Envelope{ To: bID, Message: &wrapperMessage{Message: p2ptest.Message{Value: "foo"}}, }) @@ -960,10 +961,10 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { channel, err := router.OpenChannel(ctx, chDesc) require.NoError(t, err) - channel.Out <- p2p.Envelope{ + require.NoError(t, channel.Send(ctx, p2p.Envelope{ To: peer.NodeID, Message: &p2ptest.Message{Value: "Hi"}, - } + })) require.NoError(t, router.Stop()) mockTransport.AssertExpectations(t) diff --git a/internal/statesync/dispatcher.go b/internal/statesync/dispatcher.go index 8620e6285..2e476c25d 100644 --- a/internal/statesync/dispatcher.go +++ b/internal/statesync/dispatcher.go @@ -26,16 +26,16 @@ var ( // NOTE: It is not the responsibility of the dispatcher to verify the light blocks. type Dispatcher struct { // the channel with which to send light block requests on - requestCh chan<- p2p.Envelope + requestCh *p2p.Channel mtx sync.Mutex // all pending calls that have been dispatched and are awaiting an answer calls map[types.NodeID]chan *types.LightBlock } -func NewDispatcher(requestCh chan<- p2p.Envelope) *Dispatcher { +func NewDispatcher(requestChannel *p2p.Channel) *Dispatcher { return &Dispatcher{ - requestCh: requestCh, + requestCh: requestChannel, calls: make(map[types.NodeID]chan *types.LightBlock), } } @@ -91,11 +91,14 @@ func (d *Dispatcher) dispatch(ctx context.Context, peer types.NodeID, height int d.calls[peer] = ch // send request - d.requestCh <- p2p.Envelope{ + if err := d.requestCh.Send(ctx, p2p.Envelope{ To: peer, Message: &ssproto.LightBlockRequest{ Height: uint64(height), }, + }); err != nil { + close(ch) + return ch, err } return ch, nil diff --git a/internal/statesync/dispatcher_test.go b/internal/statesync/dispatcher_test.go index e717dad12..7441327a8 100644 --- a/internal/statesync/dispatcher_test.go +++ b/internal/statesync/dispatcher_test.go @@ -18,16 +18,32 @@ import ( "github.com/tendermint/tendermint/types" ) +type channelInternal struct { + In chan p2p.Envelope + Out chan p2p.Envelope + Error chan p2p.PeerError +} + +func testChannel(size int) (*channelInternal, *p2p.Channel) { + in := &channelInternal{ + In: make(chan p2p.Envelope, size), + Out: make(chan p2p.Envelope, size), + Error: make(chan p2p.PeerError, size), + } + return in, p2p.NewChannel(0, nil, in.In, in.Out, in.Error) +} + func TestDispatcherBasic(t *testing.T) { t.Cleanup(leaktest.Check(t)) const numPeers = 5 - ch := make(chan p2p.Envelope, 100) - closeCh := make(chan struct{}) - defer close(closeCh) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + chans, ch := testChannel(100) d := NewDispatcher(ch) - go handleRequests(t, d, ch, closeCh) + go handleRequests(ctx, t, d, chans.Out) peers := createPeerSet(numPeers) wg := sync.WaitGroup{} @@ -52,19 +68,24 @@ func TestDispatcherBasic(t *testing.T) { func TestDispatcherReturnsNoBlock(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ch := make(chan p2p.Envelope, 100) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + chans, ch := testChannel(100) + d := NewDispatcher(ch) - doneCh := make(chan struct{}) + peer := factory.NodeID("a") go func() { - <-ch + <-chans.Out require.NoError(t, d.Respond(nil, peer)) - close(doneCh) + cancel() }() - lb, err := d.LightBlock(context.Background(), 1, peer) - <-doneCh + lb, err := d.LightBlock(ctx, 1, peer) + <-ctx.Done() require.Nil(t, lb) require.Nil(t, err) @@ -72,11 +93,15 @@ func TestDispatcherReturnsNoBlock(t *testing.T) { func TestDispatcherTimeOutWaitingOnLightBlock(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ch := make(chan p2p.Envelope, 100) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, ch := testChannel(100) d := NewDispatcher(ch) peer := factory.NodeID("a") - ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancelFunc := context.WithTimeout(ctx, 10*time.Millisecond) defer cancelFunc() lb, err := d.LightBlock(ctx, 1, peer) @@ -89,13 +114,15 @@ func TestDispatcherTimeOutWaitingOnLightBlock(t *testing.T) { func TestDispatcherProviders(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ch := make(chan p2p.Envelope, 100) chainID := "test-chain" - closeCh := make(chan struct{}) - defer close(closeCh) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + chans, ch := testChannel(100) d := NewDispatcher(ch) - go handleRequests(t, d, ch, closeCh) + go handleRequests(ctx, t, d, chans.Out) peers := createPeerSet(5) providers := make([]*BlockProvider, len(peers)) @@ -270,7 +297,7 @@ func TestPeerListRemove(t *testing.T) { // handleRequests is a helper function usually run in a separate go routine to // imitate the expected responses of the reactor wired to the dispatcher -func handleRequests(t *testing.T, d *Dispatcher, ch chan p2p.Envelope, closeCh chan struct{}) { +func handleRequests(ctx context.Context, t *testing.T, d *Dispatcher, ch chan p2p.Envelope) { t.Helper() for { select { @@ -280,7 +307,7 @@ func handleRequests(t *testing.T, d *Dispatcher, ch chan p2p.Envelope, closeCh c resp := mockLBResp(t, peer, int64(height), time.Now()) block, _ := resp.block.ToProto() require.NoError(t, d.Respond(block, resp.peer)) - case <-closeCh: + case <-ctx.Done(): return } } diff --git a/internal/statesync/reactor.go b/internal/statesync/reactor.go index 61e3dec08..09716fb23 100644 --- a/internal/statesync/reactor.go +++ b/internal/statesync/reactor.go @@ -195,7 +195,7 @@ func NewReactor( stateStore: stateStore, blockStore: blockStore, peers: newPeerList(), - dispatcher: NewDispatcher(blockCh.Out), + dispatcher: NewDispatcher(blockCh), providers: make(map[types.NodeID]*BlockProvider), metrics: ssMetrics, } @@ -256,8 +256,8 @@ func (r *Reactor) Sync(ctx context.Context) (sm.State, error) { r.conn, r.connQuery, r.stateProvider, - r.snapshotCh.Out, - r.chunkCh.Out, + r.snapshotCh, + r.chunkCh, r.tempDir, r.metrics, ) @@ -270,17 +270,12 @@ func (r *Reactor) Sync(ctx context.Context) (sm.State, error) { r.mtx.Unlock() }() - requestSnapshotsHook := func() { + requestSnapshotsHook := func() error { // request snapshots from all currently connected peers - msg := p2p.Envelope{ + return r.snapshotCh.Send(ctx, p2p.Envelope{ Broadcast: true, Message: &ssproto.SnapshotsRequest{}, - } - - select { - case <-ctx.Done(): - case r.snapshotCh.Out <- msg: - } + }) } state, commit, err := r.syncer.SyncAny(ctx, r.cfg.DiscoveryTime, requestSnapshotsHook) @@ -508,7 +503,7 @@ func (r *Reactor) backfill( // handleSnapshotMessage handles envelopes sent from peers on the // SnapshotChannel. It returns an error only if the Envelope.Message is unknown // for this channel. This should never be called outside of handleMessage. -func (r *Reactor) handleSnapshotMessage(envelope p2p.Envelope) error { +func (r *Reactor) handleSnapshotMessage(ctx context.Context, envelope p2p.Envelope) error { logger := r.logger.With("peer", envelope.From) switch msg := envelope.Message.(type) { @@ -526,7 +521,8 @@ func (r *Reactor) handleSnapshotMessage(envelope p2p.Envelope) error { "format", snapshot.Format, "peer", envelope.From, ) - r.snapshotCh.Out <- p2p.Envelope{ + + if err := r.snapshotCh.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &ssproto.SnapshotsResponse{ Height: snapshot.Height, @@ -535,6 +531,8 @@ func (r *Reactor) handleSnapshotMessage(envelope p2p.Envelope) error { Hash: snapshot.Hash, Metadata: snapshot.Metadata, }, + }); err != nil { + return err } } @@ -577,7 +575,7 @@ func (r *Reactor) handleSnapshotMessage(envelope p2p.Envelope) error { // handleChunkMessage handles envelopes sent from peers on the ChunkChannel. // It returns an error only if the Envelope.Message is unknown for this channel. // This should never be called outside of handleMessage. -func (r *Reactor) handleChunkMessage(envelope p2p.Envelope) error { +func (r *Reactor) handleChunkMessage(ctx context.Context, envelope p2p.Envelope) error { switch msg := envelope.Message.(type) { case *ssproto.ChunkRequest: r.logger.Debug( @@ -611,7 +609,7 @@ func (r *Reactor) handleChunkMessage(envelope p2p.Envelope) error { "chunk", msg.Index, "peer", envelope.From, ) - r.chunkCh.Out <- p2p.Envelope{ + if err := r.chunkCh.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &ssproto.ChunkResponse{ Height: msg.Height, @@ -620,6 +618,8 @@ func (r *Reactor) handleChunkMessage(envelope p2p.Envelope) error { Chunk: resp.Chunk, Missing: resp.Chunk == nil, }, + }); err != nil { + return err } case *ssproto.ChunkResponse: @@ -664,7 +664,7 @@ func (r *Reactor) handleChunkMessage(envelope p2p.Envelope) error { return nil } -func (r *Reactor) handleLightBlockMessage(envelope p2p.Envelope) error { +func (r *Reactor) handleLightBlockMessage(ctx context.Context, envelope p2p.Envelope) error { switch msg := envelope.Message.(type) { case *ssproto.LightBlockRequest: r.logger.Info("received light block request", "height", msg.Height) @@ -674,11 +674,13 @@ func (r *Reactor) handleLightBlockMessage(envelope p2p.Envelope) error { return err } if lb == nil { - r.blockCh.Out <- p2p.Envelope{ + if err := r.blockCh.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &ssproto.LightBlockResponse{ LightBlock: nil, }, + }); err != nil { + return err } return nil } @@ -691,13 +693,14 @@ func (r *Reactor) handleLightBlockMessage(envelope p2p.Envelope) error { // NOTE: If we don't have the light block we will send a nil light block // back to the requested node, indicating that we don't have it. - r.blockCh.Out <- p2p.Envelope{ + if err := r.blockCh.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &ssproto.LightBlockResponse{ LightBlock: lbproto, }, + }); err != nil { + return err } - case *ssproto.LightBlockResponse: var height int64 if msg.LightBlock != nil { @@ -715,7 +718,7 @@ func (r *Reactor) handleLightBlockMessage(envelope p2p.Envelope) error { return nil } -func (r *Reactor) handleParamsMessage(envelope p2p.Envelope) error { +func (r *Reactor) handleParamsMessage(ctx context.Context, envelope p2p.Envelope) error { switch msg := envelope.Message.(type) { case *ssproto.ParamsRequest: r.logger.Debug("received consensus params request", "height", msg.Height) @@ -726,14 +729,15 @@ func (r *Reactor) handleParamsMessage(envelope p2p.Envelope) error { } cpproto := cp.ToProto() - r.paramsCh.Out <- p2p.Envelope{ + if err := r.paramsCh.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &ssproto.ParamsResponse{ Height: msg.Height, ConsensusParams: cpproto, }, + }); err != nil { + return err } - case *ssproto.ParamsResponse: r.mtx.RLock() defer r.mtx.RUnlock() @@ -761,7 +765,7 @@ func (r *Reactor) handleParamsMessage(envelope p2p.Envelope) error { // handleMessage handles an Envelope sent from a peer on a specific p2p Channel. // It will handle errors and any possible panics gracefully. A caller can handle // any error returned by sending a PeerError on the respective channel. -func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err error) { +func (r *Reactor) handleMessage(ctx context.Context, chID p2p.ChannelID, envelope p2p.Envelope) (err error) { defer func() { if e := recover(); e != nil { err = fmt.Errorf("panic in processing message: %v", e) @@ -777,17 +781,13 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err switch chID { case SnapshotChannel: - err = r.handleSnapshotMessage(envelope) - + err = r.handleSnapshotMessage(ctx, envelope) case ChunkChannel: - err = r.handleChunkMessage(envelope) - + err = r.handleChunkMessage(ctx, envelope) case LightBlockChannel: - err = r.handleLightBlockMessage(envelope) - + err = r.handleLightBlockMessage(ctx, envelope) case ParamsChannel: - err = r.handleParamsMessage(envelope) - + err = r.handleParamsMessage(ctx, envelope) default: err = fmt.Errorf("unknown channel ID (%d) for envelope (%v)", chID, envelope) } @@ -806,7 +806,7 @@ func (r *Reactor) processCh(ctx context.Context, ch *p2p.Channel, chName string) r.logger.Debug("channel closed", "channel", chName) return case envelope := <-ch.In: - if err := r.handleMessage(ch.ID, envelope); err != nil { + if err := r.handleMessage(ctx, ch.ID, envelope); err != nil { r.logger.Error("failed to process message", "err", err, "channel", chName, @@ -999,7 +999,7 @@ func (r *Reactor) initStateProvider(ctx context.Context, chainID string, initial providers[idx] = NewBlockProvider(p, chainID, r.dispatcher) } - r.stateProvider, err = NewP2PStateProvider(ctx, chainID, initialHeight, providers, to, r.paramsCh.Out, spLogger) + r.stateProvider, err = NewP2PStateProvider(ctx, chainID, initialHeight, providers, to, r.paramsCh, spLogger) if err != nil { return fmt.Errorf("failed to initialize P2P state provider: %w", err) } diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index b1863f17b..e6273aca3 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -170,8 +170,8 @@ func setup( conn, connQuery, stateProvider, - rts.snapshotOutCh, - rts.chunkOutCh, + rts.snapshotChannel, + rts.chunkChannel, "", rts.reactor.metrics, ) diff --git a/internal/statesync/stateprovider.go b/internal/statesync/stateprovider.go index b622824cd..4f398ce77 100644 --- a/internal/statesync/stateprovider.go +++ b/internal/statesync/stateprovider.go @@ -200,7 +200,7 @@ type stateProviderP2P struct { tmsync.Mutex // light.Client is not concurrency-safe lc *light.Client initialHeight int64 - paramsSendCh chan<- p2p.Envelope + paramsSendCh *p2p.Channel paramsRecvCh chan types.ConsensusParams } @@ -212,7 +212,7 @@ func NewP2PStateProvider( initialHeight int64, providers []lightprovider.Provider, trustOptions light.TrustOptions, - paramsSendCh chan<- p2p.Envelope, + paramsSendCh *p2p.Channel, logger log.Logger, ) (StateProvider, error) { if len(providers) < 2 { @@ -382,15 +382,13 @@ func (s *stateProviderP2P) tryGetConsensusParamsFromWitnesses( return nil, fmt.Errorf("invalid provider (%s) node id: %w", p.String(), err) } - select { - case s.paramsSendCh <- p2p.Envelope{ + if err := s.paramsSendCh.Send(ctx, p2p.Envelope{ To: peer, Message: &ssproto.ParamsRequest{ Height: uint64(height), }, - }: - case <-ctx.Done(): - return nil, ctx.Err() + }); err != nil { + return nil, err } select { diff --git a/internal/statesync/syncer.go b/internal/statesync/syncer.go index a0f79494a..b5ea158a4 100644 --- a/internal/statesync/syncer.go +++ b/internal/statesync/syncer.go @@ -57,8 +57,8 @@ type syncer struct { conn proxy.AppConnSnapshot connQuery proxy.AppConnQuery snapshots *snapshotPool - snapshotCh chan<- p2p.Envelope - chunkCh chan<- p2p.Envelope + snapshotCh *p2p.Channel + chunkCh *p2p.Channel tempDir string fetchers int32 retryTimeout time.Duration @@ -79,8 +79,8 @@ func newSyncer( conn proxy.AppConnSnapshot, connQuery proxy.AppConnQuery, stateProvider StateProvider, - snapshotCh chan<- p2p.Envelope, - chunkCh chan<- p2p.Envelope, + snapshotCh *p2p.Channel, + chunkCh *p2p.Channel, tempDir string, metrics *Metrics, ) *syncer { @@ -138,29 +138,13 @@ func (s *syncer) AddSnapshot(peerID types.NodeID, snapshot *snapshot) (bool, err // AddPeer adds a peer to the pool. For now we just keep it simple and send a // single request to discover snapshots, later we may want to do retries and stuff. -func (s *syncer) AddPeer(ctx context.Context, peerID types.NodeID) (err error) { - defer func() { - // TODO: remove panic recover once AddPeer can no longer accientally send on - // closed channel. - // This recover was added to protect against the p2p message being sent - // to the snapshot channel after the snapshot channel was closed. - if r := recover(); r != nil { - err = fmt.Errorf("panic sending peer snapshot request: %v", r) - } - }() - +func (s *syncer) AddPeer(ctx context.Context, peerID types.NodeID) error { s.logger.Debug("Requesting snapshots from peer", "peer", peerID) - msg := p2p.Envelope{ + return s.snapshotCh.Send(ctx, p2p.Envelope{ To: peerID, Message: &ssproto.SnapshotsRequest{}, - } - - select { - case <-ctx.Done(): - case s.snapshotCh <- msg: - } - return err + }) } // RemovePeer removes a peer from the pool. @@ -175,14 +159,16 @@ func (s *syncer) RemovePeer(peerID types.NodeID) { func (s *syncer) SyncAny( ctx context.Context, discoveryTime time.Duration, - requestSnapshots func(), + requestSnapshots func() error, ) (sm.State, *types.Commit, error) { if discoveryTime != 0 && discoveryTime < minimumDiscoveryTime { discoveryTime = minimumDiscoveryTime } if discoveryTime > 0 { - requestSnapshots() + if err := requestSnapshots(); err != nil { + return sm.State{}, nil, err + } s.logger.Info(fmt.Sprintf("Discovering snapshots for %v", discoveryTime)) time.Sleep(discoveryTime) } @@ -506,7 +492,9 @@ func (s *syncer) fetchChunks(ctx context.Context, snapshot *snapshot, chunks *ch ticker := time.NewTicker(s.retryTimeout) defer ticker.Stop() - s.requestChunk(ctx, snapshot, index) + if err := s.requestChunk(ctx, snapshot, index); err != nil { + return + } select { case <-chunks.WaitFor(index): @@ -524,12 +512,16 @@ func (s *syncer) fetchChunks(ctx context.Context, snapshot *snapshot, chunks *ch } // requestChunk requests a chunk from a peer. -func (s *syncer) requestChunk(ctx context.Context, snapshot *snapshot, chunk uint32) { +// +// returns nil if there are no peers for the given snapshot or the +// request is successfully made and an error if the request cannot be +// completed +func (s *syncer) requestChunk(ctx context.Context, snapshot *snapshot, chunk uint32) error { peer := s.snapshots.GetPeer(snapshot) if peer == "" { s.logger.Error("No valid peers found for snapshot", "height", snapshot.Height, "format", snapshot.Format, "hash", snapshot.Hash) - return + return nil } s.logger.Debug( @@ -549,10 +541,10 @@ func (s *syncer) requestChunk(ctx context.Context, snapshot *snapshot, chunk uin }, } - select { - case s.chunkCh <- msg: - case <-ctx.Done(): + if err := s.chunkCh.Send(ctx, msg); err != nil { + return err } + return nil } // verifyApp verifies the sync, checking the app hash and last block height. It returns the diff --git a/internal/statesync/syncer_test.go b/internal/statesync/syncer_test.go index 816e6301a..bd4640fe0 100644 --- a/internal/statesync/syncer_test.go +++ b/internal/statesync/syncer_test.go @@ -184,7 +184,7 @@ func TestSyncer_SyncAny(t *testing.T) { LastBlockAppHash: []byte("app_hash"), }, nil) - newState, lastCommit, err := rts.syncer.SyncAny(ctx, 0, func() {}) + newState, lastCommit, err := rts.syncer.SyncAny(ctx, 0, func() error { return nil }) require.NoError(t, err) wg.Wait() @@ -223,7 +223,7 @@ func TestSyncer_SyncAny_noSnapshots(t *testing.T) { rts := setup(ctx, t, nil, nil, stateProvider, 2) - _, _, err := rts.syncer.SyncAny(ctx, 0, func() {}) + _, _, err := rts.syncer.SyncAny(ctx, 0, func() error { return nil }) require.Equal(t, errNoSnapshots, err) } @@ -246,7 +246,7 @@ func TestSyncer_SyncAny_abort(t *testing.T) { Snapshot: toABCI(s), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_ABORT}, nil) - _, _, err = rts.syncer.SyncAny(ctx, 0, func() {}) + _, _, err = rts.syncer.SyncAny(ctx, 0, func() error { return nil }) require.Equal(t, errAbort, err) rts.conn.AssertExpectations(t) } @@ -288,7 +288,7 @@ func TestSyncer_SyncAny_reject(t *testing.T) { Snapshot: toABCI(s11), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) - _, _, err = rts.syncer.SyncAny(ctx, 0, func() {}) + _, _, err = rts.syncer.SyncAny(ctx, 0, func() error { return nil }) require.Equal(t, errNoSnapshots, err) rts.conn.AssertExpectations(t) } @@ -326,7 +326,7 @@ func TestSyncer_SyncAny_reject_format(t *testing.T) { Snapshot: toABCI(s11), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_ABORT}, nil) - _, _, err = rts.syncer.SyncAny(ctx, 0, func() {}) + _, _, err = rts.syncer.SyncAny(ctx, 0, func() error { return nil }) require.Equal(t, errAbort, err) rts.conn.AssertExpectations(t) } @@ -375,7 +375,7 @@ func TestSyncer_SyncAny_reject_sender(t *testing.T) { Snapshot: toABCI(sa), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) - _, _, err = rts.syncer.SyncAny(ctx, 0, func() {}) + _, _, err = rts.syncer.SyncAny(ctx, 0, func() error { return nil }) require.Equal(t, errNoSnapshots, err) rts.conn.AssertExpectations(t) } @@ -401,7 +401,7 @@ func TestSyncer_SyncAny_abciError(t *testing.T) { Snapshot: toABCI(s), AppHash: []byte("app_hash"), }).Once().Return(nil, errBoom) - _, _, err = rts.syncer.SyncAny(ctx, 0, func() {}) + _, _, err = rts.syncer.SyncAny(ctx, 0, func() error { return nil }) require.True(t, errors.Is(err, errBoom)) rts.conn.AssertExpectations(t) } diff --git a/libs/events/event_cache.go b/libs/events/event_cache.go index f508e873d..41633cbef 100644 --- a/libs/events/event_cache.go +++ b/libs/events/event_cache.go @@ -1,5 +1,7 @@ package events +import "context" + // An EventCache buffers events for a Fireable // All events are cached. Filtering happens on Flush type EventCache struct { @@ -28,9 +30,9 @@ func (evc *EventCache) FireEvent(event string, data EventData) { // Fire events by running evsw.FireEvent on all cached events. Blocks. // Clears cached events -func (evc *EventCache) Flush() { +func (evc *EventCache) Flush(ctx context.Context) { for _, ei := range evc.events { - evc.evsw.FireEvent(ei.event, ei.data) + evc.evsw.FireEvent(ctx, ei.event, ei.data) } // Clear the buffer, since we only add to it with append it's safe to just set it to nil and maybe safe an allocation evc.events = nil diff --git a/libs/events/event_cache_test.go b/libs/events/event_cache_test.go index a5bb975c9..13ab341f6 100644 --- a/libs/events/event_cache_test.go +++ b/libs/events/event_cache_test.go @@ -16,23 +16,25 @@ func TestEventCache_Flush(t *testing.T) { err := evsw.Start(ctx) require.NoError(t, err) - err = evsw.AddListenerForEvent("nothingness", "", func(data EventData) { + err = evsw.AddListenerForEvent("nothingness", "", func(_ context.Context, data EventData) error { // Check we are not initializing an empty buffer full of zeroed eventInfos in the EventCache require.FailNow(t, "We should never receive a message on this switch since none are fired") + return nil }) require.NoError(t, err) evc := NewEventCache(evsw) - evc.Flush() + evc.Flush(ctx) // Check after reset - evc.Flush() + evc.Flush(ctx) fail := true pass := false - err = evsw.AddListenerForEvent("somethingness", "something", func(data EventData) { + err = evsw.AddListenerForEvent("somethingness", "something", func(_ context.Context, data EventData) error { if fail { require.FailNow(t, "Shouldn't see a message until flushed") } pass = true + return nil }) require.NoError(t, err) @@ -40,6 +42,6 @@ func TestEventCache_Flush(t *testing.T) { evc.FireEvent("something", struct{ int }{2}) evc.FireEvent("something", struct{ int }{3}) fail = false - evc.Flush() + evc.Flush(ctx) assert.True(t, pass) } diff --git a/libs/events/events.go b/libs/events/events.go index f6151e734..29ebd672f 100644 --- a/libs/events/events.go +++ b/libs/events/events.go @@ -33,7 +33,7 @@ type Eventable interface { // // FireEvent fires an event with the given name and data. type Fireable interface { - FireEvent(eventValue string, data EventData) + FireEvent(ctx context.Context, eventValue string, data EventData) } // EventSwitch is the interface for synchronous pubsub, where listeners @@ -148,7 +148,7 @@ func (evsw *eventSwitch) RemoveListenerForEvent(event string, listenerID string) } } -func (evsw *eventSwitch) FireEvent(event string, data EventData) { +func (evsw *eventSwitch) FireEvent(ctx context.Context, event string, data EventData) { // Get the eventCell evsw.mtx.RLock() eventCell := evsw.eventCells[event] @@ -159,7 +159,7 @@ func (evsw *eventSwitch) FireEvent(event string, data EventData) { } // Fire event for all listeners in eventCell - eventCell.FireEvent(data) + eventCell.FireEvent(ctx, data) } //----------------------------------------------------------------------------- @@ -190,7 +190,7 @@ func (cell *eventCell) RemoveListener(listenerID string) int { return numListeners } -func (cell *eventCell) FireEvent(data EventData) { +func (cell *eventCell) FireEvent(ctx context.Context, data EventData) { cell.mtx.RLock() eventCallbacks := make([]EventCallback, 0, len(cell.listeners)) for _, cb := range cell.listeners { @@ -199,13 +199,16 @@ func (cell *eventCell) FireEvent(data EventData) { cell.mtx.RUnlock() for _, cb := range eventCallbacks { - cb(data) + if err := cb(ctx, data); err != nil { + // should we log or abort here? + continue + } } } //----------------------------------------------------------------------------- -type EventCallback func(data EventData) +type EventCallback func(ctx context.Context, data EventData) error type eventListener struct { id string diff --git a/libs/events/events_test.go b/libs/events/events_test.go index 0e8667908..db9385ec3 100644 --- a/libs/events/events_test.go +++ b/libs/events/events_test.go @@ -24,12 +24,17 @@ func TestAddListenerForEventFireOnce(t *testing.T) { messages := make(chan EventData) require.NoError(t, evsw.AddListenerForEvent("listener", "event", - func(data EventData) { + func(ctx context.Context, data EventData) error { // test there's no deadlock if we remove the listener inside a callback evsw.RemoveListener("listener") - messages <- data + select { + case messages <- data: + return nil + case <-ctx.Done(): + return ctx.Err() + } })) - go evsw.FireEvent("event", "data") + go evsw.FireEvent(ctx, "event", "data") received := <-messages if received != "data" { t.Errorf("message received does not match: %v", received) @@ -51,13 +56,18 @@ func TestAddListenerForEventFireMany(t *testing.T) { numbers := make(chan uint64, 4) // subscribe one listener for one event require.NoError(t, evsw.AddListenerForEvent("listener", "event", - func(data EventData) { - numbers <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) // collect received events go sumReceivedNumbers(numbers, doneSum) // go fire events - go fireEvents(evsw, "event", doneSending, uint64(1)) + go fireEvents(ctx, evsw, "event", doneSending, uint64(1)) checkSum := <-doneSending close(numbers) eventSum := <-doneSum @@ -84,23 +94,38 @@ func TestAddListenerForDifferentEvents(t *testing.T) { numbers := make(chan uint64, 4) // subscribe one listener to three events require.NoError(t, evsw.AddListenerForEvent("listener", "event1", - func(data EventData) { - numbers <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener", "event2", - func(data EventData) { - numbers <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener", "event3", - func(data EventData) { - numbers <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) // collect received events go sumReceivedNumbers(numbers, doneSum) // go fire events - go fireEvents(evsw, "event1", doneSending1, uint64(1)) - go fireEvents(evsw, "event2", doneSending2, uint64(1)) - go fireEvents(evsw, "event3", doneSending3, uint64(1)) + go fireEvents(ctx, evsw, "event1", doneSending1, uint64(1)) + go fireEvents(ctx, evsw, "event2", doneSending2, uint64(1)) + go fireEvents(ctx, evsw, "event3", doneSending3, uint64(1)) var checkSum uint64 checkSum += <-doneSending1 checkSum += <-doneSending2 @@ -134,33 +159,58 @@ func TestAddDifferentListenerForDifferentEvents(t *testing.T) { numbers2 := make(chan uint64, 4) // subscribe two listener to three events require.NoError(t, evsw.AddListenerForEvent("listener1", "event1", - func(data EventData) { - numbers1 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers1 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener1", "event2", - func(data EventData) { - numbers1 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers1 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener1", "event3", - func(data EventData) { - numbers1 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers1 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener2", "event2", - func(data EventData) { - numbers2 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers2 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener2", "event3", - func(data EventData) { - numbers2 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers2 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) // collect received events for listener1 go sumReceivedNumbers(numbers1, doneSum1) // collect received events for listener2 go sumReceivedNumbers(numbers2, doneSum2) // go fire events - go fireEvents(evsw, "event1", doneSending1, uint64(1)) - go fireEvents(evsw, "event2", doneSending2, uint64(1001)) - go fireEvents(evsw, "event3", doneSending3, uint64(2001)) + go fireEvents(ctx, evsw, "event1", doneSending1, uint64(1)) + go fireEvents(ctx, evsw, "event2", doneSending2, uint64(1001)) + go fireEvents(ctx, evsw, "event3", doneSending3, uint64(2001)) checkSumEvent1 := <-doneSending1 checkSumEvent2 := <-doneSending2 checkSumEvent3 := <-doneSending3 @@ -209,9 +259,10 @@ func TestAddAndRemoveListenerConcurrency(t *testing.T) { // we explicitly ignore errors here, since the listener will sometimes be removed // (that's what we're testing) _ = evsw.AddListenerForEvent("listener", fmt.Sprintf("event%d", index), - func(data EventData) { + func(ctx context.Context, data EventData) error { t.Errorf("should not run callback for %d.\n", index) stopInputEvent = true + return nil }) } }() @@ -222,7 +273,7 @@ func TestAddAndRemoveListenerConcurrency(t *testing.T) { evsw.RemoveListener("listener") // remove the last listener for i := 0; i < roundCount && !stopInputEvent; i++ { - evsw.FireEvent(fmt.Sprintf("event%d", i), uint64(1001)) + evsw.FireEvent(ctx, fmt.Sprintf("event%d", i), uint64(1001)) } } @@ -245,23 +296,33 @@ func TestAddAndRemoveListener(t *testing.T) { numbers2 := make(chan uint64, 4) // subscribe two listener to three events require.NoError(t, evsw.AddListenerForEvent("listener", "event1", - func(data EventData) { - numbers1 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers1 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener", "event2", - func(data EventData) { - numbers2 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers2 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) // collect received events for event1 go sumReceivedNumbers(numbers1, doneSum1) // collect received events for event2 go sumReceivedNumbers(numbers2, doneSum2) // go fire events - go fireEvents(evsw, "event1", doneSending1, uint64(1)) + go fireEvents(ctx, evsw, "event1", doneSending1, uint64(1)) checkSumEvent1 := <-doneSending1 // after sending all event1, unsubscribe for all events evsw.RemoveListener("listener") - go fireEvents(evsw, "event2", doneSending2, uint64(1001)) + go fireEvents(ctx, evsw, "event2", doneSending2, uint64(1001)) checkSumEvent2 := <-doneSending2 close(numbers1) close(numbers2) @@ -287,17 +348,19 @@ func TestRemoveListener(t *testing.T) { sum1, sum2 := 0, 0 // add some listeners and make sure they work require.NoError(t, evsw.AddListenerForEvent("listener", "event1", - func(data EventData) { + func(ctx context.Context, data EventData) error { sum1++ + return nil })) require.NoError(t, evsw.AddListenerForEvent("listener", "event2", - func(data EventData) { + func(ctx context.Context, data EventData) error { sum2++ + return nil })) for i := 0; i < count; i++ { - evsw.FireEvent("event1", true) - evsw.FireEvent("event2", true) + evsw.FireEvent(ctx, "event1", true) + evsw.FireEvent(ctx, "event2", true) } assert.Equal(t, count, sum1) assert.Equal(t, count, sum2) @@ -305,8 +368,8 @@ func TestRemoveListener(t *testing.T) { // remove one by event and make sure it is gone evsw.RemoveListenerForEvent("event2", "listener") for i := 0; i < count; i++ { - evsw.FireEvent("event1", true) - evsw.FireEvent("event2", true) + evsw.FireEvent(ctx, "event1", true) + evsw.FireEvent(ctx, "event2", true) } assert.Equal(t, count*2, sum1) assert.Equal(t, count, sum2) @@ -314,8 +377,8 @@ func TestRemoveListener(t *testing.T) { // remove the listener entirely and make sure both gone evsw.RemoveListener("listener") for i := 0; i < count; i++ { - evsw.FireEvent("event1", true) - evsw.FireEvent("event2", true) + evsw.FireEvent(ctx, "event1", true) + evsw.FireEvent(ctx, "event2", true) } assert.Equal(t, count*2, sum1) assert.Equal(t, count, sum2) @@ -347,28 +410,58 @@ func TestRemoveListenersAsync(t *testing.T) { numbers2 := make(chan uint64, 4) // subscribe two listener to three events require.NoError(t, evsw.AddListenerForEvent("listener1", "event1", - func(data EventData) { - numbers1 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers1 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener1", "event2", - func(data EventData) { - numbers1 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers1 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener1", "event3", - func(data EventData) { - numbers1 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers1 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener2", "event1", - func(data EventData) { - numbers2 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers2 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener2", "event2", - func(data EventData) { - numbers2 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers2 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) require.NoError(t, evsw.AddListenerForEvent("listener2", "event3", - func(data EventData) { - numbers2 <- data.(uint64) + func(ctx context.Context, data EventData) error { + select { + case numbers2 <- data.(uint64): + return nil + case <-ctx.Done(): + return ctx.Err() + } })) // collect received events for event1 go sumReceivedNumbers(numbers1, doneSum1) @@ -382,7 +475,7 @@ func TestRemoveListenersAsync(t *testing.T) { eventNumber := r1.Intn(3) + 1 go evsw.AddListenerForEvent(fmt.Sprintf("listener%v", listenerNumber), //nolint:errcheck // ignore for tests fmt.Sprintf("event%v", eventNumber), - func(_ EventData) {}) + func(context.Context, EventData) error { return nil }) } } removeListenersStress := func() { @@ -395,10 +488,10 @@ func TestRemoveListenersAsync(t *testing.T) { } addListenersStress() // go fire events - go fireEvents(evsw, "event1", doneSending1, uint64(1)) + go fireEvents(ctx, evsw, "event1", doneSending1, uint64(1)) removeListenersStress() - go fireEvents(evsw, "event2", doneSending2, uint64(1001)) - go fireEvents(evsw, "event3", doneSending3, uint64(2001)) + go fireEvents(ctx, evsw, "event2", doneSending2, uint64(1001)) + go fireEvents(ctx, evsw, "event3", doneSending3, uint64(2001)) checkSumEvent1 := <-doneSending1 checkSumEvent2 := <-doneSending2 checkSumEvent3 := <-doneSending3 @@ -437,13 +530,21 @@ func sumReceivedNumbers(numbers, doneSum chan uint64) { // to `offset` + 999. It additionally returns the addition of all integers // sent on `doneChan` for assertion that all events have been sent, and enabling // the test to assert all events have also been received. -func fireEvents(evsw Fireable, event string, doneChan chan uint64, - offset uint64) { +func fireEvents(ctx context.Context, evsw Fireable, event string, doneChan chan uint64, offset uint64) { + defer close(doneChan) + var sentSum uint64 for i := offset; i <= offset+uint64(999); i++ { + if ctx.Err() != nil { + break + } + + evsw.FireEvent(ctx, event, i) sentSum += i - evsw.FireEvent(event, i) } - doneChan <- sentSum - close(doneChan) + + select { + case <-ctx.Done(): + case doneChan <- sentSum: + } }