Browse Source

support unix domain websockets

pull/456/head
Ethan Buchman 8 years ago
parent
commit
1410693eae
4 changed files with 70 additions and 18 deletions
  1. +6
    -12
      client/http_client.go
  2. +11
    -5
      client/ws_client.go
  3. +52
    -0
      rpc_test.go
  4. +1
    -1
      types/types.go

+ 6
- 12
client/http_client.go View File

@ -16,24 +16,18 @@ import (
// Set the net.Dial manually so we can do http over tcp or unix.
// Get/Post require a dummyDomain but it's over written by the Transport
var dummyDomain = "http://dummyDomain/"
var dummyDomain = "http://dummyDomain"
func dialFunc(sockType, remote string) func(string, string) (net.Conn, error) {
func dialer(remote string) func(string, string) (net.Conn, error) {
return func(proto, addr string) (conn net.Conn, err error) {
return net.Dial(sockType, remote)
return net.Dial(rpctypes.SocketType(remote), remote)
}
}
// remote is IP:PORT or /path/to/socket
func socketTransport(remote string) *http.Transport {
if rpctypes.SocketType(remote) == "unix" {
return &http.Transport{
Dial: dialFunc("unix", remote),
}
} else {
return &http.Transport{
Dial: dialFunc("tcp", remote),
}
return &http.Transport{
Dial: dialer(remote),
}
}
@ -105,7 +99,7 @@ func (c *ClientURI) call(method string, params map[string]interface{}, result in
return nil, err
}
log.Info(Fmt("URI request to %v (%v): %v", c.remote, method, values))
resp, err := c.client.PostForm(dummyDomain+method, values)
resp, err := c.client.PostForm(dummyDomain+"/"+method, values)
if err != nil {
return nil, err
}


+ 11
- 5
client/ws_client.go View File

@ -19,16 +19,18 @@ const (
type WSClient struct {
QuitService
Address string
Address string // IP:PORT or /path/to/socket
Endpoint string // /websocket/url/endpoint
*websocket.Conn
ResultsCh chan json.RawMessage // closes upon WSClient.Stop()
ErrorsCh chan error // closes upon WSClient.Stop()
}
// create a new connection
func NewWSClient(addr string) *WSClient {
func NewWSClient(addr, endpoint string) *WSClient {
wsClient := &WSClient{
Address: addr,
Endpoint: endpoint,
Conn: nil,
ResultsCh: make(chan json.RawMessage, wsResultsChannelCapacity),
ErrorsCh: make(chan error, wsErrorsChannelCapacity),
@ -38,7 +40,7 @@ func NewWSClient(addr string) *WSClient {
}
func (wsc *WSClient) String() string {
return wsc.Address
return wsc.Address + ", " + wsc.Endpoint
}
func (wsc *WSClient) OnStart() error {
@ -52,10 +54,14 @@ func (wsc *WSClient) OnStart() error {
}
func (wsc *WSClient) dial() error {
// Dial
dialer := websocket.DefaultDialer
dialer := &websocket.Dialer{
NetDial: dialer(wsc.Address),
Proxy: http.ProxyFromEnvironment,
}
rHeader := http.Header{}
con, _, err := dialer.Dial(wsc.Address, rHeader)
con, _, err := dialer.Dial("ws://"+dummyDomain+wsc.Endpoint, rHeader)
if err != nil {
return err
}


+ 52
- 0
rpc_test.go View File

@ -7,6 +7,7 @@ import (
"github.com/tendermint/go-rpc/client"
"github.com/tendermint/go-rpc/server"
"github.com/tendermint/go-rpc/types"
"github.com/tendermint/go-wire"
)
@ -14,6 +15,8 @@ import (
var (
tcpAddr = "0.0.0.0:46657"
unixAddr = "/tmp/go-rpc.sock" // NOTE: must remove file for test to run again
websocketEndpoint = "/websocket/endpoint"
)
// Define a type for results and register concrete versions
@ -42,6 +45,8 @@ func StatusResult(v string) (Result, error) {
func init() {
mux := http.NewServeMux()
rpcserver.RegisterRPCFuncs(mux, Routes)
wm := rpcserver.NewWebsocketManager(Routes, nil)
mux.HandleFunc(websocketEndpoint, wm.WebsocketHandler)
go func() {
_, err := rpcserver.StartHTTPServer(tcpAddr, mux)
if err != nil {
@ -51,6 +56,8 @@ func init() {
mux = http.NewServeMux()
rpcserver.RegisterRPCFuncs(mux, Routes)
wm = rpcserver.NewWebsocketManager(Routes, nil)
mux.HandleFunc(websocketEndpoint, wm.WebsocketHandler)
go func() {
_, err := rpcserver.StartHTTPServer(unixAddr, mux)
if err != nil {
@ -93,6 +100,33 @@ func testJSONRPC(t *testing.T, cl *rpcclient.ClientJSONRPC) {
}
}
func testWS(t *testing.T, cl *rpcclient.WSClient) {
val := "acbd"
params := []interface{}{val}
err := cl.WriteJSON(rpctypes.RPCRequest{
JSONRPC: "2.0",
ID: "",
Method: "status",
Params: params,
})
if err != nil {
t.Fatal(err)
}
msg := <-cl.ResultsCh
result := new(Result)
wire.ReadJSONPtr(result, msg, &err)
if err != nil {
t.Fatal(err)
}
got := (*result).(*ResultStatus).Value
if got != val {
t.Fatalf("Got: %v .... Expected: %v \n", got, val)
}
}
//-------------
func TestURI_TCP(t *testing.T) {
cl := rpcclient.NewClientURI(tcpAddr)
testURI(t, cl)
@ -112,3 +146,21 @@ func TestJSONRPC_UNIX(t *testing.T) {
cl := rpcclient.NewClientJSONRPC(unixAddr)
testJSONRPC(t, cl)
}
func TestWS_TCP(t *testing.T) {
cl := rpcclient.NewWSClient(tcpAddr, websocketEndpoint)
_, err := cl.Start()
if err != nil {
t.Fatal(err)
}
testWS(t, cl)
}
func TestWS_UNIX(t *testing.T) {
cl := rpcclient.NewWSClient(unixAddr, websocketEndpoint)
_, err := cl.Start()
if err != nil {
t.Fatal(err)
}
testWS(t, cl)
}

+ 1
- 1
types/types.go View File

@ -86,7 +86,7 @@ type WSRPCContext struct {
// If tcp, must specify the port; `0.0.0.0` will return incorrectly as "unix" since there's no port
func SocketType(listenAddr string) string {
socketType := "unix"
if len(strings.Split(listenAddr, ":")) == 2 {
if len(strings.Split(listenAddr, ":")) >= 2 {
socketType = "tcp"
}
return socketType


Loading…
Cancel
Save