diff --git a/p2p/netaddress.go b/p2p/netaddress.go index 9cb7dd2c3..41c2cc976 100644 --- a/p2p/netaddress.go +++ b/p2p/netaddress.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "strconv" + "strings" "time" cmn "github.com/tendermint/tmlibs/common" @@ -45,7 +46,7 @@ func NewNetAddress(addr net.Addr) *NetAddress { // address in the form of "IP:Port". Also resolves the host if host // is not an IP. func NewNetAddressString(addr string) (*NetAddress, error) { - host, portStr, err := net.SplitHostPort(addr) + host, portStr, err := net.SplitHostPort(removeProtocolIfDefined(addr)) if err != nil { return nil, err } @@ -251,3 +252,11 @@ func (na *NetAddress) RFC4843() bool { return rfc4843.Contains(na.IP) } func (na *NetAddress) RFC4862() bool { return rfc4862.Contains(na.IP) } func (na *NetAddress) RFC6052() bool { return rfc6052.Contains(na.IP) } func (na *NetAddress) RFC6145() bool { return rfc6145.Contains(na.IP) } + +func removeProtocolIfDefined(addr string) string { + if strings.Contains(addr, "://") { + return strings.Split(addr, "://")[1] + } else { + return addr + } +} diff --git a/p2p/netaddress_test.go b/p2p/netaddress_test.go index db6147500..137be090c 100644 --- a/p2p/netaddress_test.go +++ b/p2p/netaddress_test.go @@ -23,29 +23,31 @@ func TestNewNetAddress(t *testing.T) { } func TestNewNetAddressString(t *testing.T) { - assert := assert.New(t) - - tests := []struct { - addr string - correct bool + testCases := []struct { + addr string + expected string + correct bool }{ - {"127.0.0.1:8080", true}, + {"127.0.0.1:8080", "127.0.0.1:8080", true}, + {"tcp://127.0.0.1:8080", "127.0.0.1:8080", true}, + {"udp://127.0.0.1:8080", "127.0.0.1:8080", true}, + {"udp//127.0.0.1:8080", "", false}, // {"127.0.0:8080", false}, - {"notahost", false}, - {"127.0.0.1:notapath", false}, - {"notahost:8080", false}, - {"8082", false}, - {"127.0.0:8080000", false}, + {"notahost", "", false}, + {"127.0.0.1:notapath", "", false}, + {"notahost:8080", "", false}, + {"8082", "", false}, + {"127.0.0:8080000", "", false}, } - for _, t := range tests { - addr, err := NewNetAddressString(t.addr) - if t.correct { - if assert.Nil(err, t.addr) { - assert.Equal(t.addr, addr.String()) + for _, tc := range testCases { + addr, err := NewNetAddressString(tc.addr) + if tc.correct { + if assert.Nil(t, err, tc.addr) { + assert.Equal(t, tc.expected, addr.String()) } } else { - assert.NotNil(err, t.addr) + assert.NotNil(t, err, tc.addr) } } }