diff --git a/switch.go b/switch.go index f9a146d2e..e3016f31b 100644 --- a/switch.go +++ b/switch.go @@ -65,6 +65,9 @@ type Switch struct { dialing *CMap nodeInfo *NodeInfo // our node info nodePrivKey crypto.PrivKeyEd25519 // our node privkey + + filterConnByAddr func(net.Addr) error + filterConnByPubKey func(crypto.PubKeyEd25519) error } var ( @@ -192,6 +195,11 @@ func (sw *Switch) OnStop() { // NOTE: This performs a blocking handshake before the peer is added. // CONTRACT: Iff error is returned, peer is nil, and conn is immediately closed. func (sw *Switch) AddPeerWithConnection(conn net.Conn, outbound bool) (*Peer, error) { + // Filter by ip + if err := sw.FilterConnByAddr(conn.RemoteAddr()); err != nil { + return nil, err + } + // Set deadline for handshake so we don't block forever on conn.ReadFull conn.SetDeadline(time.Now().Add( time.Duration(sw.config.GetInt(configKeyHandshakeTimeoutSeconds)) * time.Second)) @@ -206,6 +214,12 @@ func (sw *Switch) AddPeerWithConnection(conn net.Conn, outbound bool) (*Peer, er return nil, err } } + + // Filter by p2p-key + if err := sw.FilterConnByPubKey(sconn.(*SecretConnection).RemotePubKey()); err != nil { + return nil, err + } + // Then, perform node handshake peerNodeInfo, err := peerHandshake(sconn, sw.nodeInfo) if err != nil { @@ -251,6 +265,29 @@ func (sw *Switch) AddPeerWithConnection(conn net.Conn, outbound bool) (*Peer, er return peer, nil } +func (sw *Switch) FilterConnByAddr(addr net.Addr) error { + if sw.filterConnByAddr != nil { + return sw.filterConnByAddr(addr) + } + return nil +} + +func (sw *Switch) FilterConnByPubKey(pubkey crypto.PubKeyEd25519) error { + if sw.filterConnByPubKey != nil { + return sw.filterConnByPubKey(pubkey) + } + return nil + +} + +func (sw *Switch) SetAddrFilter(f func(net.Addr) error) { + sw.filterConnByAddr = f +} + +func (sw *Switch) SetPubKeyFilter(f func(crypto.PubKeyEd25519) error) { + sw.filterConnByPubKey = f +} + func (sw *Switch) startInitPeer(peer *Peer) { peer.Start() // spawn send/recv routines sw.addPeerToReactors(peer) // run AddPeer on each reactor diff --git a/switch_test.go b/switch_test.go index 54d6f7527..f77682b66 100644 --- a/switch_test.go +++ b/switch_test.go @@ -2,6 +2,7 @@ package p2p import ( "bytes" + "fmt" "net" "sync" "testing" @@ -9,6 +10,7 @@ import ( . "github.com/tendermint/go-common" cfg "github.com/tendermint/go-config" + "github.com/tendermint/go-crypto" "github.com/tendermint/go-wire" ) @@ -92,23 +94,24 @@ func makeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switc return switches[0], switches[1] } +func initSwitchFunc(i int, sw *Switch) *Switch { + // Make two reactors of two channels each + sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{ + &ChannelDescriptor{ID: byte(0x00), Priority: 10}, + &ChannelDescriptor{ID: byte(0x01), Priority: 10}, + }, true)) + sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{ + &ChannelDescriptor{ID: byte(0x02), Priority: 10}, + &ChannelDescriptor{ID: byte(0x03), Priority: 10}, + }, true)) + return sw +} + func TestSwitches(t *testing.T) { - s1, s2 := makeSwitchPair(t, func(i int, sw *Switch) *Switch { - // Make two reactors of two channels each - sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{ - &ChannelDescriptor{ID: byte(0x00), Priority: 10}, - &ChannelDescriptor{ID: byte(0x01), Priority: 10}, - }, true)) - sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{ - &ChannelDescriptor{ID: byte(0x02), Priority: 10}, - &ChannelDescriptor{ID: byte(0x03), Priority: 10}, - }, true)) - return sw - }) + s1, s2 := makeSwitchPair(t, initSwitchFunc) defer s1.Stop() defer s2.Stop() - // Lets send a message from s1 to s2. if s1.Peers().Size() != 1 { t.Errorf("Expected exactly 1 peer in s1, got %v", s1.Peers().Size()) } @@ -116,6 +119,7 @@ func TestSwitches(t *testing.T) { t.Errorf("Expected exactly 1 peer in s2, got %v", s2.Peers().Size()) } + // Lets send some messages ch0Msg := "channel zero" ch1Msg := "channel foo" ch2Msg := "channel bar" @@ -156,6 +160,67 @@ func TestSwitches(t *testing.T) { } +func TestConnAddrFilter(t *testing.T) { + s1 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) + s2 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) + + c1, c2 := net.Pipe() + + s1.SetAddrFilter(func(addr net.Addr) error { + if addr.String() == c1.RemoteAddr().String() { + return fmt.Errorf("Error: pipe is blacklisted") + } + return nil + }) + + // connect to good peer + go s1.AddPeerWithConnection(c1, false) // AddPeer is blocking, requires handshake. + go s2.AddPeerWithConnection(c2, true) + + // Wait for things to happen, peers to get added... + time.Sleep(100 * time.Millisecond * time.Duration(4)) + + defer s1.Stop() + defer s2.Stop() + if s1.Peers().Size() != 0 { + t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size()) + } + if s2.Peers().Size() != 0 { + t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size()) + } +} + +func TestConnPubKeyFilter(t *testing.T) { + s1 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) + s2 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) + + c1, c2 := net.Pipe() + + // set pubkey filter + s1.SetPubKeyFilter(func(pubkey crypto.PubKeyEd25519) error { + if bytes.Equal(pubkey.Bytes(), s2.nodeInfo.PubKey.Bytes()) { + return fmt.Errorf("Error: pipe is blacklisted") + } + return nil + }) + + // connect to good peer + go s1.AddPeerWithConnection(c1, false) // AddPeer is blocking, requires handshake. + go s2.AddPeerWithConnection(c2, true) + + // Wait for things to happen, peers to get added... + time.Sleep(100 * time.Millisecond * time.Duration(4)) + + defer s1.Stop() + defer s2.Stop() + if s1.Peers().Size() != 0 { + t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size()) + } + if s2.Peers().Size() != 0 { + t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size()) + } +} + func BenchmarkSwitches(b *testing.B) { b.StopTimer() diff --git a/version.go b/version.go index e9c345cbe..8608f2757 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package p2p -const Version = "0.3.3" // fuzz conn +const Version = "0.3.4" // filter by addr or pubkey