Browse Source

p2p: remove unneeded close channels from p2p layer (#7392)

pull/7399/head
Sam Kleinman 2 years ago
committed by GitHub
parent
commit
6b35cc1a47
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 264 additions and 268 deletions
  1. +2
    -2
      internal/consensus/reactor.go
  2. +19
    -19
      internal/p2p/conn/connection.go
  3. +82
    -36
      internal/p2p/conn/connection_test.go
  4. +14
    -14
      internal/p2p/mocks/connection.go
  5. +7
    -7
      internal/p2p/mocks/transport.go
  6. +11
    -2
      internal/p2p/p2ptest/network.go
  7. +17
    -33
      internal/p2p/peermanager.go
  8. +2
    -3
      internal/p2p/peermanager_scoring_test.go
  9. +0
    -10
      internal/p2p/peermanager_test.go
  10. +2
    -16
      internal/p2p/pex/reactor.go
  11. +11
    -17
      internal/p2p/router.go
  12. +23
    -37
      internal/p2p/router_test.go
  13. +3
    -3
      internal/p2p/transport.go
  14. +24
    -19
      internal/p2p/transport_mconn.go
  15. +6
    -4
      internal/p2p/transport_mconn_test.go
  16. +22
    -26
      internal/p2p/transport_memory.go
  17. +19
    -20
      internal/p2p/transport_test.go

+ 2
- 2
internal/consensus/reactor.go View File

@ -1470,7 +1470,7 @@ func (r *Reactor) peerStatsRoutine(ctx context.Context) {
switch msg.Msg.(type) {
case *VoteMessage:
if numVotes := ps.RecordVote(); numVotes%votesToContributeToBecomeGoodPeer == 0 {
r.peerUpdates.SendUpdate(p2p.PeerUpdate{
r.peerUpdates.SendUpdate(ctx, p2p.PeerUpdate{
NodeID: msg.PeerID,
Status: p2p.PeerStatusGood,
})
@ -1478,7 +1478,7 @@ func (r *Reactor) peerStatsRoutine(ctx context.Context) {
case *BlockPartMessage:
if numParts := ps.RecordBlockPart(); numParts%blocksToContributeToBecomeGoodPeer == 0 {
r.peerUpdates.SendUpdate(p2p.PeerUpdate{
r.peerUpdates.SendUpdate(ctx, p2p.PeerUpdate{
NodeID: msg.PeerID,
Status: p2p.PeerStatusGood,
})


+ 19
- 19
internal/p2p/conn/connection.go View File

@ -49,8 +49,8 @@ const (
defaultPongTimeout = 45 * time.Second
)
type receiveCbFunc func(chID ChannelID, msgBytes []byte)
type errorCbFunc func(interface{})
type receiveCbFunc func(ctx context.Context, chID ChannelID, msgBytes []byte)
type errorCbFunc func(context.Context, interface{})
/*
Each peer has one `MConnection` (multiplex connection) instance.
@ -286,21 +286,21 @@ func (c *MConnection) flush() {
}
// Catch panics, usually caused by remote disconnects.
func (c *MConnection) _recover() {
func (c *MConnection) _recover(ctx context.Context) {
if r := recover(); r != nil {
c.logger.Error("MConnection panicked", "err", r, "stack", string(debug.Stack()))
c.stopForError(fmt.Errorf("recovered from panic: %v", r))
c.stopForError(ctx, fmt.Errorf("recovered from panic: %v", r))
}
}
func (c *MConnection) stopForError(r interface{}) {
func (c *MConnection) stopForError(ctx context.Context, r interface{}) {
if err := c.Stop(); err != nil {
c.logger.Error("Error stopping connection", "err", err)
}
if atomic.CompareAndSwapUint32(&c.errored, 0, 1) {
if c.onError != nil {
c.onError(r)
c.onError(ctx, r)
}
}
}
@ -335,7 +335,7 @@ func (c *MConnection) Send(chID ChannelID, msgBytes []byte) bool {
// sendRoutine polls for packets to send from channels.
func (c *MConnection) sendRoutine(ctx context.Context) {
defer c._recover()
defer c._recover(ctx)
protoWriter := protoio.NewDelimitedWriter(c.bufConnWriter)
FOR_LOOP:
@ -390,7 +390,7 @@ FOR_LOOP:
break FOR_LOOP
case <-c.send:
// Send some PacketMsgs
eof := c.sendSomePacketMsgs()
eof := c.sendSomePacketMsgs(ctx)
if !eof {
// Keep sendRoutine awake.
select {
@ -405,7 +405,7 @@ FOR_LOOP:
}
if err != nil {
c.logger.Error("Connection failed @ sendRoutine", "conn", c, "err", err)
c.stopForError(err)
c.stopForError(ctx, err)
break FOR_LOOP
}
}
@ -417,7 +417,7 @@ FOR_LOOP:
// Returns true if messages from channels were exhausted.
// Blocks in accordance to .sendMonitor throttling.
func (c *MConnection) sendSomePacketMsgs() bool {
func (c *MConnection) sendSomePacketMsgs(ctx context.Context) bool {
// Block until .sendMonitor says we can write.
// Once we're ready we send more than we asked for,
// but amortized it should even out.
@ -425,7 +425,7 @@ func (c *MConnection) sendSomePacketMsgs() bool {
// Now send some PacketMsgs.
for i := 0; i < numBatchPacketMsgs; i++ {
if c.sendPacketMsg() {
if c.sendPacketMsg(ctx) {
return true
}
}
@ -433,7 +433,7 @@ func (c *MConnection) sendSomePacketMsgs() bool {
}
// Returns true if messages from channels were exhausted.
func (c *MConnection) sendPacketMsg() bool {
func (c *MConnection) sendPacketMsg(ctx context.Context) bool {
// Choose a channel to create a PacketMsg from.
// The chosen channel will be the one whose recentlySent/priority is the least.
var leastRatio float32 = math.MaxFloat32
@ -461,7 +461,7 @@ func (c *MConnection) sendPacketMsg() bool {
_n, err := leastChannel.writePacketMsgTo(c.bufConnWriter)
if err != nil {
c.logger.Error("Failed to write PacketMsg", "err", err)
c.stopForError(err)
c.stopForError(ctx, err)
return true
}
c.sendMonitor.Update(_n)
@ -474,7 +474,7 @@ func (c *MConnection) sendPacketMsg() bool {
// Blocks depending on how the connection is throttled.
// Otherwise, it never blocks.
func (c *MConnection) recvRoutine(ctx context.Context) {
defer c._recover()
defer c._recover(ctx)
protoReader := protoio.NewDelimitedReader(c.bufConnReader, c._maxPacketMsgSize)
@ -518,7 +518,7 @@ FOR_LOOP:
} else {
c.logger.Debug("Connection failed @ recvRoutine (reading byte)", "conn", c, "err", err)
}
c.stopForError(err)
c.stopForError(ctx, err)
}
break FOR_LOOP
}
@ -547,7 +547,7 @@ FOR_LOOP:
if pkt.PacketMsg.ChannelID < 0 || pkt.PacketMsg.ChannelID > math.MaxUint8 || !ok || channel == nil {
err := fmt.Errorf("unknown channel %X", pkt.PacketMsg.ChannelID)
c.logger.Debug("Connection failed @ recvRoutine", "conn", c, "err", err)
c.stopForError(err)
c.stopForError(ctx, err)
break FOR_LOOP
}
@ -555,19 +555,19 @@ FOR_LOOP:
if err != nil {
if c.IsRunning() {
c.logger.Debug("Connection failed @ recvRoutine", "conn", c, "err", err)
c.stopForError(err)
c.stopForError(ctx, err)
}
break FOR_LOOP
}
if msgBytes != nil {
c.logger.Debug("Received bytes", "chID", channelID, "msgBytes", msgBytes)
// NOTE: This means the reactor.Receive runs in the same thread as the p2p recv routine
c.onReceive(channelID, msgBytes)
c.onReceive(ctx, channelID, msgBytes)
}
default:
err := fmt.Errorf("unknown message type %v", reflect.TypeOf(packet))
c.logger.Error("Connection failed @ recvRoutine", "conn", c, "err", err)
c.stopForError(err)
c.stopForError(ctx, err)
break FOR_LOOP
}
}


+ 82
- 36
internal/p2p/conn/connection_test.go View File

@ -25,18 +25,18 @@ const maxPingPongPacketSize = 1024 // bytes
func createTestMConnection(logger log.Logger, conn net.Conn) *MConnection {
return createMConnectionWithCallbacks(logger, conn,
// onRecieve
func(chID ChannelID, msgBytes []byte) {
func(ctx context.Context, chID ChannelID, msgBytes []byte) {
},
// onError
func(r interface{}) {
func(ctx context.Context, r interface{}) {
})
}
func createMConnectionWithCallbacks(
logger log.Logger,
conn net.Conn,
onReceive func(chID ChannelID, msgBytes []byte),
onError func(r interface{}),
onReceive func(ctx context.Context, chID ChannelID, msgBytes []byte),
onError func(ctx context.Context, r interface{}),
) *MConnection {
cfg := DefaultMConnConfig()
cfg.PingInterval = 90 * time.Millisecond
@ -120,11 +120,17 @@ func TestMConnectionReceive(t *testing.T) {
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(chID ChannelID, msgBytes []byte) {
receivedCh <- msgBytes
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(r interface{}) {
errorsCh <- r
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
logger := log.TestingLogger()
@ -160,11 +166,17 @@ func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(chID ChannelID, msgBytes []byte) {
receivedCh <- msgBytes
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(r interface{}) {
errorsCh <- r
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
ctx, cancel := context.WithCancel(context.Background())
@ -202,12 +214,19 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(chID ChannelID, msgBytes []byte) {
receivedCh <- msgBytes
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(r interface{}) {
errorsCh <- r
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -259,11 +278,17 @@ func TestMConnectionMultiplePings(t *testing.T) {
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(chID ChannelID, msgBytes []byte) {
receivedCh <- msgBytes
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(r interface{}) {
errorsCh <- r
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -309,11 +334,17 @@ func TestMConnectionPingPongs(t *testing.T) {
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(chID ChannelID, msgBytes []byte) {
receivedCh <- msgBytes
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(r interface{}) {
errorsCh <- r
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
ctx, cancel := context.WithCancel(context.Background())
@ -370,11 +401,17 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) {
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(chID ChannelID, msgBytes []byte) {
receivedCh <- msgBytes
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(r interface{}) {
errorsCh <- r
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -406,8 +443,8 @@ func newClientAndServerConnsForReadErrors(
) (*MConnection, *MConnection) {
server, client := NetPipe()
onReceive := func(chID ChannelID, msgBytes []byte) {}
onError := func(r interface{}) {}
onReceive := func(context.Context, ChannelID, []byte) {}
onError := func(context.Context, interface{}) {}
// create client conn with two channels
chDescs := []*ChannelDescriptor{
@ -423,8 +460,11 @@ func newClientAndServerConnsForReadErrors(
// create server conn with 1 channel
// it fires on chOnErr when there's an error
serverLogger := logger.With("module", "server")
onError = func(r interface{}) {
chOnErr <- struct{}{}
onError = func(ctx context.Context, r interface{}) {
select {
case <-ctx.Done():
case chOnErr <- struct{}{}:
}
}
mconnServer := createMConnectionWithCallbacks(serverLogger, server, onReceive, onError)
@ -488,8 +528,11 @@ func TestMConnectionReadErrorLongMessage(t *testing.T) {
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
t.Cleanup(waitAll(mconnClient, mconnServer))
mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) {
chOnRcv <- struct{}{}
mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case <-ctx.Done():
case chOnRcv <- struct{}{}:
}
}
client := mconnClient.conn
@ -590,8 +633,11 @@ func TestMConnectionChannelOverflow(t *testing.T) {
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
t.Cleanup(waitAll(mconnClient, mconnServer))
mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) {
chOnRcv <- struct{}{}
mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case <-ctx.Done():
case chOnRcv <- struct{}{}:
}
}
client := mconnClient.conn


+ 14
- 14
internal/p2p/mocks/connection.go View File

@ -79,20 +79,20 @@ func (_m *Connection) LocalEndpoint() p2p.Endpoint {
return r0
}
// ReceiveMessage provides a mock function with given fields:
func (_m *Connection) ReceiveMessage() (conn.ChannelID, []byte, error) {
ret := _m.Called()
// ReceiveMessage provides a mock function with given fields: _a0
func (_m *Connection) ReceiveMessage(_a0 context.Context) (conn.ChannelID, []byte, error) {
ret := _m.Called(_a0)
var r0 conn.ChannelID
if rf, ok := ret.Get(0).(func() conn.ChannelID); ok {
r0 = rf()
if rf, ok := ret.Get(0).(func(context.Context) conn.ChannelID); ok {
r0 = rf(_a0)
} else {
r0 = ret.Get(0).(conn.ChannelID)
}
var r1 []byte
if rf, ok := ret.Get(1).(func() []byte); ok {
r1 = rf()
if rf, ok := ret.Get(1).(func(context.Context) []byte); ok {
r1 = rf(_a0)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).([]byte)
@ -100,8 +100,8 @@ func (_m *Connection) ReceiveMessage() (conn.ChannelID, []byte, error) {
}
var r2 error
if rf, ok := ret.Get(2).(func() error); ok {
r2 = rf()
if rf, ok := ret.Get(2).(func(context.Context) error); ok {
r2 = rf(_a0)
} else {
r2 = ret.Error(2)
}
@ -123,13 +123,13 @@ func (_m *Connection) RemoteEndpoint() p2p.Endpoint {
return r0
}
// SendMessage provides a mock function with given fields: _a0, _a1
func (_m *Connection) SendMessage(_a0 conn.ChannelID, _a1 []byte) error {
ret := _m.Called(_a0, _a1)
// SendMessage provides a mock function with given fields: _a0, _a1, _a2
func (_m *Connection) SendMessage(_a0 context.Context, _a1 conn.ChannelID, _a2 []byte) error {
ret := _m.Called(_a0, _a1, _a2)
var r0 error
if rf, ok := ret.Get(0).(func(conn.ChannelID, []byte) error); ok {
r0 = rf(_a0, _a1)
if rf, ok := ret.Get(0).(func(context.Context, conn.ChannelID, []byte) error); ok {
r0 = rf(_a0, _a1, _a2)
} else {
r0 = ret.Error(0)
}


+ 7
- 7
internal/p2p/mocks/transport.go View File

@ -17,13 +17,13 @@ type Transport struct {
mock.Mock
}
// Accept provides a mock function with given fields:
func (_m *Transport) Accept() (p2p.Connection, error) {
ret := _m.Called()
// Accept provides a mock function with given fields: _a0
func (_m *Transport) Accept(_a0 context.Context) (p2p.Connection, error) {
ret := _m.Called(_a0)
var r0 p2p.Connection
if rf, ok := ret.Get(0).(func() p2p.Connection); ok {
r0 = rf()
if rf, ok := ret.Get(0).(func(context.Context) p2p.Connection); ok {
r0 = rf(_a0)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(p2p.Connection)
@ -31,8 +31,8 @@ func (_m *Transport) Accept() (p2p.Connection, error) {
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)
}


+ 11
- 2
internal/p2p/p2ptest/network.go View File

@ -24,6 +24,7 @@ type Network struct {
logger log.Logger
memoryNetwork *p2p.MemoryNetwork
cancel context.CancelFunc
}
// NetworkOptions is an argument structure to parameterize the
@ -68,6 +69,9 @@ func MakeNetwork(ctx context.Context, t *testing.T, opts NetworkOptions) *Networ
// addition to creating a peer update subscription for each node. Finally, all
// nodes are connected to each other.
func (n *Network) Start(ctx context.Context, t *testing.T) {
ctx, n.cancel = context.WithCancel(ctx)
t.Cleanup(n.cancel)
// Set up a list of node addresses to dial, and a peer update subscription
// for each node.
dialQueue := []p2p.NodeAddress{}
@ -200,10 +204,10 @@ func (n *Network) Remove(ctx context.Context, t *testing.T, id types.NodeID) {
}
require.NoError(t, node.Transport.Close())
node.cancel()
if node.Router.IsRunning() {
require.NoError(t, node.Router.Stop())
}
node.PeerManager.Close()
for _, sub := range subs {
RequireUpdate(t, sub, p2p.PeerUpdate{
@ -222,12 +226,16 @@ type Node struct {
Router *p2p.Router
PeerManager *p2p.PeerManager
Transport *p2p.MemoryTransport
cancel context.CancelFunc
}
// MakeNode creates a new Node configured for the network with a
// running peer manager, but does not add it to the existing
// network. Callers are responsible for updating peering relationships.
func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) *Node {
ctx, cancel := context.WithCancel(ctx)
privKey := ed25519.GenPrivKey()
nodeID := types.NodeIDFromPubKey(privKey.PubKey())
nodeInfo := types.NodeInfo{
@ -267,8 +275,8 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions)
if router.IsRunning() {
require.NoError(t, router.Stop())
}
peerManager.Close()
require.NoError(t, transport.Close())
cancel()
})
return &Node{
@ -279,6 +287,7 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions)
Router: router,
PeerManager: peerManager,
Transport: transport,
cancel: cancel,
}
}


+ 17
- 33
internal/p2p/peermanager.go View File

@ -56,8 +56,8 @@ type PeerUpdate struct {
type PeerUpdates struct {
routerUpdatesCh chan PeerUpdate
reactorUpdatesCh chan PeerUpdate
closeCh chan struct{}
closeOnce sync.Once
doneCh chan struct{}
}
// NewPeerUpdates creates a new PeerUpdates subscription. It is primarily for
@ -67,7 +67,7 @@ func NewPeerUpdates(updatesCh chan PeerUpdate, buf int) *PeerUpdates {
return &PeerUpdates{
reactorUpdatesCh: updatesCh,
routerUpdatesCh: make(chan PeerUpdate, buf),
closeCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
}
@ -76,28 +76,28 @@ func (pu *PeerUpdates) Updates() <-chan PeerUpdate {
return pu.reactorUpdatesCh
}
// SendUpdate pushes information about a peer into the routing layer,
// presumably from a peer.
func (pu *PeerUpdates) SendUpdate(update PeerUpdate) {
select {
case <-pu.closeCh:
case pu.routerUpdatesCh <- update:
}
// Done returns a channel that is closed when the subscription is closed.
func (pu *PeerUpdates) Done() <-chan struct{} {
return pu.doneCh
}
// Close closes the peer updates subscription.
func (pu *PeerUpdates) Close() {
pu.closeOnce.Do(func() {
// NOTE: We don't close updatesCh since multiple goroutines may be
// sending on it. The PeerManager senders will select on closeCh as well
// sending on it. The PeerManager senders will select on doneCh as well
// to avoid blocking on a closed subscription.
close(pu.closeCh)
close(pu.doneCh)
})
}
// Done returns a channel that is closed when the subscription is closed.
func (pu *PeerUpdates) Done() <-chan struct{} {
return pu.closeCh
// SendUpdate pushes information about a peer into the routing layer,
// presumably from a peer.
func (pu *PeerUpdates) SendUpdate(ctx context.Context, update PeerUpdate) {
select {
case <-ctx.Done():
case pu.routerUpdatesCh <- update:
}
}
// PeerManagerOptions specifies options for a PeerManager.
@ -276,8 +276,6 @@ type PeerManager struct {
rand *rand.Rand
dialWaker *tmsync.Waker // wakes up DialNext() on relevant peer changes
evictWaker *tmsync.Waker // wakes up EvictNext() on relevant peer changes
closeCh chan struct{} // signal channel for Close()
closeOnce sync.Once
mtx sync.Mutex
store *peerStore
@ -312,7 +310,6 @@ func NewPeerManager(selfID types.NodeID, peerDB dbm.DB, options PeerManagerOptio
rand: rand.New(rand.NewSource(time.Now().UnixNano())), // nolint:gosec
dialWaker: tmsync.NewWaker(),
evictWaker: tmsync.NewWaker(),
closeCh: make(chan struct{}),
store: store,
dialing: map[types.NodeID]bool{},
@ -552,7 +549,6 @@ func (m *PeerManager) DialFailed(ctx context.Context, address NodeAddress) error
select {
case <-timer.C:
m.dialWaker.Wake()
case <-m.closeCh:
case <-ctx.Done():
}
}()
@ -864,10 +860,6 @@ func (m *PeerManager) Register(ctx context.Context, peerUpdates *PeerUpdates) {
go func() {
for {
select {
case <-peerUpdates.closeCh:
return
case <-m.closeCh:
return
case <-ctx.Done():
return
case pu := <-peerUpdates.routerUpdatesCh:
@ -882,7 +874,6 @@ func (m *PeerManager) Register(ctx context.Context, peerUpdates *PeerUpdates) {
m.mtx.Lock()
delete(m.subscriptions, peerUpdates)
m.mtx.Unlock()
case <-m.closeCh:
case <-ctx.Done():
}
}()
@ -913,27 +904,20 @@ func (m *PeerManager) processPeerEvent(pu PeerUpdate) {
// maintaining order if this is a problem.
func (m *PeerManager) broadcast(peerUpdate PeerUpdate) {
for _, sub := range m.subscriptions {
// We have to check closeCh separately first, otherwise there's a 50%
// We have to check doneChan separately first, otherwise there's a 50%
// chance the second select will send on a closed subscription.
select {
case <-sub.closeCh:
case <-sub.doneCh:
continue
default:
}
select {
case sub.reactorUpdatesCh <- peerUpdate:
case <-sub.closeCh:
case <-sub.doneCh:
}
}
}
// Close closes the peer manager, releasing resources (i.e. goroutines).
func (m *PeerManager) Close() {
m.closeOnce.Do(func() {
close(m.closeCh)
})
}
// Addresses returns all known addresses for a peer, primarily for testing.
// The order is arbitrary.
func (m *PeerManager) Addresses(peerID types.NodeID) []NodeAddress {


+ 2
- 3
internal/p2p/peermanager_scoring_test.go View File

@ -22,7 +22,6 @@ func TestPeerScoring(t *testing.T) {
db := dbm.NewMemDB()
peerManager, err := NewPeerManager(selfID, db, PeerManagerOptions{})
require.NoError(t, err)
defer peerManager.Close()
// create a fake node
id := types.NodeID(strings.Repeat("a1", 20))
@ -59,7 +58,7 @@ func TestPeerScoring(t *testing.T) {
start := peerManager.Scores()[id]
pu := peerManager.Subscribe(ctx)
defer pu.Close()
pu.SendUpdate(PeerUpdate{
pu.SendUpdate(ctx, PeerUpdate{
NodeID: id,
Status: PeerStatusGood,
})
@ -73,7 +72,7 @@ func TestPeerScoring(t *testing.T) {
start := peerManager.Scores()[id]
pu := peerManager.Subscribe(ctx)
defer pu.Close()
pu.SendUpdate(PeerUpdate{
pu.SendUpdate(ctx, PeerUpdate{
NodeID: id,
Status: PeerStatusBad,
})


+ 0
- 10
internal/p2p/peermanager_test.go View File

@ -154,7 +154,6 @@ func TestNewPeerManager_Persistence(t *testing.T) {
PeerScores: map[types.NodeID]p2p.PeerScore{bID: 1},
})
require.NoError(t, err)
defer peerManager.Close()
for _, addr := range append(append(aAddresses, bAddresses...), cAddresses...) {
added, err := peerManager.Add(addr)
@ -171,8 +170,6 @@ func TestNewPeerManager_Persistence(t *testing.T) {
cID: 0,
}, peerManager.Scores())
peerManager.Close()
// Creating a new peer manager with the same database should retain the
// peers, but they should have updated scores from the new PersistentPeers
// configuration.
@ -181,7 +178,6 @@ func TestNewPeerManager_Persistence(t *testing.T) {
PeerScores: map[types.NodeID]p2p.PeerScore{cID: 1},
})
require.NoError(t, err)
defer peerManager.Close()
require.ElementsMatch(t, aAddresses, peerManager.Addresses(aID))
require.ElementsMatch(t, bAddresses, peerManager.Addresses(bID))
@ -208,7 +204,6 @@ func TestNewPeerManager_SelfIDChange(t *testing.T) {
require.NoError(t, err)
require.True(t, added)
require.ElementsMatch(t, []types.NodeID{a.NodeID, b.NodeID}, peerManager.Peers())
peerManager.Close()
// If we change our selfID to one of the peers in the peer store, it
// should be removed from the store.
@ -1755,9 +1750,6 @@ func TestPeerManager_Close(t *testing.T) {
require.NoError(t, err)
require.Equal(t, a, dial)
require.NoError(t, peerManager.DialFailed(ctx, a))
// This should clean up the goroutines.
peerManager.Close()
}
func TestPeerManager_Advertise(t *testing.T) {
@ -1780,7 +1772,6 @@ func TestPeerManager_Advertise(t *testing.T) {
PeerScores: map[types.NodeID]p2p.PeerScore{aID: 3, bID: 2, cID: 1},
})
require.NoError(t, err)
defer peerManager.Close()
added, err := peerManager.Add(aTCP)
require.NoError(t, err)
@ -1847,7 +1838,6 @@ func TestPeerManager_SetHeight_GetHeight(t *testing.T) {
require.ElementsMatch(t, []types.NodeID{a.NodeID, b.NodeID}, peerManager.Peers())
// The heights should not be persisted.
peerManager.Close()
peerManager, err = p2p.NewPeerManager(selfID, db, p2p.PeerManagerOptions{})
require.NoError(t, err)


+ 2
- 16
internal/p2p/pex/reactor.go View File

@ -83,7 +83,6 @@ type Reactor struct {
peerManager *p2p.PeerManager
pexCh *p2p.Channel
peerUpdates *p2p.PeerUpdates
closeCh chan struct{}
// list of available peers to loop through and send peer requests to
availablePeers map[types.NodeID]struct{}
@ -128,7 +127,6 @@ func NewReactor(
peerManager: peerManager,
pexCh: pexCh,
peerUpdates: peerUpdates,
closeCh: make(chan struct{}),
availablePeers: make(map[types.NodeID]struct{}),
requestsSent: make(map[types.NodeID]struct{}),
lastReceivedRequests: make(map[types.NodeID]time.Time),
@ -150,13 +148,7 @@ func (r *Reactor) OnStart(ctx context.Context) error {
// OnStop stops the reactor by signaling to all spawned goroutines to exit and
// blocking until they all exit.
func (r *Reactor) OnStop() {
// Close closeCh to signal to all spawned goroutines to gracefully exit. All
// p2p Channels should execute Close().
close(r.closeCh)
<-r.peerUpdates.Done()
}
func (r *Reactor) OnStop() {}
// processPexCh implements a blocking event loop where we listen for p2p
// Envelope messages from the pexCh.
@ -168,8 +160,6 @@ func (r *Reactor) processPexCh(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-r.closeCh:
r.logger.Debug("stopped listening on PEX channel; closing...")
return
@ -196,17 +186,13 @@ func (r *Reactor) processPexCh(ctx context.Context) {
// close the p2p PeerUpdatesCh gracefully.
func (r *Reactor) processPeerUpdates(ctx context.Context) {
defer r.peerUpdates.Close()
for {
select {
case <-ctx.Done():
r.logger.Debug("stopped listening on peer updates channel; closing...")
return
case peerUpdate := <-r.peerUpdates.Updates():
r.processPeerUpdate(peerUpdate)
case <-r.closeCh:
r.logger.Debug("stopped listening on peer updates channel; closing...")
return
}
}
}


+ 11
- 17
internal/p2p/router.go View File

@ -158,7 +158,6 @@ type Router struct {
endpoints []Endpoint
connTracker connectionTracker
protocolTransports map[Protocol]Transport
stopCh chan struct{} // signals Router shutdown
peerMtx sync.RWMutex
peerQueues map[types.NodeID]queue // outbound messages per peer for all channels
@ -208,7 +207,6 @@ func NewRouter(
protocolTransports: map[Protocol]Transport{},
peerManager: peerManager,
options: options,
stopCh: make(chan struct{}),
channelQueues: map[ChannelID]queue{},
channelMessages: map[ChannelID]proto.Message{},
peerQueues: map[types.NodeID]queue{},
@ -399,7 +397,7 @@ func (r *Router) routeChannel(
case <-q.closed():
r.logger.Debug("dropping message for unconnected peer", "peer", envelope.To, "channel", chID)
case <-r.stopCh:
case <-ctx.Done():
return
}
}
@ -414,8 +412,6 @@ func (r *Router) routeChannel(
r.peerManager.Errored(peerError.NodeID, peerError.Err)
case <-ctx.Done():
return
case <-r.stopCh:
return
}
}
}
@ -474,7 +470,7 @@ func (r *Router) acceptPeers(ctx context.Context, transport Transport) {
r.logger.Debug("starting accept routine", "transport", transport)
for {
conn, err := transport.Accept()
conn, err := transport.Accept(ctx)
switch err {
case nil:
case io.EOF:
@ -783,14 +779,14 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec
go func() {
select {
case errCh <- r.receivePeer(peerID, conn):
case errCh <- r.receivePeer(ctx, peerID, conn):
case <-ctx.Done():
}
}()
go func() {
select {
case errCh <- r.sendPeer(peerID, conn, sendQueue):
case errCh <- r.sendPeer(ctx, peerID, conn, sendQueue):
case <-ctx.Done():
}
}()
@ -829,9 +825,9 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec
// receivePeer receives inbound messages from a peer, deserializes them and
// passes them on to the appropriate channel.
func (r *Router) receivePeer(peerID types.NodeID, conn Connection) error {
func (r *Router) receivePeer(ctx context.Context, peerID types.NodeID, conn Connection) error {
for {
chID, bz, err := conn.ReceiveMessage()
chID, bz, err := conn.ReceiveMessage(ctx)
if err != nil {
return err
}
@ -874,14 +870,14 @@ func (r *Router) receivePeer(peerID types.NodeID, conn Connection) error {
case <-queue.closed():
r.logger.Debug("channel closed, dropping message", "peer", peerID, "channel", chID)
case <-r.stopCh:
case <-ctx.Done():
return nil
}
}
}
// sendPeer sends queued messages to a peer.
func (r *Router) sendPeer(peerID types.NodeID, conn Connection, peerQueue queue) error {
func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connection, peerQueue queue) error {
for {
start := time.Now().UTC()
@ -899,7 +895,7 @@ func (r *Router) sendPeer(peerID types.NodeID, conn Connection, peerQueue queue)
continue
}
if err = conn.SendMessage(envelope.channelID, bz); err != nil {
if err = conn.SendMessage(ctx, envelope.channelID, bz); err != nil {
return err
}
@ -908,7 +904,7 @@ func (r *Router) sendPeer(peerID types.NodeID, conn Connection, peerQueue queue)
case <-peerQueue.closed():
return nil
case <-r.stopCh:
case <-ctx.Done():
return nil
}
}
@ -983,9 +979,6 @@ func (r *Router) OnStart(ctx context.Context) error {
// here, since that would cause any reactor senders to panic, so it is the
// sender's responsibility.
func (r *Router) OnStop() {
// Signal router shutdown.
close(r.stopCh)
// Close transport listeners (unblocks Accept calls).
for _, transport := range r.transports {
if err := transport.Close(); err != nil {
@ -1009,6 +1002,7 @@ func (r *Router) OnStop() {
r.peerMtx.RUnlock()
for _, q := range queues {
q.close()
<-q.closed()
}
}


+ 23
- 37
internal/p2p/router_test.go View File

@ -106,7 +106,6 @@ func TestRouter_Channel_Basic(t *testing.T) {
// Set up a router with no transports (so no peers).
peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{})
require.NoError(t, err)
defer peerManager.Close()
router, err := p2p.NewRouter(
ctx,
@ -392,25 +391,22 @@ func TestRouter_AcceptPeers(t *testing.T) {
mockConnection.On("String").Maybe().Return("mock")
mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey).
Return(tc.peerInfo, tc.peerKey, nil)
mockConnection.On("Close").Run(func(_ mock.Arguments) { closer.Close() }).Return(nil)
mockConnection.On("Close").Run(func(_ mock.Arguments) { closer.Close() }).Return(nil).Maybe()
mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{})
if tc.ok {
// without the sleep after RequireUpdate this method isn't
// always called. Consider making this call optional.
mockConnection.On("ReceiveMessage").Return(chID, nil, io.EOF)
mockConnection.On("ReceiveMessage", mock.Anything).Return(chID, nil, io.EOF).Maybe()
}
mockTransport := &mocks.Transport{}
mockTransport.On("String").Maybe().Return("mock")
mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"})
mockTransport.On("Close").Return(nil)
mockTransport.On("Accept").Once().Return(mockConnection, nil)
mockTransport.On("Accept").Maybe().Return(nil, io.EOF)
mockTransport.On("Close").Return(nil).Maybe()
mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil)
mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF)
// Set up and start the router.
peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{})
require.NoError(t, err)
defer peerManager.Close()
sub := peerManager.Subscribe(ctx)
defer sub.Close()
@ -464,13 +460,12 @@ func TestRouter_AcceptPeers_Error(t *testing.T) {
mockTransport := &mocks.Transport{}
mockTransport.On("String").Maybe().Return("mock")
mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"})
mockTransport.On("Accept").Once().Return(nil, errors.New("boom"))
mockTransport.On("Accept", mock.Anything).Once().Return(nil, errors.New("boom"))
mockTransport.On("Close").Return(nil)
// Set up and start the router.
peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{})
require.NoError(t, err)
defer peerManager.Close()
router, err := p2p.NewRouter(
ctx,
@ -503,13 +498,12 @@ func TestRouter_AcceptPeers_ErrorEOF(t *testing.T) {
mockTransport := &mocks.Transport{}
mockTransport.On("String").Maybe().Return("mock")
mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"})
mockTransport.On("Accept").Once().Return(nil, io.EOF)
mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF)
mockTransport.On("Close").Return(nil)
// Set up and start the router.
peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{})
require.NoError(t, err)
defer peerManager.Close()
router, err := p2p.NewRouter(
ctx,
@ -554,15 +548,14 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) {
mockTransport.On("String").Maybe().Return("mock")
mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"})
mockTransport.On("Close").Return(nil)
mockTransport.On("Accept").Times(3).Run(func(_ mock.Arguments) {
mockTransport.On("Accept", mock.Anything).Times(3).Run(func(_ mock.Arguments) {
acceptCh <- true
}).Return(mockConnection, nil)
mockTransport.On("Accept").Once().Return(nil, io.EOF)
mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF)
// Set up and start the router.
peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{})
require.NoError(t, err)
defer peerManager.Close()
router, err := p2p.NewRouter(
ctx,
@ -580,7 +573,7 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) {
require.Eventually(t, func() bool {
return len(acceptCh) == 3
}, time.Second, 10*time.Millisecond)
}, time.Second, 10*time.Millisecond, "num", len(acceptCh))
close(closeCh)
time.Sleep(100 * time.Millisecond)
@ -636,19 +629,17 @@ func TestRouter_DialPeers(t *testing.T) {
if tc.dialErr == nil {
mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey).
Return(tc.peerInfo, tc.peerKey, nil)
mockConnection.On("Close").Run(func(_ mock.Arguments) { closer.Close() }).Return(nil)
mockConnection.On("Close").Run(func(_ mock.Arguments) { closer.Close() }).Return(nil).Maybe()
}
if tc.ok {
// without the sleep after RequireUpdate this method isn't
// always called. Consider making this call optional.
mockConnection.On("ReceiveMessage").Return(chID, nil, io.EOF)
mockConnection.On("ReceiveMessage", mock.Anything).Return(chID, nil, io.EOF).Maybe()
}
mockTransport := &mocks.Transport{}
mockTransport.On("String").Maybe().Return("mock")
mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"})
mockTransport.On("Close").Return(nil)
mockTransport.On("Accept").Maybe().Return(nil, io.EOF)
mockTransport.On("Close").Return(nil).Maybe()
mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF)
if tc.dialErr == nil {
mockTransport.On("Dial", mock.Anything, endpoint).Once().Return(mockConnection, nil)
// This handles the retry when a dialed connection gets closed after ReceiveMessage
@ -663,7 +654,6 @@ func TestRouter_DialPeers(t *testing.T) {
// Set up and start the router.
peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{})
require.NoError(t, err)
defer peerManager.Close()
added, err := peerManager.Add(address)
require.NoError(t, err)
@ -734,7 +724,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) {
mockTransport.On("String").Maybe().Return("mock")
mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"})
mockTransport.On("Close").Return(nil)
mockTransport.On("Accept").Once().Return(nil, io.EOF)
mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF)
for _, address := range []p2p.NodeAddress{a, b, c} {
endpoint := p2p.Endpoint{Protocol: address.Protocol, Path: string(address.NodeID)}
mockTransport.On("Dial", mock.Anything, endpoint).Run(func(_ mock.Arguments) {
@ -745,7 +735,6 @@ func TestRouter_DialPeers_Parallel(t *testing.T) {
// Set up and start the router.
peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{})
require.NoError(t, err)
defer peerManager.Close()
added, err := peerManager.Add(a)
require.NoError(t, err)
@ -813,7 +802,7 @@ func TestRouter_EvictPeers(t *testing.T) {
mockConnection.On("String").Maybe().Return("mock")
mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey).
Return(peerInfo, peerKey.PubKey(), nil)
mockConnection.On("ReceiveMessage").WaitUntil(closeCh).Return(chID, nil, io.EOF)
mockConnection.On("ReceiveMessage", mock.Anything).WaitUntil(closeCh).Return(chID, nil, io.EOF)
mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{})
mockConnection.On("Close").Run(func(_ mock.Arguments) {
closeOnce.Do(func() {
@ -825,13 +814,12 @@ func TestRouter_EvictPeers(t *testing.T) {
mockTransport.On("String").Maybe().Return("mock")
mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"})
mockTransport.On("Close").Return(nil)
mockTransport.On("Accept").Once().Return(mockConnection, nil)
mockTransport.On("Accept").Maybe().Return(nil, io.EOF)
mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil)
mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF)
// Set up and start the router.
peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{})
require.NoError(t, err)
defer peerManager.Close()
sub := peerManager.Subscribe(ctx)
defer sub.Close()
@ -893,13 +881,12 @@ func TestRouter_ChannelCompatability(t *testing.T) {
mockTransport.On("String").Maybe().Return("mock")
mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"})
mockTransport.On("Close").Return(nil)
mockTransport.On("Accept").Once().Return(mockConnection, nil)
mockTransport.On("Accept").Once().Return(nil, io.EOF)
mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil)
mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF)
// Set up and start the router.
peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{})
require.NoError(t, err)
defer peerManager.Close()
router, err := p2p.NewRouter(
ctx,
@ -941,20 +928,19 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) {
Return(peer, peerKey.PubKey(), nil)
mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{})
mockConnection.On("Close").Return(nil)
mockConnection.On("ReceiveMessage").Return(chID, nil, io.EOF)
mockConnection.On("ReceiveMessage", mock.Anything).Return(chID, nil, io.EOF)
mockTransport := &mocks.Transport{}
mockTransport.On("AddChannelDescriptors", mock.Anything).Return()
mockTransport.On("String").Maybe().Return("mock")
mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"})
mockTransport.On("Close").Return(nil)
mockTransport.On("Accept").Once().Return(mockConnection, nil)
mockTransport.On("Accept").Maybe().Return(nil, io.EOF)
mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil)
mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF)
// Set up and start the router.
peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{})
require.NoError(t, err)
defer peerManager.Close()
sub := peerManager.Subscribe(ctx)
defer sub.Close()


+ 3
- 3
internal/p2p/transport.go View File

@ -39,7 +39,7 @@ type Transport interface {
// Accept waits for the next inbound connection on a listening endpoint, blocking
// until either a connection is available or the transport is closed. On closure,
// io.EOF is returned and further Accept calls are futile.
Accept() (Connection, error)
Accept(context.Context) (Connection, error)
// Dial creates an outbound connection to an endpoint.
Dial(context.Context, Endpoint) (Connection, error)
@ -85,10 +85,10 @@ type Connection interface {
// ReceiveMessage returns the next message received on the connection,
// blocking until one is available. Returns io.EOF if closed.
ReceiveMessage() (ChannelID, []byte, error)
ReceiveMessage(context.Context) (ChannelID, []byte, error)
// SendMessage sends a message on the connection. Returns io.EOF if closed.
SendMessage(ChannelID, []byte) error
SendMessage(context.Context, ChannelID, []byte) error
// LocalEndpoint returns the local endpoint for the connection.
LocalEndpoint() Endpoint


+ 24
- 19
internal/p2p/transport_mconn.go View File

@ -44,10 +44,10 @@ type MConnTransport struct {
options MConnTransportOptions
mConnConfig conn.MConnConfig
channelDescs []*ChannelDescriptor
closeCh chan struct{}
closeOnce sync.Once
listener net.Listener
closeOnce sync.Once
doneCh chan struct{}
listener net.Listener
}
// NewMConnTransport sets up a new MConnection transport. This uses the
@ -63,7 +63,7 @@ func NewMConnTransport(
logger: logger,
options: options,
mConnConfig: mConnConfig,
closeCh: make(chan struct{}),
doneCh: make(chan struct{}),
channelDescs: channelDescs,
}
}
@ -84,10 +84,11 @@ func (m *MConnTransport) Endpoints() []Endpoint {
return []Endpoint{}
}
select {
case <-m.closeCh:
case <-m.doneCh:
return []Endpoint{}
default:
}
endpoint := Endpoint{
Protocol: MConnProtocol,
}
@ -132,7 +133,7 @@ func (m *MConnTransport) Listen(endpoint Endpoint) error {
}
// Accept implements Transport.
func (m *MConnTransport) Accept() (Connection, error) {
func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) {
if m.listener == nil {
return nil, errors.New("transport is not listening")
}
@ -140,7 +141,9 @@ func (m *MConnTransport) Accept() (Connection, error) {
tcpConn, err := m.listener.Accept()
if err != nil {
select {
case <-m.closeCh:
case <-ctx.Done():
return nil, io.EOF
case <-m.doneCh:
return nil, io.EOF
default:
return nil, err
@ -178,7 +181,7 @@ func (m *MConnTransport) Dial(ctx context.Context, endpoint Endpoint) (Connectio
func (m *MConnTransport) Close() error {
var err error
m.closeOnce.Do(func() {
close(m.closeCh) // must be closed first, to handle error in Accept()
close(m.doneCh)
if m.listener != nil {
err = m.listener.Close()
}
@ -222,7 +225,7 @@ type mConnConnection struct {
channelDescs []*ChannelDescriptor
receiveCh chan mConnMessage
errorCh chan error
closeCh chan struct{}
doneCh chan struct{}
closeOnce sync.Once
mconn *conn.MConnection // set during Handshake()
@ -248,7 +251,7 @@ func newMConnConnection(
channelDescs: channelDescs,
receiveCh: make(chan mConnMessage),
errorCh: make(chan error, 1), // buffered to avoid onError leak
closeCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
}
@ -370,16 +373,16 @@ func (c *mConnConnection) handshake(
}
// onReceive is a callback for MConnection received messages.
func (c *mConnConnection) onReceive(chID ChannelID, payload []byte) {
func (c *mConnConnection) onReceive(ctx context.Context, chID ChannelID, payload []byte) {
select {
case c.receiveCh <- mConnMessage{channelID: chID, payload: payload}:
case <-c.closeCh:
case <-ctx.Done():
}
}
// onError is a callback for MConnection errors. The error is passed via errorCh
// to ReceiveMessage (but not SendMessage, for legacy P2P stack behavior).
func (c *mConnConnection) onError(e interface{}) {
func (c *mConnConnection) onError(ctx context.Context, e interface{}) {
err, ok := e.(error)
if !ok {
err = fmt.Errorf("%v", err)
@ -389,7 +392,7 @@ func (c *mConnConnection) onError(e interface{}) {
_ = c.Close()
select {
case c.errorCh <- err:
case <-c.closeCh:
case <-ctx.Done():
}
}
@ -399,14 +402,14 @@ func (c *mConnConnection) String() string {
}
// SendMessage implements Connection.
func (c *mConnConnection) SendMessage(chID ChannelID, msg []byte) error {
func (c *mConnConnection) SendMessage(ctx context.Context, chID ChannelID, msg []byte) error {
if chID > math.MaxUint8 {
return fmt.Errorf("MConnection only supports 1-byte channel IDs (got %v)", chID)
}
select {
case err := <-c.errorCh:
return err
case <-c.closeCh:
case <-ctx.Done():
return io.EOF
default:
if ok := c.mconn.Send(chID, msg); !ok {
@ -418,11 +421,13 @@ func (c *mConnConnection) SendMessage(chID ChannelID, msg []byte) error {
}
// ReceiveMessage implements Connection.
func (c *mConnConnection) ReceiveMessage() (ChannelID, []byte, error) {
func (c *mConnConnection) ReceiveMessage(ctx context.Context) (ChannelID, []byte, error) {
select {
case err := <-c.errorCh:
return 0, nil, err
case <-c.closeCh:
case <-c.doneCh:
return 0, nil, io.EOF
case <-ctx.Done():
return 0, nil, io.EOF
case msg := <-c.receiveCh:
return msg.channelID, msg.payload, nil
@ -462,7 +467,7 @@ func (c *mConnConnection) Close() error {
} else {
err = c.conn.Close()
}
close(c.closeCh)
close(c.doneCh)
})
return err
}

+ 6
- 4
internal/p2p/transport_mconn_test.go View File

@ -52,8 +52,10 @@ func TestMConnTransport_AcceptBeforeListen(t *testing.T) {
t.Cleanup(func() {
_ = transport.Close()
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := transport.Accept()
_, err := transport.Accept(ctx)
require.Error(t, err)
require.NotEqual(t, io.EOF, err) // io.EOF should be returned after Close()
}
@ -85,7 +87,7 @@ func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) {
acceptCh := make(chan p2p.Connection, 10)
go func() {
for {
conn, err := transport.Accept()
conn, err := transport.Accept(ctx)
if err != nil {
return
}
@ -203,7 +205,7 @@ func TestMConnTransport_Listen(t *testing.T) {
close(dialedChan)
}()
conn, err := transport.Accept()
conn, err := transport.Accept(ctx)
require.NoError(t, err)
_ = conn.Close()
<-dialedChan
@ -212,7 +214,7 @@ func TestMConnTransport_Listen(t *testing.T) {
require.NoError(t, peerConn.Close())
// try to read from the connection should error
_, _, err = peerConn.ReceiveMessage()
_, _, err = peerConn.ReceiveMessage(ctx)
require.Error(t, err)
// Trying to listen again should error.


+ 22
- 26
internal/p2p/transport_memory.go View File

@ -94,9 +94,7 @@ type MemoryTransport struct {
nodeID types.NodeID
bufferSize int
acceptCh chan *MemoryConnection
closeCh chan struct{}
closeOnce sync.Once
acceptCh chan *MemoryConnection
}
// newMemoryTransport creates a new MemoryTransport. This is for internal use by
@ -108,7 +106,6 @@ func newMemoryTransport(network *MemoryNetwork, nodeID types.NodeID) *MemoryTran
nodeID: nodeID,
bufferSize: network.bufferSize,
acceptCh: make(chan *MemoryConnection),
closeCh: make(chan struct{}),
}
}
@ -128,28 +125,27 @@ func (t *MemoryTransport) Protocols() []Protocol {
// Endpoints implements Transport.
func (t *MemoryTransport) Endpoints() []Endpoint {
select {
case <-t.closeCh:
if n := t.network.GetTransport(t.nodeID); n == nil {
return []Endpoint{}
default:
return []Endpoint{{
Protocol: MemoryProtocol,
Path: string(t.nodeID),
// An arbitrary IP and port is used in order for the pex
// reactor to be able to send addresses to one another.
IP: net.IPv4zero,
Port: 0,
}}
}
return []Endpoint{{
Protocol: MemoryProtocol,
Path: string(t.nodeID),
// An arbitrary IP and port is used in order for the pex
// reactor to be able to send addresses to one another.
IP: net.IPv4zero,
Port: 0,
}}
}
// Accept implements Transport.
func (t *MemoryTransport) Accept() (Connection, error) {
func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) {
select {
case conn := <-t.acceptCh:
t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path)
return conn, nil
case <-t.closeCh:
case <-ctx.Done():
return nil, io.EOF
}
}
@ -187,20 +183,14 @@ func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connecti
select {
case peer.acceptCh <- inConn:
return outConn, nil
case <-peer.closeCh:
return nil, io.EOF
case <-ctx.Done():
return nil, ctx.Err()
return nil, io.EOF
}
}
// Close implements Transport.
func (t *MemoryTransport) Close() error {
t.network.RemoveTransport(t.nodeID)
t.closeOnce.Do(func() {
close(t.closeCh)
t.logger.Info("closed transport")
})
return nil
}
@ -295,12 +285,14 @@ func (c *MemoryConnection) Handshake(
}
// ReceiveMessage implements Connection.
func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) {
func (c *MemoryConnection) ReceiveMessage(ctx context.Context) (ChannelID, []byte, error) {
// Check close first, since channels are buffered. Otherwise, below select
// may non-deterministically return non-error even when closed.
select {
case <-c.closer.Done():
return 0, nil, io.EOF
case <-ctx.Done():
return 0, nil, io.EOF
default:
}
@ -314,12 +306,14 @@ func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) {
}
// SendMessage implements Connection.
func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) error {
func (c *MemoryConnection) SendMessage(ctx context.Context, chID ChannelID, msg []byte) error {
// Check close first, since channels are buffered. Otherwise, below select
// may non-deterministically return non-error even when closed.
select {
case <-c.closer.Done():
return io.EOF
case <-ctx.Done():
return io.EOF
default:
}
@ -327,6 +321,8 @@ func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) error {
case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
c.logger.Debug("sent message", "chID", chID, "msg", msg)
return nil
case <-ctx.Done():
return io.EOF
case <-c.closer.Done():
return io.EOF
}


+ 19
- 20
internal/p2p/transport_test.go View File

@ -46,21 +46,23 @@ func TestTransport_AcceptClose(t *testing.T) {
withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
opctx, opcancel := context.WithCancel(ctx)
// In-progress Accept should error on concurrent close.
errCh := make(chan error, 1)
go func() {
time.Sleep(200 * time.Millisecond)
opcancel()
errCh <- a.Close()
}()
_, err := a.Accept()
_, err := a.Accept(opctx)
require.Error(t, err)
require.Equal(t, io.EOF, err)
require.NoError(t, <-errCh)
// Closed transport should return error immediately.
_, err = a.Accept()
_, err = a.Accept(opctx)
require.Error(t, err)
require.Equal(t, io.EOF, err)
})
@ -93,7 +95,7 @@ func TestTransport_DialEndpoints(t *testing.T) {
// Spawn a goroutine to simply accept any connections until closed.
go func() {
for {
conn, err := a.Accept()
conn, err := a.Accept(ctx)
if err != nil {
return
}
@ -177,7 +179,6 @@ func TestTransport_Dial(t *testing.T) {
cancel()
_, err := a.Dial(cancelCtx, bEndpoint)
require.Error(t, err)
require.Equal(t, err, context.Canceled)
// Unavailable endpoint should error.
err = b.Close()
@ -188,7 +189,7 @@ func TestTransport_Dial(t *testing.T) {
// Dialing from a closed transport should still work.
errCh := make(chan error, 1)
go func() {
conn, err := a.Accept()
conn, err := a.Accept(ctx)
if err == nil {
_ = conn.Close()
}
@ -351,13 +352,12 @@ func TestConnection_FlushClose(t *testing.T) {
err := ab.Close()
require.NoError(t, err)
_, _, err = ab.ReceiveMessage()
_, _, err = ab.ReceiveMessage(ctx)
require.Error(t, err)
require.Equal(t, io.EOF, err)
err = ab.SendMessage(chID, []byte("closed"))
err = ab.SendMessage(ctx, chID, []byte("closed"))
require.Error(t, err)
require.Equal(t, io.EOF, err)
})
}
@ -388,19 +388,19 @@ func TestConnection_SendReceive(t *testing.T) {
ab, ba := dialAcceptHandshake(ctx, t, a, b)
// Can send and receive a to b.
err := ab.SendMessage(chID, []byte("foo"))
err := ab.SendMessage(ctx, chID, []byte("foo"))
require.NoError(t, err)
ch, msg, err := ba.ReceiveMessage()
ch, msg, err := ba.ReceiveMessage(ctx)
require.NoError(t, err)
require.Equal(t, []byte("foo"), msg)
require.Equal(t, chID, ch)
// Can send and receive b to a.
err = ba.SendMessage(chID, []byte("bar"))
err = ba.SendMessage(ctx, chID, []byte("bar"))
require.NoError(t, err)
_, msg, err = ab.ReceiveMessage()
_, msg, err = ab.ReceiveMessage(ctx)
require.NoError(t, err)
require.Equal(t, []byte("bar"), msg)
@ -410,9 +410,9 @@ func TestConnection_SendReceive(t *testing.T) {
err = b.Close()
require.NoError(t, err)
err = ab.SendMessage(chID, []byte("still here"))
err = ab.SendMessage(ctx, chID, []byte("still here"))
require.NoError(t, err)
ch, msg, err = ba.ReceiveMessage()
ch, msg, err = ba.ReceiveMessage(ctx)
require.NoError(t, err)
require.Equal(t, chID, ch)
require.Equal(t, []byte("still here"), msg)
@ -422,21 +422,20 @@ func TestConnection_SendReceive(t *testing.T) {
err = ba.Close()
require.NoError(t, err)
_, _, err = ab.ReceiveMessage()
_, _, err = ab.ReceiveMessage(ctx)
require.Error(t, err)
require.Equal(t, io.EOF, err)
err = ab.SendMessage(chID, []byte("closed"))
err = ab.SendMessage(ctx, chID, []byte("closed"))
require.Error(t, err)
require.Equal(t, io.EOF, err)
_, _, err = ba.ReceiveMessage()
_, _, err = ba.ReceiveMessage(ctx)
require.Error(t, err)
require.Equal(t, io.EOF, err)
err = ba.SendMessage(chID, []byte("closed"))
err = ba.SendMessage(ctx, chID, []byte("closed"))
require.Error(t, err)
require.Equal(t, io.EOF, err)
})
}
@ -606,7 +605,7 @@ func dialAccept(ctx context.Context, t *testing.T, a, b p2p.Transport) (p2p.Conn
acceptCh := make(chan p2p.Connection, 1)
errCh := make(chan error, 1)
go func() {
conn, err := b.Accept()
conn, err := b.Accept(ctx)
errCh <- err
acceptCh <- conn
}()


Loading…
Cancel
Save