diff --git a/p2p/switch.go b/p2p/switch.go index 4a5191c78..ed5c8b720 100644 --- a/p2p/switch.go +++ b/p2p/switch.go @@ -320,6 +320,9 @@ func (sw *Switch) Peers() IPeerSet { // TODO: make record depending on reason. func (sw *Switch) StopPeerForError(peer Peer, reason interface{}) { sw.Logger.Error("Stopping peer for error", "peer", peer, "err", reason) + if peer == nil { + return + } sw.stopAndRemovePeer(peer, reason) if peer.IsPersistent() { diff --git a/p2p/switch_test.go b/p2p/switch_test.go index 50e0adb3a..4d6ea16fc 100644 --- a/p2p/switch_test.go +++ b/p2p/switch_test.go @@ -706,11 +706,15 @@ func TestSwitchInitPeerIsNotCalledBeforeRemovePeer(t *testing.T) { defer rp.Stop() _, err = rp.Dial(sw.NetAddress()) require.NoError(t, err) - // wait till the switch adds rp to the peer set - time.Sleep(50 * time.Millisecond) - // stop peer asynchronously - go sw.StopPeerForError(sw.Peers().Get(rp.ID()), "test") + // wait till the switch adds rp to the peer set, then stop the peer asynchronously + for { + time.Sleep(20 * time.Millisecond) + if peer := sw.Peers().Get(rp.ID()); peer != nil { + go sw.StopPeerForError(peer, "test") + break + } + } // simulate peer reconnecting to us _, err = rp.Dial(sw.NetAddress())