diff --git a/internal/p2p/peer_test.go b/internal/p2p/peer_test.go index bfc3e32b4..dfe7bc798 100644 --- a/internal/p2p/peer_test.go +++ b/internal/p2p/peer_test.go @@ -140,6 +140,7 @@ func testOutboundPeerConn( type remotePeer struct { PrivKey crypto.PrivKey Config *config.P2PConfig + Network string addr *NetAddress channels bytes.HexBytes listenAddr string @@ -222,7 +223,7 @@ func (rp *remotePeer) accept() { } func (rp *remotePeer) nodeInfo() types.NodeInfo { - return types.NodeInfo{ + ni := types.NodeInfo{ ProtocolVersion: defaultProtocolVersion, NodeID: rp.Addr().ID, ListenAddr: rp.listener.Addr().String(), @@ -231,4 +232,8 @@ func (rp *remotePeer) nodeInfo() types.NodeInfo { Channels: rp.channels, Moniker: "remote_peer", } + if rp.Network != "" { + ni.Network = rp.Network + } + return ni } diff --git a/internal/p2p/router.go b/internal/p2p/router.go index ff4a34ccd..30c4d4fba 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -601,7 +601,6 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) { r.logger.Error("peer handshake failed", "endpoint", conn, "err", err) return } - if err := r.filterPeersID(ctx, peerInfo.NodeID); err != nil { r.logger.Debug("peer filtered by node ID", "node", peerInfo.NodeID, "err", err) return @@ -688,11 +687,17 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { return } - _, _, err = r.handshakePeer(ctx, conn, address.NodeID) + peerInfo, _, err := r.handshakePeer(ctx, conn, address.NodeID) + var errRejected ErrRejected switch { case errors.Is(err, context.Canceled): conn.Close() return + case errors.As(err, &errRejected) && errRejected.IsIncompatible(): + r.logger.Error("peer rejected due to incompatibility", "node", peerInfo.NodeID, "err", err) + r.peerManager.Errored(peerInfo.NodeID, err) + conn.Close() + return case err != nil: r.logger.Error("failed to handshake with peer", "peer", address, "err", err) if err = r.peerManager.DialFailed(address); err != nil { @@ -795,7 +800,6 @@ func (r *Router) handshakePeer( if err != nil { return peerInfo, peerKey, err } - if err = peerInfo.Validate(); err != nil { return peerInfo, peerKey, fmt.Errorf("invalid handshake NodeInfo: %w", err) } @@ -807,6 +811,13 @@ func (r *Router) handshakePeer( return peerInfo, peerKey, fmt.Errorf("expected to connect with peer %q, got %q", expectID, peerInfo.NodeID) } + if err := r.nodeInfo.CompatibleWith(peerInfo); err != nil { + return peerInfo, peerKey, ErrRejected{ + err: err, + id: peerInfo.ID(), + isIncompatible: true, + } + } return peerInfo, peerKey, nil } diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 5a1518168..6ad8542f4 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -327,6 +327,16 @@ func TestRouter_AcceptPeers(t *testing.T) { "empty handshake": {types.NodeInfo{}, nil, false}, "invalid key": {peerInfo, selfKey.PubKey(), false}, "self handshake": {selfInfo, selfKey.PubKey(), false}, + "incompatible peer": { + types.NodeInfo{ + NodeID: peerID, + ListenAddr: "0.0.0.0:0", + Network: "other-network", + Moniker: string(peerID), + }, + peerKey.PubKey(), + false, + }, } for name, tc := range testcases { tc := tc @@ -532,6 +542,18 @@ func TestRouter_DialPeers(t *testing.T) { "invalid key": {peerInfo.NodeID, peerInfo, selfKey.PubKey(), nil, false}, "unexpected node ID": {peerInfo.NodeID, selfInfo, selfKey.PubKey(), nil, false}, "dial error": {peerInfo.NodeID, peerInfo, peerKey.PubKey(), errors.New("boom"), false}, + "incompatible peer": { + peerInfo.NodeID, + types.NodeInfo{ + NodeID: peerID, + ListenAddr: "0.0.0.0:0", + Network: "other-network", + Moniker: string(peerID), + }, + peerKey.PubKey(), + nil, + false, + }, } for name, tc := range testcases { tc := tc diff --git a/internal/p2p/switch.go b/internal/p2p/switch.go index e35a307d6..eeb93a994 100644 --- a/internal/p2p/switch.go +++ b/internal/p2p/switch.go @@ -690,13 +690,16 @@ func (sw *Switch) acceptRoutine() { } switch err := err.(type) { case ErrRejected: + addr := err.Addr() if err.IsSelf() { // Remove the given address from the address book and add to our addresses // to avoid dialing in the future. - addr := err.Addr() sw.addrBook.RemoveAddress(&addr) sw.addrBook.AddOurAddress(&addr) } + if err.IsIncompatible() { + sw.addrBook.RemoveAddress(&addr) + } sw.Logger.Info( "Inbound Peer rejected", @@ -822,9 +825,12 @@ func (sw *Switch) addOutboundPeerWithConfig( // to avoid dialing in the future. sw.addrBook.RemoveAddress(addr) sw.addrBook.AddOurAddress(addr) - - return err } + if e.IsIncompatible() { + sw.addrBook.RemoveAddress(addr) + } + + return err } // retry persistent peers after diff --git a/internal/p2p/switch_test.go b/internal/p2p/switch_test.go index dd15ff30c..8cb755c9f 100644 --- a/internal/p2p/switch_test.go +++ b/internal/p2p/switch_test.go @@ -213,6 +213,26 @@ func TestSwitchFiltersOutItself(t *testing.T) { assertNoPeersAfterTimeout(t, s1, 100*time.Millisecond) } +func TestSwitchDialFailsOnIncompatiblePeer(t *testing.T) { + s1 := MakeSwitch(cfg, 1, "127.0.0.1", "123.123.123", initSwitchFunc, log.TestingLogger()) + ni := s1.NodeInfo() + ni.Network = "network-a" + s1.SetNodeInfo(ni) + + rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg, Network: "network-b"} + rp.Start() + defer rp.Stop() + + err := s1.DialPeerWithAddress(rp.Addr()) + require.Error(t, err) + errRejected, ok := err.(ErrRejected) + require.True(t, ok, "expected error to be of type IsRejected") + require.True(t, errRejected.IsIncompatible(), "expected error to be IsIncompatible") + + // remote peer should not have been added to the addressbook + require.False(t, s1.addrBook.HasAddress(rp.Addr())) +} + func TestSwitchPeerFilter(t *testing.T) { var ( filters = []PeerFilterFunc{ @@ -697,6 +717,36 @@ func TestSwitchAcceptRoutine(t *testing.T) { } } +func TestSwitchRejectsIncompatiblePeers(t *testing.T) { + sw := MakeSwitch(cfg, 1, "127.0.0.1", "123.123.123", initSwitchFunc, log.TestingLogger()) + ni := sw.NodeInfo() + ni.Network = "network-a" + sw.SetNodeInfo(ni) + + err := sw.Start() + require.NoError(t, err) + t.Cleanup(func() { + err := sw.Stop() + require.NoError(t, err) + }) + + rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg, Network: "network-b"} + rp.Start() + defer rp.Stop() + + assert.Equal(t, 0, sw.Peers().Size()) + + conn, err := rp.Dial(sw.NetAddress()) + assert.Nil(t, err) + + one := make([]byte, 1) + _ = conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + _, err = conn.Read(one) + assert.Error(t, err) + + assert.Equal(t, 0, sw.Peers().Size()) +} + type errorTransport struct { acceptErr error }