diff --git a/node/node.go b/node/node.go index d5548415a..c8029cf81 100644 --- a/node/node.go +++ b/node/node.go @@ -2,6 +2,7 @@ package node import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -415,7 +416,10 @@ func (n *Node) startRPC() ([]net.Listener, error) { for i, listenAddr := range listenAddrs { mux := http.NewServeMux() rpcLogger := n.Logger.With("module", "rpc-server") - wm := rpcserver.NewWebsocketManager(rpccore.Routes) + onDisconnect := rpcserver.OnDisconnect(func(remoteAddr string) { + n.eventBus.UnsubscribeAll(context.Background(), remoteAddr) + }) + wm := rpcserver.NewWebsocketManager(rpccore.Routes, onDisconnect) wm.SetLogger(rpcLogger.With("protocol", "websocket")) mux.HandleFunc("/websocket", wm.WebsocketHandler) rpcserver.RegisterRPCFuncs(mux, rpccore.Routes, rpcLogger) diff --git a/rpc/lib/server/handlers.go b/rpc/lib/server/handlers.go index ddb7f9624..deede5894 100644 --- a/rpc/lib/server/handlers.go +++ b/rpc/lib/server/handlers.go @@ -349,9 +349,10 @@ const ( defaultWSPingPeriod = (defaultWSReadWait * 9) / 10 ) -// a single websocket connection -// contains listener id, underlying ws connection, -// and the event switch for subscribing to events +// a single websocket connection contains listener id, underlying ws +// connection, and the event switch for subscribing to events. +// +// In case of an error, the connection is stopped. type wsConnection struct { cmn.BaseService @@ -374,13 +375,17 @@ type wsConnection struct { // Send pings to server with this period. Must be less than readWait, but greater than zero. pingPeriod time.Duration + + // called before stopping the connection. + onDisconnect func(remoteAddr string) } -// NewWSConnection wraps websocket.Conn. See the commentary on the -// func(*wsConnection) functions for a detailed description of how to configure -// ping period and pong wait time. -// NOTE: if the write buffer is full, pongs may be dropped, which may cause clients to disconnect. -// see https://github.com/gorilla/websocket/issues/97 +// NewWSConnection wraps websocket.Conn. +// +// See the commentary on the func(*wsConnection) functions for a detailed +// description of how to configure ping period and pong wait time. NOTE: if the +// write buffer is full, pongs may be dropped, which may cause clients to +// disconnect. see https://github.com/gorilla/websocket/issues/97 func NewWSConnection(baseConn *websocket.Conn, funcMap map[string]*RPCFunc, options ...func(*wsConnection)) *wsConnection { wsc := &wsConnection{ remoteAddr: baseConn.RemoteAddr().String(), @@ -431,7 +436,16 @@ func PingPeriod(pingPeriod time.Duration) func(*wsConnection) { } } -// OnStart starts the read and write routines. It blocks until the connection closes. +// OnDisconnect called before stopping the connection. +// It should only be used in the constructor - not Goroutine-safe. +func OnDisconnect(cb func(remoteAddr string)) func(*wsConnection) { + return func(wsc *wsConnection) { + wsc.onDisconnect = cb + } +} + +// OnStart implements cmn.Service by starting the read and write routines. It +// blocks until the connection closes. func (wsc *wsConnection) OnStart() error { wsc.writeChan = make(chan types.RPCResponse, wsc.writeChanCapacity) @@ -443,7 +457,7 @@ func (wsc *wsConnection) OnStart() error { return nil } -// OnStop unsubscribes from all events. +// OnStop is a nop. func (wsc *wsConnection) OnStop() { // Both read and write loops close the websocket connection when they exit their loops. // The writeChan is never closed, to allow WriteRPCResponse() to fail.