Browse Source

p2p: refactor channel Send/out (#7414)

pull/7420/head
Sam Kleinman 3 years ago
committed by GitHub
parent
commit
bd6dc3ca88
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 625 additions and 452 deletions
  1. +22
    -18
      internal/blocksync/reactor.go
  2. +14
    -12
      internal/consensus/byzantine_test.go
  3. +2
    -3
      internal/consensus/invalid_test.go
  4. +139
    -90
      internal/consensus/reactor.go
  5. +4
    -4
      internal/consensus/replay_test.go
  6. +47
    -20
      internal/consensus/state.go
  7. +13
    -13
      internal/consensus/state_test.go
  8. +3
    -5
      internal/evidence/reactor.go
  9. +1
    -20
      internal/evidence/reactor_test.go
  10. +4
    -4
      internal/mempool/reactor.go
  11. +2
    -31
      internal/mempool/reactor_test.go
  12. +3
    -3
      internal/p2p/channel.go
  13. +1
    -1
      internal/p2p/channel_test.go
  14. +12
    -8
      internal/p2p/p2ptest/require.go
  15. +12
    -8
      internal/p2p/pex/reactor.go
  16. +16
    -11
      internal/p2p/pex/reactor_test.go
  17. +22
    -21
      internal/p2p/router_test.go
  18. +7
    -4
      internal/statesync/dispatcher.go
  19. +45
    -18
      internal/statesync/dispatcher_test.go
  20. +33
    -33
      internal/statesync/reactor.go
  21. +2
    -2
      internal/statesync/reactor_test.go
  22. +5
    -7
      internal/statesync/stateprovider.go
  23. +23
    -31
      internal/statesync/syncer.go
  24. +7
    -7
      internal/statesync/syncer_test.go
  25. +4
    -2
      libs/events/event_cache.go
  26. +7
    -5
      libs/events/event_cache_test.go
  27. +9
    -6
      libs/events/events.go
  28. +166
    -65
      libs/events/events_test.go

+ 22
- 18
internal/blocksync/reactor.go View File

@ -2,6 +2,7 @@ package blocksync
import (
"context"
"errors"
"fmt"
"runtime/debug"
"sync"
@ -185,40 +186,38 @@ func (r *Reactor) OnStop() {
// respondToPeer loads a block and sends it to the requesting peer, if we have it.
// Otherwise, we'll respond saying we do not have it.
func (r *Reactor) respondToPeer(msg *bcproto.BlockRequest, peerID types.NodeID) {
func (r *Reactor) respondToPeer(ctx context.Context, msg *bcproto.BlockRequest, peerID types.NodeID) error {
block := r.store.LoadBlock(msg.Height)
if block != nil {
blockProto, err := block.ToProto()
if err != nil {
r.logger.Error("failed to convert msg to protobuf", "err", err)
return
return err
}
r.blockSyncCh.Out <- p2p.Envelope{
return r.blockSyncCh.Send(ctx, p2p.Envelope{
To: peerID,
Message: &bcproto.BlockResponse{Block: blockProto},
}
return
})
}
r.logger.Info("peer requesting a block we do not have", "peer", peerID, "height", msg.Height)
r.blockSyncCh.Out <- p2p.Envelope{
return r.blockSyncCh.Send(ctx, p2p.Envelope{
To: peerID,
Message: &bcproto.NoBlockResponse{Height: msg.Height},
}
})
}
// handleBlockSyncMessage handles envelopes sent from peers on the
// BlockSyncChannel. It returns an error only if the Envelope.Message is unknown
// for this channel. This should never be called outside of handleMessage.
func (r *Reactor) handleBlockSyncMessage(envelope p2p.Envelope) error {
func (r *Reactor) handleBlockSyncMessage(ctx context.Context, envelope p2p.Envelope) error {
logger := r.logger.With("peer", envelope.From)
switch msg := envelope.Message.(type) {
case *bcproto.BlockRequest:
r.respondToPeer(msg, envelope.From)
return r.respondToPeer(ctx, msg, envelope.From)
case *bcproto.BlockResponse:
block, err := types.BlockFromProto(msg.Block)
if err != nil {
@ -229,14 +228,13 @@ func (r *Reactor) handleBlockSyncMessage(envelope p2p.Envelope) error {
r.pool.AddBlock(envelope.From, block, block.Size())
case *bcproto.StatusRequest:
r.blockSyncCh.Out <- p2p.Envelope{
return r.blockSyncCh.Send(ctx, p2p.Envelope{
To: envelope.From,
Message: &bcproto.StatusResponse{
Height: r.store.Height(),
Base: r.store.Base(),
},
}
})
case *bcproto.StatusResponse:
r.pool.SetPeerRange(envelope.From, msg.Base, msg.Height)
@ -253,7 +251,7 @@ func (r *Reactor) handleBlockSyncMessage(envelope p2p.Envelope) error {
// handleMessage handles an Envelope sent from a peer on a specific p2p Channel.
// It will handle errors and any possible panics gracefully. A caller can handle
// any error returned by sending a PeerError on the respective channel.
func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err error) {
func (r *Reactor) handleMessage(ctx context.Context, chID p2p.ChannelID, envelope p2p.Envelope) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("panic in processing message: %v", e)
@ -269,7 +267,7 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err
switch chID {
case BlockSyncChannel:
err = r.handleBlockSyncMessage(envelope)
err = r.handleBlockSyncMessage(ctx, envelope)
default:
err = fmt.Errorf("unknown channel ID (%d) for envelope (%v)", chID, envelope)
@ -290,7 +288,11 @@ func (r *Reactor) processBlockSyncCh(ctx context.Context) {
r.logger.Debug("stopped listening on block sync channel; closing...")
return
case envelope := <-r.blockSyncCh.In:
if err := r.handleMessage(r.blockSyncCh.ID, envelope); err != nil {
if err := r.handleMessage(ctx, r.blockSyncCh.ID, envelope); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
}
r.logger.Error("failed to process message", "ch_id", r.blockSyncCh.ID, "envelope", envelope, "err", err)
if serr := r.blockSyncCh.SendError(ctx, p2p.PeerError{
NodeID: envelope.From,
@ -300,7 +302,9 @@ func (r *Reactor) processBlockSyncCh(ctx context.Context) {
}
}
case envelope := <-r.blockSyncOutBridgeCh:
r.blockSyncCh.Out <- envelope
if err := r.blockSyncCh.Send(ctx, envelope); err != nil {
return
}
}
}
}


+ 14
- 12
internal/consensus/byzantine_test.go View File

@ -141,20 +141,22 @@ func TestByzantinePrevoteEquivocation(t *testing.T) {
for _, ps := range bzReactor.peers {
if i < len(bzReactor.peers)/2 {
bzNodeState.logger.Info("signed and pushed vote", "vote", prevote1, "peer", ps.peerID)
bzReactor.voteCh.Out <- p2p.Envelope{
To: ps.peerID,
Message: &tmcons.Vote{
Vote: prevote1.ToProto(),
},
}
require.NoError(t, bzReactor.voteCh.Send(ctx,
p2p.Envelope{
To: ps.peerID,
Message: &tmcons.Vote{
Vote: prevote1.ToProto(),
},
}))
} else {
bzNodeState.logger.Info("signed and pushed vote", "vote", prevote2, "peer", ps.peerID)
bzReactor.voteCh.Out <- p2p.Envelope{
To: ps.peerID,
Message: &tmcons.Vote{
Vote: prevote2.ToProto(),
},
}
require.NoError(t, bzReactor.voteCh.Send(ctx,
p2p.Envelope{
To: ps.peerID,
Message: &tmcons.Vote{
Vote: prevote2.ToProto(),
},
}))
}
i++


+ 2
- 3
internal/consensus/invalid_test.go View File

@ -124,13 +124,12 @@ func invalidDoPrevoteFunc(
for _, ps := range r.peers {
cs.logger.Info("sending bad vote", "block", blockHash, "peer", ps.peerID)
r.voteCh.Out <- p2p.Envelope{
require.NoError(t, r.voteCh.Send(ctx, p2p.Envelope{
To: ps.peerID,
Message: &tmcons.Vote{
Vote: precommit.ToProto(),
},
}
}))
}
}()
}

+ 139
- 90
internal/consensus/reactor.go View File

@ -317,16 +317,16 @@ func (r *Reactor) GetPeerState(peerID types.NodeID) (*PeerState, bool) {
return ps, ok
}
func (r *Reactor) broadcastNewRoundStepMessage(rs *cstypes.RoundState) {
r.stateCh.Out <- p2p.Envelope{
func (r *Reactor) broadcastNewRoundStepMessage(ctx context.Context, rs *cstypes.RoundState) error {
return r.stateCh.Send(ctx, p2p.Envelope{
Broadcast: true,
Message: makeRoundStepMessage(rs),
}
})
}
func (r *Reactor) broadcastNewValidBlockMessage(rs *cstypes.RoundState) {
func (r *Reactor) broadcastNewValidBlockMessage(ctx context.Context, rs *cstypes.RoundState) error {
psHeader := rs.ProposalBlockParts.Header()
r.stateCh.Out <- p2p.Envelope{
return r.stateCh.Send(ctx, p2p.Envelope{
Broadcast: true,
Message: &tmcons.NewValidBlock{
Height: rs.Height,
@ -335,11 +335,11 @@ func (r *Reactor) broadcastNewValidBlockMessage(rs *cstypes.RoundState) {
BlockParts: rs.ProposalBlockParts.BitArray().ToProto(),
IsCommit: rs.Step == cstypes.RoundStepCommit,
},
}
})
}
func (r *Reactor) broadcastHasVoteMessage(vote *types.Vote) {
r.stateCh.Out <- p2p.Envelope{
func (r *Reactor) broadcastHasVoteMessage(ctx context.Context, vote *types.Vote) error {
return r.stateCh.Send(ctx, p2p.Envelope{
Broadcast: true,
Message: &tmcons.HasVote{
Height: vote.Height,
@ -347,7 +347,7 @@ func (r *Reactor) broadcastHasVoteMessage(vote *types.Vote) {
Type: vote.Type,
Index: vote.ValidatorIndex,
},
}
})
}
// subscribeToBroadcastEvents subscribes for new round steps and votes using the
@ -357,11 +357,17 @@ func (r *Reactor) subscribeToBroadcastEvents() {
err := r.state.evsw.AddListenerForEvent(
listenerIDConsensus,
types.EventNewRoundStepValue,
func(data tmevents.EventData) {
r.broadcastNewRoundStepMessage(data.(*cstypes.RoundState))
func(ctx context.Context, data tmevents.EventData) error {
if err := r.broadcastNewRoundStepMessage(ctx, data.(*cstypes.RoundState)); err != nil {
return err
}
select {
case r.state.onStopCh <- data.(*cstypes.RoundState):
return nil
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
},
)
@ -372,8 +378,8 @@ func (r *Reactor) subscribeToBroadcastEvents() {
err = r.state.evsw.AddListenerForEvent(
listenerIDConsensus,
types.EventValidBlockValue,
func(data tmevents.EventData) {
r.broadcastNewValidBlockMessage(data.(*cstypes.RoundState))
func(ctx context.Context, data tmevents.EventData) error {
return r.broadcastNewValidBlockMessage(ctx, data.(*cstypes.RoundState))
},
)
if err != nil {
@ -383,8 +389,8 @@ func (r *Reactor) subscribeToBroadcastEvents() {
err = r.state.evsw.AddListenerForEvent(
listenerIDConsensus,
types.EventVoteValue,
func(data tmevents.EventData) {
r.broadcastHasVoteMessage(data.(*types.Vote))
func(ctx context.Context, data tmevents.EventData) error {
return r.broadcastHasVoteMessage(ctx, data.(*types.Vote))
},
)
if err != nil {
@ -406,19 +412,14 @@ func makeRoundStepMessage(rs *cstypes.RoundState) *tmcons.NewRoundStep {
}
}
func (r *Reactor) sendNewRoundStepMessage(ctx context.Context, peerID types.NodeID) {
rs := r.state.GetRoundState()
msg := makeRoundStepMessage(rs)
select {
case <-ctx.Done():
case r.stateCh.Out <- p2p.Envelope{
func (r *Reactor) sendNewRoundStepMessage(ctx context.Context, peerID types.NodeID) error {
return r.stateCh.Send(ctx, p2p.Envelope{
To: peerID,
Message: msg,
}:
}
Message: makeRoundStepMessage(r.state.GetRoundState()),
})
}
func (r *Reactor) gossipDataForCatchup(rs *cstypes.RoundState, prs *cstypes.PeerRoundState, ps *PeerState) {
func (r *Reactor) gossipDataForCatchup(ctx context.Context, rs *cstypes.RoundState, prs *cstypes.PeerRoundState, ps *PeerState) {
logger := r.logger.With("height", prs.Height).With("peer", ps.peerID)
if index, ok := prs.ProposalBlockParts.Not().PickRandom(); ok {
@ -467,14 +468,14 @@ func (r *Reactor) gossipDataForCatchup(rs *cstypes.RoundState, prs *cstypes.Peer
}
logger.Debug("sending block part for catchup", "round", prs.Round, "index", index)
r.dataCh.Out <- p2p.Envelope{
_ = r.dataCh.Send(ctx, p2p.Envelope{
To: ps.peerID,
Message: &tmcons.BlockPart{
Height: prs.Height, // not our height, so it does not matter.
Round: prs.Round, // not our height, so it does not matter
Part: *partProto,
},
}
})
return
}
@ -521,13 +522,15 @@ OUTER_LOOP:
}
logger.Debug("sending block part", "height", prs.Height, "round", prs.Round)
r.dataCh.Out <- p2p.Envelope{
if err := r.dataCh.Send(ctx, p2p.Envelope{
To: ps.peerID,
Message: &tmcons.BlockPart{
Height: rs.Height, // this tells peer that this part applies to us
Round: rs.Round, // this tells peer that this part applies to us
Part: *partProto,
},
}); err != nil {
return
}
ps.SetHasProposalBlockPart(prs.Height, prs.Round, index)
@ -566,7 +569,7 @@ OUTER_LOOP:
continue OUTER_LOOP
}
r.gossipDataForCatchup(rs, prs, ps)
r.gossipDataForCatchup(ctx, rs, prs, ps)
continue OUTER_LOOP
}
@ -593,11 +596,13 @@ OUTER_LOOP:
propProto := rs.Proposal.ToProto()
logger.Debug("sending proposal", "height", prs.Height, "round", prs.Round)
r.dataCh.Out <- p2p.Envelope{
if err := r.dataCh.Send(ctx, p2p.Envelope{
To: ps.peerID,
Message: &tmcons.Proposal{
Proposal: *propProto,
},
}); err != nil {
return
}
// NOTE: A peer might have received a different proposal message, so
@ -614,13 +619,15 @@ OUTER_LOOP:
pPolProto := pPol.ToProto()
logger.Debug("sending POL", "height", prs.Height, "round", prs.Round)
r.dataCh.Out <- p2p.Envelope{
if err := r.dataCh.Send(ctx, p2p.Envelope{
To: ps.peerID,
Message: &tmcons.ProposalPOL{
Height: rs.Height,
ProposalPolRound: rs.Proposal.POLRound,
ProposalPol: *pPolProto,
},
}); err != nil {
return
}
}
@ -640,24 +647,24 @@ OUTER_LOOP:
// pickSendVote picks a vote and sends it to the peer. It will return true if
// there is a vote to send and false otherwise.
func (r *Reactor) pickSendVote(ctx context.Context, ps *PeerState, votes types.VoteSetReader) bool {
if vote, ok := ps.PickVoteToSend(votes); ok {
r.logger.Debug("sending vote message", "ps", ps, "vote", vote)
select {
case <-ctx.Done():
case r.voteCh.Out <- p2p.Envelope{
To: ps.peerID,
Message: &tmcons.Vote{
Vote: vote.ToProto(),
},
}:
}
func (r *Reactor) pickSendVote(ctx context.Context, ps *PeerState, votes types.VoteSetReader) (bool, error) {
vote, ok := ps.PickVoteToSend(votes)
if !ok {
return false, nil
}
ps.SetHasVote(vote)
return true
r.logger.Debug("sending vote message", "ps", ps, "vote", vote)
if err := r.voteCh.Send(ctx, p2p.Envelope{
To: ps.peerID,
Message: &tmcons.Vote{
Vote: vote.ToProto(),
},
}); err != nil {
return false, err
}
return false
ps.SetHasVote(vote)
return true, nil
}
func (r *Reactor) gossipVotesForHeight(
@ -665,62 +672,75 @@ func (r *Reactor) gossipVotesForHeight(
rs *cstypes.RoundState,
prs *cstypes.PeerRoundState,
ps *PeerState,
) bool {
) (bool, error) {
logger := r.logger.With("height", prs.Height).With("peer", ps.peerID)
// if there are lastCommits to send...
if prs.Step == cstypes.RoundStepNewHeight {
if r.pickSendVote(ctx, ps, rs.LastCommit) {
if ok, err := r.pickSendVote(ctx, ps, rs.LastCommit); err != nil {
return false, err
} else if ok {
logger.Debug("picked rs.LastCommit to send")
return true
return true, nil
}
}
// if there are POL prevotes to send...
if prs.Step <= cstypes.RoundStepPropose && prs.Round != -1 && prs.Round <= rs.Round && prs.ProposalPOLRound != -1 {
if polPrevotes := rs.Votes.Prevotes(prs.ProposalPOLRound); polPrevotes != nil {
if r.pickSendVote(ctx, ps, polPrevotes) {
if ok, err := r.pickSendVote(ctx, ps, polPrevotes); err != nil {
return false, err
} else if ok {
logger.Debug("picked rs.Prevotes(prs.ProposalPOLRound) to send", "round", prs.ProposalPOLRound)
return true
return true, nil
}
}
}
// if there are prevotes to send...
if prs.Step <= cstypes.RoundStepPrevoteWait && prs.Round != -1 && prs.Round <= rs.Round {
if r.pickSendVote(ctx, ps, rs.Votes.Prevotes(prs.Round)) {
if ok, err := r.pickSendVote(ctx, ps, rs.Votes.Prevotes(prs.Round)); err != nil {
return false, err
} else if ok {
logger.Debug("picked rs.Prevotes(prs.Round) to send", "round", prs.Round)
return true
return true, nil
}
}
// if there are precommits to send...
if prs.Step <= cstypes.RoundStepPrecommitWait && prs.Round != -1 && prs.Round <= rs.Round {
if r.pickSendVote(ctx, ps, rs.Votes.Precommits(prs.Round)) {
if ok, err := r.pickSendVote(ctx, ps, rs.Votes.Precommits(prs.Round)); err != nil {
return false, err
} else if ok {
logger.Debug("picked rs.Precommits(prs.Round) to send", "round", prs.Round)
return true
return true, nil
}
}
// if there are prevotes to send...(which are needed because of validBlock mechanism)
if prs.Round != -1 && prs.Round <= rs.Round {
if r.pickSendVote(ctx, ps, rs.Votes.Prevotes(prs.Round)) {
if ok, err := r.pickSendVote(ctx, ps, rs.Votes.Prevotes(prs.Round)); err != nil {
return false, err
} else if ok {
logger.Debug("picked rs.Prevotes(prs.Round) to send", "round", prs.Round)
return true
return true, nil
}
}
// if there are POLPrevotes to send...
if prs.ProposalPOLRound != -1 {
if polPrevotes := rs.Votes.Prevotes(prs.ProposalPOLRound); polPrevotes != nil {
if r.pickSendVote(ctx, ps, polPrevotes) {
if ok, err := r.pickSendVote(ctx, ps, polPrevotes); err != nil {
return false, err
} else if ok {
logger.Debug("picked rs.Prevotes(prs.ProposalPOLRound) to send", "round", prs.ProposalPOLRound)
return true
return true, nil
}
}
}
return false
return false, nil
}
func (r *Reactor) gossipVotesRoutine(ctx context.Context, ps *PeerState) {
@ -763,14 +783,18 @@ OUTER_LOOP:
// if height matches, then send LastCommit, Prevotes, and Precommits
if rs.Height == prs.Height {
if r.gossipVotesForHeight(ctx, rs, prs, ps) {
if ok, err := r.gossipVotesForHeight(ctx, rs, prs, ps); err != nil {
return
} else if ok {
continue OUTER_LOOP
}
}
// special catchup logic -- if peer is lagging by height 1, send LastCommit
if prs.Height != 0 && rs.Height == prs.Height+1 {
if r.pickSendVote(ctx, ps, rs.LastCommit) {
if ok, err := r.pickSendVote(ctx, ps, rs.LastCommit); err != nil {
return
} else if ok {
logger.Debug("picked rs.LastCommit to send", "height", prs.Height)
continue OUTER_LOOP
}
@ -782,7 +806,9 @@ OUTER_LOOP:
// Load the block commit for prs.Height, which contains precommit
// signatures for prs.Height.
if commit := r.state.blockStore.LoadBlockCommit(prs.Height); commit != nil {
if r.pickSendVote(ctx, ps, commit) {
if ok, err := r.pickSendVote(ctx, ps, commit); err != nil {
return
} else if ok {
logger.Debug("picked Catchup commit to send", "height", prs.Height)
continue OUTER_LOOP
}
@ -844,7 +870,7 @@ OUTER_LOOP:
if rs.Height == prs.Height {
if maj23, ok := rs.Votes.Prevotes(prs.Round).TwoThirdsMajority(); ok {
r.stateCh.Out <- p2p.Envelope{
if err := r.stateCh.Send(ctx, p2p.Envelope{
To: ps.peerID,
Message: &tmcons.VoteSetMaj23{
Height: prs.Height,
@ -852,6 +878,8 @@ OUTER_LOOP:
Type: tmproto.PrevoteType,
BlockID: maj23.ToProto(),
},
}); err != nil {
return
}
timer.Reset(r.state.config.PeerQueryMaj23SleepDuration)
@ -871,7 +899,7 @@ OUTER_LOOP:
if rs.Height == prs.Height {
if maj23, ok := rs.Votes.Precommits(prs.Round).TwoThirdsMajority(); ok {
r.stateCh.Out <- p2p.Envelope{
if err := r.stateCh.Send(ctx, p2p.Envelope{
To: ps.peerID,
Message: &tmcons.VoteSetMaj23{
Height: prs.Height,
@ -879,6 +907,8 @@ OUTER_LOOP:
Type: tmproto.PrecommitType,
BlockID: maj23.ToProto(),
},
}); err != nil {
return
}
select {
@ -898,7 +928,7 @@ OUTER_LOOP:
if rs.Height == prs.Height && prs.ProposalPOLRound >= 0 {
if maj23, ok := rs.Votes.Prevotes(prs.ProposalPOLRound).TwoThirdsMajority(); ok {
r.stateCh.Out <- p2p.Envelope{
if err := r.stateCh.Send(ctx, p2p.Envelope{
To: ps.peerID,
Message: &tmcons.VoteSetMaj23{
Height: prs.Height,
@ -906,6 +936,8 @@ OUTER_LOOP:
Type: tmproto.PrevoteType,
BlockID: maj23.ToProto(),
},
}); err != nil {
return
}
timer.Reset(r.state.config.PeerQueryMaj23SleepDuration)
@ -928,7 +960,7 @@ OUTER_LOOP:
if prs.CatchupCommitRound != -1 && prs.Height > 0 && prs.Height <= r.state.blockStore.Height() &&
prs.Height >= r.state.blockStore.Base() {
if commit := r.state.LoadCommit(prs.Height); commit != nil {
r.stateCh.Out <- p2p.Envelope{
if err := r.stateCh.Send(ctx, p2p.Envelope{
To: ps.peerID,
Message: &tmcons.VoteSetMaj23{
Height: prs.Height,
@ -936,6 +968,8 @@ OUTER_LOOP:
Type: tmproto.PrecommitType,
BlockID: commit.BlockID.ToProto(),
},
}); err != nil {
return
}
timer.Reset(r.state.config.PeerQueryMaj23SleepDuration)
@ -1006,7 +1040,7 @@ func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpda
// Send our state to the peer. If we're block-syncing, broadcast a
// RoundStepMessage later upon SwitchToConsensus().
if !r.waitSync {
go r.sendNewRoundStepMessage(ctx, ps.peerID)
go func() { _ = r.sendNewRoundStepMessage(ctx, ps.peerID) }()
}
}
@ -1036,7 +1070,7 @@ func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpda
// If we fail to find the peer state for the envelope sender, we perform a no-op
// and return. This can happen when we process the envelope after the peer is
// removed.
func (r *Reactor) handleStateMessage(envelope p2p.Envelope, msgI Message) error {
func (r *Reactor) handleStateMessage(ctx context.Context, envelope p2p.Envelope, msgI Message) error {
ps, ok := r.GetPeerState(envelope.From)
if !ok || ps == nil {
r.logger.Debug("failed to find peer state", "peer", envelope.From, "ch_id", "StateChannel")
@ -1104,9 +1138,11 @@ func (r *Reactor) handleStateMessage(envelope p2p.Envelope, msgI Message) error
eMsg.Votes = *votesProto
}
r.voteSetBitsCh.Out <- p2p.Envelope{
if err := r.voteSetBitsCh.Send(ctx, p2p.Envelope{
To: envelope.From,
Message: eMsg,
}); err != nil {
return err
}
default:
@ -1120,7 +1156,7 @@ func (r *Reactor) handleStateMessage(envelope p2p.Envelope, msgI Message) error
// fail to find the peer state for the envelope sender, we perform a no-op and
// return. This can happen when we process the envelope after the peer is
// removed.
func (r *Reactor) handleDataMessage(envelope p2p.Envelope, msgI Message) error {
func (r *Reactor) handleDataMessage(ctx context.Context, envelope p2p.Envelope, msgI Message) error {
logger := r.logger.With("peer", envelope.From, "ch_id", "DataChannel")
ps, ok := r.GetPeerState(envelope.From)
@ -1139,17 +1175,24 @@ func (r *Reactor) handleDataMessage(envelope p2p.Envelope, msgI Message) error {
pMsg := msgI.(*ProposalMessage)
ps.SetHasProposal(pMsg.Proposal)
r.state.peerMsgQueue <- msgInfo{pMsg, envelope.From}
select {
case <-ctx.Done():
return ctx.Err()
case r.state.peerMsgQueue <- msgInfo{pMsg, envelope.From}:
}
case *tmcons.ProposalPOL:
ps.ApplyProposalPOLMessage(msgI.(*ProposalPOLMessage))
case *tmcons.BlockPart:
bpMsg := msgI.(*BlockPartMessage)
ps.SetHasProposalBlockPart(bpMsg.Height, bpMsg.Round, int(bpMsg.Part.Index))
r.Metrics.BlockParts.With("peer_id", string(envelope.From)).Add(1)
r.state.peerMsgQueue <- msgInfo{bpMsg, envelope.From}
select {
case r.state.peerMsgQueue <- msgInfo{bpMsg, envelope.From}:
return nil
case <-ctx.Done():
return ctx.Err()
}
default:
return fmt.Errorf("received unknown message on DataChannel: %T", msg)
@ -1162,7 +1205,7 @@ func (r *Reactor) handleDataMessage(envelope p2p.Envelope, msgI Message) error {
// fail to find the peer state for the envelope sender, we perform a no-op and
// return. This can happen when we process the envelope after the peer is
// removed.
func (r *Reactor) handleVoteMessage(envelope p2p.Envelope, msgI Message) error {
func (r *Reactor) handleVoteMessage(ctx context.Context, envelope p2p.Envelope, msgI Message) error {
logger := r.logger.With("peer", envelope.From, "ch_id", "VoteChannel")
ps, ok := r.GetPeerState(envelope.From)
@ -1188,20 +1231,22 @@ func (r *Reactor) handleVoteMessage(envelope p2p.Envelope, msgI Message) error {
ps.EnsureVoteBitArrays(height-1, lastCommitSize)
ps.SetHasVote(vMsg.Vote)
r.state.peerMsgQueue <- msgInfo{vMsg, envelope.From}
select {
case r.state.peerMsgQueue <- msgInfo{vMsg, envelope.From}:
return nil
case <-ctx.Done():
return ctx.Err()
}
default:
return fmt.Errorf("received unknown message on VoteChannel: %T", msg)
}
return nil
}
// handleVoteSetBitsMessage handles envelopes sent from peers on the
// VoteSetBitsChannel. If we fail to find the peer state for the envelope sender,
// we perform a no-op and return. This can happen when we process the envelope
// after the peer is removed.
func (r *Reactor) handleVoteSetBitsMessage(envelope p2p.Envelope, msgI Message) error {
func (r *Reactor) handleVoteSetBitsMessage(ctx context.Context, envelope p2p.Envelope, msgI Message) error {
logger := r.logger.With("peer", envelope.From, "ch_id", "VoteSetBitsChannel")
ps, ok := r.GetPeerState(envelope.From)
@ -1259,7 +1304,7 @@ func (r *Reactor) handleVoteSetBitsMessage(envelope p2p.Envelope, msgI Message)
// the p2p channel.
//
// NOTE: We block on consensus state for proposals, block parts, and votes.
func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err error) {
func (r *Reactor) handleMessage(ctx context.Context, chID p2p.ChannelID, envelope p2p.Envelope) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("panic in processing message: %v", e)
@ -1290,16 +1335,16 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err
switch chID {
case StateChannel:
err = r.handleStateMessage(envelope, msgI)
err = r.handleStateMessage(ctx, envelope, msgI)
case DataChannel:
err = r.handleDataMessage(envelope, msgI)
err = r.handleDataMessage(ctx, envelope, msgI)
case VoteChannel:
err = r.handleVoteMessage(envelope, msgI)
err = r.handleVoteMessage(ctx, envelope, msgI)
case VoteSetBitsChannel:
err = r.handleVoteSetBitsMessage(envelope, msgI)
err = r.handleVoteSetBitsMessage(ctx, envelope, msgI)
default:
err = fmt.Errorf("unknown channel ID (%d) for envelope (%v)", chID, envelope)
@ -1320,7 +1365,7 @@ func (r *Reactor) processStateCh(ctx context.Context) {
r.logger.Debug("stopped listening on StateChannel; closing...")
return
case envelope := <-r.stateCh.In:
if err := r.handleMessage(r.stateCh.ID, envelope); err != nil {
if err := r.handleMessage(ctx, r.stateCh.ID, envelope); err != nil {
r.logger.Error("failed to process message", "ch_id", r.stateCh.ID, "envelope", envelope, "err", err)
if serr := r.stateCh.SendError(ctx, p2p.PeerError{
NodeID: envelope.From,
@ -1345,7 +1390,7 @@ func (r *Reactor) processDataCh(ctx context.Context) {
r.logger.Debug("stopped listening on DataChannel; closing...")
return
case envelope := <-r.dataCh.In:
if err := r.handleMessage(r.dataCh.ID, envelope); err != nil {
if err := r.handleMessage(ctx, r.dataCh.ID, envelope); err != nil {
r.logger.Error("failed to process message", "ch_id", r.dataCh.ID, "envelope", envelope, "err", err)
if serr := r.dataCh.SendError(ctx, p2p.PeerError{
NodeID: envelope.From,
@ -1370,7 +1415,7 @@ func (r *Reactor) processVoteCh(ctx context.Context) {
r.logger.Debug("stopped listening on VoteChannel; closing...")
return
case envelope := <-r.voteCh.In:
if err := r.handleMessage(r.voteCh.ID, envelope); err != nil {
if err := r.handleMessage(ctx, r.voteCh.ID, envelope); err != nil {
r.logger.Error("failed to process message", "ch_id", r.voteCh.ID, "envelope", envelope, "err", err)
if serr := r.voteCh.SendError(ctx, p2p.PeerError{
NodeID: envelope.From,
@ -1395,7 +1440,11 @@ func (r *Reactor) processVoteSetBitsCh(ctx context.Context) {
r.logger.Debug("stopped listening on VoteSetBitsChannel; closing...")
return
case envelope := <-r.voteSetBitsCh.In:
if err := r.handleMessage(r.voteSetBitsCh.ID, envelope); err != nil {
if err := r.handleMessage(ctx, r.voteSetBitsCh.ID, envelope); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
}
r.logger.Error("failed to process message", "ch_id", r.voteSetBitsCh.ID, "envelope", envelope, "err", err)
if serr := r.voteSetBitsCh.SendError(ctx, p2p.PeerError{
NodeID: envelope.From,


+ 4
- 4
internal/consensus/replay_test.go View File

@ -391,7 +391,7 @@ func setupSimulator(ctx context.Context, t *testing.T) *simulatorTestSuite {
proposal.Signature = p.Signature
// set the proposal block
if err := css[0].SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil {
if err := css[0].SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
ensureNewProposal(proposalCh, height, round)
@ -423,7 +423,7 @@ func setupSimulator(ctx context.Context, t *testing.T) *simulatorTestSuite {
proposal.Signature = p.Signature
// set the proposal block
if err := css[0].SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil {
if err := css[0].SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
ensureNewProposal(proposalCh, height, round)
@ -482,7 +482,7 @@ func setupSimulator(ctx context.Context, t *testing.T) *simulatorTestSuite {
proposal.Signature = p.Signature
// set the proposal block
if err := css[0].SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil {
if err := css[0].SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
ensureNewProposal(proposalCh, height, round)
@ -545,7 +545,7 @@ func setupSimulator(ctx context.Context, t *testing.T) *simulatorTestSuite {
proposal.Signature = p.Signature
// set the proposal block
if err := css[0].SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil {
if err := css[0].SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
ensureNewProposal(proposalCh, height, round)


+ 47
- 20
internal/consensus/state.go View File

@ -511,58 +511,85 @@ func (cs *State) OpenWAL(ctx context.Context, walFile string) (WAL, error) {
// TODO: should these return anything or let callers just use events?
// AddVote inputs a vote.
func (cs *State) AddVote(vote *types.Vote, peerID types.NodeID) (added bool, err error) {
func (cs *State) AddVote(ctx context.Context, vote *types.Vote, peerID types.NodeID) error {
if peerID == "" {
cs.internalMsgQueue <- msgInfo{&VoteMessage{vote}, ""}
select {
case <-ctx.Done():
return ctx.Err()
case cs.internalMsgQueue <- msgInfo{&VoteMessage{vote}, ""}:
return nil
}
} else {
cs.peerMsgQueue <- msgInfo{&VoteMessage{vote}, peerID}
select {
case <-ctx.Done():
return ctx.Err()
case cs.peerMsgQueue <- msgInfo{&VoteMessage{vote}, peerID}:
return nil
}
}
// TODO: wait for event?!
return false, nil
}
// SetProposal inputs a proposal.
func (cs *State) SetProposal(proposal *types.Proposal, peerID types.NodeID) error {
func (cs *State) SetProposal(ctx context.Context, proposal *types.Proposal, peerID types.NodeID) error {
if peerID == "" {
cs.internalMsgQueue <- msgInfo{&ProposalMessage{proposal}, ""}
select {
case <-ctx.Done():
return ctx.Err()
case cs.internalMsgQueue <- msgInfo{&ProposalMessage{proposal}, ""}:
return nil
}
} else {
cs.peerMsgQueue <- msgInfo{&ProposalMessage{proposal}, peerID}
select {
case <-ctx.Done():
return ctx.Err()
case cs.peerMsgQueue <- msgInfo{&ProposalMessage{proposal}, peerID}:
return nil
}
}
// TODO: wait for event?!
return nil
}
// AddProposalBlockPart inputs a part of the proposal block.
func (cs *State) AddProposalBlockPart(height int64, round int32, part *types.Part, peerID types.NodeID) error {
func (cs *State) AddProposalBlockPart(ctx context.Context, height int64, round int32, part *types.Part, peerID types.NodeID) error {
if peerID == "" {
cs.internalMsgQueue <- msgInfo{&BlockPartMessage{height, round, part}, ""}
select {
case <-ctx.Done():
return ctx.Err()
case cs.internalMsgQueue <- msgInfo{&BlockPartMessage{height, round, part}, ""}:
return nil
}
} else {
cs.peerMsgQueue <- msgInfo{&BlockPartMessage{height, round, part}, peerID}
select {
case <-ctx.Done():
return ctx.Err()
case cs.peerMsgQueue <- msgInfo{&BlockPartMessage{height, round, part}, peerID}:
return nil
}
}
// TODO: wait for event?!
return nil
}
// SetProposalAndBlock inputs the proposal and all block parts.
func (cs *State) SetProposalAndBlock(
ctx context.Context,
proposal *types.Proposal,
block *types.Block,
parts *types.PartSet,
peerID types.NodeID,
) error {
if err := cs.SetProposal(proposal, peerID); err != nil {
if err := cs.SetProposal(ctx, proposal, peerID); err != nil {
return err
}
for i := 0; i < int(parts.Total()); i++ {
part := parts.GetPart(i)
if err := cs.AddProposalBlockPart(proposal.Height, proposal.Round, part, peerID); err != nil {
if err := cs.AddProposalBlockPart(ctx, proposal.Height, proposal.Round, part, peerID); err != nil {
return err
}
}
@ -761,7 +788,7 @@ func (cs *State) newStep(ctx context.Context) {
cs.logger.Error("failed publishing new round step", "err", err)
}
cs.evsw.FireEvent(types.EventNewRoundStepValue, &cs.RoundState)
cs.evsw.FireEvent(ctx, types.EventNewRoundStepValue, &cs.RoundState)
}
}
@ -1607,7 +1634,7 @@ func (cs *State) enterCommit(ctx context.Context, height int64, commitRound int3
logger.Error("failed publishing valid block", "err", err)
}
cs.evsw.FireEvent(types.EventValidBlockValue, &cs.RoundState)
cs.evsw.FireEvent(ctx, types.EventValidBlockValue, &cs.RoundState)
}
}
}
@ -2075,7 +2102,7 @@ func (cs *State) addVote(
return added, err
}
cs.evsw.FireEvent(types.EventVoteValue, vote)
cs.evsw.FireEvent(ctx, types.EventVoteValue, vote)
// if we can skip timeoutCommit and have all the votes now,
if cs.config.SkipTimeoutCommit && cs.LastCommit.HasAll() {
@ -2104,7 +2131,7 @@ func (cs *State) addVote(
if err := cs.eventBus.PublishEventVote(ctx, types.EventDataVote{Vote: vote}); err != nil {
return added, err
}
cs.evsw.FireEvent(types.EventVoteValue, vote)
cs.evsw.FireEvent(ctx, types.EventVoteValue, vote)
switch vote.Type {
case tmproto.PrevoteType:
@ -2158,7 +2185,7 @@ func (cs *State) addVote(
cs.ProposalBlockParts = types.NewPartSetFromHeader(blockID.PartSetHeader)
}
cs.evsw.FireEvent(types.EventValidBlockValue, &cs.RoundState)
cs.evsw.FireEvent(ctx, types.EventValidBlockValue, &cs.RoundState)
if err := cs.eventBus.PublishEventValidBlock(ctx, cs.RoundStateEvent()); err != nil {
return added, err
}


+ 13
- 13
internal/consensus/state_test.go View File

@ -251,7 +251,7 @@ func TestStateBadProposal(t *testing.T) {
proposal.Signature = p.Signature
// set the proposal block
if err := cs1.SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
@ -314,7 +314,7 @@ func TestStateOversizedBlock(t *testing.T) {
totalBytes += len(part.Bytes)
}
if err := cs1.SetProposalAndBlock(proposal, propBlock, propBlockParts, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, proposal, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
@ -621,7 +621,7 @@ func TestStateLockNoPOL(t *testing.T) {
// now we're on a new round and not the proposer
// so set the proposal block
if err := cs1.SetProposalAndBlock(prop, propBlock, propBlock.MakePartSet(partSize), ""); err != nil {
if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlock.MakePartSet(partSize), ""); err != nil {
t.Fatal(err)
}
@ -723,7 +723,7 @@ func TestStateLockPOLRelock(t *testing.T) {
round++ // moving to the next round
//XXX: this isnt guaranteed to get there before the timeoutPropose ...
if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
@ -828,7 +828,7 @@ func TestStateLockPOLUnlock(t *testing.T) {
cs1 unlocks!
*/
//XXX: this isnt guaranteed to get there before the timeoutPropose ...
if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
@ -940,7 +940,7 @@ func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) {
// we should have unlocked and locked on the new block, sending a precommit for this new block
validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil)
if err := cs1.SetProposalAndBlock(prop, propBlock, secondBlockParts, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, secondBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
@ -971,7 +971,7 @@ func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) {
Round2 (vs3, C) // C C C C // C nil nil nil)
*/
if err := cs1.SetProposalAndBlock(prop, propBlock, thirdPropBlockParts, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, thirdPropBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
@ -1048,7 +1048,7 @@ func TestStateLockPOLSafety1(t *testing.T) {
ensureNewRound(newRoundCh, height, round)
//XXX: this isnt guaranteed to get there before the timeoutPropose ...
if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
/*Round2
@ -1160,7 +1160,7 @@ func TestStateLockPOLSafety2(t *testing.T) {
startTestRound(ctx, cs1, height, round)
ensureNewRound(newRoundCh, height, round)
if err := cs1.SetProposalAndBlock(prop1, propBlock1, propBlockParts1, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, prop1, propBlock1, propBlockParts1, "some peer"); err != nil {
t.Fatal(err)
}
ensureNewProposal(proposalCh, height, round)
@ -1193,7 +1193,7 @@ func TestStateLockPOLSafety2(t *testing.T) {
newProp.Signature = p.Signature
if err := cs1.SetProposalAndBlock(newProp, propBlock0, propBlockParts0, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, newProp, propBlock0, propBlockParts0, "some peer"); err != nil {
t.Fatal(err)
}
@ -1428,7 +1428,7 @@ func TestSetValidBlockOnDelayedProposal(t *testing.T) {
ensurePrecommit(voteCh, height, round)
validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil)
if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
@ -1658,7 +1658,7 @@ func TestCommitFromPreviousRound(t *testing.T) {
assert.True(t, rs.ProposalBlock == nil)
assert.True(t, rs.ProposalBlockParts.Header().Equals(propBlockParts.Header()))
if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
@ -1797,7 +1797,7 @@ func TestResetTimeoutPrecommitUponNewHeight(t *testing.T) {
prop, propBlock := decideProposal(ctx, cs1, vs2, height+1, 0)
propBlockParts := propBlock.MakePartSet(partSize)
if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil {
if err := cs1.SetProposalAndBlock(ctx, prop, propBlock, propBlockParts, "some peer"); err != nil {
t.Fatal(err)
}
ensureNewProposal(proposalCh, height+1, 0)


+ 3
- 5
internal/evidence/reactor.go View File

@ -319,15 +319,13 @@ func (r *Reactor) broadcastEvidenceLoop(ctx context.Context, peerID types.NodeID
// and thus would not be able to process the evidence correctly. Also, the
// peer may receive this piece of evidence multiple times if it added and
// removed frequently from the broadcasting peer.
select {
case <-ctx.Done():
return
case r.evidenceCh.Out <- p2p.Envelope{
if err := r.evidenceCh.Send(ctx, p2p.Envelope{
To: peerID,
Message: &tmproto.EvidenceList{
Evidence: []tmproto.Evidence{*evProto},
},
}:
}); err != nil {
return
}
r.logger.Debug("gossiped evidence to peer", "evidence", ev, "peer", peerID)


+ 1
- 20
internal/evidence/reactor_test.go View File

@ -108,8 +108,8 @@ func setup(ctx context.Context, t *testing.T, stateStores []sm.Store, chBuf uint
}
}
leaktest.Check(t)
})
t.Cleanup(leaktest.Check(t))
return rts
}
@ -191,21 +191,6 @@ func (rts *reactorTestSuite) waitForEvidence(t *testing.T, evList types.Evidence
wg.Wait()
}
func (rts *reactorTestSuite) assertEvidenceChannelsEmpty(t *testing.T) {
t.Helper()
for id, r := range rts.reactors {
require.NoError(t, r.Stop(), "stopping reactor #%s", id)
r.Wait()
require.False(t, r.IsRunning(), "reactor #%d did not stop", id)
}
for id, ech := range rts.evidenceChannels {
require.Empty(t, ech.Out, "checking channel #%q", id)
}
}
func createEvidenceList(
t *testing.T,
pool *evidence.Pool,
@ -325,8 +310,6 @@ func TestReactorBroadcastEvidence(t *testing.T) {
for _, pool := range rts.pools {
require.Equal(t, numEvidence, int(pool.Size()))
}
rts.assertEvidenceChannelsEmpty(t)
}
// TestReactorSelectiveBroadcast tests a context where we have two reactors
@ -367,8 +350,6 @@ func TestReactorBroadcastEvidence_Lagging(t *testing.T) {
require.Equal(t, numEvidence, int(rts.pools[primary.NodeID].Size()))
require.Equal(t, int(height2), int(rts.pools[secondary.NodeID].Size()))
rts.assertEvidenceChannelsEmpty(t)
}
func TestReactorBroadcastEvidence_Pending(t *testing.T) {


+ 4
- 4
internal/mempool/reactor.go View File

@ -349,15 +349,15 @@ func (r *Reactor) broadcastTxRoutine(ctx context.Context, peerID types.NodeID, c
if ok := r.mempool.txStore.TxHasPeer(memTx.hash, peerMempoolID); !ok {
// Send the mempool tx to the corresponding peer. Note, the peer may be
// behind and thus would not be able to process the mempool tx correctly.
select {
case r.mempoolCh.Out <- p2p.Envelope{
if err := r.mempoolCh.Send(ctx, p2p.Envelope{
To: peerID,
Message: &protomem.Txs{
Txs: [][]byte{memTx.tx},
},
}:
case <-ctx.Done():
}); err != nil {
return
}
r.logger.Debug(
"gossiped tx to peer",
"tx", fmt.Sprintf("%X", memTx.tx.Hash()),


+ 2
- 31
internal/mempool/reactor_test.go View File

@ -109,14 +109,6 @@ func (rts *reactorTestSuite) start(ctx context.Context, t *testing.T) {
"network does not have expected number of nodes")
}
func (rts *reactorTestSuite) assertMempoolChannelsDrained(t *testing.T) {
t.Helper()
for _, mch := range rts.mempoolChannels {
require.Empty(t, mch.Out, "checking channel %q (len=%d)", mch.ID, len(mch.Out))
}
}
func (rts *reactorTestSuite) waitForTxns(t *testing.T, txs []types.Tx, ids ...types.NodeID) {
t.Helper()
@ -296,8 +288,6 @@ func TestReactorNoBroadcastToSender(t *testing.T) {
require.Eventually(t, func() bool {
return rts.mempools[secondary].Size() == 0
}, time.Minute, 100*time.Millisecond)
rts.assertMempoolChannelsDrained(t)
}
func TestReactor_MaxTxBytes(t *testing.T) {
@ -334,8 +324,6 @@ func TestReactor_MaxTxBytes(t *testing.T) {
tx2 := tmrand.Bytes(cfg.Mempool.MaxTxBytes + 1)
err = rts.mempools[primary].CheckTx(ctx, tx2, nil, TxInfo{SenderID: UnknownPeerID})
require.Error(t, err)
rts.assertMempoolChannelsDrained(t)
}
func TestDontExhaustMaxActiveIDs(t *testing.T) {
@ -359,30 +347,13 @@ func TestDontExhaustMaxActiveIDs(t *testing.T) {
NodeID: peerID,
}
rts.mempoolChannels[nodeID].Out <- p2p.Envelope{
require.NoError(t, rts.mempoolChannels[nodeID].Send(ctx, p2p.Envelope{
To: peerID,
Message: &protomem.Txs{
Txs: [][]byte{},
},
}
}))
}
require.Eventually(
t,
func() bool {
for _, mch := range rts.mempoolChannels {
if len(mch.Out) > 0 {
return false
}
}
return true
},
time.Minute,
10*time.Millisecond,
)
rts.assertMempoolChannelsDrained(t)
}
func TestMempoolIDsPanicsIfNodeRequestsOvermaxActiveIDs(t *testing.T) {


+ 3
- 3
internal/p2p/channel.go View File

@ -63,7 +63,7 @@ func (pe PeerError) Unwrap() error { return pe.Err }
type Channel struct {
ID ChannelID
In <-chan Envelope // inbound messages (peers to reactors)
Out chan<- Envelope // outbound messages (reactors to peers)
outCh chan<- Envelope // outbound messages (reactors to peers)
errCh chan<- PeerError // peer error reporting
messageType proto.Message // the channel's message type, used for unmarshaling
@ -82,7 +82,7 @@ func NewChannel(
ID: id,
messageType: messageType,
In: inCh,
Out: outCh,
outCh: outCh,
errCh: errCh,
}
}
@ -93,7 +93,7 @@ func (ch *Channel) Send(ctx context.Context, envelope Envelope) error {
select {
case <-ctx.Done():
return ctx.Err()
case ch.Out <- envelope:
case ch.outCh <- envelope:
return nil
}
}


+ 1
- 1
internal/p2p/channel_test.go View File

@ -24,7 +24,7 @@ func testChannel(size int) (*channelInternal, *Channel) {
}
ch := &Channel{
In: in.In,
Out: in.Out,
outCh: in.Out,
errCh: in.Error,
}
return in, ch


+ 12
- 8
internal/p2p/p2ptest/require.go View File

@ -64,26 +64,30 @@ func RequireReceiveUnordered(t *testing.T, channel *p2p.Channel, expect []p2p.En
}
// RequireSend requires that the given envelope is sent on the channel.
func RequireSend(t *testing.T, channel *p2p.Channel, envelope p2p.Envelope) {
timer := time.NewTimer(time.Second) // not time.After due to goroutine leaks
defer timer.Stop()
select {
case channel.Out <- envelope:
case <-timer.C:
require.Fail(t, "timed out sending message", "%v on channel %v", envelope, channel.ID)
func RequireSend(ctx context.Context, t *testing.T, channel *p2p.Channel, envelope p2p.Envelope) {
tctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
err := channel.Send(tctx, envelope)
switch {
case errors.Is(err, context.DeadlineExceeded):
require.Fail(t, "timed out sending message to %q", envelope.To)
default:
require.NoError(t, err, "unexpected error")
}
}
// RequireSendReceive requires that a given Protobuf message is sent to the
// given peer, and then that the given response is received back.
func RequireSendReceive(
ctx context.Context,
t *testing.T,
channel *p2p.Channel,
peerID types.NodeID,
send proto.Message,
receive proto.Message,
) {
RequireSend(t, channel, p2p.Envelope{To: peerID, Message: send})
RequireSend(ctx, t, channel, p2p.Envelope{To: peerID, Message: send})
RequireReceive(t, channel, p2p.Envelope{From: peerID, Message: send})
}


+ 12
- 8
internal/p2p/pex/reactor.go View File

@ -165,12 +165,12 @@ func (r *Reactor) processPexCh(ctx context.Context) {
// outbound requests for new peers
case <-timer.C:
r.sendRequestForPeers()
r.sendRequestForPeers(ctx)
// inbound requests for new peers or responses to requests sent by this
// reactor
case envelope := <-r.pexCh.In:
if err := r.handleMessage(r.pexCh.ID, envelope); err != nil {
if err := r.handleMessage(ctx, r.pexCh.ID, envelope); err != nil {
r.logger.Error("failed to process message", "ch_id", r.pexCh.ID, "envelope", envelope, "err", err)
if serr := r.pexCh.SendError(ctx, p2p.PeerError{
NodeID: envelope.From,
@ -199,7 +199,7 @@ func (r *Reactor) processPeerUpdates(ctx context.Context) {
}
// handlePexMessage handles envelopes sent from peers on the PexChannel.
func (r *Reactor) handlePexMessage(envelope p2p.Envelope) error {
func (r *Reactor) handlePexMessage(ctx context.Context, envelope p2p.Envelope) error {
logger := r.logger.With("peer", envelope.From)
switch msg := envelope.Message.(type) {
@ -219,9 +219,11 @@ func (r *Reactor) handlePexMessage(envelope p2p.Envelope) error {
URL: addr.String(),
}
}
r.pexCh.Out <- p2p.Envelope{
if err := r.pexCh.Send(ctx, p2p.Envelope{
To: envelope.From,
Message: &protop2p.PexResponse{Addresses: pexAddresses},
}); err != nil {
return err
}
case *protop2p.PexResponse:
@ -264,7 +266,7 @@ func (r *Reactor) handlePexMessage(envelope p2p.Envelope) error {
// handleMessage handles an Envelope sent from a peer on a specific p2p Channel.
// It will handle errors and any possible panics gracefully. A caller can handle
// any error returned by sending a PeerError on the respective channel.
func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err error) {
func (r *Reactor) handleMessage(ctx context.Context, chID p2p.ChannelID, envelope p2p.Envelope) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("panic in processing message: %v", e)
@ -280,7 +282,7 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err
switch chID {
case p2p.ChannelID(PexChannel):
err = r.handlePexMessage(envelope)
err = r.handlePexMessage(ctx, envelope)
default:
err = fmt.Errorf("unknown channel ID (%d) for envelope (%v)", chID, envelope)
@ -312,7 +314,7 @@ func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) {
// peer a request for more peer addresses. The function then moves the
// peer into the requestsSent bucket and calculates when the next request
// time should be
func (r *Reactor) sendRequestForPeers() {
func (r *Reactor) sendRequestForPeers(ctx context.Context) {
r.mtx.Lock()
defer r.mtx.Unlock()
if len(r.availablePeers) == 0 {
@ -330,9 +332,11 @@ func (r *Reactor) sendRequestForPeers() {
}
// send out the pex request
r.pexCh.Out <- p2p.Envelope{
if err := r.pexCh.Send(ctx, p2p.Envelope{
To: peerID,
Message: &protop2p.PexRequest{},
}); err != nil {
return
}
// remove the peer from the abvailable peers list and mark it in the requestsSent map


+ 16
- 11
internal/p2p/pex/reactor_test.go View File

@ -45,7 +45,7 @@ func TestReactorBasic(t *testing.T) {
// assert that when a mock node sends a request it receives a response (and
// the correct one)
testNet.sendRequest(t, firstNode, secondNode)
testNet.sendRequest(ctx, t, firstNode, secondNode)
testNet.listenForResponse(t, secondNode, firstNode, shortWait, []p2pproto.PexAddress(nil))
}
@ -112,8 +112,8 @@ func TestReactorSendsResponseWithoutRequest(t *testing.T) {
// firstNode sends the secondNode an unrequested response
// NOTE: secondNode will send a request by default during startup so we send
// two responses to counter that.
testNet.sendResponse(t, firstNode, secondNode, []int{thirdNode})
testNet.sendResponse(t, firstNode, secondNode, []int{thirdNode})
testNet.sendResponse(ctx, t, firstNode, secondNode, []int{thirdNode})
testNet.sendResponse(ctx, t, firstNode, secondNode, []int{thirdNode})
// secondNode should evict the firstNode
testNet.listenForPeerUpdate(ctx, t, secondNode, firstNode, p2p.PeerStatusDown, shortWait)
@ -139,7 +139,7 @@ func TestReactorNeverSendsTooManyPeers(t *testing.T) {
// first we check that even although we have 110 peers, honest pex reactors
// only send 100 (test if secondNode sends firstNode 100 addresses)
testNet.pingAndlistenForNAddresses(t, secondNode, firstNode, shortWait, 100)
testNet.pingAndlistenForNAddresses(ctx, t, secondNode, firstNode, shortWait, 100)
}
func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) {
@ -475,11 +475,13 @@ func (r *reactorTestSuite) listenForRequest(t *testing.T, fromNode, toNode int,
}
func (r *reactorTestSuite) pingAndlistenForNAddresses(
ctx context.Context,
t *testing.T,
fromNode, toNode int,
waitPeriod time.Duration,
addresses int,
) {
t.Helper()
r.logger.Info("Listening for addresses", "from", fromNode, "to", toNode)
to, from := r.checkNodePair(t, toNode, fromNode)
conditional := func(msg p2p.Envelope) bool {
@ -499,10 +501,10 @@ func (r *reactorTestSuite) pingAndlistenForNAddresses(
// if we didn't get the right length, we wait and send the
// request again
time.Sleep(300 * time.Millisecond)
r.sendRequest(t, toNode, fromNode)
r.sendRequest(ctx, t, toNode, fromNode)
return false
}
r.sendRequest(t, toNode, fromNode)
r.sendRequest(ctx, t, toNode, fromNode)
r.listenFor(t, to, conditional, assertion, waitPeriod)
}
@ -566,27 +568,30 @@ func (r *reactorTestSuite) getAddressesFor(nodes []int) []p2pproto.PexAddress {
return addresses
}
func (r *reactorTestSuite) sendRequest(t *testing.T, fromNode, toNode int) {
func (r *reactorTestSuite) sendRequest(ctx context.Context, t *testing.T, fromNode, toNode int) {
t.Helper()
to, from := r.checkNodePair(t, toNode, fromNode)
r.pexChannels[from].Out <- p2p.Envelope{
require.NoError(t, r.pexChannels[from].Send(ctx, p2p.Envelope{
To: to,
Message: &p2pproto.PexRequest{},
}
}))
}
func (r *reactorTestSuite) sendResponse(
ctx context.Context,
t *testing.T,
fromNode, toNode int,
withNodes []int,
) {
t.Helper()
from, to := r.checkNodePair(t, fromNode, toNode)
addrs := r.getAddressesFor(withNodes)
r.pexChannels[from].Out <- p2p.Envelope{
require.NoError(t, r.pexChannels[from].Send(ctx, p2p.Envelope{
To: to,
Message: &p2pproto.PexResponse{
Addresses: addrs,
},
}
}))
}
func (r *reactorTestSuite) requireNumberOfPeers(


+ 22
- 21
internal/p2p/router_test.go View File

@ -32,11 +32,12 @@ func echoReactor(ctx context.Context, channel *p2p.Channel) {
select {
case envelope := <-channel.In:
value := envelope.Message.(*p2ptest.Message).Value
channel.Out <- p2p.Envelope{
if err := channel.Send(ctx, p2p.Envelope{
To: envelope.From,
Message: &p2ptest.Message{Value: value},
}); err != nil {
return
}
case <-ctx.Done():
return
}
@ -64,14 +65,14 @@ func TestRouter_Network(t *testing.T) {
// Sending a message to each peer should work.
for _, peer := range peers {
p2ptest.RequireSendReceive(t, channel, peer.NodeID,
p2ptest.RequireSendReceive(ctx, t, channel, peer.NodeID,
&p2ptest.Message{Value: "foo"},
&p2ptest.Message{Value: "foo"},
)
}
// Sending a broadcast should return back a message from all peers.
p2ptest.RequireSend(t, channel, p2p.Envelope{
p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{
Broadcast: true,
Message: &p2ptest.Message{Value: "bar"},
})
@ -151,13 +152,13 @@ func TestRouter_Channel_Basic(t *testing.T) {
require.NoError(t, err)
// We should be able to send on the channel, even though there are no peers.
p2ptest.RequireSend(t, channel, p2p.Envelope{
p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{
To: types.NodeID(strings.Repeat("a", 40)),
Message: &p2ptest.Message{Value: "foo"},
})
// A message to ourselves should be dropped.
p2ptest.RequireSend(t, channel, p2p.Envelope{
p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{
To: selfID,
Message: &p2ptest.Message{Value: "self"},
})
@ -184,40 +185,40 @@ func TestRouter_Channel_SendReceive(t *testing.T) {
// Sending a message a->b should work, and not send anything
// further to a, b, or c.
p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}})
p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}})
p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "foo"}})
p2ptest.RequireEmpty(t, a, b, c)
// Sending a nil message a->b should be dropped.
p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: nil})
p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: nil})
p2ptest.RequireEmpty(t, a, b, c)
// Sending a different message type should be dropped.
p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}})
p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}})
p2ptest.RequireEmpty(t, a, b, c)
// Sending to an unknown peer should be dropped.
p2ptest.RequireSend(t, a, p2p.Envelope{
p2ptest.RequireSend(ctx, t, a, p2p.Envelope{
To: types.NodeID(strings.Repeat("a", 40)),
Message: &p2ptest.Message{Value: "a"},
})
p2ptest.RequireEmpty(t, a, b, c)
// Sending without a recipient should be dropped.
p2ptest.RequireSend(t, a, p2p.Envelope{Message: &p2ptest.Message{Value: "noto"}})
p2ptest.RequireSend(ctx, t, a, p2p.Envelope{Message: &p2ptest.Message{Value: "noto"}})
p2ptest.RequireEmpty(t, a, b, c)
// Sending to self should be dropped.
p2ptest.RequireSend(t, a, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "self"}})
p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "self"}})
p2ptest.RequireEmpty(t, a, b, c)
// Removing b and sending to it should be dropped.
network.Remove(ctx, t, bID)
p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "nob"}})
p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "nob"}})
p2ptest.RequireEmpty(t, a, b, c)
// After all this, sending a message c->a should work.
p2ptest.RequireSend(t, c, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "bar"}})
p2ptest.RequireSend(ctx, t, c, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "bar"}})
p2ptest.RequireReceive(t, a, p2p.Envelope{From: cID, Message: &p2ptest.Message{Value: "bar"}})
p2ptest.RequireEmpty(t, a, b, c)
@ -244,7 +245,7 @@ func TestRouter_Channel_Broadcast(t *testing.T) {
network.Start(ctx, t)
// Sending a broadcast from b should work.
p2ptest.RequireSend(t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}})
p2ptest.RequireSend(ctx, t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}})
p2ptest.RequireReceive(t, a, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}})
p2ptest.RequireReceive(t, c, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}})
p2ptest.RequireReceive(t, d, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}})
@ -252,7 +253,7 @@ func TestRouter_Channel_Broadcast(t *testing.T) {
// Removing one node from the network shouldn't prevent broadcasts from working.
network.Remove(ctx, t, dID)
p2ptest.RequireSend(t, a, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "bar"}})
p2ptest.RequireSend(ctx, t, a, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "bar"}})
p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}})
p2ptest.RequireReceive(t, c, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}})
p2ptest.RequireEmpty(t, a, b, c, d)
@ -285,16 +286,16 @@ func TestRouter_Channel_Wrapper(t *testing.T) {
// Since wrapperMessage implements p2p.Wrapper and handles Message, it
// should automatically wrap and unwrap sent messages -- we prepend the
// wrapper actions to the message value to signal this.
p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}})
p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}})
p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "unwrap:wrap:foo"}})
// If we send a different message that can't be wrapped, it should be dropped.
p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}})
p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}})
p2ptest.RequireEmpty(t, b)
// If we send the wrapper message itself, it should also be passed through
// since WrapperMessage supports it, and should only be unwrapped at the receiver.
p2ptest.RequireSend(t, a, p2p.Envelope{
p2ptest.RequireSend(ctx, t, a, p2p.Envelope{
To: bID,
Message: &wrapperMessage{Message: p2ptest.Message{Value: "foo"}},
})
@ -960,10 +961,10 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) {
channel, err := router.OpenChannel(ctx, chDesc)
require.NoError(t, err)
channel.Out <- p2p.Envelope{
require.NoError(t, channel.Send(ctx, p2p.Envelope{
To: peer.NodeID,
Message: &p2ptest.Message{Value: "Hi"},
}
}))
require.NoError(t, router.Stop())
mockTransport.AssertExpectations(t)


+ 7
- 4
internal/statesync/dispatcher.go View File

@ -26,16 +26,16 @@ var (
// NOTE: It is not the responsibility of the dispatcher to verify the light blocks.
type Dispatcher struct {
// the channel with which to send light block requests on
requestCh chan<- p2p.Envelope
requestCh *p2p.Channel
mtx sync.Mutex
// all pending calls that have been dispatched and are awaiting an answer
calls map[types.NodeID]chan *types.LightBlock
}
func NewDispatcher(requestCh chan<- p2p.Envelope) *Dispatcher {
func NewDispatcher(requestChannel *p2p.Channel) *Dispatcher {
return &Dispatcher{
requestCh: requestCh,
requestCh: requestChannel,
calls: make(map[types.NodeID]chan *types.LightBlock),
}
}
@ -91,11 +91,14 @@ func (d *Dispatcher) dispatch(ctx context.Context, peer types.NodeID, height int
d.calls[peer] = ch
// send request
d.requestCh <- p2p.Envelope{
if err := d.requestCh.Send(ctx, p2p.Envelope{
To: peer,
Message: &ssproto.LightBlockRequest{
Height: uint64(height),
},
}); err != nil {
close(ch)
return ch, err
}
return ch, nil


+ 45
- 18
internal/statesync/dispatcher_test.go View File

@ -18,16 +18,32 @@ import (
"github.com/tendermint/tendermint/types"
)
type channelInternal struct {
In chan p2p.Envelope
Out chan p2p.Envelope
Error chan p2p.PeerError
}
func testChannel(size int) (*channelInternal, *p2p.Channel) {
in := &channelInternal{
In: make(chan p2p.Envelope, size),
Out: make(chan p2p.Envelope, size),
Error: make(chan p2p.PeerError, size),
}
return in, p2p.NewChannel(0, nil, in.In, in.Out, in.Error)
}
func TestDispatcherBasic(t *testing.T) {
t.Cleanup(leaktest.Check(t))
const numPeers = 5
ch := make(chan p2p.Envelope, 100)
closeCh := make(chan struct{})
defer close(closeCh)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
chans, ch := testChannel(100)
d := NewDispatcher(ch)
go handleRequests(t, d, ch, closeCh)
go handleRequests(ctx, t, d, chans.Out)
peers := createPeerSet(numPeers)
wg := sync.WaitGroup{}
@ -52,19 +68,24 @@ func TestDispatcherBasic(t *testing.T) {
func TestDispatcherReturnsNoBlock(t *testing.T) {
t.Cleanup(leaktest.Check(t))
ch := make(chan p2p.Envelope, 100)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
chans, ch := testChannel(100)
d := NewDispatcher(ch)
doneCh := make(chan struct{})
peer := factory.NodeID("a")
go func() {
<-ch
<-chans.Out
require.NoError(t, d.Respond(nil, peer))
close(doneCh)
cancel()
}()
lb, err := d.LightBlock(context.Background(), 1, peer)
<-doneCh
lb, err := d.LightBlock(ctx, 1, peer)
<-ctx.Done()
require.Nil(t, lb)
require.Nil(t, err)
@ -72,11 +93,15 @@ func TestDispatcherReturnsNoBlock(t *testing.T) {
func TestDispatcherTimeOutWaitingOnLightBlock(t *testing.T) {
t.Cleanup(leaktest.Check(t))
ch := make(chan p2p.Envelope, 100)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, ch := testChannel(100)
d := NewDispatcher(ch)
peer := factory.NodeID("a")
ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Millisecond)
ctx, cancelFunc := context.WithTimeout(ctx, 10*time.Millisecond)
defer cancelFunc()
lb, err := d.LightBlock(ctx, 1, peer)
@ -89,13 +114,15 @@ func TestDispatcherTimeOutWaitingOnLightBlock(t *testing.T) {
func TestDispatcherProviders(t *testing.T) {
t.Cleanup(leaktest.Check(t))
ch := make(chan p2p.Envelope, 100)
chainID := "test-chain"
closeCh := make(chan struct{})
defer close(closeCh)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
chans, ch := testChannel(100)
d := NewDispatcher(ch)
go handleRequests(t, d, ch, closeCh)
go handleRequests(ctx, t, d, chans.Out)
peers := createPeerSet(5)
providers := make([]*BlockProvider, len(peers))
@ -270,7 +297,7 @@ func TestPeerListRemove(t *testing.T) {
// handleRequests is a helper function usually run in a separate go routine to
// imitate the expected responses of the reactor wired to the dispatcher
func handleRequests(t *testing.T, d *Dispatcher, ch chan p2p.Envelope, closeCh chan struct{}) {
func handleRequests(ctx context.Context, t *testing.T, d *Dispatcher, ch chan p2p.Envelope) {
t.Helper()
for {
select {
@ -280,7 +307,7 @@ func handleRequests(t *testing.T, d *Dispatcher, ch chan p2p.Envelope, closeCh c
resp := mockLBResp(t, peer, int64(height), time.Now())
block, _ := resp.block.ToProto()
require.NoError(t, d.Respond(block, resp.peer))
case <-closeCh:
case <-ctx.Done():
return
}
}


+ 33
- 33
internal/statesync/reactor.go View File

@ -195,7 +195,7 @@ func NewReactor(
stateStore: stateStore,
blockStore: blockStore,
peers: newPeerList(),
dispatcher: NewDispatcher(blockCh.Out),
dispatcher: NewDispatcher(blockCh),
providers: make(map[types.NodeID]*BlockProvider),
metrics: ssMetrics,
}
@ -256,8 +256,8 @@ func (r *Reactor) Sync(ctx context.Context) (sm.State, error) {
r.conn,
r.connQuery,
r.stateProvider,
r.snapshotCh.Out,
r.chunkCh.Out,
r.snapshotCh,
r.chunkCh,
r.tempDir,
r.metrics,
)
@ -270,17 +270,12 @@ func (r *Reactor) Sync(ctx context.Context) (sm.State, error) {
r.mtx.Unlock()
}()
requestSnapshotsHook := func() {
requestSnapshotsHook := func() error {
// request snapshots from all currently connected peers
msg := p2p.Envelope{
return r.snapshotCh.Send(ctx, p2p.Envelope{
Broadcast: true,
Message: &ssproto.SnapshotsRequest{},
}
select {
case <-ctx.Done():
case r.snapshotCh.Out <- msg:
}
})
}
state, commit, err := r.syncer.SyncAny(ctx, r.cfg.DiscoveryTime, requestSnapshotsHook)
@ -508,7 +503,7 @@ func (r *Reactor) backfill(
// handleSnapshotMessage handles envelopes sent from peers on the
// SnapshotChannel. It returns an error only if the Envelope.Message is unknown
// for this channel. This should never be called outside of handleMessage.
func (r *Reactor) handleSnapshotMessage(envelope p2p.Envelope) error {
func (r *Reactor) handleSnapshotMessage(ctx context.Context, envelope p2p.Envelope) error {
logger := r.logger.With("peer", envelope.From)
switch msg := envelope.Message.(type) {
@ -526,7 +521,8 @@ func (r *Reactor) handleSnapshotMessage(envelope p2p.Envelope) error {
"format", snapshot.Format,
"peer", envelope.From,
)
r.snapshotCh.Out <- p2p.Envelope{
if err := r.snapshotCh.Send(ctx, p2p.Envelope{
To: envelope.From,
Message: &ssproto.SnapshotsResponse{
Height: snapshot.Height,
@ -535,6 +531,8 @@ func (r *Reactor) handleSnapshotMessage(envelope p2p.Envelope) error {
Hash: snapshot.Hash,
Metadata: snapshot.Metadata,
},
}); err != nil {
return err
}
}
@ -577,7 +575,7 @@ func (r *Reactor) handleSnapshotMessage(envelope p2p.Envelope) error {
// handleChunkMessage handles envelopes sent from peers on the ChunkChannel.
// It returns an error only if the Envelope.Message is unknown for this channel.
// This should never be called outside of handleMessage.
func (r *Reactor) handleChunkMessage(envelope p2p.Envelope) error {
func (r *Reactor) handleChunkMessage(ctx context.Context, envelope p2p.Envelope) error {
switch msg := envelope.Message.(type) {
case *ssproto.ChunkRequest:
r.logger.Debug(
@ -611,7 +609,7 @@ func (r *Reactor) handleChunkMessage(envelope p2p.Envelope) error {
"chunk", msg.Index,
"peer", envelope.From,
)
r.chunkCh.Out <- p2p.Envelope{
if err := r.chunkCh.Send(ctx, p2p.Envelope{
To: envelope.From,
Message: &ssproto.ChunkResponse{
Height: msg.Height,
@ -620,6 +618,8 @@ func (r *Reactor) handleChunkMessage(envelope p2p.Envelope) error {
Chunk: resp.Chunk,
Missing: resp.Chunk == nil,
},
}); err != nil {
return err
}
case *ssproto.ChunkResponse:
@ -664,7 +664,7 @@ func (r *Reactor) handleChunkMessage(envelope p2p.Envelope) error {
return nil
}
func (r *Reactor) handleLightBlockMessage(envelope p2p.Envelope) error {
func (r *Reactor) handleLightBlockMessage(ctx context.Context, envelope p2p.Envelope) error {
switch msg := envelope.Message.(type) {
case *ssproto.LightBlockRequest:
r.logger.Info("received light block request", "height", msg.Height)
@ -674,11 +674,13 @@ func (r *Reactor) handleLightBlockMessage(envelope p2p.Envelope) error {
return err
}
if lb == nil {
r.blockCh.Out <- p2p.Envelope{
if err := r.blockCh.Send(ctx, p2p.Envelope{
To: envelope.From,
Message: &ssproto.LightBlockResponse{
LightBlock: nil,
},
}); err != nil {
return err
}
return nil
}
@ -691,13 +693,14 @@ func (r *Reactor) handleLightBlockMessage(envelope p2p.Envelope) error {
// NOTE: If we don't have the light block we will send a nil light block
// back to the requested node, indicating that we don't have it.
r.blockCh.Out <- p2p.Envelope{
if err := r.blockCh.Send(ctx, p2p.Envelope{
To: envelope.From,
Message: &ssproto.LightBlockResponse{
LightBlock: lbproto,
},
}); err != nil {
return err
}
case *ssproto.LightBlockResponse:
var height int64
if msg.LightBlock != nil {
@ -715,7 +718,7 @@ func (r *Reactor) handleLightBlockMessage(envelope p2p.Envelope) error {
return nil
}
func (r *Reactor) handleParamsMessage(envelope p2p.Envelope) error {
func (r *Reactor) handleParamsMessage(ctx context.Context, envelope p2p.Envelope) error {
switch msg := envelope.Message.(type) {
case *ssproto.ParamsRequest:
r.logger.Debug("received consensus params request", "height", msg.Height)
@ -726,14 +729,15 @@ func (r *Reactor) handleParamsMessage(envelope p2p.Envelope) error {
}
cpproto := cp.ToProto()
r.paramsCh.Out <- p2p.Envelope{
if err := r.paramsCh.Send(ctx, p2p.Envelope{
To: envelope.From,
Message: &ssproto.ParamsResponse{
Height: msg.Height,
ConsensusParams: cpproto,
},
}); err != nil {
return err
}
case *ssproto.ParamsResponse:
r.mtx.RLock()
defer r.mtx.RUnlock()
@ -761,7 +765,7 @@ func (r *Reactor) handleParamsMessage(envelope p2p.Envelope) error {
// handleMessage handles an Envelope sent from a peer on a specific p2p Channel.
// It will handle errors and any possible panics gracefully. A caller can handle
// any error returned by sending a PeerError on the respective channel.
func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err error) {
func (r *Reactor) handleMessage(ctx context.Context, chID p2p.ChannelID, envelope p2p.Envelope) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("panic in processing message: %v", e)
@ -777,17 +781,13 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err
switch chID {
case SnapshotChannel:
err = r.handleSnapshotMessage(envelope)
err = r.handleSnapshotMessage(ctx, envelope)
case ChunkChannel:
err = r.handleChunkMessage(envelope)
err = r.handleChunkMessage(ctx, envelope)
case LightBlockChannel:
err = r.handleLightBlockMessage(envelope)
err = r.handleLightBlockMessage(ctx, envelope)
case ParamsChannel:
err = r.handleParamsMessage(envelope)
err = r.handleParamsMessage(ctx, envelope)
default:
err = fmt.Errorf("unknown channel ID (%d) for envelope (%v)", chID, envelope)
}
@ -806,7 +806,7 @@ func (r *Reactor) processCh(ctx context.Context, ch *p2p.Channel, chName string)
r.logger.Debug("channel closed", "channel", chName)
return
case envelope := <-ch.In:
if err := r.handleMessage(ch.ID, envelope); err != nil {
if err := r.handleMessage(ctx, ch.ID, envelope); err != nil {
r.logger.Error("failed to process message",
"err", err,
"channel", chName,
@ -999,7 +999,7 @@ func (r *Reactor) initStateProvider(ctx context.Context, chainID string, initial
providers[idx] = NewBlockProvider(p, chainID, r.dispatcher)
}
r.stateProvider, err = NewP2PStateProvider(ctx, chainID, initialHeight, providers, to, r.paramsCh.Out, spLogger)
r.stateProvider, err = NewP2PStateProvider(ctx, chainID, initialHeight, providers, to, r.paramsCh, spLogger)
if err != nil {
return fmt.Errorf("failed to initialize P2P state provider: %w", err)
}


+ 2
- 2
internal/statesync/reactor_test.go View File

@ -170,8 +170,8 @@ func setup(
conn,
connQuery,
stateProvider,
rts.snapshotOutCh,
rts.chunkOutCh,
rts.snapshotChannel,
rts.chunkChannel,
"",
rts.reactor.metrics,
)


+ 5
- 7
internal/statesync/stateprovider.go View File

@ -200,7 +200,7 @@ type stateProviderP2P struct {
tmsync.Mutex // light.Client is not concurrency-safe
lc *light.Client
initialHeight int64
paramsSendCh chan<- p2p.Envelope
paramsSendCh *p2p.Channel
paramsRecvCh chan types.ConsensusParams
}
@ -212,7 +212,7 @@ func NewP2PStateProvider(
initialHeight int64,
providers []lightprovider.Provider,
trustOptions light.TrustOptions,
paramsSendCh chan<- p2p.Envelope,
paramsSendCh *p2p.Channel,
logger log.Logger,
) (StateProvider, error) {
if len(providers) < 2 {
@ -382,15 +382,13 @@ func (s *stateProviderP2P) tryGetConsensusParamsFromWitnesses(
return nil, fmt.Errorf("invalid provider (%s) node id: %w", p.String(), err)
}
select {
case s.paramsSendCh <- p2p.Envelope{
if err := s.paramsSendCh.Send(ctx, p2p.Envelope{
To: peer,
Message: &ssproto.ParamsRequest{
Height: uint64(height),
},
}:
case <-ctx.Done():
return nil, ctx.Err()
}); err != nil {
return nil, err
}
select {


+ 23
- 31
internal/statesync/syncer.go View File

@ -57,8 +57,8 @@ type syncer struct {
conn proxy.AppConnSnapshot
connQuery proxy.AppConnQuery
snapshots *snapshotPool
snapshotCh chan<- p2p.Envelope
chunkCh chan<- p2p.Envelope
snapshotCh *p2p.Channel
chunkCh *p2p.Channel
tempDir string
fetchers int32
retryTimeout time.Duration
@ -79,8 +79,8 @@ func newSyncer(
conn proxy.AppConnSnapshot,
connQuery proxy.AppConnQuery,
stateProvider StateProvider,
snapshotCh chan<- p2p.Envelope,
chunkCh chan<- p2p.Envelope,
snapshotCh *p2p.Channel,
chunkCh *p2p.Channel,
tempDir string,
metrics *Metrics,
) *syncer {
@ -138,29 +138,13 @@ func (s *syncer) AddSnapshot(peerID types.NodeID, snapshot *snapshot) (bool, err
// AddPeer adds a peer to the pool. For now we just keep it simple and send a
// single request to discover snapshots, later we may want to do retries and stuff.
func (s *syncer) AddPeer(ctx context.Context, peerID types.NodeID) (err error) {
defer func() {
// TODO: remove panic recover once AddPeer can no longer accientally send on
// closed channel.
// This recover was added to protect against the p2p message being sent
// to the snapshot channel after the snapshot channel was closed.
if r := recover(); r != nil {
err = fmt.Errorf("panic sending peer snapshot request: %v", r)
}
}()
func (s *syncer) AddPeer(ctx context.Context, peerID types.NodeID) error {
s.logger.Debug("Requesting snapshots from peer", "peer", peerID)
msg := p2p.Envelope{
return s.snapshotCh.Send(ctx, p2p.Envelope{
To: peerID,
Message: &ssproto.SnapshotsRequest{},
}
select {
case <-ctx.Done():
case s.snapshotCh <- msg:
}
return err
})
}
// RemovePeer removes a peer from the pool.
@ -175,14 +159,16 @@ func (s *syncer) RemovePeer(peerID types.NodeID) {
func (s *syncer) SyncAny(
ctx context.Context,
discoveryTime time.Duration,
requestSnapshots func(),
requestSnapshots func() error,
) (sm.State, *types.Commit, error) {
if discoveryTime != 0 && discoveryTime < minimumDiscoveryTime {
discoveryTime = minimumDiscoveryTime
}
if discoveryTime > 0 {
requestSnapshots()
if err := requestSnapshots(); err != nil {
return sm.State{}, nil, err
}
s.logger.Info(fmt.Sprintf("Discovering snapshots for %v", discoveryTime))
time.Sleep(discoveryTime)
}
@ -506,7 +492,9 @@ func (s *syncer) fetchChunks(ctx context.Context, snapshot *snapshot, chunks *ch
ticker := time.NewTicker(s.retryTimeout)
defer ticker.Stop()
s.requestChunk(ctx, snapshot, index)
if err := s.requestChunk(ctx, snapshot, index); err != nil {
return
}
select {
case <-chunks.WaitFor(index):
@ -524,12 +512,16 @@ func (s *syncer) fetchChunks(ctx context.Context, snapshot *snapshot, chunks *ch
}
// requestChunk requests a chunk from a peer.
func (s *syncer) requestChunk(ctx context.Context, snapshot *snapshot, chunk uint32) {
//
// returns nil if there are no peers for the given snapshot or the
// request is successfully made and an error if the request cannot be
// completed
func (s *syncer) requestChunk(ctx context.Context, snapshot *snapshot, chunk uint32) error {
peer := s.snapshots.GetPeer(snapshot)
if peer == "" {
s.logger.Error("No valid peers found for snapshot", "height", snapshot.Height,
"format", snapshot.Format, "hash", snapshot.Hash)
return
return nil
}
s.logger.Debug(
@ -549,10 +541,10 @@ func (s *syncer) requestChunk(ctx context.Context, snapshot *snapshot, chunk uin
},
}
select {
case s.chunkCh <- msg:
case <-ctx.Done():
if err := s.chunkCh.Send(ctx, msg); err != nil {
return err
}
return nil
}
// verifyApp verifies the sync, checking the app hash and last block height. It returns the


+ 7
- 7
internal/statesync/syncer_test.go View File

@ -184,7 +184,7 @@ func TestSyncer_SyncAny(t *testing.T) {
LastBlockAppHash: []byte("app_hash"),
}, nil)
newState, lastCommit, err := rts.syncer.SyncAny(ctx, 0, func() {})
newState, lastCommit, err := rts.syncer.SyncAny(ctx, 0, func() error { return nil })
require.NoError(t, err)
wg.Wait()
@ -223,7 +223,7 @@ func TestSyncer_SyncAny_noSnapshots(t *testing.T) {
rts := setup(ctx, t, nil, nil, stateProvider, 2)
_, _, err := rts.syncer.SyncAny(ctx, 0, func() {})
_, _, err := rts.syncer.SyncAny(ctx, 0, func() error { return nil })
require.Equal(t, errNoSnapshots, err)
}
@ -246,7 +246,7 @@ func TestSyncer_SyncAny_abort(t *testing.T) {
Snapshot: toABCI(s), AppHash: []byte("app_hash"),
}).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_ABORT}, nil)
_, _, err = rts.syncer.SyncAny(ctx, 0, func() {})
_, _, err = rts.syncer.SyncAny(ctx, 0, func() error { return nil })
require.Equal(t, errAbort, err)
rts.conn.AssertExpectations(t)
}
@ -288,7 +288,7 @@ func TestSyncer_SyncAny_reject(t *testing.T) {
Snapshot: toABCI(s11), AppHash: []byte("app_hash"),
}).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil)
_, _, err = rts.syncer.SyncAny(ctx, 0, func() {})
_, _, err = rts.syncer.SyncAny(ctx, 0, func() error { return nil })
require.Equal(t, errNoSnapshots, err)
rts.conn.AssertExpectations(t)
}
@ -326,7 +326,7 @@ func TestSyncer_SyncAny_reject_format(t *testing.T) {
Snapshot: toABCI(s11), AppHash: []byte("app_hash"),
}).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_ABORT}, nil)
_, _, err = rts.syncer.SyncAny(ctx, 0, func() {})
_, _, err = rts.syncer.SyncAny(ctx, 0, func() error { return nil })
require.Equal(t, errAbort, err)
rts.conn.AssertExpectations(t)
}
@ -375,7 +375,7 @@ func TestSyncer_SyncAny_reject_sender(t *testing.T) {
Snapshot: toABCI(sa), AppHash: []byte("app_hash"),
}).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil)
_, _, err = rts.syncer.SyncAny(ctx, 0, func() {})
_, _, err = rts.syncer.SyncAny(ctx, 0, func() error { return nil })
require.Equal(t, errNoSnapshots, err)
rts.conn.AssertExpectations(t)
}
@ -401,7 +401,7 @@ func TestSyncer_SyncAny_abciError(t *testing.T) {
Snapshot: toABCI(s), AppHash: []byte("app_hash"),
}).Once().Return(nil, errBoom)
_, _, err = rts.syncer.SyncAny(ctx, 0, func() {})
_, _, err = rts.syncer.SyncAny(ctx, 0, func() error { return nil })
require.True(t, errors.Is(err, errBoom))
rts.conn.AssertExpectations(t)
}


+ 4
- 2
libs/events/event_cache.go View File

@ -1,5 +1,7 @@
package events
import "context"
// An EventCache buffers events for a Fireable
// All events are cached. Filtering happens on Flush
type EventCache struct {
@ -28,9 +30,9 @@ func (evc *EventCache) FireEvent(event string, data EventData) {
// Fire events by running evsw.FireEvent on all cached events. Blocks.
// Clears cached events
func (evc *EventCache) Flush() {
func (evc *EventCache) Flush(ctx context.Context) {
for _, ei := range evc.events {
evc.evsw.FireEvent(ei.event, ei.data)
evc.evsw.FireEvent(ctx, ei.event, ei.data)
}
// Clear the buffer, since we only add to it with append it's safe to just set it to nil and maybe safe an allocation
evc.events = nil


+ 7
- 5
libs/events/event_cache_test.go View File

@ -16,23 +16,25 @@ func TestEventCache_Flush(t *testing.T) {
err := evsw.Start(ctx)
require.NoError(t, err)
err = evsw.AddListenerForEvent("nothingness", "", func(data EventData) {
err = evsw.AddListenerForEvent("nothingness", "", func(_ context.Context, data EventData) error {
// Check we are not initializing an empty buffer full of zeroed eventInfos in the EventCache
require.FailNow(t, "We should never receive a message on this switch since none are fired")
return nil
})
require.NoError(t, err)
evc := NewEventCache(evsw)
evc.Flush()
evc.Flush(ctx)
// Check after reset
evc.Flush()
evc.Flush(ctx)
fail := true
pass := false
err = evsw.AddListenerForEvent("somethingness", "something", func(data EventData) {
err = evsw.AddListenerForEvent("somethingness", "something", func(_ context.Context, data EventData) error {
if fail {
require.FailNow(t, "Shouldn't see a message until flushed")
}
pass = true
return nil
})
require.NoError(t, err)
@ -40,6 +42,6 @@ func TestEventCache_Flush(t *testing.T) {
evc.FireEvent("something", struct{ int }{2})
evc.FireEvent("something", struct{ int }{3})
fail = false
evc.Flush()
evc.Flush(ctx)
assert.True(t, pass)
}

+ 9
- 6
libs/events/events.go View File

@ -33,7 +33,7 @@ type Eventable interface {
//
// FireEvent fires an event with the given name and data.
type Fireable interface {
FireEvent(eventValue string, data EventData)
FireEvent(ctx context.Context, eventValue string, data EventData)
}
// EventSwitch is the interface for synchronous pubsub, where listeners
@ -148,7 +148,7 @@ func (evsw *eventSwitch) RemoveListenerForEvent(event string, listenerID string)
}
}
func (evsw *eventSwitch) FireEvent(event string, data EventData) {
func (evsw *eventSwitch) FireEvent(ctx context.Context, event string, data EventData) {
// Get the eventCell
evsw.mtx.RLock()
eventCell := evsw.eventCells[event]
@ -159,7 +159,7 @@ func (evsw *eventSwitch) FireEvent(event string, data EventData) {
}
// Fire event for all listeners in eventCell
eventCell.FireEvent(data)
eventCell.FireEvent(ctx, data)
}
//-----------------------------------------------------------------------------
@ -190,7 +190,7 @@ func (cell *eventCell) RemoveListener(listenerID string) int {
return numListeners
}
func (cell *eventCell) FireEvent(data EventData) {
func (cell *eventCell) FireEvent(ctx context.Context, data EventData) {
cell.mtx.RLock()
eventCallbacks := make([]EventCallback, 0, len(cell.listeners))
for _, cb := range cell.listeners {
@ -199,13 +199,16 @@ func (cell *eventCell) FireEvent(data EventData) {
cell.mtx.RUnlock()
for _, cb := range eventCallbacks {
cb(data)
if err := cb(ctx, data); err != nil {
// should we log or abort here?
continue
}
}
}
//-----------------------------------------------------------------------------
type EventCallback func(data EventData)
type EventCallback func(ctx context.Context, data EventData) error
type eventListener struct {
id string


+ 166
- 65
libs/events/events_test.go View File

@ -24,12 +24,17 @@ func TestAddListenerForEventFireOnce(t *testing.T) {
messages := make(chan EventData)
require.NoError(t, evsw.AddListenerForEvent("listener", "event",
func(data EventData) {
func(ctx context.Context, data EventData) error {
// test there's no deadlock if we remove the listener inside a callback
evsw.RemoveListener("listener")
messages <- data
select {
case messages <- data:
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
go evsw.FireEvent("event", "data")
go evsw.FireEvent(ctx, "event", "data")
received := <-messages
if received != "data" {
t.Errorf("message received does not match: %v", received)
@ -51,13 +56,18 @@ func TestAddListenerForEventFireMany(t *testing.T) {
numbers := make(chan uint64, 4)
// subscribe one listener for one event
require.NoError(t, evsw.AddListenerForEvent("listener", "event",
func(data EventData) {
numbers <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
// collect received events
go sumReceivedNumbers(numbers, doneSum)
// go fire events
go fireEvents(evsw, "event", doneSending, uint64(1))
go fireEvents(ctx, evsw, "event", doneSending, uint64(1))
checkSum := <-doneSending
close(numbers)
eventSum := <-doneSum
@ -84,23 +94,38 @@ func TestAddListenerForDifferentEvents(t *testing.T) {
numbers := make(chan uint64, 4)
// subscribe one listener to three events
require.NoError(t, evsw.AddListenerForEvent("listener", "event1",
func(data EventData) {
numbers <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener", "event2",
func(data EventData) {
numbers <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener", "event3",
func(data EventData) {
numbers <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
// collect received events
go sumReceivedNumbers(numbers, doneSum)
// go fire events
go fireEvents(evsw, "event1", doneSending1, uint64(1))
go fireEvents(evsw, "event2", doneSending2, uint64(1))
go fireEvents(evsw, "event3", doneSending3, uint64(1))
go fireEvents(ctx, evsw, "event1", doneSending1, uint64(1))
go fireEvents(ctx, evsw, "event2", doneSending2, uint64(1))
go fireEvents(ctx, evsw, "event3", doneSending3, uint64(1))
var checkSum uint64
checkSum += <-doneSending1
checkSum += <-doneSending2
@ -134,33 +159,58 @@ func TestAddDifferentListenerForDifferentEvents(t *testing.T) {
numbers2 := make(chan uint64, 4)
// subscribe two listener to three events
require.NoError(t, evsw.AddListenerForEvent("listener1", "event1",
func(data EventData) {
numbers1 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers1 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener1", "event2",
func(data EventData) {
numbers1 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers1 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener1", "event3",
func(data EventData) {
numbers1 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers1 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener2", "event2",
func(data EventData) {
numbers2 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers2 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener2", "event3",
func(data EventData) {
numbers2 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers2 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
// collect received events for listener1
go sumReceivedNumbers(numbers1, doneSum1)
// collect received events for listener2
go sumReceivedNumbers(numbers2, doneSum2)
// go fire events
go fireEvents(evsw, "event1", doneSending1, uint64(1))
go fireEvents(evsw, "event2", doneSending2, uint64(1001))
go fireEvents(evsw, "event3", doneSending3, uint64(2001))
go fireEvents(ctx, evsw, "event1", doneSending1, uint64(1))
go fireEvents(ctx, evsw, "event2", doneSending2, uint64(1001))
go fireEvents(ctx, evsw, "event3", doneSending3, uint64(2001))
checkSumEvent1 := <-doneSending1
checkSumEvent2 := <-doneSending2
checkSumEvent3 := <-doneSending3
@ -209,9 +259,10 @@ func TestAddAndRemoveListenerConcurrency(t *testing.T) {
// we explicitly ignore errors here, since the listener will sometimes be removed
// (that's what we're testing)
_ = evsw.AddListenerForEvent("listener", fmt.Sprintf("event%d", index),
func(data EventData) {
func(ctx context.Context, data EventData) error {
t.Errorf("should not run callback for %d.\n", index)
stopInputEvent = true
return nil
})
}
}()
@ -222,7 +273,7 @@ func TestAddAndRemoveListenerConcurrency(t *testing.T) {
evsw.RemoveListener("listener") // remove the last listener
for i := 0; i < roundCount && !stopInputEvent; i++ {
evsw.FireEvent(fmt.Sprintf("event%d", i), uint64(1001))
evsw.FireEvent(ctx, fmt.Sprintf("event%d", i), uint64(1001))
}
}
@ -245,23 +296,33 @@ func TestAddAndRemoveListener(t *testing.T) {
numbers2 := make(chan uint64, 4)
// subscribe two listener to three events
require.NoError(t, evsw.AddListenerForEvent("listener", "event1",
func(data EventData) {
numbers1 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers1 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener", "event2",
func(data EventData) {
numbers2 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers2 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
// collect received events for event1
go sumReceivedNumbers(numbers1, doneSum1)
// collect received events for event2
go sumReceivedNumbers(numbers2, doneSum2)
// go fire events
go fireEvents(evsw, "event1", doneSending1, uint64(1))
go fireEvents(ctx, evsw, "event1", doneSending1, uint64(1))
checkSumEvent1 := <-doneSending1
// after sending all event1, unsubscribe for all events
evsw.RemoveListener("listener")
go fireEvents(evsw, "event2", doneSending2, uint64(1001))
go fireEvents(ctx, evsw, "event2", doneSending2, uint64(1001))
checkSumEvent2 := <-doneSending2
close(numbers1)
close(numbers2)
@ -287,17 +348,19 @@ func TestRemoveListener(t *testing.T) {
sum1, sum2 := 0, 0
// add some listeners and make sure they work
require.NoError(t, evsw.AddListenerForEvent("listener", "event1",
func(data EventData) {
func(ctx context.Context, data EventData) error {
sum1++
return nil
}))
require.NoError(t, evsw.AddListenerForEvent("listener", "event2",
func(data EventData) {
func(ctx context.Context, data EventData) error {
sum2++
return nil
}))
for i := 0; i < count; i++ {
evsw.FireEvent("event1", true)
evsw.FireEvent("event2", true)
evsw.FireEvent(ctx, "event1", true)
evsw.FireEvent(ctx, "event2", true)
}
assert.Equal(t, count, sum1)
assert.Equal(t, count, sum2)
@ -305,8 +368,8 @@ func TestRemoveListener(t *testing.T) {
// remove one by event and make sure it is gone
evsw.RemoveListenerForEvent("event2", "listener")
for i := 0; i < count; i++ {
evsw.FireEvent("event1", true)
evsw.FireEvent("event2", true)
evsw.FireEvent(ctx, "event1", true)
evsw.FireEvent(ctx, "event2", true)
}
assert.Equal(t, count*2, sum1)
assert.Equal(t, count, sum2)
@ -314,8 +377,8 @@ func TestRemoveListener(t *testing.T) {
// remove the listener entirely and make sure both gone
evsw.RemoveListener("listener")
for i := 0; i < count; i++ {
evsw.FireEvent("event1", true)
evsw.FireEvent("event2", true)
evsw.FireEvent(ctx, "event1", true)
evsw.FireEvent(ctx, "event2", true)
}
assert.Equal(t, count*2, sum1)
assert.Equal(t, count, sum2)
@ -347,28 +410,58 @@ func TestRemoveListenersAsync(t *testing.T) {
numbers2 := make(chan uint64, 4)
// subscribe two listener to three events
require.NoError(t, evsw.AddListenerForEvent("listener1", "event1",
func(data EventData) {
numbers1 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers1 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener1", "event2",
func(data EventData) {
numbers1 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers1 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener1", "event3",
func(data EventData) {
numbers1 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers1 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener2", "event1",
func(data EventData) {
numbers2 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers2 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener2", "event2",
func(data EventData) {
numbers2 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers2 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
require.NoError(t, evsw.AddListenerForEvent("listener2", "event3",
func(data EventData) {
numbers2 <- data.(uint64)
func(ctx context.Context, data EventData) error {
select {
case numbers2 <- data.(uint64):
return nil
case <-ctx.Done():
return ctx.Err()
}
}))
// collect received events for event1
go sumReceivedNumbers(numbers1, doneSum1)
@ -382,7 +475,7 @@ func TestRemoveListenersAsync(t *testing.T) {
eventNumber := r1.Intn(3) + 1
go evsw.AddListenerForEvent(fmt.Sprintf("listener%v", listenerNumber), //nolint:errcheck // ignore for tests
fmt.Sprintf("event%v", eventNumber),
func(_ EventData) {})
func(context.Context, EventData) error { return nil })
}
}
removeListenersStress := func() {
@ -395,10 +488,10 @@ func TestRemoveListenersAsync(t *testing.T) {
}
addListenersStress()
// go fire events
go fireEvents(evsw, "event1", doneSending1, uint64(1))
go fireEvents(ctx, evsw, "event1", doneSending1, uint64(1))
removeListenersStress()
go fireEvents(evsw, "event2", doneSending2, uint64(1001))
go fireEvents(evsw, "event3", doneSending3, uint64(2001))
go fireEvents(ctx, evsw, "event2", doneSending2, uint64(1001))
go fireEvents(ctx, evsw, "event3", doneSending3, uint64(2001))
checkSumEvent1 := <-doneSending1
checkSumEvent2 := <-doneSending2
checkSumEvent3 := <-doneSending3
@ -437,13 +530,21 @@ func sumReceivedNumbers(numbers, doneSum chan uint64) {
// to `offset` + 999. It additionally returns the addition of all integers
// sent on `doneChan` for assertion that all events have been sent, and enabling
// the test to assert all events have also been received.
func fireEvents(evsw Fireable, event string, doneChan chan uint64,
offset uint64) {
func fireEvents(ctx context.Context, evsw Fireable, event string, doneChan chan uint64, offset uint64) {
defer close(doneChan)
var sentSum uint64
for i := offset; i <= offset+uint64(999); i++ {
if ctx.Err() != nil {
break
}
evsw.FireEvent(ctx, event, i)
sentSum += i
evsw.FireEvent(event, i)
}
doneChan <- sentSum
close(doneChan)
select {
case <-ctx.Done():
case doneChan <- sentSum:
}
}

Loading…
Cancel
Save