diff --git a/rpc/jsonrpc/client/http_json_client.go b/rpc/jsonrpc/client/http_json_client.go index 59727390a..fbfe073e0 100644 --- a/rpc/jsonrpc/client/http_json_client.go +++ b/rpc/jsonrpc/client/http_json_client.go @@ -21,6 +21,7 @@ const ( protoWSS = "wss" protoWS = "ws" protoTCP = "tcp" + protoUNIX = "unix" ) //------------------------------------------------------------- @@ -28,6 +29,8 @@ const ( // Parsed URL structure type parsedURL struct { url.URL + + isUnixSocket bool } // Parse URL and set defaults @@ -42,7 +45,16 @@ func newParsedURL(remoteAddr string) (*parsedURL, error) { u.Scheme = protoTCP } - return &parsedURL{*u}, nil + pu := &parsedURL{ + URL: *u, + isUnixSocket: false, + } + + if u.Scheme == protoUNIX { + pu.isUnixSocket = true + } + + return pu, nil } // Change protocol to HTTP for unknown protocols and TCP protocol - useful for RPC connections @@ -65,10 +77,26 @@ func (u parsedURL) GetHostWithPath() string { // Get a trimmed address - useful for WS connections func (u parsedURL) GetTrimmedHostWithPath() string { - // replace / with . for http requests (kvstore domain) + // if it's not an unix socket we return the normal URL + if !u.isUnixSocket { + return u.GetHostWithPath() + } + // if it's a unix socket we replace the host slashes with a period + // this is because otherwise the http.Client would think that the + // domain is invalid. return strings.ReplaceAll(u.GetHostWithPath(), "/", ".") } +// GetDialAddress returns the endpoint to dial for the parsed URL +func (u parsedURL) GetDialAddress() string { + // if it's not a unix socket we return the host, example: localhost:443 + if !u.isUnixSocket { + return u.Host + } + // otherwise we return the path of the unix socket, ex /tmp/socket + return u.GetHostWithPath() +} + // Get a trimmed address with protocol - useful as address in RPC connections func (u parsedURL) GetTrimmedURL() string { return u.Scheme + "://" + u.GetTrimmedHostWithPath() @@ -350,7 +378,7 @@ func makeHTTPDialer(remoteAddr string) (func(string, string) (net.Conn, error), } dialFn := func(proto, addr string) (net.Conn, error) { - return net.Dial(protocol, u.GetHostWithPath()) + return net.Dial(protocol, u.GetDialAddress()) } return dialFn, nil diff --git a/rpc/jsonrpc/client/http_json_client_test.go b/rpc/jsonrpc/client/http_json_client_test.go index 5c8ef1a25..4b82ff1eb 100644 --- a/rpc/jsonrpc/client/http_json_client_test.go +++ b/rpc/jsonrpc/client/http_json_client_test.go @@ -34,3 +34,53 @@ func TestHTTPClientMakeHTTPDialer(t *testing.T) { require.NotNil(t, addr) } } + +func Test_parsedURL(t *testing.T) { + type test struct { + url string + expectedURL string + expectedHostWithPath string + expectedDialAddress string + } + + tests := map[string]test{ + "unix endpoint": { + url: "unix:///tmp/test", + expectedURL: "unix://.tmp.test", + expectedHostWithPath: "/tmp/test", + expectedDialAddress: "/tmp/test", + }, + + "http endpoint": { + url: "https://example.com", + expectedURL: "https://example.com", + expectedHostWithPath: "example.com", + expectedDialAddress: "example.com", + }, + + "http endpoint with port": { + url: "https://example.com:8080", + expectedURL: "https://example.com:8080", + expectedHostWithPath: "example.com:8080", + expectedDialAddress: "example.com:8080", + }, + + "http path routed endpoint": { + url: "https://example.com:8080/rpc", + expectedURL: "https://example.com:8080/rpc", + expectedHostWithPath: "example.com:8080/rpc", + expectedDialAddress: "example.com:8080", + }, + } + + for name, tt := range tests { + tt := tt // suppressing linter + t.Run(name, func(t *testing.T) { + parsed, err := newParsedURL(tt.url) + require.NoError(t, err) + require.Equal(t, tt.expectedDialAddress, parsed.GetDialAddress()) + require.Equal(t, tt.expectedURL, parsed.GetTrimmedURL()) + require.Equal(t, tt.expectedHostWithPath, parsed.GetHostWithPath()) + }) + } +}