From a4d0a431003b21751b3011908a08badafe890f72 Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Tue, 11 Jan 2022 11:04:55 -0800 Subject: [PATCH] rpc: refactor the HTTP POST handler (#7555) No functional changes. - Pull out a some helper code to simplify the control flow within the body of the HTTP request handler. - Front-load the URL path check so it does not get repeated for each request. --- rpc/jsonrpc/server/http_json_handler.go | 161 +++++++++++++----------- 1 file changed, 86 insertions(+), 75 deletions(-) diff --git a/rpc/jsonrpc/server/http_json_handler.go b/rpc/jsonrpc/server/http_json_handler.go index dabeee074..879a58df9 100644 --- a/rpc/jsonrpc/server/http_json_handler.go +++ b/rpc/jsonrpc/server/http_json_handler.go @@ -20,127 +20,103 @@ import ( // jsonrpc calls grab the given method's function info and runs reflect.Call func makeJSONRPCHandler(funcMap map[string]*RPCFunc, logger log.Logger) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - b, err := io.ReadAll(r.Body) - if err != nil { - res := rpctypes.RPCInvalidRequestError(nil, - fmt.Errorf("error reading request body: %w", err), - ) - if wErr := WriteRPCResponseHTTPError(w, res); wErr != nil { - logger.Error("failed to write response", "res", res, "err", wErr) + return func(w http.ResponseWriter, hreq *http.Request) { + fail := func(res rpctypes.RPCResponse) { + if err := WriteRPCResponseHTTPError(w, res); err != nil { + logger.Error("Failed writing error response", "res", res, "err", err) } + } + + // For POST requests, reject a non-root URL path. This should not happen + // in the standard configuration, since the wrapper checks the path. + if hreq.URL.Path != "/" { + fail(rpctypes.RPCInvalidRequestError(nil, fmt.Errorf("invalid path: %q", hreq.URL.Path))) + return + } + + b, err := io.ReadAll(hreq.Body) + if err != nil { + fail(rpctypes.RPCInvalidRequestError(nil, fmt.Errorf("reading request body: %w", err))) return } // if its an empty request (like from a browser), just display a list of // functions if len(b) == 0 { - writeListOfEndpoints(w, r, funcMap) + writeListOfEndpoints(w, hreq, funcMap) return } - // first try to unmarshal the incoming request as an array of RPC requests - var ( - requests []rpctypes.RPCRequest - responses []rpctypes.RPCResponse - ) - if err := json.Unmarshal(b, &requests); err != nil { - // next, try to unmarshal as a single request - var request rpctypes.RPCRequest - if err := json.Unmarshal(b, &request); err != nil { - res := rpctypes.RPCParseError(fmt.Errorf("error unmarshaling request: %w", err)) - if wErr := WriteRPCResponseHTTPError(w, res); wErr != nil { - logger.Error("failed to write response", "res", res, "err", wErr) - } - return - } - requests = []rpctypes.RPCRequest{request} + requests, err := parseRequests(b) + if err != nil { + fail(rpctypes.RPCParseError(fmt.Errorf("decoding request: %w", err))) + return } // Set the default response cache to true unless // 1. Any RPC request rrror. // 2. Any RPC request doesn't allow to be cached. // 3. Any RPC request has the height argument and the value is 0 (the default). - var c = true - for _, request := range requests { - request := request - - // A Notification is a Request object without an "id" member. - // The Server MUST NOT reply to a Notification, including those that are within a batch request. - if request.ID == nil { - logger.Debug( - "HTTPJSONRPC received a notification, skipping... (please send a non-empty ID if you want to call a method)", - "req", request, - ) + var responses []rpctypes.RPCResponse + mayCache := true + for _, req := range requests { + // Ignore notifications, which this service does not support. + if req.ID == nil { + logger.Debug("Ignoring notification", "req", req) continue } - if len(r.URL.Path) > 1 { - responses = append( - responses, - rpctypes.RPCInvalidRequestError(request.ID, fmt.Errorf("path %s is invalid", r.URL.Path)), - ) - c = false - continue - } - rpcFunc, ok := funcMap[request.Method] + + rpcFunc, ok := funcMap[req.Method] if !ok || rpcFunc.ws { - responses = append(responses, rpctypes.RPCMethodNotFoundError(request.ID)) - c = false + responses = append(responses, rpctypes.RPCMethodNotFoundError(req.ID)) + mayCache = false continue } - ctx := &rpctypes.Context{JSONReq: &request, HTTPReq: r} - args := []reflect.Value{reflect.ValueOf(ctx)} - if len(request.Params) > 0 { - fnArgs, err := jsonParamsToArgs(rpcFunc, request.Params) - if err != nil { - responses = append( - responses, - rpctypes.RPCInvalidParamsError(request.ID, fmt.Errorf("error converting json params to arguments: %w", err)), - ) - c = false - continue - } - args = append(args, fnArgs...) + if !rpcFunc.cache { + mayCache = false + } + args, err := parseParams(rpcFunc, hreq, req) + if err != nil { + responses = append(responses, rpctypes.RPCInvalidParamsError( + req.ID, fmt.Errorf("converting JSON parameters: %w", err))) + mayCache = false + continue } - if hasDefaultHeight(request, args) { - c = false + if hasDefaultHeight(req, args) { + mayCache = false } returns := rpcFunc.f.Call(args) - logger.Debug("HTTPJSONRPC", "method", request.Method, "args", args, "returns", returns) + logger.Debug("HTTPJSONRPC", "method", req.Method, "args", args, "returns", returns) result, err := unreflectResult(returns) switch e := err.(type) { // if no error then return a success response case nil: - responses = append(responses, rpctypes.NewRPCSuccessResponse(request.ID, result)) + responses = append(responses, rpctypes.NewRPCSuccessResponse(req.ID, result)) // if this already of type RPC error then forward that error case *rpctypes.RPCError: - responses = append(responses, rpctypes.NewRPCErrorResponse(request.ID, e.Code, e.Message, e.Data)) - c = false + responses = append(responses, rpctypes.NewRPCErrorResponse(req.ID, e.Code, e.Message, e.Data)) + mayCache = false default: // we need to unwrap the error and parse it accordingly switch errors.Unwrap(err) { // check if the error was due to an invald request case coretypes.ErrZeroOrNegativeHeight, coretypes.ErrZeroOrNegativePerPage, coretypes.ErrPageOutOfRange, coretypes.ErrInvalidRequest: - responses = append(responses, rpctypes.RPCInvalidRequestError(request.ID, err)) - c = false + responses = append(responses, rpctypes.RPCInvalidRequestError(req.ID, err)) + mayCache = false // lastly default all remaining errors as internal errors default: // includes ctypes.ErrHeightNotAvailable and ctypes.ErrHeightExceedsChainHead - responses = append(responses, rpctypes.RPCInternalError(request.ID, err)) - c = false + responses = append(responses, rpctypes.RPCInternalError(req.ID, err)) + mayCache = false } } - - if c && !rpcFunc.cache { - c = false - } } if len(responses) > 0 { - if wErr := WriteRPCResponseHTTP(w, c, responses...); wErr != nil { + if wErr := WriteRPCResponseHTTP(w, mayCache, responses...); wErr != nil { logger.Error("failed to write responses", "err", wErr) } } @@ -160,6 +136,24 @@ func handleInvalidJSONRPCPaths(next http.HandlerFunc) http.HandlerFunc { } } +// parseRequests parses a JSON-RPC request or request batch from data. +func parseRequests(data []byte) ([]rpctypes.RPCRequest, error) { + var reqs []rpctypes.RPCRequest + var err error + + isArray := bytes.HasPrefix(bytes.TrimSpace(data), []byte("[")) + if isArray { + err = json.Unmarshal(data, &reqs) + } else { + reqs = append(reqs, rpctypes.RPCRequest{}) + err = json.Unmarshal(data, &reqs[0]) + } + if err != nil { + return nil, err + } + return reqs, nil +} + func mapParamsToArgs( rpcFunc *RPCFunc, params map[string]json.RawMessage, @@ -209,6 +203,23 @@ func arrayParamsToArgs( return values, nil } +// parseParams parses the JSON parameters of rpcReq into the arguments of fn, +// returning the corresponding argument values or an error. +func parseParams(fn *RPCFunc, httpReq *http.Request, rpcReq rpctypes.RPCRequest) ([]reflect.Value, error) { + args := []reflect.Value{reflect.ValueOf(&rpctypes.Context{ + JSONReq: &rpcReq, + HTTPReq: httpReq, + })} + if len(rpcReq.Params) == 0 { + return args, nil + } + fargs, err := jsonParamsToArgs(fn, rpcReq.Params) + if err != nil { + return nil, err + } + return append(args, fargs...), nil +} + // raw is unparsed json (from json.RawMessage) encoding either a map or an // array. //