diff --git a/rpc/jsonrpc/client/decode.go b/rpc/jsonrpc/client/decode.go index 1254949eb..2babcf70c 100644 --- a/rpc/jsonrpc/client/decode.go +++ b/rpc/jsonrpc/client/decode.go @@ -2,26 +2,25 @@ package client import ( "encoding/json" - "errors" "fmt" rpctypes "github.com/tendermint/tendermint/rpc/jsonrpc/types" ) -func unmarshalResponseBytes(responseBytes []byte, expectedID rpctypes.JSONRPCIntID, result interface{}) error { +func unmarshalResponseBytes(responseBytes []byte, expectedID string, result interface{}) error { // Read response. If rpc/core/types is imported, the result will unmarshal // into the correct type. - response := &rpctypes.RPCResponse{} - if err := json.Unmarshal(responseBytes, response); err != nil { - return fmt.Errorf("error unmarshaling: %w", err) + var response rpctypes.RPCResponse + if err := json.Unmarshal(responseBytes, &response); err != nil { + return fmt.Errorf("unmarshaling response: %w", err) } if response.Error != nil { return response.Error } - if err := validateAndVerifyID(response, expectedID); err != nil { - return fmt.Errorf("wrong ID: %w", err) + if got := response.ID(); got != expectedID { + return fmt.Errorf("got response ID %q, wanted %q", got, expectedID) } // Unmarshal the RawMessage into the result. @@ -31,7 +30,7 @@ func unmarshalResponseBytes(responseBytes []byte, expectedID rpctypes.JSONRPCInt return nil } -func unmarshalResponseBytesArray(responseBytes []byte, expectedIDs []rpctypes.JSONRPCIntID, results []interface{}) error { +func unmarshalResponseBytesArray(responseBytes []byte, expectedIDs []string, results []interface{}) error { var responses []rpctypes.RPCResponse if err := json.Unmarshal(responseBytes, &responses); err != nil { return fmt.Errorf("unmarshaling responses: %w", err) @@ -40,62 +39,32 @@ func unmarshalResponseBytesArray(responseBytes []byte, expectedIDs []rpctypes.JS } // Intersect IDs from responses with expectedIDs. - ids := make([]rpctypes.JSONRPCIntID, len(responses)) - var ok bool + ids := make([]string, len(responses)) for i, resp := range responses { - ids[i], ok = resp.ID.(rpctypes.JSONRPCIntID) - if !ok { - return fmt.Errorf("expected JSONRPCIntID, got %T", resp.ID) - } + ids[i] = resp.ID() } if err := validateResponseIDs(ids, expectedIDs); err != nil { return fmt.Errorf("wrong IDs: %w", err) } - for i := 0; i < len(responses); i++ { - if err := json.Unmarshal(responses[i].Result, results[i]); err != nil { - return fmt.Errorf("error unmarshaling #%d result: %w", i, err) + for i, resp := range responses { + if err := json.Unmarshal(resp.Result, results[i]); err != nil { + return fmt.Errorf("unmarshaling result %d: %w", i, err) } } return nil } -func validateResponseIDs(ids, expectedIDs []rpctypes.JSONRPCIntID) error { - m := make(map[rpctypes.JSONRPCIntID]bool, len(expectedIDs)) - for _, expectedID := range expectedIDs { - m[expectedID] = true +func validateResponseIDs(ids, expectedIDs []string) error { + m := make(map[string]struct{}, len(expectedIDs)) + for _, id := range expectedIDs { + m[id] = struct{}{} } for i, id := range ids { - if m[id] { - delete(m, id) - } else { - return fmt.Errorf("unsolicited ID #%d: %v", i, id) + if _, ok := m[id]; !ok { + return fmt.Errorf("unexpected response ID %d: %q", i, id) } } - - return nil -} - -// From the JSON-RPC 2.0 spec: -// id: It MUST be the same as the value of the id member in the Request Object. -func validateAndVerifyID(res *rpctypes.RPCResponse, expectedID rpctypes.JSONRPCIntID) error { - if err := validateResponseID(res.ID); err != nil { - return err - } - if expectedID != res.ID.(rpctypes.JSONRPCIntID) { // validateResponseID ensured res.ID has the right type - return fmt.Errorf("response ID (%d) does not match request ID (%d)", res.ID, expectedID) - } - return nil -} - -func validateResponseID(id interface{}) error { - if id == nil { - return errors.New("no ID") - } - _, ok := id.(rpctypes.JSONRPCIntID) - if !ok { - return fmt.Errorf("expected JSONRPCIntID, but got: %T", id) - } return nil } diff --git a/rpc/jsonrpc/client/http_json_client.go b/rpc/jsonrpc/client/http_json_client.go index 8da108890..c1cad7097 100644 --- a/rpc/jsonrpc/client/http_json_client.go +++ b/rpc/jsonrpc/client/http_json_client.go @@ -183,8 +183,8 @@ func NewWithHTTPClient(remote string, c *http.Client) (*Client, error) { func (c *Client) Call(ctx context.Context, method string, params, result interface{}) error { id := c.nextRequestID() - request, err := rpctypes.ParamsToRequest(id, method, params) - if err != nil { + request := rpctypes.NewRequest(id) + if err := request.SetMethodAndParams(method, params); err != nil { return fmt.Errorf("failed to encode params: %w", err) } @@ -210,14 +210,13 @@ func (c *Client) Call(ctx context.Context, method string, params, result interfa return err } - defer httpResponse.Body.Close() - responseBytes, err := io.ReadAll(httpResponse.Body) + httpResponse.Body.Close() if err != nil { - return fmt.Errorf("failed to read response body: %w", err) + return fmt.Errorf("reading response body: %w", err) } - return unmarshalResponseBytes(responseBytes, id, result) + return unmarshalResponseBytes(responseBytes, request.ID(), result) } // NewRequestBatch starts a batch of requests for this client. @@ -258,17 +257,16 @@ func (c *Client) sendBatch(ctx context.Context, requests []*jsonRPCBufferedReque return nil, fmt.Errorf("post: %w", err) } - defer httpResponse.Body.Close() - responseBytes, err := io.ReadAll(httpResponse.Body) + httpResponse.Body.Close() if err != nil { - return nil, fmt.Errorf("read response body: %w", err) + return nil, fmt.Errorf("reading response body: %w", err) } // collect ids to check responses IDs in unmarshalResponseBytesArray - ids := make([]rpctypes.JSONRPCIntID, len(requests)) + ids := make([]string, len(requests)) for i, req := range requests { - ids[i] = req.request.ID.(rpctypes.JSONRPCIntID) + ids[i] = req.request.ID() } if err := unmarshalResponseBytesArray(responseBytes, ids, results); err != nil { @@ -277,12 +275,12 @@ func (c *Client) sendBatch(ctx context.Context, requests []*jsonRPCBufferedReque return results, nil } -func (c *Client) nextRequestID() rpctypes.JSONRPCIntID { +func (c *Client) nextRequestID() int { c.mtx.Lock() + defer c.mtx.Unlock() id := c.nextReqID c.nextReqID++ - c.mtx.Unlock() - return rpctypes.JSONRPCIntID(id) + return id } //------------------------------------------------------------------------------------ @@ -345,9 +343,8 @@ func (b *RequestBatch) Send(ctx context.Context) ([]interface{}, error) { // Call enqueues a request to call the given RPC method with the specified // parameters, in the same way that the `Client.Call` function would. func (b *RequestBatch) Call(_ context.Context, method string, params, result interface{}) error { - id := b.client.nextRequestID() - request, err := rpctypes.ParamsToRequest(id, method, params) - if err != nil { + request := rpctypes.NewRequest(b.client.nextRequestID()) + if err := request.SetMethodAndParams(method, params); err != nil { return err } b.enqueue(&jsonRPCBufferedRequest{request: request, result: result}) diff --git a/rpc/jsonrpc/client/ws_client.go b/rpc/jsonrpc/client/ws_client.go index 8e232eb25..38664cd73 100644 --- a/rpc/jsonrpc/client/ws_client.go +++ b/rpc/jsonrpc/client/ws_client.go @@ -204,21 +204,21 @@ func (c *WSClient) Send(ctx context.Context, request rpctypes.RPCRequest) error // Call enqueues a call request onto the Send queue. Requests are JSON encoded. func (c *WSClient) Call(ctx context.Context, method string, params map[string]interface{}) error { - request, err := rpctypes.ParamsToRequest(c.nextRequestID(), method, params) - if err != nil { + req := rpctypes.NewRequest(c.nextRequestID()) + if err := req.SetMethodAndParams(method, params); err != nil { return err } - return c.Send(ctx, request) + return c.Send(ctx, req) } // Private methods -func (c *WSClient) nextRequestID() rpctypes.JSONRPCIntID { +func (c *WSClient) nextRequestID() int { c.mtx.Lock() + defer c.mtx.Unlock() id := c.nextReqID c.nextReqID++ - c.mtx.Unlock() - return rpctypes.JSONRPCIntID(id) + return id } func (c *WSClient) dial() error { @@ -456,11 +456,6 @@ func (c *WSClient) readRoutine(ctx context.Context) { continue } - if err = validateResponseID(response.ID); err != nil { - c.Logger.Error("error in response ID", "id", response.ID, "err", err) - continue - } - // TODO: events resulting from /subscribe do not work with -> // because they are implemented as responses with the subscribe request's // ID. According to the spec, they should be notifications (requests diff --git a/rpc/jsonrpc/client/ws_client_test.go b/rpc/jsonrpc/client/ws_client_test.go index 5a74210ce..c0c47a012 100644 --- a/rpc/jsonrpc/client/ws_client_test.go +++ b/rpc/jsonrpc/client/ws_client_test.go @@ -64,7 +64,9 @@ func (h *myTestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() res := json.RawMessage(`{}`) - emptyRespBytes, _ := json.Marshal(rpctypes.RPCResponse{Result: res, ID: req.ID}) + + emptyRespBytes, err := json.Marshal(req.MakeResponse(res)) + require.NoError(h.t, err) if err := conn.WriteMessage(messageType, emptyRespBytes); err != nil { return } diff --git a/rpc/jsonrpc/server/http_json_handler.go b/rpc/jsonrpc/server/http_json_handler.go index 90defb5b5..4d0c19c28 100644 --- a/rpc/jsonrpc/server/http_json_handler.go +++ b/rpc/jsonrpc/server/http_json_handler.go @@ -53,7 +53,7 @@ func makeJSONRPCHandler(funcMap map[string]*RPCFunc, logger log.Logger) http.Han var responses []rpctypes.RPCResponse for _, req := range requests { // Ignore notifications, which this service does not support. - if req.ID == nil { + if req.IsNotification() { logger.Debug("Ignoring notification", "req", req) continue } diff --git a/rpc/jsonrpc/server/http_json_handler_test.go b/rpc/jsonrpc/server/http_json_handler_test.go index 5e4df724a..1f5d2c320 100644 --- a/rpc/jsonrpc/server/http_json_handler_test.go +++ b/rpc/jsonrpc/server/http_json_handler_test.go @@ -38,24 +38,24 @@ func TestRPCParams(t *testing.T) { tests := []struct { payload string wantErr string - expectedID interface{} + expectedID string }{ // bad - {`{"jsonrpc": "2.0", "id": "0"}`, "Method not found", rpctypes.JSONRPCStringID("0")}, - {`{"jsonrpc": "2.0", "method": "y", "id": "0"}`, "Method not found", rpctypes.JSONRPCStringID("0")}, + {`{"jsonrpc": "2.0", "id": "0"}`, "Method not found", `"0"`}, + {`{"jsonrpc": "2.0", "method": "y", "id": "0"}`, "Method not found", `"0"`}, // id not captured in JSON parsing failures - {`{"method": "c", "id": "0", "params": a}`, "invalid character", nil}, - {`{"method": "c", "id": "0", "params": ["a"]}`, "got 1", rpctypes.JSONRPCStringID("0")}, - {`{"method": "c", "id": "0", "params": ["a", "b"]}`, "invalid syntax", rpctypes.JSONRPCStringID("0")}, - {`{"method": "c", "id": "0", "params": [1, 1]}`, "of type string", rpctypes.JSONRPCStringID("0")}, + {`{"method": "c", "id": "0", "params": a}`, "invalid character", ""}, + {`{"method": "c", "id": "0", "params": ["a"]}`, "got 1", `"0"`}, + {`{"method": "c", "id": "0", "params": ["a", "b"]}`, "invalid syntax", `"0"`}, + {`{"method": "c", "id": "0", "params": [1, 1]}`, "of type string", `"0"`}, // no ID - notification // {`{"jsonrpc": "2.0", "method": "c", "params": ["a", "10"]}`, false, nil}, // good - {`{"jsonrpc": "2.0", "method": "c", "id": "0", "params": null}`, "", rpctypes.JSONRPCStringID("0")}, - {`{"method": "c", "id": "0", "params": {}}`, "", rpctypes.JSONRPCStringID("0")}, - {`{"method": "c", "id": "0", "params": ["a", "10"]}`, "", rpctypes.JSONRPCStringID("0")}, + {`{"jsonrpc": "2.0", "method": "c", "id": "0", "params": null}`, "", `"0"`}, + {`{"method": "c", "id": "0", "params": {}}`, "", `"0"`}, + {`{"method": "c", "id": "0", "params": ["a", "10"]}`, "", `"0"`}, } for i, tt := range tests { @@ -73,7 +73,7 @@ func TestRPCParams(t *testing.T) { recv := new(rpctypes.RPCResponse) assert.Nil(t, json.Unmarshal(blob, recv), "#%d: expecting successful parsing of an RPCResponse:\nblob: %s", i, blob) assert.NotEqual(t, recv, new(rpctypes.RPCResponse), "#%d: not expecting a blank RPCResponse", i) - assert.Equal(t, tt.expectedID, recv.ID, "#%d: expected ID not matched in RPCResponse", i) + assert.Equal(t, tt.expectedID, recv.ID(), "#%d: expected ID not matched in RPCResponse", i) if tt.wantErr == "" { assert.Nil(t, recv.Error, "#%d: not expecting an error", i) } else { @@ -89,19 +89,19 @@ func TestJSONRPCID(t *testing.T) { tests := []struct { payload string wantErr bool - expectedID interface{} + expectedID string }{ // good id - {`{"jsonrpc": "2.0", "method": "c", "id": "0", "params": ["a", "10"]}`, false, rpctypes.JSONRPCStringID("0")}, - {`{"jsonrpc": "2.0", "method": "c", "id": "abc", "params": ["a", "10"]}`, false, rpctypes.JSONRPCStringID("abc")}, - {`{"jsonrpc": "2.0", "method": "c", "id": 0, "params": ["a", "10"]}`, false, rpctypes.JSONRPCIntID(0)}, - {`{"jsonrpc": "2.0", "method": "c", "id": 1, "params": ["a", "10"]}`, false, rpctypes.JSONRPCIntID(1)}, - {`{"jsonrpc": "2.0", "method": "c", "id": 1.3, "params": ["a", "10"]}`, false, rpctypes.JSONRPCIntID(1)}, - {`{"jsonrpc": "2.0", "method": "c", "id": -1, "params": ["a", "10"]}`, false, rpctypes.JSONRPCIntID(-1)}, + {`{"jsonrpc": "2.0", "method": "c", "id": "0", "params": ["a", "10"]}`, false, `"0"`}, + {`{"jsonrpc": "2.0", "method": "c", "id": "abc", "params": ["a", "10"]}`, false, `"abc"`}, + {`{"jsonrpc": "2.0", "method": "c", "id": 0, "params": ["a", "10"]}`, false, `0`}, + {`{"jsonrpc": "2.0", "method": "c", "id": 1, "params": ["a", "10"]}`, false, `1`}, + {`{"jsonrpc": "2.0", "method": "c", "id": -1, "params": ["a", "10"]}`, false, `-1`}, // bad id - {`{"jsonrpc": "2.0", "method": "c", "id": {}, "params": ["a", "10"]}`, true, nil}, - {`{"jsonrpc": "2.0", "method": "c", "id": [], "params": ["a", "10"]}`, true, nil}, + {`{"jsonrpc": "2.0", "method": "c", "id": {}, "params": ["a", "10"]}`, true, ""}, // object + {`{"jsonrpc": "2.0", "method": "c", "id": [], "params": ["a", "10"]}`, true, ""}, // array + {`{"jsonrpc": "2.0", "method": "c", "id": 1.3, "params": ["a", "10"]}`, true, ""}, // fractional } for i, tt := range tests { @@ -123,7 +123,7 @@ func TestJSONRPCID(t *testing.T) { assert.NoError(t, err, "#%d: expecting successful parsing of an RPCResponse:\nblob: %s", i, blob) if !tt.wantErr { assert.NotEqual(t, recv, new(rpctypes.RPCResponse), "#%d: not expecting a blank RPCResponse", i) - assert.Equal(t, tt.expectedID, recv.ID, "#%d: expected ID not matched in RPCResponse", i) + assert.Equal(t, tt.expectedID, recv.ID(), "#%d: expected ID not matched in RPCResponse", i) assert.Nil(t, recv.Error, "#%d: not expecting an error", i) } else { assert.True(t, recv.Error.Code < 0, "#%d: not expecting a positive JSONRPC code", i) diff --git a/rpc/jsonrpc/server/http_server_test.go b/rpc/jsonrpc/server/http_server_test.go index 12be80d1b..838a2ef6c 100644 --- a/rpc/jsonrpc/server/http_server_test.go +++ b/rpc/jsonrpc/server/http_server_test.go @@ -125,7 +125,7 @@ func TestServeTLS(t *testing.T) { } func TestWriteRPCResponse(t *testing.T) { - req := rpctypes.RPCRequest{ID: rpctypes.JSONRPCIntID(-1)} + req := rpctypes.NewRequest(-1) // one argument w := httptest.NewRecorder() @@ -160,7 +160,7 @@ func TestWriteRPCResponse(t *testing.T) { func TestWriteHTTPResponse(t *testing.T) { w := httptest.NewRecorder() logger := log.NewNopLogger() - req := rpctypes.RPCRequest{ID: rpctypes.JSONRPCIntID(-1)} + req := rpctypes.NewRequest(-1) writeHTTPResponse(w, logger, req.MakeErrorf(rpctypes.CodeInternalError, "foo")) resp := w.Result() body, err := io.ReadAll(resp.Body) diff --git a/rpc/jsonrpc/server/http_uri_handler.go b/rpc/jsonrpc/server/http_uri_handler.go index 8d626f225..67491adf5 100644 --- a/rpc/jsonrpc/server/http_uri_handler.go +++ b/rpc/jsonrpc/server/http_uri_handler.go @@ -16,7 +16,7 @@ import ( // uriReqID is a placeholder ID used for GET requests, which do not receive a // JSON-RPC request ID from the caller. -var uriReqID = rpctypes.JSONRPCIntID(-1) +const uriReqID = -1 // convert from a function name to the http handler func makeHTTPHandler(rpcFunc *RPCFunc, logger log.Logger) func(http.ResponseWriter, *http.Request) { @@ -31,7 +31,7 @@ func makeHTTPHandler(rpcFunc *RPCFunc, logger log.Logger) func(http.ResponseWrit fmt.Fprintln(w, err.Error()) return } - jreq := rpctypes.RPCRequest{ID: uriReqID} + jreq := rpctypes.NewRequest(uriReqID) outs := rpcFunc.f.Call(args) logger.Debug("HTTPRestRPC", "method", req.URL.Path, "args", args, "returns", outs) diff --git a/rpc/jsonrpc/server/ws_handler.go b/rpc/jsonrpc/server/ws_handler.go index 2311d9b0b..8a96955de 100644 --- a/rpc/jsonrpc/server/ws_handler.go +++ b/rpc/jsonrpc/server/ws_handler.go @@ -274,7 +274,7 @@ func (wsc *wsConnection) readRoutine(ctx context.Context) { if !ok { err = fmt.Errorf("WSJSONRPC: %v", r) } - req := rpctypes.RPCRequest{ID: uriReqID} + req := rpctypes.NewRequest(uriReqID) wsc.Logger.Error("Panic in WSJSONRPC handler", "err", err, "stack", string(debug.Stack())) if err := wsc.WriteRPCResponse(writeCtx, req.MakeErrorf(rpctypes.CodeInternalError, "Panic in handler: %v", err)); err != nil { @@ -325,7 +325,7 @@ func (wsc *wsConnection) readRoutine(ctx context.Context) { // A Notification is a Request object without an "id" member. // The Server MUST NOT reply to a Notification, including those that are within a batch request. - if request.ID == nil { + if request.IsNotification() { wsc.Logger.Debug( "WSJSONRPC received a notification, skipping... (please send a non-empty ID if you want to call a method)", "req", request, diff --git a/rpc/jsonrpc/server/ws_handler_test.go b/rpc/jsonrpc/server/ws_handler_test.go index 3d78c0d9b..ae73a953b 100644 --- a/rpc/jsonrpc/server/ws_handler_test.go +++ b/rpc/jsonrpc/server/ws_handler_test.go @@ -32,14 +32,9 @@ func TestWebsocketManagerHandler(t *testing.T) { } // check basic functionality works - req, err := rpctypes.ParamsToRequest( - rpctypes.JSONRPCStringID("TestWebsocketManager"), - "c", - map[string]interface{}{"s": "a", "i": 10}, - ) - require.NoError(t, err) - err = c.WriteJSON(req) - require.NoError(t, err) + req := rpctypes.NewRequest(1001) + require.NoError(t, req.SetMethodAndParams("c", map[string]interface{}{"s": "a", "i": 10})) + require.NoError(t, c.WriteJSON(req)) var resp rpctypes.RPCResponse err = c.ReadJSON(&resp) diff --git a/rpc/jsonrpc/types/types.go b/rpc/jsonrpc/types/types.go index 111a9b656..0c0500bf0 100644 --- a/rpc/jsonrpc/types/types.go +++ b/rpc/jsonrpc/types/types.go @@ -1,51 +1,19 @@ package types import ( + "bytes" "context" "encoding/json" "errors" "fmt" "net/http" - "reflect" + "regexp" + "strconv" "strings" "github.com/tendermint/tendermint/rpc/coretypes" ) -// a wrapper to emulate a sum type: jsonrpcid = string | int -// TODO: refactor when Go 2.0 arrives https://github.com/golang/go/issues/19412 -type jsonrpcid interface { - isJSONRPCID() -} - -// JSONRPCStringID a wrapper for JSON-RPC string IDs -type JSONRPCStringID string - -func (JSONRPCStringID) isJSONRPCID() {} -func (id JSONRPCStringID) String() string { return string(id) } - -// JSONRPCIntID a wrapper for JSON-RPC integer IDs -type JSONRPCIntID int - -func (JSONRPCIntID) isJSONRPCID() {} -func (id JSONRPCIntID) String() string { return fmt.Sprintf("%d", id) } - -func idFromInterface(idInterface interface{}) (jsonrpcid, error) { - switch id := idInterface.(type) { - case string: - return JSONRPCStringID(id), nil - case float64: - // json.Unmarshal uses float64 for all numbers - // (https://golang.org/pkg/encoding/json/#Unmarshal), - // but the JSONRPC2.0 spec says the id SHOULD NOT contain - // decimals - so we truncate the decimals here. - return JSONRPCIntID(int(id)), nil - default: - typ := reflect.TypeOf(id) - return nil, fmt.Errorf("json-rpc ID (%v) is of unknown type (%v)", id, typ) - } -} - // ErrorCode is the type of JSON-RPC error codes. type ErrorCode int @@ -77,18 +45,39 @@ var errorCodeString = map[ErrorCode]string{ // REQUEST type RPCRequest struct { - ID jsonrpcid + id json.RawMessage + Method string Params json.RawMessage } +// NewRequest returns an empty request with the specified ID. +func NewRequest(id int) RPCRequest { + return RPCRequest{id: []byte(strconv.Itoa(id))} +} + +// ID returns a string representation of the request ID. +func (req RPCRequest) ID() string { return string(req.id) } + +// IsNotification reports whether req is a notification (has an empty ID). +func (req RPCRequest) IsNotification() bool { return len(req.id) == 0 } + type rpcRequestJSON struct { V string `json:"jsonrpc"` // must be "2.0" - ID interface{} `json:"id,omitempty"` + ID json.RawMessage `json:"id,omitempty"` M string `json:"method"` P json.RawMessage `json:"params"` } +// isNullOrEmpty reports whether data is empty or the JSON "null" value. +func isNullOrEmpty(data json.RawMessage) bool { + return len(data) == 0 || bytes.Equal(data, []byte("null")) +} + +// validID matches the text of a JSON value that is allowed to serve as a +// JSON-RPC request ID. Precondition: Target value is legal JSON. +var validID = regexp.MustCompile(`^(?:".*"|-?\d+)$`) + // UnmarshalJSON decodes a request from a JSON-RPC 2.0 request object. func (req *RPCRequest) UnmarshalJSON(data []byte) error { var wrapper rpcRequestJSON @@ -98,12 +87,11 @@ func (req *RPCRequest) UnmarshalJSON(data []byte) error { return fmt.Errorf("invalid version: %q", wrapper.V) } - if wrapper.ID != nil { - id, err := idFromInterface(wrapper.ID) - if err != nil { - return fmt.Errorf("invalid request ID: %w", err) + if !isNullOrEmpty(wrapper.ID) { + if !validID.Match(wrapper.ID) { + return fmt.Errorf("invalid request ID: %q", string(wrapper.ID)) } - req.ID = id + req.id = wrapper.ID } req.Method = wrapper.M req.Params = wrapper.P @@ -114,14 +102,14 @@ func (req *RPCRequest) UnmarshalJSON(data []byte) error { func (req RPCRequest) MarshalJSON() ([]byte, error) { return json.Marshal(rpcRequestJSON{ V: "2.0", - ID: req.ID, + ID: req.id, M: req.Method, P: req.Params, }) } func (req RPCRequest) String() string { - return fmt.Sprintf("RPCRequest{%s %s/%X}", req.ID, req.Method, req.Params) + return fmt.Sprintf("RPCRequest{%s %s/%X}", req.ID(), req.Method, req.Params) } // MakeResponse constructs a success response to req with the given result. If @@ -131,14 +119,14 @@ func (req RPCRequest) MakeResponse(result interface{}) RPCResponse { if err != nil { return req.MakeErrorf(CodeInternalError, "marshaling result: %v", err) } - return RPCResponse{ID: req.ID, Result: data} + return RPCResponse{id: req.id, Result: data} } // MakeErrorf constructs an error response to req with the given code and a // message constructed by formatting msg with args. func (req RPCRequest) MakeErrorf(code ErrorCode, msg string, args ...interface{}) RPCResponse { return RPCResponse{ - ID: req.ID, + id: req.id, Error: &RPCError{ Code: int(code), Message: code.String(), @@ -154,36 +142,35 @@ func (req RPCRequest) MakeError(err error) RPCResponse { panic("cannot construct an error response for nil") } if e, ok := err.(*RPCError); ok { - return RPCResponse{ID: req.ID, Error: e} + return RPCResponse{id: req.id, Error: e} } if errors.Is(err, coretypes.ErrZeroOrNegativeHeight) || errors.Is(err, coretypes.ErrZeroOrNegativePerPage) || errors.Is(err, coretypes.ErrPageOutOfRange) || errors.Is(err, coretypes.ErrInvalidRequest) { - return RPCResponse{ID: req.ID, Error: &RPCError{ + return RPCResponse{id: req.id, Error: &RPCError{ Code: int(CodeInvalidRequest), Message: CodeInvalidRequest.String(), Data: err.Error(), }} } - return RPCResponse{ID: req.ID, Error: &RPCError{ + return RPCResponse{id: req.id, Error: &RPCError{ Code: int(CodeInternalError), Message: CodeInternalError.String(), Data: err.Error(), }} } -// ParamsToRequest constructs a new RPCRequest with the given ID, method, and parameters. -func ParamsToRequest(id jsonrpcid, method string, params interface{}) (RPCRequest, error) { +// SetMethodAndParams updates the method and parameters of req with the given +// values, leaving the ID unchanged. +func (req *RPCRequest) SetMethodAndParams(method string, params interface{}) error { payload, err := json.Marshal(params) if err != nil { - return RPCRequest{}, err + return err } - return RPCRequest{ - ID: id, - Method: method, - Params: payload, - }, nil + req.Method = method + req.Params = payload + return nil } //---------------------------------------- @@ -204,14 +191,18 @@ func (err RPCError) Error() string { } type RPCResponse struct { - ID jsonrpcid + id json.RawMessage + Result json.RawMessage Error *RPCError } +// ID returns a representation of the response ID. +func (resp RPCResponse) ID() string { return string(resp.id) } + type rpcResponseJSON struct { V string `json:"jsonrpc"` // must be "2.0" - ID interface{} `json:"id,omitempty"` + ID json.RawMessage `json:"id,omitempty"` R json.RawMessage `json:"result,omitempty"` E *RPCError `json:"error,omitempty"` } @@ -225,14 +216,12 @@ func (resp *RPCResponse) UnmarshalJSON(data []byte) error { return fmt.Errorf("invalid version: %q", wrapper.V) } - if wrapper.ID != nil { - id, err := idFromInterface(wrapper.ID) - if err != nil { - return fmt.Errorf("invalid response ID: %w", err) + if !isNullOrEmpty(wrapper.ID) { + if !validID.Match(wrapper.ID) { + return fmt.Errorf("invalid response ID: %q", string(wrapper.ID)) } - resp.ID = id + resp.id = wrapper.ID } - resp.Error = wrapper.E resp.Result = wrapper.R return nil @@ -242,7 +231,7 @@ func (resp *RPCResponse) UnmarshalJSON(data []byte) error { func (resp RPCResponse) MarshalJSON() ([]byte, error) { return json.Marshal(rpcResponseJSON{ V: "2.0", - ID: resp.ID, + ID: resp.id, R: resp.Result, E: resp.Error, }) @@ -250,9 +239,9 @@ func (resp RPCResponse) MarshalJSON() ([]byte, error) { func (resp RPCResponse) String() string { if resp.Error == nil { - return fmt.Sprintf("RPCResponse{%s %X}", resp.ID, resp.Result) + return fmt.Sprintf("RPCResponse{%s %X}", resp.ID(), resp.Result) } - return fmt.Sprintf("RPCResponse{%s %v}", resp.ID, resp.Error) + return fmt.Sprintf("RPCResponse{%s %v}", resp.ID(), resp.Error) } //---------------------------------------- diff --git a/rpc/jsonrpc/types/types_test.go b/rpc/jsonrpc/types/types_test.go index 2dbfed895..d5be2f74d 100644 --- a/rpc/jsonrpc/types/types_test.go +++ b/rpc/jsonrpc/types/types_test.go @@ -13,65 +13,48 @@ type SampleResult struct { Value string } -type responseTest struct { - id jsonrpcid - expected string -} - -var responseTests = []responseTest{ - {JSONRPCStringID("1"), `"1"`}, - {JSONRPCStringID("alphabet"), `"alphabet"`}, - {JSONRPCStringID(""), `""`}, - {JSONRPCStringID("àáâ"), `"àáâ"`}, - {JSONRPCIntID(-1), "-1"}, - {JSONRPCIntID(0), "0"}, - {JSONRPCIntID(1), "1"}, - {JSONRPCIntID(100), "100"}, +// Valid JSON identifier texts. +var testIDs = []string{ + `"1"`, `"alphabet"`, `""`, `"àáâ"`, "-1", "0", "1", "100", } func TestResponses(t *testing.T) { - for _, tt := range responseTests { - req := RPCRequest{ - ID: tt.id, - Method: "whatever", - } + for _, id := range testIDs { + req := RPCRequest{id: json.RawMessage(id)} a := req.MakeResponse(&SampleResult{"hello"}) b, err := json.Marshal(a) - require.NoError(t, err) - s := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"Value":"hello"}}`, tt.expected) + require.NoError(t, err, "input id: %q", id) + s := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"Value":"hello"}}`, id) assert.Equal(t, s, string(b)) d := req.MakeErrorf(CodeParseError, "hello world") e, err := json.Marshal(d) require.NoError(t, err) - f := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"error":{"code":-32700,"message":"Parse error","data":"hello world"}}`, tt.expected) + f := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"error":{"code":-32700,"message":"Parse error","data":"hello world"}}`, id) assert.Equal(t, f, string(e)) g := req.MakeErrorf(CodeMethodNotFound, "foo") h, err := json.Marshal(g) require.NoError(t, err) - i := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"error":{"code":-32601,"message":"Method not found","data":"foo"}}`, tt.expected) + i := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"error":{"code":-32601,"message":"Method not found","data":"foo"}}`, id) assert.Equal(t, string(h), i) } } func TestUnmarshallResponses(t *testing.T) { - for _, tt := range responseTests { + for _, id := range testIDs { response := &RPCResponse{} - err := json.Unmarshal( - []byte(fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"Value":"hello"}}`, tt.expected)), - response, - ) - require.NoError(t, err) + input := fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"Value":"hello"}}`, id) + require.NoError(t, json.Unmarshal([]byte(input), &response)) - req := RPCRequest{ID: tt.id} + req := RPCRequest{id: json.RawMessage(id)} a := req.MakeResponse(&SampleResult{"hello"}) assert.Equal(t, *response, a) } - response := &RPCResponse{} - err := json.Unmarshal([]byte(`{"jsonrpc":"2.0","id":true,"result":{"Value":"hello"}}`), response) - require.Error(t, err) + var response RPCResponse + const input = `{"jsonrpc":"2.0","id":true,"result":{"Value":"hello"}}` + require.Error(t, json.Unmarshal([]byte(input), &response)) } func TestRPCError(t *testing.T) {