diff --git a/rpc/lib/client/ws_client.go b/rpc/lib/client/ws_client.go index c617bf21a..758df658f 100644 --- a/rpc/lib/client/ws_client.go +++ b/rpc/lib/client/ws_client.go @@ -39,22 +39,23 @@ type WSClient struct { Dialer func(string, string) (net.Conn, error) PingPongLatencyTimer metrics.Timer - sentLastPingAt time.Time // user facing channels, closed only when the client is being stopped. ResultsCh chan json.RawMessage ErrorsCh chan error // internal channels - send chan types.RPCRequest // user requests - backlog chan types.RPCRequest // stores a single user request received during a conn failure - reconnectAfter chan error // reconnect requests - receiveRoutineQuit chan struct{} // a way for receiveRoutine to close writeRoutine + send chan types.RPCRequest // user requests + backlog chan types.RPCRequest // stores a single user request received during a conn failure + reconnectAfter chan error // reconnect requests + readRoutineQuit chan struct{} // a way for readRoutine to close writeRoutine reconnecting bool - wg sync.WaitGroup - mtx sync.RWMutex + wg sync.WaitGroup + + mtx sync.RWMutex + sentLastPingAt time.Time // Time allowed to read the next pong message from the server. pongWait time.Duration @@ -147,8 +148,9 @@ func (c *WSClient) IsActive() bool { return c.IsRunning() && !c.IsReconnecting() } -// Send asynchronously sends the given RPCRequest to the server. Results will -// be available on ResultsCh, errors, if any, on ErrorsCh. +// Send the given RPC request to the server. Results will be available on +// ResultsCh, errors, if any, on ErrorsCh. Will block until send succeeds or +// ctx.Done is closed. func (c *WSClient) Send(ctx context.Context, request types.RPCRequest) error { select { case c.send <- request: @@ -159,8 +161,7 @@ func (c *WSClient) Send(ctx context.Context, request types.RPCRequest) error { } } -// Call asynchronously calls a given method by sending an RPCRequest to the -// server. Results will be available on ResultsCh, errors, if any, on ErrorsCh. +// Call the given method. See Send description. func (c *WSClient) Call(ctx context.Context, method string, params map[string]interface{}) error { request, err := types.MapToRequest("", method, params) if err != nil { @@ -169,9 +170,8 @@ func (c *WSClient) Call(ctx context.Context, method string, params map[string]in return c.Send(ctx, request) } -// CallWithArrayParams asynchronously calls a given method by sending an -// RPCRequest to the server. Results will be available on ResultsCh, errors, if -// any, on ErrorsCh. +// CallWithArrayParams the given method with params in a form of array. See +// Send description. func (c *WSClient) CallWithArrayParams(ctx context.Context, method string, params []interface{}) error { request, err := types.ArrayToRequest("", method, params) if err != nil { @@ -231,8 +231,8 @@ func (c *WSClient) reconnect() error { func (c *WSClient) startReadWriteRoutines() { c.wg.Add(2) - c.receiveRoutineQuit = make(chan struct{}) - go c.receiveRoutine() + c.readRoutineQuit = make(chan struct{}) + go c.readRoutine() go c.writeRoutine() } @@ -240,7 +240,7 @@ func (c *WSClient) reconnectRoutine() { for { select { case originalError := <-c.reconnectAfter: - // wait until writeRoutine and receiveRoutine finish + // wait until writeRoutine and readRoutine finish c.wg.Wait() err := c.reconnect() if err != nil { @@ -310,7 +310,7 @@ func (c *WSClient) writeRoutine() { c.sentLastPingAt = time.Now() c.mtx.Unlock() c.Logger.Debug("sent ping") - case <-c.receiveRoutineQuit: + case <-c.readRoutineQuit: return case <-c.Quit: c.conn.WriteMessage(websocket.CloseMessage, []byte{}) @@ -321,13 +321,14 @@ func (c *WSClient) writeRoutine() { // The client ensures that there is at most one reader to a connection by // executing all reads from this goroutine. -func (c *WSClient) receiveRoutine() { +func (c *WSClient) readRoutine() { defer func() { c.conn.Close() c.wg.Done() }() c.conn.SetReadDeadline(time.Now().Add(c.pongWait)) + c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(c.pongWait)) c.mtx.RLock() @@ -336,6 +337,7 @@ func (c *WSClient) receiveRoutine() { c.Logger.Debug("got pong") return nil }) + for { _, data, err := c.conn.ReadMessage() if err != nil { @@ -344,7 +346,7 @@ func (c *WSClient) receiveRoutine() { } c.Logger.Error("failed to read response", "err", err) - close(c.receiveRoutineQuit) + close(c.readRoutineQuit) c.reconnectAfter <- err return } diff --git a/rpc/lib/client/ws_client_test.go b/rpc/lib/client/ws_client_test.go index 6778a0894..32385bfd7 100644 --- a/rpc/lib/client/ws_client_test.go +++ b/rpc/lib/client/ws_client_test.go @@ -65,28 +65,13 @@ func TestWSClientReconnectsAfterReadFailure(t *testing.T) { defer c.Stop() wg.Add(1) - go func() { - for { - select { - case res := <-c.ResultsCh: - if res != nil { - wg.Done() - } - case err := <-c.ErrorsCh: - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - case <-c.Quit: - return - } - } - }() + go callWgDoneOnResult(t, c, &wg) h.mtx.Lock() h.closeConnAfterRead = true h.mtx.Unlock() - // results in error + // results in WS read error, no send retry because write succeeded call(t, "a", c) // expect to reconnect almost immediately @@ -112,27 +97,12 @@ func TestWSClientReconnectsAfterWriteFailure(t *testing.T) { defer c.Stop() wg.Add(2) - go func() { - for { - select { - case res := <-c.ResultsCh: - if res != nil { - wg.Done() - } - case err := <-c.ErrorsCh: - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - case <-c.Quit: - return - } - } - }() + go callWgDoneOnResult(t, c, &wg) // hacky way to abort the connection before write c.conn.Close() - // results in error, the client should resend on reconnect + // results in WS write error, the client should resend on reconnect call(t, "a", c) // expect to reconnect almost immediately @@ -167,7 +137,7 @@ func TestWSClientReconnectFailure(t *testing.T) { c.conn.Close() s.Close() - // results in error + // results in WS write error call(t, "a", c) // expect to reconnect almost immediately @@ -204,8 +174,23 @@ func startClient(t *testing.T, addr net.Addr) *WSClient { } func call(t *testing.T, method string, c *WSClient) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - err := c.Call(ctx, method, make(map[string]interface{})) + err := c.Call(context.Background(), method, make(map[string]interface{})) require.NoError(t, err) } + +func callWgDoneOnResult(t *testing.T, c *WSClient, wg *sync.WaitGroup) { + for { + select { + case res := <-c.ResultsCh: + if res != nil { + wg.Done() + } + case err := <-c.ErrorsCh: + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + case <-c.Quit: + return + } + } +}