diff --git a/rpc/jsonrpc/types/types.go b/rpc/jsonrpc/types/types.go index 108d056a7..e53a925ed 100644 --- a/rpc/jsonrpc/types/types.go +++ b/rpc/jsonrpc/types/types.go @@ -47,49 +47,47 @@ func idFromInterface(idInterface interface{}) (jsonrpcid, error) { // REQUEST type RPCRequest struct { - JSONRPC string `json:"jsonrpc"` - ID jsonrpcid `json:"id,omitempty"` - Method string `json:"method"` - Params json.RawMessage `json:"params"` // must be map[string]interface{} or []interface{} + ID jsonrpcid + Method string + Params json.RawMessage } -// UnmarshalJSON custom JSON unmarshaling due to jsonrpcid being string or int +type rpcRequestJSON struct { + V string `json:"jsonrpc"` // must be "2.0" + ID interface{} `json:"id,omitempty"` + M string `json:"method"` + P json.RawMessage `json:"params"` +} + +// UnmarshalJSON decodes a request from a JSON-RPC 2.0 request object. func (req *RPCRequest) UnmarshalJSON(data []byte) error { - unsafeReq := struct { - JSONRPC string `json:"jsonrpc"` - ID interface{} `json:"id,omitempty"` - Method string `json:"method"` - Params json.RawMessage `json:"params"` // must be map[string]interface{} or []interface{} - }{} - - err := json.Unmarshal(data, &unsafeReq) - if err != nil { + var wrapper rpcRequestJSON + if err := json.Unmarshal(data, &wrapper); err != nil { return err + } else if wrapper.V != "" && wrapper.V != "2.0" { + return fmt.Errorf("invalid version: %q", wrapper.V) } - if unsafeReq.ID == nil { // notification - return nil - } - - req.JSONRPC = unsafeReq.JSONRPC - req.Method = unsafeReq.Method - req.Params = unsafeReq.Params - id, err := idFromInterface(unsafeReq.ID) - if err != nil { - return err + if wrapper.ID != nil { + id, err := idFromInterface(wrapper.ID) + if err != nil { + return fmt.Errorf("invalid request ID: %w", err) + } + req.ID = id } - req.ID = id - + req.Method = wrapper.M + req.Params = wrapper.P return nil } -func NewRPCRequest(id jsonrpcid, method string, params json.RawMessage) RPCRequest { - return RPCRequest{ - JSONRPC: "2.0", - ID: id, - Method: method, - Params: params, - } +// MarshalJSON marshals a request with the appropriate version tag. +func (req RPCRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(rpcRequestJSON{ + V: "2.0", + ID: req.ID, + M: req.Method, + P: req.Params, + }) } func (req RPCRequest) String() string { @@ -102,7 +100,11 @@ func ParamsToRequest(id jsonrpcid, method string, params interface{}) (RPCReques if err != nil { return RPCRequest{}, err } - return NewRPCRequest(id, method, payload), nil + return RPCRequest{ + ID: id, + Method: method, + Params: payload, + }, nil } //---------------------------------------- @@ -123,52 +125,62 @@ func (err RPCError) Error() string { } type RPCResponse struct { - JSONRPC string `json:"jsonrpc"` - ID jsonrpcid `json:"id,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *RPCError `json:"error,omitempty"` + ID jsonrpcid + Result json.RawMessage + Error *RPCError +} + +type rpcResponseJSON struct { + V string `json:"jsonrpc"` // must be "2.0" + ID interface{} `json:"id,omitempty"` + R json.RawMessage `json:"result,omitempty"` + E *RPCError `json:"error,omitempty"` } -// UnmarshalJSON custom JSON unmarshaling due to jsonrpcid being string or int +// UnmarshalJSON decodes a response from a JSON-RPC 2.0 response object. func (resp *RPCResponse) UnmarshalJSON(data []byte) error { - unsafeResp := &struct { - JSONRPC string `json:"jsonrpc"` - ID interface{} `json:"id,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *RPCError `json:"error,omitempty"` - }{} - err := json.Unmarshal(data, &unsafeResp) - if err != nil { + var wrapper rpcResponseJSON + if err := json.Unmarshal(data, &wrapper); err != nil { return err + } else if wrapper.V != "" && wrapper.V != "2.0" { + return fmt.Errorf("invalid version: %q", wrapper.V) } - resp.JSONRPC = unsafeResp.JSONRPC - resp.Error = unsafeResp.Error - resp.Result = unsafeResp.Result - if unsafeResp.ID == nil { - return nil - } - id, err := idFromInterface(unsafeResp.ID) - if err != nil { - return err + if wrapper.ID != nil { + id, err := idFromInterface(wrapper.ID) + if err != nil { + return fmt.Errorf("invalid response ID: %w", err) + } + resp.ID = id } - resp.ID = id + + resp.Error = wrapper.E + resp.Result = wrapper.R return nil } +// MarshalJSON marshals a response with the appropriate version tag. +func (resp RPCResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(rpcResponseJSON{ + V: "2.0", + ID: resp.ID, + R: resp.Result, + E: resp.Error, + }) +} + func NewRPCSuccessResponse(id jsonrpcid, res interface{}) RPCResponse { result, err := json.Marshal(res) if err != nil { return RPCInternalError(id, fmt.Errorf("error marshaling response: %w", err)) } - return RPCResponse{JSONRPC: "2.0", ID: id, Result: result} + return RPCResponse{ID: id, Result: result} } func NewRPCErrorResponse(id jsonrpcid, code int, msg string, data string) RPCResponse { return RPCResponse{ - JSONRPC: "2.0", - ID: id, - Error: &RPCError{Code: code, Message: msg, Data: data}, + ID: id, + Error: &RPCError{Code: code, Message: msg, Data: data}, } }