diff --git a/common/net.go b/common/net.go index 2f9c9c8c2..bdbe38f79 100644 --- a/common/net.go +++ b/common/net.go @@ -5,10 +5,22 @@ import ( "strings" ) -// protoAddr: e.g. "tcp://127.0.0.1:8080" or "unix:///tmp/test.sock" +// Connect dials the given address and returns a net.Conn. The protoAddr argument should be prefixed with the protocol, +// eg. "tcp://127.0.0.1:8080" or "unix:///tmp/test.sock" func Connect(protoAddr string) (net.Conn, error) { - parts := strings.SplitN(protoAddr, "://", 2) - proto, address := parts[0], parts[1] + proto, address := ProtocolAndAddress(protoAddr) conn, err := net.Dial(proto, address) return conn, err } + +// ProtocolAndAddress splits an address into the protocol and address components. +// For instance, "tcp://127.0.0.1:8080" will be split into "tcp" and "127.0.0.1:8080". +// If the address has no protocol prefix, the default is "tcp". +func ProtocolAndAddress(listenAddr string) (string, string) { + protocol, address := "tcp", listenAddr + parts := strings.SplitN(address, "://", 2) + if len(parts) == 2 { + protocol, address = parts[0], parts[1] + } + return protocol, address +} diff --git a/common/net_test.go b/common/net_test.go new file mode 100644 index 000000000..38d2ae82d --- /dev/null +++ b/common/net_test.go @@ -0,0 +1,38 @@ +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProtocolAndAddress(t *testing.T) { + + cases := []struct { + fullAddr string + proto string + addr string + }{ + { + "tcp://mydomain:80", + "tcp", + "mydomain:80", + }, + { + "mydomain:80", + "tcp", + "mydomain:80", + }, + { + "unix://mydomain:80", + "unix", + "mydomain:80", + }, + } + + for _, c := range cases { + proto, addr := ProtocolAndAddress(c.fullAddr) + assert.Equal(t, proto, c.proto) + assert.Equal(t, addr, c.addr) + } +}