From 75b1b1d6c545e904ddd80934346b78e8f2accee7 Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Mon, 31 Jan 2022 12:11:42 -0800 Subject: [PATCH] rpc: simplify the handling of JSON-RPC request and response IDs (#7738) * rpc: simplify the handling of JSON-RPC request and response IDs Replace the ID wrapper interface with plain JSON. Internally, the client libraries use only integer IDs, and the server does not care about the ID structure apart from checking its validity. Basic structure of this change: - Remove the jsonrpcid interface and its helpers. - Unexport the ID field of request and response. - Add helpers for constructing requests and responses. - Fix up usage and tests. --- rpc/jsonrpc/client/decode.go | 67 +++------- rpc/jsonrpc/client/http_json_client.go | 31 +++-- rpc/jsonrpc/client/ws_client.go | 17 +-- rpc/jsonrpc/client/ws_client_test.go | 4 +- rpc/jsonrpc/server/http_json_handler.go | 2 +- rpc/jsonrpc/server/http_json_handler_test.go | 42 +++---- rpc/jsonrpc/server/http_server_test.go | 4 +- rpc/jsonrpc/server/http_uri_handler.go | 4 +- rpc/jsonrpc/server/ws_handler.go | 4 +- rpc/jsonrpc/server/ws_handler_test.go | 11 +- rpc/jsonrpc/types/types.go | 125 +++++++++---------- rpc/jsonrpc/types/types_test.go | 49 +++----- 12 files changed, 145 insertions(+), 215 deletions(-) diff --git a/rpc/jsonrpc/client/decode.go b/rpc/jsonrpc/client/decode.go index 1254949eb..2babcf70c 100644 --- a/rpc/jsonrpc/client/decode.go +++ b/rpc/jsonrpc/client/decode.go @@ -2,26 +2,25 @@ package client import ( "encoding/json" - "errors" "fmt" rpctypes "github.com/tendermint/tendermint/rpc/jsonrpc/types" ) -func unmarshalResponseBytes(responseBytes []byte, expectedID rpctypes.JSONRPCIntID, result interface{}) error { +func unmarshalResponseBytes(responseBytes []byte, expectedID string, result interface{}) error { // Read response. If rpc/core/types is imported, the result will unmarshal // into the correct type. - response := &rpctypes.RPCResponse{} - if err := json.Unmarshal(responseBytes, response); err != nil { - return fmt.Errorf("error unmarshaling: %w", err) + var response rpctypes.RPCResponse + if err := json.Unmarshal(responseBytes, &response); err != nil { + return fmt.Errorf("unmarshaling response: %w", err) } if response.Error != nil { return response.Error } - if err := validateAndVerifyID(response, expectedID); err != nil { - return fmt.Errorf("wrong ID: %w", err) + if got := response.ID(); got != expectedID { + return fmt.Errorf("got response ID %q, wanted %q", got, expectedID) } // Unmarshal the RawMessage into the result. @@ -31,7 +30,7 @@ func unmarshalResponseBytes(responseBytes []byte, expectedID rpctypes.JSONRPCInt return nil } -func unmarshalResponseBytesArray(responseBytes []byte, expectedIDs []rpctypes.JSONRPCIntID, results []interface{}) error { +func unmarshalResponseBytesArray(responseBytes []byte, expectedIDs []string, results []interface{}) error { var responses []rpctypes.RPCResponse if err := json.Unmarshal(responseBytes, &responses); err != nil { return fmt.Errorf("unmarshaling responses: %w", err) @@ -40,62 +39,32 @@ func unmarshalResponseBytesArray(responseBytes []byte, expectedIDs []rpctypes.JS } // Intersect IDs from responses with expectedIDs. - ids := make([]rpctypes.JSONRPCIntID, len(responses)) - var ok bool + ids := make([]string, len(responses)) for i, resp := range responses { - ids[i], ok = resp.ID.(rpctypes.JSONRPCIntID) - if !ok { - return fmt.Errorf("expected JSONRPCIntID, got %T", resp.ID) - } + ids[i] = resp.ID() } if err := validateResponseIDs(ids, expectedIDs); err != nil { return fmt.Errorf("wrong IDs: %w", err) } - for i := 0; i < len(responses); i++ { - if err := json.Unmarshal(responses[i].Result, results[i]); err != nil { - return fmt.Errorf("error unmarshaling #%d result: %w", i, err) + for i, resp := range responses { + if err := json.Unmarshal(resp.Result, results[i]); err != nil { + return fmt.Errorf("unmarshaling result %d: %w", i, err) } } return nil } -func validateResponseIDs(ids, expectedIDs []rpctypes.JSONRPCIntID) error { - m := make(map[rpctypes.JSONRPCIntID]bool, len(expectedIDs)) - for _, expectedID := range expectedIDs { - m[expectedID] = true +func validateResponseIDs(ids, expectedIDs []string) error { + m := make(map[string]struct{}, len(expectedIDs)) + for _, id := range expectedIDs { + m[id] = struct{}{} } for i, id := range ids { - if m[id] { - delete(m, id) - } else { - return fmt.Errorf("unsolicited ID #%d: %v", i, id) + if _, ok := m[id]; !ok { + return fmt.Errorf("unexpected response ID %d: %q", i, id) } } - - return nil -} - -// From the JSON-RPC 2.0 spec: -// id: It MUST be the same as the value of the id member in the Request Object. -func validateAndVerifyID(res *rpctypes.RPCResponse, expectedID rpctypes.JSONRPCIntID) error { - if err := validateResponseID(res.ID); err != nil { - return err - } - if expectedID != res.ID.(rpctypes.JSONRPCIntID) { // validateResponseID ensured res.ID has the right type - return fmt.Errorf("response ID (%d) does not match request ID (%d)", res.ID, expectedID) - } - return nil -} - -func validateResponseID(id interface{}) error { - if id == nil { - return errors.New("no ID") - } - _, ok := id.(rpctypes.JSONRPCIntID) - if !ok { - return fmt.Errorf("expected JSONRPCIntID, but got: %T", id) - } return nil } diff --git a/rpc/jsonrpc/client/http_json_client.go b/rpc/jsonrpc/client/http_json_client.go index 8da108890..c1cad7097 100644 --- a/rpc/jsonrpc/client/http_json_client.go +++ b/rpc/jsonrpc/client/http_json_client.go @@ -183,8 +183,8 @@ func NewWithHTTPClient(remote string, c *http.Client) (*Client, error) { func (c *Client) Call(ctx context.Context, method string, params, result interface{}) error { id := c.nextRequestID() - request, err := rpctypes.ParamsToRequest(id, method, params) - if err != nil { + request := rpctypes.NewRequest(id) + if err := request.SetMethodAndParams(method, params); err != nil { return fmt.Errorf("failed to encode params: %w", err) } @@ -210,14 +210,13 @@ func (c *Client) Call(ctx context.Context, method string, params, result interfa return err } - defer httpResponse.Body.Close() - responseBytes, err := io.ReadAll(httpResponse.Body) + httpResponse.Body.Close() if err != nil { - return fmt.Errorf("failed to read response body: %w", err) + return fmt.Errorf("reading response body: %w", err) } - return unmarshalResponseBytes(responseBytes, id, result) + return unmarshalResponseBytes(responseBytes, request.ID(), result) } // NewRequestBatch starts a batch of requests for this client. @@ -258,17 +257,16 @@ func (c *Client) sendBatch(ctx context.Context, requests []*jsonRPCBufferedReque return nil, fmt.Errorf("post: %w", err) } - defer httpResponse.Body.Close() - responseBytes, err := io.ReadAll(httpResponse.Body) + httpResponse.Body.Close() if err != nil { - return nil, fmt.Errorf("read response body: %w", err) + return nil, fmt.Errorf("reading response body: %w", err) } // collect ids to check responses IDs in unmarshalResponseBytesArray - ids := make([]rpctypes.JSONRPCIntID, len(requests)) + ids := make([]string, len(requests)) for i, req := range requests { - ids[i] = req.request.ID.(rpctypes.JSONRPCIntID) + ids[i] = req.request.ID() } if err := unmarshalResponseBytesArray(responseBytes, ids, results); err != nil { @@ -277,12 +275,12 @@ func (c *Client) sendBatch(ctx context.Context, requests []*jsonRPCBufferedReque return results, nil } -func (c *Client) nextRequestID() rpctypes.JSONRPCIntID { +func (c *Client) nextRequestID() int { c.mtx.Lock() + defer c.mtx.Unlock() id := c.nextReqID c.nextReqID++ - c.mtx.Unlock() - return rpctypes.JSONRPCIntID(id) + return id } //------------------------------------------------------------------------------------ @@ -345,9 +343,8 @@ func (b *RequestBatch) Send(ctx context.Context) ([]interface{}, error) { // Call enqueues a request to call the given RPC method with the specified // parameters, in the same way that the `Client.Call` function would. func (b *RequestBatch) Call(_ context.Context, method string, params, result interface{}) error { - id := b.client.nextRequestID() - request, err := rpctypes.ParamsToRequest(id, method, params) - if err != nil { + request := rpctypes.NewRequest(b.client.nextRequestID()) + if err := request.SetMethodAndParams(method, params); err != nil { return err } b.enqueue(&jsonRPCBufferedRequest{request: request, result: result}) diff --git a/rpc/jsonrpc/client/ws_client.go b/rpc/jsonrpc/client/ws_client.go index 8e232eb25..38664cd73 100644 --- a/rpc/jsonrpc/client/ws_client.go +++ b/rpc/jsonrpc/client/ws_client.go @@ -204,21 +204,21 @@ func (c *WSClient) Send(ctx context.Context, request rpctypes.RPCRequest) error // Call enqueues a call request onto the Send queue. Requests are JSON encoded. func (c *WSClient) Call(ctx context.Context, method string, params map[string]interface{}) error { - request, err := rpctypes.ParamsToRequest(c.nextRequestID(), method, params) - if err != nil { + req := rpctypes.NewRequest(c.nextRequestID()) + if err := req.SetMethodAndParams(method, params); err != nil { return err } - return c.Send(ctx, request) + return c.Send(ctx, req) } // Private methods -func (c *WSClient) nextRequestID() rpctypes.JSONRPCIntID { +func (c *WSClient) nextRequestID() int { c.mtx.Lock() + defer c.mtx.Unlock() id := c.nextReqID c.nextReqID++ - c.mtx.Unlock() - return rpctypes.JSONRPCIntID(id) + return id } func (c *WSClient) dial() error { @@ -456,11 +456,6 @@ func (c *WSClient) readRoutine(ctx context.Context) { continue } - if err = validateResponseID(response.ID); err != nil { - c.Logger.Error("error in response ID", "id", response.ID, "err", err) - continue - } - // TODO: events resulting from /subscribe do not work with -> // because they are implemented as responses with the subscribe request's // ID. According to the spec, they should be notifications (requests diff --git a/rpc/jsonrpc/client/ws_client_test.go b/rpc/jsonrpc/client/ws_client_test.go index 5a74210ce..c0c47a012 100644 --- a/rpc/jsonrpc/client/ws_client_test.go +++ b/rpc/jsonrpc/client/ws_client_test.go @@ -64,7 +64,9 @@ func (h *myTestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() res := json.RawMessage(`{}`) - emptyRespBytes, _ := json.Marshal(rpctypes.RPCResponse{Result: res, ID: req.ID}) + + emptyRespBytes, err := json.Marshal(req.MakeResponse(res)) + require.NoError(h.t, err) if err := conn.WriteMessage(messageType, emptyRespBytes); err != nil { return } diff --git a/rpc/jsonrpc/server/http_json_handler.go b/rpc/jsonrpc/server/http_json_handler.go index 90defb5b5..4d0c19c28 100644 --- a/rpc/jsonrpc/server/http_json_handler.go +++ b/rpc/jsonrpc/server/http_json_handler.go @@ -53,7 +53,7 @@ func makeJSONRPCHandler(funcMap map[string]*RPCFunc, logger log.Logger) http.Han var responses []rpctypes.RPCResponse for _, req := range requests { // Ignore notifications, which this service does not support. - if req.ID == nil { + if req.IsNotification() { logger.Debug("Ignoring notification", "req", req) continue } diff --git a/rpc/jsonrpc/server/http_json_handler_test.go b/rpc/jsonrpc/server/http_json_handler_test.go index 5e4df724a..1f5d2c320 100644 --- a/rpc/jsonrpc/server/http_json_handler_test.go +++ b/rpc/jsonrpc/server/http_json_handler_test.go @@ -38,24 +38,24 @@ func TestRPCParams(t *testing.T) { tests := []struct { payload string wantErr string - expectedID interface{} + expectedID string }{ // bad - {`{"jsonrpc": "2.0", "id": "0"}`, "Method not found", rpctypes.JSONRPCStringID("0")}, - {`{"jsonrpc": "2.0", "method": "y", "id": "0"}`, "Method not found", rpctypes.JSONRPCStringID("0")}, + {`{"jsonrpc": "2.0", "id": "0"}`, "Method not found", `"0"`}, + {`{"jsonrpc": "2.0", "method": "y", "id": "0"}`, "Method not found", `"0"`}, // id not captured in JSON parsing failures - {`{"method": "c", "id": "0", "params": a}`, "invalid character", nil}, - {`{"method": "c", "id": "0", "params": ["a"]}`, "got 1", rpctypes.JSONRPCStringID("0")}, - {`{"method": "c", "id": "0", "params": ["a", "b"]}`, "invalid syntax", rpctypes.JSONRPCStringID("0")}, - {`{"method": "c", "id": "0", "params": [1, 1]}`, "of type string", rpctypes.JSONRPCStringID("0")}, + {`{"method": "c", "id": "0", "params": a}`, "invalid character", ""}, + {`{"method": "c", "id": "0", "params": ["a"]}`, "got 1", `"0"`}, + {`{"method": "c", "id": "0", "params": ["a", "b"]}`, "invalid syntax", `"0"`}, + {`{"method": "c", "id": "0", "params": [1, 1]}`, "of type string", `"0"`}, // no ID - notification // {`{"jsonrpc": "2.0", "method": "c", "params": ["a", "10"]}`, false, nil}, // good - {`{"jsonrpc": "2.0", "method": "c", "id": "0", "params": null}`, "", rpctypes.JSONRPCStringID("0")}, - {`{"method": "c", "id": "0", "params": {}}`, "", rpctypes.JSONRPCStringID("0")}, - {`{"method": "c", "id": "0", "params": ["a", "10"]}`, "", rpctypes.JSONRPCStringID("0")}, + {`{"jsonrpc": "2.0", "method": "c", "id": "0", "params": null}`, "", `"0"`}, + {`{"method": "c", "id": "0", "params": {}}`, "", `"0"`}, + {`{"method": "c", "id": "0", "params": ["a", "10"]}`, "", `"0"`}, } for i, tt := range tests { @@ -73,7 +73,7 @@ func TestRPCParams(t *testing.T) { recv := new(rpctypes.RPCResponse) assert.Nil(t, json.Unmarshal(blob, recv), "#%d: expecting successful parsing of an RPCResponse:\nblob: %s", i, blob) assert.NotEqual(t, recv, new(rpctypes.RPCResponse), "#%d: not expecting a blank RPCResponse", i) - assert.Equal(t, tt.expectedID, recv.ID, "#%d: expected ID not matched in RPCResponse", i) + assert.Equal(t, tt.expectedID, recv.ID(), "#%d: expected ID not matched in RPCResponse", i) if tt.wantErr == "" { assert.Nil(t, recv.Error, "#%d: not expecting an error", i) } else { @@ -89,19 +89,19 @@ func TestJSONRPCID(t *testing.T) { tests := []struct { payload string wantErr bool - expectedID interface{} + expectedID string }{ // good id - {`{"jsonrpc": "2.0", "method": "c", "id": "0", "params": ["a", "10"]}`, false, rpctypes.JSONRPCStringID("0")}, - {`{"jsonrpc": "2.0", "method": "c", "id": "abc", "params": ["a", "10"]}`, false, rpctypes.JSONRPCStringID("abc")}, - {`{"jsonrpc": "2.0", "method": "c", "id": 0, "params": ["a", "10"]}`, false, rpctypes.JSONRPCIntID(0)}, - {`{"jsonrpc": "2.0", "method": "c", "id": 1, "params": ["a", "10"]}`, false, rpctypes.JSONRPCIntID(1)}, - {`{"jsonrpc": "2.0", "method": "c", "id": 1.3, "params": ["a", "10"]}`, false, rpctypes.JSONRPCIntID(1)}, - {`{"jsonrpc": "2.0", "method": "c", "id": -1, "params": ["a", "10"]}`, false, rpctypes.JSONRPCIntID(-1)}, + {`{"jsonrpc": "2.0", "method": "c", "id": "0", "params": ["a", "10"]}`, false, `"0"`}, + {`{"jsonrpc": "2.0", "method": "c", "id": "abc", "params": ["a", "10"]}`, false, `"abc"`}, + {`{"jsonrpc": "2.0", "method": "c", "id": 0, "params": ["a", "10"]}`, false, `0`}, + {`{"jsonrpc": "2.0", "method": "c", "id": 1, "params": ["a", "10"]}`, false, `1`}, + {`{"jsonrpc": "2.0", "method": "c", "id": -1, "params": ["a", "10"]}`, false, `-1`}, // bad id - {`{"jsonrpc": "2.0", "method": "c", "id": {}, "params": ["a", "10"]}`, true, nil}, - {`{"jsonrpc": "2.0", "method": "c", "id": [], "params": ["a", "10"]}`, true, nil}, + {`{"jsonrpc": "2.0", "method": "c", "id": {}, "params": ["a", "10"]}`, true, ""}, // object + {`{"jsonrpc": "2.0", "method": "c", "id": [], "params": ["a", "10"]}`, true, ""}, // array + {`{"jsonrpc": "2.0", "method": "c", "id": 1.3, "params": ["a", "10"]}`, true, ""}, // fractional } for i, tt := range tests { @@ -123,7 +123,7 @@ func TestJSONRPCID(t *testing.T) { assert.NoError(t, err, "#%d: expecting successful parsing of an RPCResponse:\nblob: %s", i, blob) if !tt.wantErr { assert.NotEqual(t, recv, new(rpctypes.RPCResponse), "#%d: not expecting a blank RPCResponse", i) - assert.Equal(t, tt.expectedID, recv.ID, "#%d: expected ID not matched in RPCResponse", i) + assert.Equal(t, tt.expectedID, recv.ID(), "#%d: expected ID not matched in RPCResponse", i) assert.Nil(t, recv.Error, "#%d: not expecting an error", i) } else { assert.True(t, recv.Error.Code < 0, "#%d: not expecting a positive JSONRPC code", i) diff --git a/rpc/jsonrpc/server/http_server_test.go b/rpc/jsonrpc/server/http_server_test.go index 12be80d1b..838a2ef6c 100644 --- a/rpc/jsonrpc/server/http_server_test.go +++ b/rpc/jsonrpc/server/http_server_test.go @@ -125,7 +125,7 @@ func TestServeTLS(t *testing.T) { } func TestWriteRPCResponse(t *testing.T) { - req := rpctypes.RPCRequest{ID: rpctypes.JSONRPCIntID(-1)} + req := rpctypes.NewRequest(-1) // one argument w := httptest.NewRecorder() @@ -160,7 +160,7 @@ func TestWriteRPCResponse(t *testing.T) { func TestWriteHTTPResponse(t *testing.T) { w := httptest.NewRecorder() logger := log.NewNopLogger() - req := rpctypes.RPCRequest{ID: rpctypes.JSONRPCIntID(-1)} + req := rpctypes.NewRequest(-1) writeHTTPResponse(w, logger, req.MakeErrorf(rpctypes.CodeInternalError, "foo")) resp := w.Result() body, err := io.ReadAll(resp.Body) diff --git a/rpc/jsonrpc/server/http_uri_handler.go b/rpc/jsonrpc/server/http_uri_handler.go index 8d626f225..67491adf5 100644 --- a/rpc/jsonrpc/server/http_uri_handler.go +++ b/rpc/jsonrpc/server/http_uri_handler.go @@ -16,7 +16,7 @@ import ( // uriReqID is a placeholder ID used for GET requests, which do not receive a // JSON-RPC request ID from the caller. -var uriReqID = rpctypes.JSONRPCIntID(-1) +const uriReqID = -1 // convert from a function name to the http handler func makeHTTPHandler(rpcFunc *RPCFunc, logger log.Logger) func(http.ResponseWriter, *http.Request) { @@ -31,7 +31,7 @@ func makeHTTPHandler(rpcFunc *RPCFunc, logger log.Logger) func(http.ResponseWrit fmt.Fprintln(w, err.Error()) return } - jreq := rpctypes.RPCRequest{ID: uriReqID} + jreq := rpctypes.NewRequest(uriReqID) outs := rpcFunc.f.Call(args) logger.Debug("HTTPRestRPC", "method", req.URL.Path, "args", args, "returns", outs) diff --git a/rpc/jsonrpc/server/ws_handler.go b/rpc/jsonrpc/server/ws_handler.go index 2311d9b0b..8a96955de 100644 --- a/rpc/jsonrpc/server/ws_handler.go +++ b/rpc/jsonrpc/server/ws_handler.go @@ -274,7 +274,7 @@ func (wsc *wsConnection) readRoutine(ctx context.Context) { if !ok { err = fmt.Errorf("WSJSONRPC: %v", r) } - req := rpctypes.RPCRequest{ID: uriReqID} + req := rpctypes.NewRequest(uriReqID) wsc.Logger.Error("Panic in WSJSONRPC handler", "err", err, "stack", string(debug.Stack())) if err := wsc.WriteRPCResponse(writeCtx, req.MakeErrorf(rpctypes.CodeInternalError, "Panic in handler: %v", err)); err != nil { @@ -325,7 +325,7 @@ func (wsc *wsConnection) readRoutine(ctx context.Context) { // A Notification is a Request object without an "id" member. // The Server MUST NOT reply to a Notification, including those that are within a batch request. - if request.ID == nil { + if request.IsNotification() { wsc.Logger.Debug( "WSJSONRPC received a notification, skipping... (please send a non-empty ID if you want to call a method)", "req", request, diff --git a/rpc/jsonrpc/server/ws_handler_test.go b/rpc/jsonrpc/server/ws_handler_test.go index 3d78c0d9b..ae73a953b 100644 --- a/rpc/jsonrpc/server/ws_handler_test.go +++ b/rpc/jsonrpc/server/ws_handler_test.go @@ -32,14 +32,9 @@ func TestWebsocketManagerHandler(t *testing.T) { } // check basic functionality works - req, err := rpctypes.ParamsToRequest( - rpctypes.JSONRPCStringID("TestWebsocketManager"), - "c", - map[string]interface{}{"s": "a", "i": 10}, - ) - require.NoError(t, err) - err = c.WriteJSON(req) - require.NoError(t, err) + req := rpctypes.NewRequest(1001) + require.NoError(t, req.SetMethodAndParams("c", map[string]interface{}{"s": "a", "i": 10})) + require.NoError(t, c.WriteJSON(req)) var resp rpctypes.RPCResponse err = c.ReadJSON(&resp) diff --git a/rpc/jsonrpc/types/types.go b/rpc/jsonrpc/types/types.go index 111a9b656..0c0500bf0 100644 --- a/rpc/jsonrpc/types/types.go +++ b/rpc/jsonrpc/types/types.go @@ -1,51 +1,19 @@ package types import ( + "bytes" "context" "encoding/json" "errors" "fmt" "net/http" - "reflect" + "regexp" + "strconv" "strings" "github.com/tendermint/tendermint/rpc/coretypes" ) -// a wrapper to emulate a sum type: jsonrpcid = string | int -// TODO: refactor when Go 2.0 arrives https://github.com/golang/go/issues/19412 -type jsonrpcid interface { - isJSONRPCID() -} - -// JSONRPCStringID a wrapper for JSON-RPC string IDs -type JSONRPCStringID string - -func (JSONRPCStringID) isJSONRPCID() {} -func (id JSONRPCStringID) String() string { return string(id) } - -// JSONRPCIntID a wrapper for JSON-RPC integer IDs -type JSONRPCIntID int - -func (JSONRPCIntID) isJSONRPCID() {} -func (id JSONRPCIntID) String() string { return fmt.Sprintf("%d", id) } - -func idFromInterface(idInterface interface{}) (jsonrpcid, error) { - switch id := idInterface.(type) { - case string: - return JSONRPCStringID(id), nil - case float64: - // json.Unmarshal uses float64 for all numbers - // (https://golang.org/pkg/encoding/json/#Unmarshal), - // but the JSONRPC2.0 spec says the id SHOULD NOT contain - // decimals - so we truncate the decimals here. - return JSONRPCIntID(int(id)), nil - default: - typ := reflect.TypeOf(id) - return nil, fmt.Errorf("json-rpc ID (%v) is of unknown type (%v)", id, typ) - } -} - // ErrorCode is the type of JSON-RPC error codes. type ErrorCode int @@ -77,18 +45,39 @@ var errorCodeString = map[ErrorCode]string{ // REQUEST type RPCRequest struct { - ID jsonrpcid + id json.RawMessage + Method string Params json.RawMessage } +// NewRequest returns an empty request with the specified ID. +func NewRequest(id int) RPCRequest { + return RPCRequest{id: []byte(strconv.Itoa(id))} +} + +// ID returns a string representation of the request ID. +func (req RPCRequest) ID() string { return string(req.id) } + +// IsNotification reports whether req is a notification (has an empty ID). +func (req RPCRequest) IsNotification() bool { return len(req.id) == 0 } + type rpcRequestJSON struct { V string `json:"jsonrpc"` // must be "2.0" - ID interface{} `json:"id,omitempty"` + ID json.RawMessage `json:"id,omitempty"` M string `json:"method"` P json.RawMessage `json:"params"` } +// isNullOrEmpty reports whether data is empty or the JSON "null" value. +func isNullOrEmpty(data json.RawMessage) bool { + return len(data) == 0 || bytes.Equal(data, []byte("null")) +} + +// validID matches the text of a JSON value that is allowed to serve as a +// JSON-RPC request ID. Precondition: Target value is legal JSON. +var validID = regexp.MustCompile(`^(?:".*"|-?\d+)$`) + // UnmarshalJSON decodes a request from a JSON-RPC 2.0 request object. func (req *RPCRequest) UnmarshalJSON(data []byte) error { var wrapper rpcRequestJSON @@ -98,12 +87,11 @@ func (req *RPCRequest) UnmarshalJSON(data []byte) error { return fmt.Errorf("invalid version: %q", wrapper.V) } - if wrapper.ID != nil { - id, err := idFromInterface(wrapper.ID) - if err != nil { - return fmt.Errorf("invalid request ID: %w", err) + if !isNullOrEmpty(wrapper.ID) { + if !validID.Match(wrapper.ID) { + return fmt.Errorf("invalid request ID: %q", string(wrapper.ID)) } - req.ID = id + req.id = wrapper.ID } req.Method = wrapper.M req.Params = wrapper.P @@ -114,14 +102,14 @@ func (req *RPCRequest) UnmarshalJSON(data []byte) error { func (req RPCRequest) MarshalJSON() ([]byte, error) { return json.Marshal(rpcRequestJSON{ V: "2.0", - ID: req.ID, + ID: req.id, M: req.Method, P: req.Params, }) } func (req RPCRequest) String() string { - return fmt.Sprintf("RPCRequest{%s %s/%X}", req.ID, req.Method, req.Params) + return fmt.Sprintf("RPCRequest{%s %s/%X}", req.ID(), req.Method, req.Params) } // MakeResponse constructs a success response to req with the given result. If @@ -131,14 +119,14 @@ func (req RPCRequest) MakeResponse(result interface{}) RPCResponse { if err != nil { return req.MakeErrorf(CodeInternalError, "marshaling result: %v", err) } - return RPCResponse{ID: req.ID, Result: data} + return RPCResponse{id: req.id, Result: data} } // MakeErrorf constructs an error response to req with the given code and a // message constructed by formatting msg with args. func (req RPCRequest) MakeErrorf(code ErrorCode, msg string, args ...interface{}) RPCResponse { return RPCResponse{ - ID: req.ID, + id: req.id, Error: &RPCError{ Code: int(code), Message: code.String(), @@ -154,36 +142,35 @@ func (req RPCRequest) MakeError(err error) RPCResponse { panic("cannot construct an error response for nil") } if e, ok := err.(*RPCError); ok { - return RPCResponse{ID: req.ID, Error: e} + return RPCResponse{id: req.id, Error: e} } if errors.Is(err, coretypes.ErrZeroOrNegativeHeight) || errors.Is(err, coretypes.ErrZeroOrNegativePerPage) || errors.Is(err, coretypes.ErrPageOutOfRange) || errors.Is(err, coretypes.ErrInvalidRequest) { - return RPCResponse{ID: req.ID, Error: &RPCError{ + return RPCResponse{id: req.id, Error: &RPCError{ Code: int(CodeInvalidRequest), Message: CodeInvalidRequest.String(), Data: err.Error(), }} } - return RPCResponse{ID: req.ID, Error: &RPCError{ + return RPCResponse{id: req.id, Error: &RPCError{ Code: int(CodeInternalError), Message: CodeInternalError.String(), Data: err.Error(), }} } -// ParamsToRequest constructs a new RPCRequest with the given ID, method, and parameters. -func ParamsToRequest(id jsonrpcid, method string, params interface{}) (RPCRequest, error) { +// SetMethodAndParams updates the method and parameters of req with the given +// values, leaving the ID unchanged. +func (req *RPCRequest) SetMethodAndParams(method string, params interface{}) error { payload, err := json.Marshal(params) if err != nil { - return RPCRequest{}, err + return err } - return RPCRequest{ - ID: id, - Method: method, - Params: payload, - }, nil + req.Method = method + req.Params = payload + return nil } //---------------------------------------- @@ -204,14 +191,18 @@ func (err RPCError) Error() string { } type RPCResponse struct { - ID jsonrpcid + id json.RawMessage + Result json.RawMessage Error *RPCError } +// ID returns a representation of the response ID. +func (resp RPCResponse) ID() string { return string(resp.id) } + type rpcResponseJSON struct { V string `json:"jsonrpc"` // must be "2.0" - ID interface{} `json:"id,omitempty"` + ID json.RawMessage `json:"id,omitempty"` R json.RawMessage `json:"result,omitempty"` E *RPCError `json:"error,omitempty"` } @@ -225,14 +216,12 @@ func (resp *RPCResponse) UnmarshalJSON(data []byte) error { return fmt.Errorf("invalid version: %q", wrapper.V) } - if wrapper.ID != nil { - id, err := idFromInterface(wrapper.ID) - if err != nil { - return fmt.Errorf("invalid response ID: %w", err) + if !isNullOrEmpty(wrapper.ID) { + if !validID.Match(wrapper.ID) { + return fmt.Errorf("invalid response ID: %q", string(wrapper.ID)) } - resp.ID = id + resp.id = wrapper.ID } - resp.Error = wrapper.E resp.Result = wrapper.R return nil @@ -242,7 +231,7 @@ func (resp *RPCResponse) UnmarshalJSON(data []byte) error { func (resp RPCResponse) MarshalJSON() ([]byte, error) { return json.Marshal(rpcResponseJSON{ V: "2.0", - ID: resp.ID, + ID: resp.id, R: resp.Result, E: resp.Error, }) @@ -250,9 +239,9 @@ func (resp RPCResponse) MarshalJSON() ([]byte, error) { func (resp RPCResponse) String() string { if resp.Error == nil { - return fmt.Sprintf("RPCResponse{%s %X}", resp.ID, resp.Result) + return fmt.Sprintf("RPCResponse{%s %X}", resp.ID(), resp.Result) } - return fmt.Sprintf("RPCResponse{%s %v}", resp.ID, resp.Error) + return fmt.Sprintf("RPCResponse{%s %v}", resp.ID(), resp.Error) } //---------------------------------------- diff --git a/rpc/jsonrpc/types/types_test.go b/rpc/jsonrpc/types/types_test.go index 2dbfed895..d5be2f74d 100644 --- a/rpc/jsonrpc/types/types_test.go +++ b/rpc/jsonrpc/types/types_test.go @@ -13,65 +13,48 @@ type SampleResult struct { Value string } -type responseTest struct { - id jsonrpcid - expected string -} - -var responseTests = []responseTest{ - {JSONRPCStringID("1"), `"1"`}, - {JSONRPCStringID("alphabet"), `"alphabet"`}, - {JSONRPCStringID(""), `""`}, - {JSONRPCStringID("àáâ"), `"àáâ"`}, - {JSONRPCIntID(-1), "-1"}, - {JSONRPCIntID(0), "0"}, - {JSONRPCIntID(1), "1"}, - {JSONRPCIntID(100), "100"}, +// Valid JSON identifier texts. +var testIDs = []string{ + `"1"`, `"alphabet"`, `""`, `"àáâ"`, "-1", "0", "1", "100", } func TestResponses(t *testing.T) { - for _, tt := range responseTests { - req := RPCRequest{ - ID: tt.id, - Method: "whatever", - } + for _, id := range testIDs { + req := RPCRequest{id: json.RawMessage(id)} a := req.MakeResponse(&SampleResult{"hello"}) b, err := json.Marshal(a) - require.NoError(t, err) - s := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"Value":"hello"}}`, tt.expected) + require.NoError(t, err, "input id: %q", id) + s := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"Value":"hello"}}`, id) assert.Equal(t, s, string(b)) d := req.MakeErrorf(CodeParseError, "hello world") e, err := json.Marshal(d) require.NoError(t, err) - f := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"error":{"code":-32700,"message":"Parse error","data":"hello world"}}`, tt.expected) + f := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"error":{"code":-32700,"message":"Parse error","data":"hello world"}}`, id) assert.Equal(t, f, string(e)) g := req.MakeErrorf(CodeMethodNotFound, "foo") h, err := json.Marshal(g) require.NoError(t, err) - i := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"error":{"code":-32601,"message":"Method not found","data":"foo"}}`, tt.expected) + i := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"error":{"code":-32601,"message":"Method not found","data":"foo"}}`, id) assert.Equal(t, string(h), i) } } func TestUnmarshallResponses(t *testing.T) { - for _, tt := range responseTests { + for _, id := range testIDs { response := &RPCResponse{} - err := json.Unmarshal( - []byte(fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"Value":"hello"}}`, tt.expected)), - response, - ) - require.NoError(t, err) + input := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"Value":"hello"}}`, id) + require.NoError(t, json.Unmarshal([]byte(input), &response)) - req := RPCRequest{ID: tt.id} + req := RPCRequest{id: json.RawMessage(id)} a := req.MakeResponse(&SampleResult{"hello"}) assert.Equal(t, *response, a) } - response := &RPCResponse{} - err := json.Unmarshal([]byte(`{"jsonrpc":"2.0","id":true,"result":{"Value":"hello"}}`), response) - require.Error(t, err) + var response RPCResponse + const input = `{"jsonrpc":"2.0","id":true,"result":{"Value":"hello"}}` + require.Error(t, json.Unmarshal([]byte(input), &response)) } func TestRPCError(t *testing.T) {