diff --git a/internal/consensus/reactor.go b/internal/consensus/reactor.go index 803834d96..d0d625e26 100644 --- a/internal/consensus/reactor.go +++ b/internal/consensus/reactor.go @@ -1470,7 +1470,7 @@ func (r *Reactor) peerStatsRoutine(ctx context.Context) { switch msg.Msg.(type) { case *VoteMessage: if numVotes := ps.RecordVote(); numVotes%votesToContributeToBecomeGoodPeer == 0 { - r.peerUpdates.SendUpdate(p2p.PeerUpdate{ + r.peerUpdates.SendUpdate(ctx, p2p.PeerUpdate{ NodeID: msg.PeerID, Status: p2p.PeerStatusGood, }) @@ -1478,7 +1478,7 @@ func (r *Reactor) peerStatsRoutine(ctx context.Context) { case *BlockPartMessage: if numParts := ps.RecordBlockPart(); numParts%blocksToContributeToBecomeGoodPeer == 0 { - r.peerUpdates.SendUpdate(p2p.PeerUpdate{ + r.peerUpdates.SendUpdate(ctx, p2p.PeerUpdate{ NodeID: msg.PeerID, Status: p2p.PeerStatusGood, }) diff --git a/internal/p2p/conn/connection.go b/internal/p2p/conn/connection.go index a2808f216..fa21358c1 100644 --- a/internal/p2p/conn/connection.go +++ b/internal/p2p/conn/connection.go @@ -49,8 +49,8 @@ const ( defaultPongTimeout = 45 * time.Second ) -type receiveCbFunc func(chID ChannelID, msgBytes []byte) -type errorCbFunc func(interface{}) +type receiveCbFunc func(ctx context.Context, chID ChannelID, msgBytes []byte) +type errorCbFunc func(context.Context, interface{}) /* Each peer has one `MConnection` (multiplex connection) instance. @@ -286,21 +286,21 @@ func (c *MConnection) flush() { } // Catch panics, usually caused by remote disconnects. -func (c *MConnection) _recover() { +func (c *MConnection) _recover(ctx context.Context) { if r := recover(); r != nil { c.logger.Error("MConnection panicked", "err", r, "stack", string(debug.Stack())) - c.stopForError(fmt.Errorf("recovered from panic: %v", r)) + c.stopForError(ctx, fmt.Errorf("recovered from panic: %v", r)) } } -func (c *MConnection) stopForError(r interface{}) { +func (c *MConnection) stopForError(ctx context.Context, r interface{}) { if err := c.Stop(); err != nil { c.logger.Error("Error stopping connection", "err", err) } if atomic.CompareAndSwapUint32(&c.errored, 0, 1) { if c.onError != nil { - c.onError(r) + c.onError(ctx, r) } } } @@ -335,7 +335,7 @@ func (c *MConnection) Send(chID ChannelID, msgBytes []byte) bool { // sendRoutine polls for packets to send from channels. func (c *MConnection) sendRoutine(ctx context.Context) { - defer c._recover() + defer c._recover(ctx) protoWriter := protoio.NewDelimitedWriter(c.bufConnWriter) FOR_LOOP: @@ -390,7 +390,7 @@ FOR_LOOP: break FOR_LOOP case <-c.send: // Send some PacketMsgs - eof := c.sendSomePacketMsgs() + eof := c.sendSomePacketMsgs(ctx) if !eof { // Keep sendRoutine awake. select { @@ -405,7 +405,7 @@ FOR_LOOP: } if err != nil { c.logger.Error("Connection failed @ sendRoutine", "conn", c, "err", err) - c.stopForError(err) + c.stopForError(ctx, err) break FOR_LOOP } } @@ -417,7 +417,7 @@ FOR_LOOP: // Returns true if messages from channels were exhausted. // Blocks in accordance to .sendMonitor throttling. -func (c *MConnection) sendSomePacketMsgs() bool { +func (c *MConnection) sendSomePacketMsgs(ctx context.Context) bool { // Block until .sendMonitor says we can write. // Once we're ready we send more than we asked for, // but amortized it should even out. @@ -425,7 +425,7 @@ func (c *MConnection) sendSomePacketMsgs() bool { // Now send some PacketMsgs. for i := 0; i < numBatchPacketMsgs; i++ { - if c.sendPacketMsg() { + if c.sendPacketMsg(ctx) { return true } } @@ -433,7 +433,7 @@ func (c *MConnection) sendSomePacketMsgs() bool { } // Returns true if messages from channels were exhausted. -func (c *MConnection) sendPacketMsg() bool { +func (c *MConnection) sendPacketMsg(ctx context.Context) bool { // Choose a channel to create a PacketMsg from. // The chosen channel will be the one whose recentlySent/priority is the least. var leastRatio float32 = math.MaxFloat32 @@ -461,7 +461,7 @@ func (c *MConnection) sendPacketMsg() bool { _n, err := leastChannel.writePacketMsgTo(c.bufConnWriter) if err != nil { c.logger.Error("Failed to write PacketMsg", "err", err) - c.stopForError(err) + c.stopForError(ctx, err) return true } c.sendMonitor.Update(_n) @@ -474,7 +474,7 @@ func (c *MConnection) sendPacketMsg() bool { // Blocks depending on how the connection is throttled. // Otherwise, it never blocks. func (c *MConnection) recvRoutine(ctx context.Context) { - defer c._recover() + defer c._recover(ctx) protoReader := protoio.NewDelimitedReader(c.bufConnReader, c._maxPacketMsgSize) @@ -518,7 +518,7 @@ FOR_LOOP: } else { c.logger.Debug("Connection failed @ recvRoutine (reading byte)", "conn", c, "err", err) } - c.stopForError(err) + c.stopForError(ctx, err) } break FOR_LOOP } @@ -547,7 +547,7 @@ FOR_LOOP: if pkt.PacketMsg.ChannelID < 0 || pkt.PacketMsg.ChannelID > math.MaxUint8 || !ok || channel == nil { err := fmt.Errorf("unknown channel %X", pkt.PacketMsg.ChannelID) c.logger.Debug("Connection failed @ recvRoutine", "conn", c, "err", err) - c.stopForError(err) + c.stopForError(ctx, err) break FOR_LOOP } @@ -555,19 +555,19 @@ FOR_LOOP: if err != nil { if c.IsRunning() { c.logger.Debug("Connection failed @ recvRoutine", "conn", c, "err", err) - c.stopForError(err) + c.stopForError(ctx, err) } break FOR_LOOP } if msgBytes != nil { c.logger.Debug("Received bytes", "chID", channelID, "msgBytes", msgBytes) // NOTE: This means the reactor.Receive runs in the same thread as the p2p recv routine - c.onReceive(channelID, msgBytes) + c.onReceive(ctx, channelID, msgBytes) } default: err := fmt.Errorf("unknown message type %v", reflect.TypeOf(packet)) c.logger.Error("Connection failed @ recvRoutine", "conn", c, "err", err) - c.stopForError(err) + c.stopForError(ctx, err) break FOR_LOOP } } diff --git a/internal/p2p/conn/connection_test.go b/internal/p2p/conn/connection_test.go index f1b2ae24c..0700db1b0 100644 --- a/internal/p2p/conn/connection_test.go +++ b/internal/p2p/conn/connection_test.go @@ -25,18 +25,18 @@ const maxPingPongPacketSize = 1024 // bytes func createTestMConnection(logger log.Logger, conn net.Conn) *MConnection { return createMConnectionWithCallbacks(logger, conn, // onRecieve - func(chID ChannelID, msgBytes []byte) { + func(ctx context.Context, chID ChannelID, msgBytes []byte) { }, // onError - func(r interface{}) { + func(ctx context.Context, r interface{}) { }) } func createMConnectionWithCallbacks( logger log.Logger, conn net.Conn, - onReceive func(chID ChannelID, msgBytes []byte), - onError func(r interface{}), + onReceive func(ctx context.Context, chID ChannelID, msgBytes []byte), + onError func(ctx context.Context, r interface{}), ) *MConnection { cfg := DefaultMConnConfig() cfg.PingInterval = 90 * time.Millisecond @@ -120,11 +120,17 @@ func TestMConnectionReceive(t *testing.T) { receivedCh := make(chan []byte) errorsCh := make(chan interface{}) - onReceive := func(chID ChannelID, msgBytes []byte) { - receivedCh <- msgBytes + onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { + select { + case receivedCh <- msgBytes: + case <-ctx.Done(): + } } - onError := func(r interface{}) { - errorsCh <- r + onError := func(ctx context.Context, r interface{}) { + select { + case errorsCh <- r: + case <-ctx.Done(): + } } logger := log.TestingLogger() @@ -160,11 +166,17 @@ func TestMConnectionPongTimeoutResultsInError(t *testing.T) { receivedCh := make(chan []byte) errorsCh := make(chan interface{}) - onReceive := func(chID ChannelID, msgBytes []byte) { - receivedCh <- msgBytes + onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { + select { + case receivedCh <- msgBytes: + case <-ctx.Done(): + } } - onError := func(r interface{}) { - errorsCh <- r + onError := func(ctx context.Context, r interface{}) { + select { + case errorsCh <- r: + case <-ctx.Done(): + } } ctx, cancel := context.WithCancel(context.Background()) @@ -202,12 +214,19 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { receivedCh := make(chan []byte) errorsCh := make(chan interface{}) - onReceive := func(chID ChannelID, msgBytes []byte) { - receivedCh <- msgBytes + onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { + select { + case receivedCh <- msgBytes: + case <-ctx.Done(): + } } - onError := func(r interface{}) { - errorsCh <- r + onError := func(ctx context.Context, r interface{}) { + select { + case errorsCh <- r: + case <-ctx.Done(): + } } + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -259,11 +278,17 @@ func TestMConnectionMultiplePings(t *testing.T) { receivedCh := make(chan []byte) errorsCh := make(chan interface{}) - onReceive := func(chID ChannelID, msgBytes []byte) { - receivedCh <- msgBytes + onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { + select { + case receivedCh <- msgBytes: + case <-ctx.Done(): + } } - onError := func(r interface{}) { - errorsCh <- r + onError := func(ctx context.Context, r interface{}) { + select { + case errorsCh <- r: + case <-ctx.Done(): + } } ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -309,11 +334,17 @@ func TestMConnectionPingPongs(t *testing.T) { receivedCh := make(chan []byte) errorsCh := make(chan interface{}) - onReceive := func(chID ChannelID, msgBytes []byte) { - receivedCh <- msgBytes + onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { + select { + case receivedCh <- msgBytes: + case <-ctx.Done(): + } } - onError := func(r interface{}) { - errorsCh <- r + onError := func(ctx context.Context, r interface{}) { + select { + case errorsCh <- r: + case <-ctx.Done(): + } } ctx, cancel := context.WithCancel(context.Background()) @@ -370,11 +401,17 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) { receivedCh := make(chan []byte) errorsCh := make(chan interface{}) - onReceive := func(chID ChannelID, msgBytes []byte) { - receivedCh <- msgBytes + onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { + select { + case receivedCh <- msgBytes: + case <-ctx.Done(): + } } - onError := func(r interface{}) { - errorsCh <- r + onError := func(ctx context.Context, r interface{}) { + select { + case errorsCh <- r: + case <-ctx.Done(): + } } ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -406,8 +443,8 @@ func newClientAndServerConnsForReadErrors( ) (*MConnection, *MConnection) { server, client := NetPipe() - onReceive := func(chID ChannelID, msgBytes []byte) {} - onError := func(r interface{}) {} + onReceive := func(context.Context, ChannelID, []byte) {} + onError := func(context.Context, interface{}) {} // create client conn with two channels chDescs := []*ChannelDescriptor{ @@ -423,8 +460,11 @@ func newClientAndServerConnsForReadErrors( // create server conn with 1 channel // it fires on chOnErr when there's an error serverLogger := logger.With("module", "server") - onError = func(r interface{}) { - chOnErr <- struct{}{} + onError = func(ctx context.Context, r interface{}) { + select { + case <-ctx.Done(): + case chOnErr <- struct{}{}: + } } mconnServer := createMConnectionWithCallbacks(serverLogger, server, onReceive, onError) @@ -488,8 +528,11 @@ func TestMConnectionReadErrorLongMessage(t *testing.T) { mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) t.Cleanup(waitAll(mconnClient, mconnServer)) - mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) { - chOnRcv <- struct{}{} + mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) { + select { + case <-ctx.Done(): + case chOnRcv <- struct{}{}: + } } client := mconnClient.conn @@ -590,8 +633,11 @@ func TestMConnectionChannelOverflow(t *testing.T) { mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) t.Cleanup(waitAll(mconnClient, mconnServer)) - mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) { - chOnRcv <- struct{}{} + mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) { + select { + case <-ctx.Done(): + case chOnRcv <- struct{}{}: + } } client := mconnClient.conn diff --git a/internal/p2p/mocks/connection.go b/internal/p2p/mocks/connection.go index 65b9afafb..576fb2386 100644 --- a/internal/p2p/mocks/connection.go +++ b/internal/p2p/mocks/connection.go @@ -79,20 +79,20 @@ func (_m *Connection) LocalEndpoint() p2p.Endpoint { return r0 } -// ReceiveMessage provides a mock function with given fields: -func (_m *Connection) ReceiveMessage() (conn.ChannelID, []byte, error) { - ret := _m.Called() +// ReceiveMessage provides a mock function with given fields: _a0 +func (_m *Connection) ReceiveMessage(_a0 context.Context) (conn.ChannelID, []byte, error) { + ret := _m.Called(_a0) var r0 conn.ChannelID - if rf, ok := ret.Get(0).(func() conn.ChannelID); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) conn.ChannelID); ok { + r0 = rf(_a0) } else { r0 = ret.Get(0).(conn.ChannelID) } var r1 []byte - if rf, ok := ret.Get(1).(func() []byte); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) []byte); ok { + r1 = rf(_a0) } else { if ret.Get(1) != nil { r1 = ret.Get(1).([]byte) @@ -100,8 +100,8 @@ func (_m *Connection) ReceiveMessage() (conn.ChannelID, []byte, error) { } var r2 error - if rf, ok := ret.Get(2).(func() error); ok { - r2 = rf() + if rf, ok := ret.Get(2).(func(context.Context) error); ok { + r2 = rf(_a0) } else { r2 = ret.Error(2) } @@ -123,13 +123,13 @@ func (_m *Connection) RemoteEndpoint() p2p.Endpoint { return r0 } -// SendMessage provides a mock function with given fields: _a0, _a1 -func (_m *Connection) SendMessage(_a0 conn.ChannelID, _a1 []byte) error { - ret := _m.Called(_a0, _a1) +// SendMessage provides a mock function with given fields: _a0, _a1, _a2 +func (_m *Connection) SendMessage(_a0 context.Context, _a1 conn.ChannelID, _a2 []byte) error { + ret := _m.Called(_a0, _a1, _a2) var r0 error - if rf, ok := ret.Get(0).(func(conn.ChannelID, []byte) error); ok { - r0 = rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(context.Context, conn.ChannelID, []byte) error); ok { + r0 = rf(_a0, _a1, _a2) } else { r0 = ret.Error(0) } diff --git a/internal/p2p/mocks/transport.go b/internal/p2p/mocks/transport.go index eea1de4c5..b17290118 100644 --- a/internal/p2p/mocks/transport.go +++ b/internal/p2p/mocks/transport.go @@ -17,13 +17,13 @@ type Transport struct { mock.Mock } -// Accept provides a mock function with given fields: -func (_m *Transport) Accept() (p2p.Connection, error) { - ret := _m.Called() +// Accept provides a mock function with given fields: _a0 +func (_m *Transport) Accept(_a0 context.Context) (p2p.Connection, error) { + ret := _m.Called(_a0) var r0 p2p.Connection - if rf, ok := ret.Get(0).(func() p2p.Connection); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) p2p.Connection); ok { + r0 = rf(_a0) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(p2p.Connection) @@ -31,8 +31,8 @@ func (_m *Transport) Accept() (p2p.Connection, error) { } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) } else { r1 = ret.Error(1) } diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index 6fc5d7c11..30f1a435f 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -24,6 +24,7 @@ type Network struct { logger log.Logger memoryNetwork *p2p.MemoryNetwork + cancel context.CancelFunc } // NetworkOptions is an argument structure to parameterize the @@ -68,6 +69,9 @@ func MakeNetwork(ctx context.Context, t *testing.T, opts NetworkOptions) *Networ // addition to creating a peer update subscription for each node. Finally, all // nodes are connected to each other. func (n *Network) Start(ctx context.Context, t *testing.T) { + ctx, n.cancel = context.WithCancel(ctx) + t.Cleanup(n.cancel) + // Set up a list of node addresses to dial, and a peer update subscription // for each node. dialQueue := []p2p.NodeAddress{} @@ -200,10 +204,10 @@ func (n *Network) Remove(ctx context.Context, t *testing.T, id types.NodeID) { } require.NoError(t, node.Transport.Close()) + node.cancel() if node.Router.IsRunning() { require.NoError(t, node.Router.Stop()) } - node.PeerManager.Close() for _, sub := range subs { RequireUpdate(t, sub, p2p.PeerUpdate{ @@ -222,12 +226,16 @@ type Node struct { Router *p2p.Router PeerManager *p2p.PeerManager Transport *p2p.MemoryTransport + + cancel context.CancelFunc } // MakeNode creates a new Node configured for the network with a // running peer manager, but does not add it to the existing // network. Callers are responsible for updating peering relationships. func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) *Node { + ctx, cancel := context.WithCancel(ctx) + privKey := ed25519.GenPrivKey() nodeID := types.NodeIDFromPubKey(privKey.PubKey()) nodeInfo := types.NodeInfo{ @@ -267,8 +275,8 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) if router.IsRunning() { require.NoError(t, router.Stop()) } - peerManager.Close() require.NoError(t, transport.Close()) + cancel() }) return &Node{ @@ -279,6 +287,7 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) Router: router, PeerManager: peerManager, Transport: transport, + cancel: cancel, } } diff --git a/internal/p2p/peermanager.go b/internal/p2p/peermanager.go index 0ab0128ca..40dcf8464 100644 --- a/internal/p2p/peermanager.go +++ b/internal/p2p/peermanager.go @@ -56,8 +56,8 @@ type PeerUpdate struct { type PeerUpdates struct { routerUpdatesCh chan PeerUpdate reactorUpdatesCh chan PeerUpdate - closeCh chan struct{} closeOnce sync.Once + doneCh chan struct{} } // NewPeerUpdates creates a new PeerUpdates subscription. It is primarily for @@ -67,7 +67,7 @@ func NewPeerUpdates(updatesCh chan PeerUpdate, buf int) *PeerUpdates { return &PeerUpdates{ reactorUpdatesCh: updatesCh, routerUpdatesCh: make(chan PeerUpdate, buf), - closeCh: make(chan struct{}), + doneCh: make(chan struct{}), } } @@ -76,28 +76,28 @@ func (pu *PeerUpdates) Updates() <-chan PeerUpdate { return pu.reactorUpdatesCh } -// SendUpdate pushes information about a peer into the routing layer, -// presumably from a peer. -func (pu *PeerUpdates) SendUpdate(update PeerUpdate) { - select { - case <-pu.closeCh: - case pu.routerUpdatesCh <- update: - } +// Done returns a channel that is closed when the subscription is closed. +func (pu *PeerUpdates) Done() <-chan struct{} { + return pu.doneCh } // Close closes the peer updates subscription. func (pu *PeerUpdates) Close() { pu.closeOnce.Do(func() { // NOTE: We don't close updatesCh since multiple goroutines may be - // sending on it. The PeerManager senders will select on closeCh as well + // sending on it. The PeerManager senders will select on doneCh as well // to avoid blocking on a closed subscription. - close(pu.closeCh) + close(pu.doneCh) }) } -// Done returns a channel that is closed when the subscription is closed. -func (pu *PeerUpdates) Done() <-chan struct{} { - return pu.closeCh +// SendUpdate pushes information about a peer into the routing layer, +// presumably from a peer. +func (pu *PeerUpdates) SendUpdate(ctx context.Context, update PeerUpdate) { + select { + case <-ctx.Done(): + case pu.routerUpdatesCh <- update: + } } // PeerManagerOptions specifies options for a PeerManager. @@ -276,8 +276,6 @@ type PeerManager struct { rand *rand.Rand dialWaker *tmsync.Waker // wakes up DialNext() on relevant peer changes evictWaker *tmsync.Waker // wakes up EvictNext() on relevant peer changes - closeCh chan struct{} // signal channel for Close() - closeOnce sync.Once mtx sync.Mutex store *peerStore @@ -312,7 +310,6 @@ func NewPeerManager(selfID types.NodeID, peerDB dbm.DB, options PeerManagerOptio rand: rand.New(rand.NewSource(time.Now().UnixNano())), // nolint:gosec dialWaker: tmsync.NewWaker(), evictWaker: tmsync.NewWaker(), - closeCh: make(chan struct{}), store: store, dialing: map[types.NodeID]bool{}, @@ -552,7 +549,6 @@ func (m *PeerManager) DialFailed(ctx context.Context, address NodeAddress) error select { case <-timer.C: m.dialWaker.Wake() - case <-m.closeCh: case <-ctx.Done(): } }() @@ -864,10 +860,6 @@ func (m *PeerManager) Register(ctx context.Context, peerUpdates *PeerUpdates) { go func() { for { select { - case <-peerUpdates.closeCh: - return - case <-m.closeCh: - return case <-ctx.Done(): return case pu := <-peerUpdates.routerUpdatesCh: @@ -882,7 +874,6 @@ func (m *PeerManager) Register(ctx context.Context, peerUpdates *PeerUpdates) { m.mtx.Lock() delete(m.subscriptions, peerUpdates) m.mtx.Unlock() - case <-m.closeCh: case <-ctx.Done(): } }() @@ -913,27 +904,20 @@ func (m *PeerManager) processPeerEvent(pu PeerUpdate) { // maintaining order if this is a problem. func (m *PeerManager) broadcast(peerUpdate PeerUpdate) { for _, sub := range m.subscriptions { - // We have to check closeCh separately first, otherwise there's a 50% + // We have to check doneChan separately first, otherwise there's a 50% // chance the second select will send on a closed subscription. select { - case <-sub.closeCh: + case <-sub.doneCh: continue default: } select { case sub.reactorUpdatesCh <- peerUpdate: - case <-sub.closeCh: + case <-sub.doneCh: } } } -// Close closes the peer manager, releasing resources (i.e. goroutines). -func (m *PeerManager) Close() { - m.closeOnce.Do(func() { - close(m.closeCh) - }) -} - // Addresses returns all known addresses for a peer, primarily for testing. // The order is arbitrary. func (m *PeerManager) Addresses(peerID types.NodeID) []NodeAddress { diff --git a/internal/p2p/peermanager_scoring_test.go b/internal/p2p/peermanager_scoring_test.go index fe23767c4..ecaf71c98 100644 --- a/internal/p2p/peermanager_scoring_test.go +++ b/internal/p2p/peermanager_scoring_test.go @@ -22,7 +22,6 @@ func TestPeerScoring(t *testing.T) { db := dbm.NewMemDB() peerManager, err := NewPeerManager(selfID, db, PeerManagerOptions{}) require.NoError(t, err) - defer peerManager.Close() // create a fake node id := types.NodeID(strings.Repeat("a1", 20)) @@ -59,7 +58,7 @@ func TestPeerScoring(t *testing.T) { start := peerManager.Scores()[id] pu := peerManager.Subscribe(ctx) defer pu.Close() - pu.SendUpdate(PeerUpdate{ + pu.SendUpdate(ctx, PeerUpdate{ NodeID: id, Status: PeerStatusGood, }) @@ -73,7 +72,7 @@ func TestPeerScoring(t *testing.T) { start := peerManager.Scores()[id] pu := peerManager.Subscribe(ctx) defer pu.Close() - pu.SendUpdate(PeerUpdate{ + pu.SendUpdate(ctx, PeerUpdate{ NodeID: id, Status: PeerStatusBad, }) diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index cf1b0707e..dec92dab0 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -154,7 +154,6 @@ func TestNewPeerManager_Persistence(t *testing.T) { PeerScores: map[types.NodeID]p2p.PeerScore{bID: 1}, }) require.NoError(t, err) - defer peerManager.Close() for _, addr := range append(append(aAddresses, bAddresses...), cAddresses...) { added, err := peerManager.Add(addr) @@ -171,8 +170,6 @@ func TestNewPeerManager_Persistence(t *testing.T) { cID: 0, }, peerManager.Scores()) - peerManager.Close() - // Creating a new peer manager with the same database should retain the // peers, but they should have updated scores from the new PersistentPeers // configuration. @@ -181,7 +178,6 @@ func TestNewPeerManager_Persistence(t *testing.T) { PeerScores: map[types.NodeID]p2p.PeerScore{cID: 1}, }) require.NoError(t, err) - defer peerManager.Close() require.ElementsMatch(t, aAddresses, peerManager.Addresses(aID)) require.ElementsMatch(t, bAddresses, peerManager.Addresses(bID)) @@ -208,7 +204,6 @@ func TestNewPeerManager_SelfIDChange(t *testing.T) { require.NoError(t, err) require.True(t, added) require.ElementsMatch(t, []types.NodeID{a.NodeID, b.NodeID}, peerManager.Peers()) - peerManager.Close() // If we change our selfID to one of the peers in the peer store, it // should be removed from the store. @@ -1755,9 +1750,6 @@ func TestPeerManager_Close(t *testing.T) { require.NoError(t, err) require.Equal(t, a, dial) require.NoError(t, peerManager.DialFailed(ctx, a)) - - // This should clean up the goroutines. - peerManager.Close() } func TestPeerManager_Advertise(t *testing.T) { @@ -1780,7 +1772,6 @@ func TestPeerManager_Advertise(t *testing.T) { PeerScores: map[types.NodeID]p2p.PeerScore{aID: 3, bID: 2, cID: 1}, }) require.NoError(t, err) - defer peerManager.Close() added, err := peerManager.Add(aTCP) require.NoError(t, err) @@ -1847,7 +1838,6 @@ func TestPeerManager_SetHeight_GetHeight(t *testing.T) { require.ElementsMatch(t, []types.NodeID{a.NodeID, b.NodeID}, peerManager.Peers()) // The heights should not be persisted. - peerManager.Close() peerManager, err = p2p.NewPeerManager(selfID, db, p2p.PeerManagerOptions{}) require.NoError(t, err) diff --git a/internal/p2p/pex/reactor.go b/internal/p2p/pex/reactor.go index 69ff5206c..b42bb2f4b 100644 --- a/internal/p2p/pex/reactor.go +++ b/internal/p2p/pex/reactor.go @@ -83,7 +83,6 @@ type Reactor struct { peerManager *p2p.PeerManager pexCh *p2p.Channel peerUpdates *p2p.PeerUpdates - closeCh chan struct{} // list of available peers to loop through and send peer requests to availablePeers map[types.NodeID]struct{} @@ -128,7 +127,6 @@ func NewReactor( peerManager: peerManager, pexCh: pexCh, peerUpdates: peerUpdates, - closeCh: make(chan struct{}), availablePeers: make(map[types.NodeID]struct{}), requestsSent: make(map[types.NodeID]struct{}), lastReceivedRequests: make(map[types.NodeID]time.Time), @@ -150,13 +148,7 @@ func (r *Reactor) OnStart(ctx context.Context) error { // OnStop stops the reactor by signaling to all spawned goroutines to exit and // blocking until they all exit. -func (r *Reactor) OnStop() { - // Close closeCh to signal to all spawned goroutines to gracefully exit. All - // p2p Channels should execute Close(). - close(r.closeCh) - - <-r.peerUpdates.Done() -} +func (r *Reactor) OnStop() {} // processPexCh implements a blocking event loop where we listen for p2p // Envelope messages from the pexCh. @@ -168,8 +160,6 @@ func (r *Reactor) processPexCh(ctx context.Context) { select { case <-ctx.Done(): - return - case <-r.closeCh: r.logger.Debug("stopped listening on PEX channel; closing...") return @@ -196,17 +186,13 @@ func (r *Reactor) processPexCh(ctx context.Context) { // close the p2p PeerUpdatesCh gracefully. func (r *Reactor) processPeerUpdates(ctx context.Context) { defer r.peerUpdates.Close() - for { select { case <-ctx.Done(): + r.logger.Debug("stopped listening on peer updates channel; closing...") return case peerUpdate := <-r.peerUpdates.Updates(): r.processPeerUpdate(peerUpdate) - - case <-r.closeCh: - r.logger.Debug("stopped listening on peer updates channel; closing...") - return } } } diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 7d1529ace..8f751ec6a 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -158,7 +158,6 @@ type Router struct { endpoints []Endpoint connTracker connectionTracker protocolTransports map[Protocol]Transport - stopCh chan struct{} // signals Router shutdown peerMtx sync.RWMutex peerQueues map[types.NodeID]queue // outbound messages per peer for all channels @@ -208,7 +207,6 @@ func NewRouter( protocolTransports: map[Protocol]Transport{}, peerManager: peerManager, options: options, - stopCh: make(chan struct{}), channelQueues: map[ChannelID]queue{}, channelMessages: map[ChannelID]proto.Message{}, peerQueues: map[types.NodeID]queue{}, @@ -399,7 +397,7 @@ func (r *Router) routeChannel( case <-q.closed(): r.logger.Debug("dropping message for unconnected peer", "peer", envelope.To, "channel", chID) - case <-r.stopCh: + case <-ctx.Done(): return } } @@ -414,8 +412,6 @@ func (r *Router) routeChannel( r.peerManager.Errored(peerError.NodeID, peerError.Err) case <-ctx.Done(): return - case <-r.stopCh: - return } } } @@ -474,7 +470,7 @@ func (r *Router) acceptPeers(ctx context.Context, transport Transport) { r.logger.Debug("starting accept routine", "transport", transport) for { - conn, err := transport.Accept() + conn, err := transport.Accept(ctx) switch err { case nil: case io.EOF: @@ -783,14 +779,14 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec go func() { select { - case errCh <- r.receivePeer(peerID, conn): + case errCh <- r.receivePeer(ctx, peerID, conn): case <-ctx.Done(): } }() go func() { select { - case errCh <- r.sendPeer(peerID, conn, sendQueue): + case errCh <- r.sendPeer(ctx, peerID, conn, sendQueue): case <-ctx.Done(): } }() @@ -829,9 +825,9 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec // receivePeer receives inbound messages from a peer, deserializes them and // passes them on to the appropriate channel. -func (r *Router) receivePeer(peerID types.NodeID, conn Connection) error { +func (r *Router) receivePeer(ctx context.Context, peerID types.NodeID, conn Connection) error { for { - chID, bz, err := conn.ReceiveMessage() + chID, bz, err := conn.ReceiveMessage(ctx) if err != nil { return err } @@ -874,14 +870,14 @@ func (r *Router) receivePeer(peerID types.NodeID, conn Connection) error { case <-queue.closed(): r.logger.Debug("channel closed, dropping message", "peer", peerID, "channel", chID) - case <-r.stopCh: + case <-ctx.Done(): return nil } } } // sendPeer sends queued messages to a peer. -func (r *Router) sendPeer(peerID types.NodeID, conn Connection, peerQueue queue) error { +func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connection, peerQueue queue) error { for { start := time.Now().UTC() @@ -899,7 +895,7 @@ func (r *Router) sendPeer(peerID types.NodeID, conn Connection, peerQueue queue) continue } - if err = conn.SendMessage(envelope.channelID, bz); err != nil { + if err = conn.SendMessage(ctx, envelope.channelID, bz); err != nil { return err } @@ -908,7 +904,7 @@ func (r *Router) sendPeer(peerID types.NodeID, conn Connection, peerQueue queue) case <-peerQueue.closed(): return nil - case <-r.stopCh: + case <-ctx.Done(): return nil } } @@ -983,9 +979,6 @@ func (r *Router) OnStart(ctx context.Context) error { // here, since that would cause any reactor senders to panic, so it is the // sender's responsibility. func (r *Router) OnStop() { - // Signal router shutdown. - close(r.stopCh) - // Close transport listeners (unblocks Accept calls). for _, transport := range r.transports { if err := transport.Close(); err != nil { @@ -1009,6 +1002,7 @@ func (r *Router) OnStop() { r.peerMtx.RUnlock() for _, q := range queues { + q.close() <-q.closed() } } diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 8a4c9e4bc..a561f68cd 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -106,7 +106,6 @@ func TestRouter_Channel_Basic(t *testing.T) { // Set up a router with no transports (so no peers). peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - defer peerManager.Close() router, err := p2p.NewRouter( ctx, @@ -392,25 +391,22 @@ func TestRouter_AcceptPeers(t *testing.T) { mockConnection.On("String").Maybe().Return("mock") mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). Return(tc.peerInfo, tc.peerKey, nil) - mockConnection.On("Close").Run(func(_ mock.Arguments) { closer.Close() }).Return(nil) + mockConnection.On("Close").Run(func(_ mock.Arguments) { closer.Close() }).Return(nil).Maybe() mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) if tc.ok { - // without the sleep after RequireUpdate this method isn't - // always called. Consider making this call optional. - mockConnection.On("ReceiveMessage").Return(chID, nil, io.EOF) + mockConnection.On("ReceiveMessage", mock.Anything).Return(chID, nil, io.EOF).Maybe() } mockTransport := &mocks.Transport{} mockTransport.On("String").Maybe().Return("mock") mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"}) - mockTransport.On("Close").Return(nil) - mockTransport.On("Accept").Once().Return(mockConnection, nil) - mockTransport.On("Accept").Maybe().Return(nil, io.EOF) + mockTransport.On("Close").Return(nil).Maybe() + mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) // Set up and start the router. peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - defer peerManager.Close() sub := peerManager.Subscribe(ctx) defer sub.Close() @@ -464,13 +460,12 @@ func TestRouter_AcceptPeers_Error(t *testing.T) { mockTransport := &mocks.Transport{} mockTransport.On("String").Maybe().Return("mock") mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"}) - mockTransport.On("Accept").Once().Return(nil, errors.New("boom")) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, errors.New("boom")) mockTransport.On("Close").Return(nil) // Set up and start the router. peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - defer peerManager.Close() router, err := p2p.NewRouter( ctx, @@ -503,13 +498,12 @@ func TestRouter_AcceptPeers_ErrorEOF(t *testing.T) { mockTransport := &mocks.Transport{} mockTransport.On("String").Maybe().Return("mock") mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"}) - mockTransport.On("Accept").Once().Return(nil, io.EOF) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) mockTransport.On("Close").Return(nil) // Set up and start the router. peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - defer peerManager.Close() router, err := p2p.NewRouter( ctx, @@ -554,15 +548,14 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { mockTransport.On("String").Maybe().Return("mock") mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"}) mockTransport.On("Close").Return(nil) - mockTransport.On("Accept").Times(3).Run(func(_ mock.Arguments) { + mockTransport.On("Accept", mock.Anything).Times(3).Run(func(_ mock.Arguments) { acceptCh <- true }).Return(mockConnection, nil) - mockTransport.On("Accept").Once().Return(nil, io.EOF) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) // Set up and start the router. peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - defer peerManager.Close() router, err := p2p.NewRouter( ctx, @@ -580,7 +573,7 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { require.Eventually(t, func() bool { return len(acceptCh) == 3 - }, time.Second, 10*time.Millisecond) + }, time.Second, 10*time.Millisecond, "num", len(acceptCh)) close(closeCh) time.Sleep(100 * time.Millisecond) @@ -636,19 +629,17 @@ func TestRouter_DialPeers(t *testing.T) { if tc.dialErr == nil { mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). Return(tc.peerInfo, tc.peerKey, nil) - mockConnection.On("Close").Run(func(_ mock.Arguments) { closer.Close() }).Return(nil) + mockConnection.On("Close").Run(func(_ mock.Arguments) { closer.Close() }).Return(nil).Maybe() } if tc.ok { - // without the sleep after RequireUpdate this method isn't - // always called. Consider making this call optional. - mockConnection.On("ReceiveMessage").Return(chID, nil, io.EOF) + mockConnection.On("ReceiveMessage", mock.Anything).Return(chID, nil, io.EOF).Maybe() } mockTransport := &mocks.Transport{} mockTransport.On("String").Maybe().Return("mock") mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"}) - mockTransport.On("Close").Return(nil) - mockTransport.On("Accept").Maybe().Return(nil, io.EOF) + mockTransport.On("Close").Return(nil).Maybe() + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) if tc.dialErr == nil { mockTransport.On("Dial", mock.Anything, endpoint).Once().Return(mockConnection, nil) // This handles the retry when a dialed connection gets closed after ReceiveMessage @@ -663,7 +654,6 @@ func TestRouter_DialPeers(t *testing.T) { // Set up and start the router. peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - defer peerManager.Close() added, err := peerManager.Add(address) require.NoError(t, err) @@ -734,7 +724,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { mockTransport.On("String").Maybe().Return("mock") mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"}) mockTransport.On("Close").Return(nil) - mockTransport.On("Accept").Once().Return(nil, io.EOF) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) for _, address := range []p2p.NodeAddress{a, b, c} { endpoint := p2p.Endpoint{Protocol: address.Protocol, Path: string(address.NodeID)} mockTransport.On("Dial", mock.Anything, endpoint).Run(func(_ mock.Arguments) { @@ -745,7 +735,6 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { // Set up and start the router. peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - defer peerManager.Close() added, err := peerManager.Add(a) require.NoError(t, err) @@ -813,7 +802,7 @@ func TestRouter_EvictPeers(t *testing.T) { mockConnection.On("String").Maybe().Return("mock") mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). Return(peerInfo, peerKey.PubKey(), nil) - mockConnection.On("ReceiveMessage").WaitUntil(closeCh).Return(chID, nil, io.EOF) + mockConnection.On("ReceiveMessage", mock.Anything).WaitUntil(closeCh).Return(chID, nil, io.EOF) mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) mockConnection.On("Close").Run(func(_ mock.Arguments) { closeOnce.Do(func() { @@ -825,13 +814,12 @@ func TestRouter_EvictPeers(t *testing.T) { mockTransport.On("String").Maybe().Return("mock") mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"}) mockTransport.On("Close").Return(nil) - mockTransport.On("Accept").Once().Return(mockConnection, nil) - mockTransport.On("Accept").Maybe().Return(nil, io.EOF) + mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) // Set up and start the router. peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - defer peerManager.Close() sub := peerManager.Subscribe(ctx) defer sub.Close() @@ -893,13 +881,12 @@ func TestRouter_ChannelCompatability(t *testing.T) { mockTransport.On("String").Maybe().Return("mock") mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"}) mockTransport.On("Close").Return(nil) - mockTransport.On("Accept").Once().Return(mockConnection, nil) - mockTransport.On("Accept").Once().Return(nil, io.EOF) + mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) // Set up and start the router. peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - defer peerManager.Close() router, err := p2p.NewRouter( ctx, @@ -941,20 +928,19 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { Return(peer, peerKey.PubKey(), nil) mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) mockConnection.On("Close").Return(nil) - mockConnection.On("ReceiveMessage").Return(chID, nil, io.EOF) + mockConnection.On("ReceiveMessage", mock.Anything).Return(chID, nil, io.EOF) mockTransport := &mocks.Transport{} mockTransport.On("AddChannelDescriptors", mock.Anything).Return() mockTransport.On("String").Maybe().Return("mock") mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"}) mockTransport.On("Close").Return(nil) - mockTransport.On("Accept").Once().Return(mockConnection, nil) - mockTransport.On("Accept").Maybe().Return(nil, io.EOF) + mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) // Set up and start the router. peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - defer peerManager.Close() sub := peerManager.Subscribe(ctx) defer sub.Close() diff --git a/internal/p2p/transport.go b/internal/p2p/transport.go index 08de0d3b0..041bbda3a 100644 --- a/internal/p2p/transport.go +++ b/internal/p2p/transport.go @@ -39,7 +39,7 @@ type Transport interface { // Accept waits for the next inbound connection on a listening endpoint, blocking // until either a connection is available or the transport is closed. On closure, // io.EOF is returned and further Accept calls are futile. - Accept() (Connection, error) + Accept(context.Context) (Connection, error) // Dial creates an outbound connection to an endpoint. Dial(context.Context, Endpoint) (Connection, error) @@ -85,10 +85,10 @@ type Connection interface { // ReceiveMessage returns the next message received on the connection, // blocking until one is available. Returns io.EOF if closed. - ReceiveMessage() (ChannelID, []byte, error) + ReceiveMessage(context.Context) (ChannelID, []byte, error) // SendMessage sends a message on the connection. Returns io.EOF if closed. - SendMessage(ChannelID, []byte) error + SendMessage(context.Context, ChannelID, []byte) error // LocalEndpoint returns the local endpoint for the connection. LocalEndpoint() Endpoint diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index b89671670..46227ff8f 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -44,10 +44,10 @@ type MConnTransport struct { options MConnTransportOptions mConnConfig conn.MConnConfig channelDescs []*ChannelDescriptor - closeCh chan struct{} - closeOnce sync.Once - listener net.Listener + closeOnce sync.Once + doneCh chan struct{} + listener net.Listener } // NewMConnTransport sets up a new MConnection transport. This uses the @@ -63,7 +63,7 @@ func NewMConnTransport( logger: logger, options: options, mConnConfig: mConnConfig, - closeCh: make(chan struct{}), + doneCh: make(chan struct{}), channelDescs: channelDescs, } } @@ -84,10 +84,11 @@ func (m *MConnTransport) Endpoints() []Endpoint { return []Endpoint{} } select { - case <-m.closeCh: + case <-m.doneCh: return []Endpoint{} default: } + endpoint := Endpoint{ Protocol: MConnProtocol, } @@ -132,7 +133,7 @@ func (m *MConnTransport) Listen(endpoint Endpoint) error { } // Accept implements Transport. -func (m *MConnTransport) Accept() (Connection, error) { +func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) { if m.listener == nil { return nil, errors.New("transport is not listening") } @@ -140,7 +141,9 @@ func (m *MConnTransport) Accept() (Connection, error) { tcpConn, err := m.listener.Accept() if err != nil { select { - case <-m.closeCh: + case <-ctx.Done(): + return nil, io.EOF + case <-m.doneCh: return nil, io.EOF default: return nil, err @@ -178,7 +181,7 @@ func (m *MConnTransport) Dial(ctx context.Context, endpoint Endpoint) (Connectio func (m *MConnTransport) Close() error { var err error m.closeOnce.Do(func() { - close(m.closeCh) // must be closed first, to handle error in Accept() + close(m.doneCh) if m.listener != nil { err = m.listener.Close() } @@ -222,7 +225,7 @@ type mConnConnection struct { channelDescs []*ChannelDescriptor receiveCh chan mConnMessage errorCh chan error - closeCh chan struct{} + doneCh chan struct{} closeOnce sync.Once mconn *conn.MConnection // set during Handshake() @@ -248,7 +251,7 @@ func newMConnConnection( channelDescs: channelDescs, receiveCh: make(chan mConnMessage), errorCh: make(chan error, 1), // buffered to avoid onError leak - closeCh: make(chan struct{}), + doneCh: make(chan struct{}), } } @@ -370,16 +373,16 @@ func (c *mConnConnection) handshake( } // onReceive is a callback for MConnection received messages. -func (c *mConnConnection) onReceive(chID ChannelID, payload []byte) { +func (c *mConnConnection) onReceive(ctx context.Context, chID ChannelID, payload []byte) { select { case c.receiveCh <- mConnMessage{channelID: chID, payload: payload}: - case <-c.closeCh: + case <-ctx.Done(): } } // onError is a callback for MConnection errors. The error is passed via errorCh // to ReceiveMessage (but not SendMessage, for legacy P2P stack behavior). -func (c *mConnConnection) onError(e interface{}) { +func (c *mConnConnection) onError(ctx context.Context, e interface{}) { err, ok := e.(error) if !ok { err = fmt.Errorf("%v", err) @@ -389,7 +392,7 @@ func (c *mConnConnection) onError(e interface{}) { _ = c.Close() select { case c.errorCh <- err: - case <-c.closeCh: + case <-ctx.Done(): } } @@ -399,14 +402,14 @@ func (c *mConnConnection) String() string { } // SendMessage implements Connection. -func (c *mConnConnection) SendMessage(chID ChannelID, msg []byte) error { +func (c *mConnConnection) SendMessage(ctx context.Context, chID ChannelID, msg []byte) error { if chID > math.MaxUint8 { return fmt.Errorf("MConnection only supports 1-byte channel IDs (got %v)", chID) } select { case err := <-c.errorCh: return err - case <-c.closeCh: + case <-ctx.Done(): return io.EOF default: if ok := c.mconn.Send(chID, msg); !ok { @@ -418,11 +421,13 @@ func (c *mConnConnection) SendMessage(chID ChannelID, msg []byte) error { } // ReceiveMessage implements Connection. -func (c *mConnConnection) ReceiveMessage() (ChannelID, []byte, error) { +func (c *mConnConnection) ReceiveMessage(ctx context.Context) (ChannelID, []byte, error) { select { case err := <-c.errorCh: return 0, nil, err - case <-c.closeCh: + case <-c.doneCh: + return 0, nil, io.EOF + case <-ctx.Done(): return 0, nil, io.EOF case msg := <-c.receiveCh: return msg.channelID, msg.payload, nil @@ -462,7 +467,7 @@ func (c *mConnConnection) Close() error { } else { err = c.conn.Close() } - close(c.closeCh) + close(c.doneCh) }) return err } diff --git a/internal/p2p/transport_mconn_test.go b/internal/p2p/transport_mconn_test.go index 4d9a945cb..0851fe0e2 100644 --- a/internal/p2p/transport_mconn_test.go +++ b/internal/p2p/transport_mconn_test.go @@ -52,8 +52,10 @@ func TestMConnTransport_AcceptBeforeListen(t *testing.T) { t.Cleanup(func() { _ = transport.Close() }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - _, err := transport.Accept() + _, err := transport.Accept(ctx) require.Error(t, err) require.NotEqual(t, io.EOF, err) // io.EOF should be returned after Close() } @@ -85,7 +87,7 @@ func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { acceptCh := make(chan p2p.Connection, 10) go func() { for { - conn, err := transport.Accept() + conn, err := transport.Accept(ctx) if err != nil { return } @@ -203,7 +205,7 @@ func TestMConnTransport_Listen(t *testing.T) { close(dialedChan) }() - conn, err := transport.Accept() + conn, err := transport.Accept(ctx) require.NoError(t, err) _ = conn.Close() <-dialedChan @@ -212,7 +214,7 @@ func TestMConnTransport_Listen(t *testing.T) { require.NoError(t, peerConn.Close()) // try to read from the connection should error - _, _, err = peerConn.ReceiveMessage() + _, _, err = peerConn.ReceiveMessage(ctx) require.Error(t, err) // Trying to listen again should error. diff --git a/internal/p2p/transport_memory.go b/internal/p2p/transport_memory.go index 5d9291675..27b9e77e1 100644 --- a/internal/p2p/transport_memory.go +++ b/internal/p2p/transport_memory.go @@ -94,9 +94,7 @@ type MemoryTransport struct { nodeID types.NodeID bufferSize int - acceptCh chan *MemoryConnection - closeCh chan struct{} - closeOnce sync.Once + acceptCh chan *MemoryConnection } // newMemoryTransport creates a new MemoryTransport. This is for internal use by @@ -108,7 +106,6 @@ func newMemoryTransport(network *MemoryNetwork, nodeID types.NodeID) *MemoryTran nodeID: nodeID, bufferSize: network.bufferSize, acceptCh: make(chan *MemoryConnection), - closeCh: make(chan struct{}), } } @@ -128,28 +125,27 @@ func (t *MemoryTransport) Protocols() []Protocol { // Endpoints implements Transport. func (t *MemoryTransport) Endpoints() []Endpoint { - select { - case <-t.closeCh: + if n := t.network.GetTransport(t.nodeID); n == nil { return []Endpoint{} - default: - return []Endpoint{{ - Protocol: MemoryProtocol, - Path: string(t.nodeID), - // An arbitrary IP and port is used in order for the pex - // reactor to be able to send addresses to one another. - IP: net.IPv4zero, - Port: 0, - }} } + + return []Endpoint{{ + Protocol: MemoryProtocol, + Path: string(t.nodeID), + // An arbitrary IP and port is used in order for the pex + // reactor to be able to send addresses to one another. + IP: net.IPv4zero, + Port: 0, + }} } // Accept implements Transport. -func (t *MemoryTransport) Accept() (Connection, error) { +func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) { select { case conn := <-t.acceptCh: t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path) return conn, nil - case <-t.closeCh: + case <-ctx.Done(): return nil, io.EOF } } @@ -187,20 +183,14 @@ func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connecti select { case peer.acceptCh <- inConn: return outConn, nil - case <-peer.closeCh: - return nil, io.EOF case <-ctx.Done(): - return nil, ctx.Err() + return nil, io.EOF } } // Close implements Transport. func (t *MemoryTransport) Close() error { t.network.RemoveTransport(t.nodeID) - t.closeOnce.Do(func() { - close(t.closeCh) - t.logger.Info("closed transport") - }) return nil } @@ -295,12 +285,14 @@ func (c *MemoryConnection) Handshake( } // ReceiveMessage implements Connection. -func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) { +func (c *MemoryConnection) ReceiveMessage(ctx context.Context) (ChannelID, []byte, error) { // Check close first, since channels are buffered. Otherwise, below select // may non-deterministically return non-error even when closed. select { case <-c.closer.Done(): return 0, nil, io.EOF + case <-ctx.Done(): + return 0, nil, io.EOF default: } @@ -314,12 +306,14 @@ func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) { } // SendMessage implements Connection. -func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) error { +func (c *MemoryConnection) SendMessage(ctx context.Context, chID ChannelID, msg []byte) error { // Check close first, since channels are buffered. Otherwise, below select // may non-deterministically return non-error even when closed. select { case <-c.closer.Done(): return io.EOF + case <-ctx.Done(): + return io.EOF default: } @@ -327,6 +321,8 @@ func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) error { case c.sendCh <- memoryMessage{channelID: chID, message: msg}: c.logger.Debug("sent message", "chID", chID, "msg", msg) return nil + case <-ctx.Done(): + return io.EOF case <-c.closer.Done(): return io.EOF } diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index a53be251d..63ce5ad5b 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -46,21 +46,23 @@ func TestTransport_AcceptClose(t *testing.T) { withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) + opctx, opcancel := context.WithCancel(ctx) // In-progress Accept should error on concurrent close. errCh := make(chan error, 1) go func() { time.Sleep(200 * time.Millisecond) + opcancel() errCh <- a.Close() }() - _, err := a.Accept() + _, err := a.Accept(opctx) require.Error(t, err) require.Equal(t, io.EOF, err) require.NoError(t, <-errCh) // Closed transport should return error immediately. - _, err = a.Accept() + _, err = a.Accept(opctx) require.Error(t, err) require.Equal(t, io.EOF, err) }) @@ -93,7 +95,7 @@ func TestTransport_DialEndpoints(t *testing.T) { // Spawn a goroutine to simply accept any connections until closed. go func() { for { - conn, err := a.Accept() + conn, err := a.Accept(ctx) if err != nil { return } @@ -177,7 +179,6 @@ func TestTransport_Dial(t *testing.T) { cancel() _, err := a.Dial(cancelCtx, bEndpoint) require.Error(t, err) - require.Equal(t, err, context.Canceled) // Unavailable endpoint should error. err = b.Close() @@ -188,7 +189,7 @@ func TestTransport_Dial(t *testing.T) { // Dialing from a closed transport should still work. errCh := make(chan error, 1) go func() { - conn, err := a.Accept() + conn, err := a.Accept(ctx) if err == nil { _ = conn.Close() } @@ -351,13 +352,12 @@ func TestConnection_FlushClose(t *testing.T) { err := ab.Close() require.NoError(t, err) - _, _, err = ab.ReceiveMessage() + _, _, err = ab.ReceiveMessage(ctx) require.Error(t, err) require.Equal(t, io.EOF, err) - err = ab.SendMessage(chID, []byte("closed")) + err = ab.SendMessage(ctx, chID, []byte("closed")) require.Error(t, err) - require.Equal(t, io.EOF, err) }) } @@ -388,19 +388,19 @@ func TestConnection_SendReceive(t *testing.T) { ab, ba := dialAcceptHandshake(ctx, t, a, b) // Can send and receive a to b. - err := ab.SendMessage(chID, []byte("foo")) + err := ab.SendMessage(ctx, chID, []byte("foo")) require.NoError(t, err) - ch, msg, err := ba.ReceiveMessage() + ch, msg, err := ba.ReceiveMessage(ctx) require.NoError(t, err) require.Equal(t, []byte("foo"), msg) require.Equal(t, chID, ch) // Can send and receive b to a. - err = ba.SendMessage(chID, []byte("bar")) + err = ba.SendMessage(ctx, chID, []byte("bar")) require.NoError(t, err) - _, msg, err = ab.ReceiveMessage() + _, msg, err = ab.ReceiveMessage(ctx) require.NoError(t, err) require.Equal(t, []byte("bar"), msg) @@ -410,9 +410,9 @@ func TestConnection_SendReceive(t *testing.T) { err = b.Close() require.NoError(t, err) - err = ab.SendMessage(chID, []byte("still here")) + err = ab.SendMessage(ctx, chID, []byte("still here")) require.NoError(t, err) - ch, msg, err = ba.ReceiveMessage() + ch, msg, err = ba.ReceiveMessage(ctx) require.NoError(t, err) require.Equal(t, chID, ch) require.Equal(t, []byte("still here"), msg) @@ -422,21 +422,20 @@ func TestConnection_SendReceive(t *testing.T) { err = ba.Close() require.NoError(t, err) - _, _, err = ab.ReceiveMessage() + _, _, err = ab.ReceiveMessage(ctx) require.Error(t, err) require.Equal(t, io.EOF, err) - err = ab.SendMessage(chID, []byte("closed")) + err = ab.SendMessage(ctx, chID, []byte("closed")) require.Error(t, err) require.Equal(t, io.EOF, err) - _, _, err = ba.ReceiveMessage() + _, _, err = ba.ReceiveMessage(ctx) require.Error(t, err) require.Equal(t, io.EOF, err) - err = ba.SendMessage(chID, []byte("closed")) + err = ba.SendMessage(ctx, chID, []byte("closed")) require.Error(t, err) - require.Equal(t, io.EOF, err) }) } @@ -606,7 +605,7 @@ func dialAccept(ctx context.Context, t *testing.T, a, b p2p.Transport) (p2p.Conn acceptCh := make(chan p2p.Connection, 1) errCh := make(chan error, 1) go func() { - conn, err := b.Accept() + conn, err := b.Accept(ctx) errCh <- err acceptCh <- conn }()