From c769e3e09b6b56af639805a4eb22fb901957196a Mon Sep 17 00:00:00 2001 From: Callum Waters Date: Mon, 5 Jul 2021 13:00:19 +0200 Subject: [PATCH] p2p: track peer channels to avoid sending across a channel a peer doesn't have (#6601) --- internal/p2p/p2p_test.go | 2 + internal/p2p/p2ptest/network.go | 1 + internal/p2p/router.go | 67 ++++++++++++---- internal/p2p/router_test.go | 131 ++++++++++++++++++++++++++++++-- types/node_info.go | 30 +++++++- types/node_info_test.go | 14 ++++ 6 files changed, 220 insertions(+), 25 deletions(-) diff --git a/internal/p2p/p2p_test.go b/internal/p2p/p2p_test.go index 20d157668..6e524d492 100644 --- a/internal/p2p/p2p_test.go +++ b/internal/p2p/p2p_test.go @@ -29,6 +29,7 @@ var ( ListenAddr: "0.0.0.0:0", Network: "test", Moniker: string(selfID), + Channels: []byte{0x01, 0x02}, } peerKey crypto.PrivKey = ed25519.GenPrivKeyFromSecret([]byte{0x84, 0xd7, 0x01, 0xbf, 0x83, 0x20, 0x1c, 0xfe}) @@ -38,5 +39,6 @@ var ( ListenAddr: "0.0.0.0:0", Network: "test", Moniker: string(peerID), + Channels: []byte{0x01, 0x02}, } ) diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index 258d218f1..1daba3f14 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -285,6 +285,7 @@ func (n *Node) MakeChannel(t *testing.T, chDesc p2p.ChannelDescriptor, messageType proto.Message, size int) *p2p.Channel { channel, err := n.Router.OpenChannel(chDesc, messageType, size) require.NoError(t, err) + require.Contains(t, n.Router.NodeInfo().Channels, chDesc.ID) t.Cleanup(func() { RequireEmpty(t, channel) channel.Close() diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 30c4d4fba..7b3b2505c 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -257,8 +257,10 @@ type Router struct { protocolTransports map[Protocol]Transport stopCh chan struct{} // signals Router shutdown - peerMtx sync.RWMutex - peerQueues map[types.NodeID]queue // outbound messages per peer for all channels + peerMtx sync.RWMutex + peerQueues map[types.NodeID]queue // outbound messages per peer for all channels + // the channels that the peer queue has open + peerChannels map[types.NodeID]channelIDs queueFactory func(int) queue // FIXME: We don't strictly need to use a mutex for this if we seal the @@ -304,6 +306,7 @@ func NewRouter( channelQueues: map[ChannelID]queue{}, channelMessages: map[ChannelID]proto.Message{}, peerQueues: map[types.NodeID]queue{}, + peerChannels: make(map[types.NodeID]channelIDs), } router.BaseService = service.NewBaseService(logger, "router", router) @@ -387,6 +390,9 @@ func (r *Router) OpenChannel(chDesc ChannelDescriptor, messageType proto.Message r.channelQueues[id] = queue r.channelMessages[id] = messageType + // add the channel to the nodeInfo if it's not already there. + r.nodeInfo.AddChannel(uint16(chDesc.ID)) + go func() { defer func() { r.channelMtx.Lock() @@ -441,14 +447,27 @@ func (r *Router) routeChannel( r.peerMtx.RLock() queues = make([]queue, 0, len(r.peerQueues)) - for _, q := range r.peerQueues { - queues = append(queues, q) + for nodeID, q := range r.peerQueues { + peerChs := r.peerChannels[nodeID] + + // check whether the peer is receiving on that channel + if _, ok := peerChs[chID]; ok { + queues = append(queues, q) + } } r.peerMtx.RUnlock() } else { r.peerMtx.RLock() + q, ok := r.peerQueues[envelope.To] + contains := false + if ok { + peerChs := r.peerChannels[envelope.To] + + // check whether the peer is receiving on that channel + _, contains = peerChs[chID] + } r.peerMtx.RUnlock() if !ok { @@ -456,6 +475,12 @@ func (r *Router) routeChannel( continue } + if !contains { + r.logger.Error("tried to send message across a channel that the peer doesn't have available", + "peer", envelope.To, "channel", chID) + continue + } + queues = []queue{q} } @@ -612,7 +637,7 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) { return } - r.routePeer(peerInfo.NodeID, conn) + r.routePeer(peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels)) } // dialPeers maintains outbound connections to peers by dialing them. @@ -688,16 +713,10 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { } peerInfo, _, err := r.handshakePeer(ctx, conn, address.NodeID) - var errRejected ErrRejected switch { case errors.Is(err, context.Canceled): conn.Close() return - case errors.As(err, &errRejected) && errRejected.IsIncompatible(): - r.logger.Error("peer rejected due to incompatibility", "node", peerInfo.NodeID, "err", err) - r.peerManager.Errored(peerInfo.NodeID, err) - conn.Close() - return case err != nil: r.logger.Error("failed to handshake with peer", "peer", address, "err", err) if err = r.peerManager.DialFailed(address); err != nil { @@ -712,14 +731,13 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { "op", "outgoing/dialing", "peer", address.NodeID, "err", err) conn.Close() return - } // routePeer (also) calls connection close - go r.routePeer(address.NodeID, conn) + go r.routePeer(address.NodeID, conn, toChannelIDs(peerInfo.Channels)) } -func (r *Router) getOrMakeQueue(peerID types.NodeID) queue { +func (r *Router) getOrMakeQueue(peerID types.NodeID, channels channelIDs) queue { r.peerMtx.Lock() defer r.peerMtx.Unlock() @@ -729,6 +747,7 @@ func (r *Router) getOrMakeQueue(peerID types.NodeID) queue { peerQueue := r.queueFactory(queueBufferDefault) r.peerQueues[peerID] = peerQueue + r.peerChannels[peerID] = channels return peerQueue } @@ -830,14 +849,15 @@ func (r *Router) runWithPeerMutex(fn func() error) error { // routePeer routes inbound and outbound messages between a peer and the reactor // channels. It will close the given connection and send queue when done, or if // they are closed elsewhere it will cause this method to shut down and return. -func (r *Router) routePeer(peerID types.NodeID, conn Connection) { +func (r *Router) routePeer(peerID types.NodeID, conn Connection, channels channelIDs) { r.metrics.Peers.Add(1) r.peerManager.Ready(peerID) - sendQueue := r.getOrMakeQueue(peerID) + sendQueue := r.getOrMakeQueue(peerID, channels) defer func() { r.peerMtx.Lock() delete(r.peerQueues, peerID) + delete(r.peerChannels, peerID) r.peerMtx.Unlock() sendQueue.close() @@ -994,6 +1014,11 @@ func (r *Router) evictPeers() { } } +// NodeInfo returns a copy of the current NodeInfo. Used for testing. +func (r *Router) NodeInfo() types.NodeInfo { + return r.nodeInfo.Copy() +} + // OnStart implements service.Service. func (r *Router) OnStart() error { go r.dialPeers() @@ -1054,3 +1079,13 @@ func (r *Router) stopCtx() context.Context { return ctx } + +type channelIDs map[ChannelID]struct{} + +func toChannelIDs(bytes []byte) channelIDs { + c := make(map[ChannelID]struct{}, len(bytes)) + for _, b := range bytes { + c[ChannelID(b)] = struct{}{} + } + return c +} diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 6ad8542f4..436e3f004 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -48,12 +48,12 @@ func TestRouter_Network(t *testing.T) { // Create a test network and open a channel where all peers run echoReactor. network := p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: 8}) - network.Start(t) - local := network.RandomNode() peers := network.Peers(local.NodeID) channels := network.MakeChannels(t, chDesc, &p2ptest.Message{}, 0) + network.Start(t) + channel := channels[local.NodeID] for _, peer := range peers { go echoReactor(channels[peer.NodeID]) @@ -94,7 +94,7 @@ func TestRouter_Network(t *testing.T) { }) } -func TestRouter_Channel(t *testing.T) { +func TestRouter_Channel_Basic(t *testing.T) { t.Cleanup(leaktest.Check(t)) // Set up a router with no transports (so no peers). @@ -121,6 +121,7 @@ func TestRouter_Channel(t *testing.T) { // Opening a channel should work. channel, err := router.OpenChannel(chDesc, &p2ptest.Message{}, 0) require.NoError(t, err) + require.Contains(t, router.NodeInfo().Channels, chDesc.ID) // Opening the same channel again should fail. _, err = router.OpenChannel(chDesc, &p2ptest.Message{}, 0) @@ -130,6 +131,7 @@ func TestRouter_Channel(t *testing.T) { chDesc2 := p2p.ChannelDescriptor{ID: byte(2)} _, err = router.OpenChannel(chDesc2, &p2ptest.Message{}, 0) require.NoError(t, err) + require.Contains(t, router.NodeInfo().Channels, chDesc2.ID) // Closing the channel, then opening it again should be fine. channel.Close() @@ -158,7 +160,6 @@ func TestRouter_Channel_SendReceive(t *testing.T) { // Create a test network and open a channel on all nodes. network := p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: 3}) - network.Start(t) ids := network.NodeIDs() aID, bID, cID := ids[0], ids[1], ids[2] @@ -166,13 +167,15 @@ func TestRouter_Channel_SendReceive(t *testing.T) { a, b, c := channels[aID], channels[bID], channels[cID] otherChannels := network.MakeChannels(t, p2ptest.MakeChannelDesc(9), &p2ptest.Message{}, 0) + network.Start(t) + // Sending a message a->b should work, and not send anything // further to a, b, or c. p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "foo"}}) p2ptest.RequireEmpty(t, a, b, c) - // Sending a nil message a->c should be dropped. + // Sending a nil message a->b should be dropped. p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: nil}) p2ptest.RequireEmpty(t, a, b, c) @@ -216,13 +219,14 @@ func TestRouter_Channel_Broadcast(t *testing.T) { // Create a test network and open a channel on all nodes. network := p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: 4}) - network.Start(t) ids := network.NodeIDs() aID, bID, cID, dID := ids[0], ids[1], ids[2], ids[3] channels := network.MakeChannels(t, chDesc, &p2ptest.Message{}, 0) a, b, c, d := channels[aID], channels[bID], channels[cID], channels[dID] + network.Start(t) + // Sending a broadcast from b should work. p2ptest.RequireSend(t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}}) p2ptest.RequireReceive(t, a, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) @@ -243,13 +247,14 @@ func TestRouter_Channel_Wrapper(t *testing.T) { // Create a test network and open a channel on all nodes. network := p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: 2}) - network.Start(t) ids := network.NodeIDs() aID, bID := ids[0], ids[1] channels := network.MakeChannels(t, chDesc, &wrapperMessage{}, 0) a, b := channels[aID], channels[bID] + network.Start(t) + // Since wrapperMessage implements p2p.Wrapper and handles Message, it // should automatically wrap and unwrap sent messages -- we prepend the // wrapper actions to the message value to signal this. @@ -790,3 +795,115 @@ func TestRouter_EvictPeers(t *testing.T) { mockTransport.AssertExpectations(t) mockConnection.AssertExpectations(t) } + +func TestRouter_ChannelCompatability(t *testing.T) { + t.Cleanup(leaktest.Check(t)) + + incompatiblePeer := types.NodeInfo{ + NodeID: peerID, + ListenAddr: "0.0.0.0:0", + Network: "test", + Moniker: string(peerID), + Channels: []byte{0x03}, + } + + mockConnection := &mocks.Connection{} + mockConnection.On("String").Maybe().Return("mock") + mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + Return(incompatiblePeer, peerKey.PubKey(), nil) + mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) + mockConnection.On("Close").Return(nil) + + mockTransport := &mocks.Transport{} + mockTransport.On("String").Maybe().Return("mock") + mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"}) + mockTransport.On("Close").Return(nil) + mockTransport.On("Accept").Once().Return(mockConnection, nil) + mockTransport.On("Accept").Once().Return(nil, io.EOF) + + // Set up and start the router. + peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) + require.NoError(t, err) + defer peerManager.Close() + + router, err := p2p.NewRouter( + log.TestingLogger(), + p2p.NopMetrics(), + selfInfo, + selfKey, + peerManager, + []p2p.Transport{mockTransport}, + p2p.RouterOptions{}, + ) + require.NoError(t, err) + require.NoError(t, router.Start()) + time.Sleep(1 * time.Second) + require.NoError(t, router.Stop()) + require.Empty(t, peerManager.Peers()) + + mockConnection.AssertExpectations(t) + mockTransport.AssertExpectations(t) +} + +func TestRouter_DontSendOnInvalidChannel(t *testing.T) { + t.Cleanup(leaktest.Check(t)) + + peer := types.NodeInfo{ + NodeID: peerID, + ListenAddr: "0.0.0.0:0", + Network: "test", + Moniker: string(peerID), + Channels: []byte{0x02}, + } + + mockConnection := &mocks.Connection{} + mockConnection.On("String").Maybe().Return("mock") + mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + Return(peer, peerKey.PubKey(), nil) + mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) + mockConnection.On("Close").Return(nil) + mockConnection.On("ReceiveMessage").Return(chID, nil, io.EOF) + + mockTransport := &mocks.Transport{} + mockTransport.On("String").Maybe().Return("mock") + mockTransport.On("Protocols").Return([]p2p.Protocol{"mock"}) + mockTransport.On("Close").Return(nil) + mockTransport.On("Accept").Once().Return(mockConnection, nil) + mockTransport.On("Accept").Once().Return(nil, io.EOF) + + // Set up and start the router. + peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) + require.NoError(t, err) + defer peerManager.Close() + + sub := peerManager.Subscribe() + defer sub.Close() + + router, err := p2p.NewRouter( + log.TestingLogger(), + p2p.NopMetrics(), + selfInfo, + selfKey, + peerManager, + []p2p.Transport{mockTransport}, + p2p.RouterOptions{}, + ) + require.NoError(t, err) + require.NoError(t, router.Start()) + + p2ptest.RequireUpdate(t, sub, p2p.PeerUpdate{ + NodeID: peerInfo.NodeID, + Status: p2p.PeerStatusUp, + }) + + channel, err := router.OpenChannel(chDesc, &p2ptest.Message{}, 0) + require.NoError(t, err) + + channel.Out <- p2p.Envelope{ + To: peer.NodeID, + Message: &p2ptest.Message{Value: "Hi"}, + } + + require.NoError(t, router.Stop()) + mockTransport.AssertExpectations(t) +} diff --git a/types/node_info.go b/types/node_info.go index 226558eec..9dbdbf70d 100644 --- a/types/node_info.go +++ b/types/node_info.go @@ -39,8 +39,9 @@ type NodeInfo struct { // Check compatibility. // Channels are HexBytes so easier to read as JSON - Network string `json:"network"` // network/chain ID - Version string `json:"version"` // major.minor.revision + Network string `json:"network"` // network/chain ID + Version string `json:"version"` // major.minor.revision + // FIXME: This should be changed to uint16 to be consistent with the updated channel type Channels bytes.HexBytes `json:"channels"` // channels this node knows about // ASCIIText fields @@ -171,6 +172,31 @@ func (info NodeInfo) NetAddress() (*NetAddress, error) { return NewNetAddressString(idAddr) } +// AddChannel is used by the router when a channel is opened to add it to the node info +func (info *NodeInfo) AddChannel(channel uint16) { + // check that the channel doesn't already exist + for _, ch := range info.Channels { + if ch == byte(channel) { + return + } + } + + info.Channels = append(info.Channels, byte(channel)) +} + +func (info NodeInfo) Copy() NodeInfo { + return NodeInfo{ + ProtocolVersion: info.ProtocolVersion, + NodeID: info.NodeID, + ListenAddr: info.ListenAddr, + Network: info.Network, + Version: info.Version, + Channels: info.Channels, + Moniker: info.Moniker, + Other: info.Other, + } +} + func (info NodeInfo) ToProto() *tmp2p.NodeInfo { dni := new(tmp2p.NodeInfo) diff --git a/types/node_info_test.go b/types/node_info_test.go index f2663558b..812cec184 100644 --- a/types/node_info_test.go +++ b/types/node_info_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/crypto/ed25519" tmnet "github.com/tendermint/tendermint/libs/net" "github.com/tendermint/tendermint/version" @@ -159,3 +160,16 @@ func TestNodeInfoCompatible(t *testing.T) { assert.Error(t, ni1.CompatibleWith(ni)) } } + +func TestNodeInfoAddChannel(t *testing.T) { + nodeInfo := testNodeInfo(testNodeID(), "testing") + nodeInfo.Channels = []byte{} + require.Empty(t, nodeInfo.Channels) + + nodeInfo.AddChannel(2) + require.Contains(t, nodeInfo.Channels, byte(0x02)) + + // adding the same channel again shouldn't be a problem + nodeInfo.AddChannel(2) + require.Contains(t, nodeInfo.Channels, byte(0x02)) +}