From 12814db90dc082f3f8ae86ac341b6e46912bfc4e Mon Sep 17 00:00:00 2001 From: Ethan Buchman Date: Tue, 7 Apr 2015 02:28:21 -0500 Subject: [PATCH] rpc: use gorilla websockets Conflicts: rpc/handlers.go rpc/http_server.go --- rpc/handlers.go | 43 ++++++++++++++++++++++++++++++------- rpc/http_server.go | 19 +++++++++++++++- rpc/test/client_rpc_test.go | 39 +++++++++++++++++++++++++++++++++ rpc/test/helpers.go | 5 +++-- 4 files changed, 95 insertions(+), 11 deletions(-) diff --git a/rpc/handlers.go b/rpc/handlers.go index 933329162..4e55e0267 100644 --- a/rpc/handlers.go +++ b/rpc/handlers.go @@ -4,9 +4,9 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/gorilla/websocket" "github.com/tendermint/tendermint/binary" "github.com/tendermint/tendermint/events" - "golang.org/x/net/websocket" "io/ioutil" "net/http" "reflect" @@ -26,7 +26,7 @@ func RegisterRPCFuncs(mux *http.ServeMux, funcMap map[string]*RPCFunc) { func RegisterEventsHandler(mux *http.ServeMux, evsw *events.EventSwitch) { // websocket endpoint w := NewWebsocketManager(evsw) - mux.Handle("/events", websocket.Handler(w.eventsHandler)) + http.HandleFunc("/events", w.websocketHandler) // websocket.Handler(w.eventsHandler)) } //------------------------------------- @@ -233,6 +233,7 @@ func (c *Connection) Close() { // main manager for all websocket connections // holds the event switch type WebsocketManager struct { + websocket.Upgrader ew *events.EventSwitch cons map[string]*Connection } @@ -241,18 +242,38 @@ func NewWebsocketManager(ew *events.EventSwitch) *WebsocketManager { return &WebsocketManager{ ew: ew, cons: make(map[string]*Connection), + Upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + // TODO + return true + }, + }, } } -func (w *WebsocketManager) eventsHandler(con *websocket.Conn) { +func (wm *WebsocketManager) websocketHandler(w http.ResponseWriter, r *http.Request) { + conn, err := wm.Upgrade(w, r, nil) + if err != nil { + // TODO + log.Error("Failed to upgrade to websocket connection", "error", err) + return + } + wm.handleWebsocket(conn) + +} + +func (w *WebsocketManager) handleWebsocket(con *websocket.Conn) { // register connection c := NewConnection(con) - w.cons[con.RemoteAddr().String()] = c + w.cons[c.id] = c + log.Info("New websocket connection", "origin", c.id) // read subscriptions/unsubscriptions to events go w.read(c) // write responses - go w.write(c) + w.write(c) } const ( @@ -274,19 +295,22 @@ func (w *WebsocketManager) read(con *Connection) { } default: var in []byte - if err := websocket.Message.Receive(con.wsCon, &in); err != nil { + _, in, err := con.wsCon.ReadMessage() + if err != nil { + //if err := websocket.Message.Receive(con.wsCon, &in); err != nil { // an error reading the connection, // so kill the connection con.quitChan <- struct{}{} } var req WsRequest - err := json.Unmarshal(in, &req) + err = json.Unmarshal(in, &req) if err != nil { errStr := fmt.Sprintf("Error unmarshaling data: %s", err.Error()) con.writeChan <- WsResponse{Error: errStr} } switch req.Type { case "subscribe": + log.Info("New event subscription", "con id", con.id, "event", req.Event) w.ew.AddListenerForEvent(con.id, req.Event, func(msg interface{}) { resp := WsResponse{ Event: req.Event, @@ -328,7 +352,10 @@ func (w *WebsocketManager) write(con *Connection) { if *err != nil { log.Error("Failed to write JSON WsResponse", "error", err) } else { - websocket.Message.Send(con.wsCon, buf.Bytes()) + //websocket.Message.Send(con.wsCon, buf.Bytes()) + if err := con.wsCon.WriteMessage(websocket.TextMessage, buf.Bytes()); err != nil { + log.Error("Failed to write response on websocket", "error", err) + } } case <-con.quitChan: w.closeConn(con) diff --git a/rpc/http_server.go b/rpc/http_server.go index 2defdf5d0..1603ab79b 100644 --- a/rpc/http_server.go +++ b/rpc/http_server.go @@ -2,8 +2,10 @@ package rpc import ( + "bufio" "bytes" "fmt" + "net" "net/http" "runtime/debug" "time" @@ -50,7 +52,7 @@ func WriteRPCResponse(w http.ResponseWriter, res RPCResponse) { func RecoverAndLogHandler(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Wrap the ResponseWriter to remember the status - rww := &ResponseWriterWrapper{-1, w} + rww := &ResponseWriterWrapper{-1, w, w.(http.Hijacker)} begin := time.Now() // Common headers @@ -97,9 +99,24 @@ func RecoverAndLogHandler(handler http.Handler) http.Handler { type ResponseWriterWrapper struct { Status int http.ResponseWriter + hj http.Hijacker // necessary for websocket upgrades } func (w *ResponseWriterWrapper) WriteHeader(status int) { w.Status = status w.ResponseWriter.WriteHeader(status) } + +// implements http.Hijacker +func (w *ResponseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return w.hj.Hijack() +} + +// Stick it as a deferred statement in gouroutines to prevent the program from crashing. +func Recover(daemonName string) { + if e := recover(); e != nil { + stack := string(debug.Stack()) + errorString := fmt.Sprintf("[%s] %s\n%s", daemonName, e, stack) + alert.Alert(errorString) + } +} diff --git a/rpc/test/client_rpc_test.go b/rpc/test/client_rpc_test.go index 0d455c3d3..9c261d9f9 100644 --- a/rpc/test/client_rpc_test.go +++ b/rpc/test/client_rpc_test.go @@ -1,6 +1,10 @@ package rpc import ( + "fmt" + "github.com/gorilla/websocket" + "github.com/tendermint/tendermint/rpc" + "net/http" "testing" ) @@ -73,3 +77,38 @@ func TestJSONCallCode(t *testing.T) { func TestJSONCallContract(t *testing.T) { testCall(t, "JSONRPC") } + +//-------------------------------------------------------------------------------- +// Test the websocket client + +func TestWSConnect(t *testing.T) { + dialer := websocket.DefaultDialer + rHeader := http.Header{} + _, r, err := dialer.Dial(websocketAddr, rHeader) + if err != nil { + t.Fatal(err) + } + fmt.Println("respoinse:", r) + +} + +func TestWSSubscribe(t *testing.T) { + dialer := websocket.DefaultDialer + rHeader := http.Header{} + con, _, err := dialer.Dial(websocketAddr, rHeader) + if err != nil { + t.Fatal(err) + } + err = con.WriteJSON(rpc.WsRequest{ + Type: "subscribe", + Event: "newblock", + }) + if err != nil { + t.Fatal(err) + } + /* + typ, p, err := con.ReadMessage() + fmt.Println("RESPONSE:", typ, string(p), err) + */ + +} diff --git a/rpc/test/helpers.go b/rpc/test/helpers.go index 8fb594da4..82650c186 100644 --- a/rpc/test/helpers.go +++ b/rpc/test/helpers.go @@ -18,8 +18,9 @@ import ( // global variables for use across all tests var ( - rpcAddr = "127.0.0.1:8089" - requestAddr = "http://" + rpcAddr + "/" + rpcAddr = "127.0.0.1:8089" + requestAddr = "http://" + rpcAddr + "/" + websocketAddr = "ws://" + rpcAddr + "/events" node *nm.Node