From ec59b1a1ae42949367f77f4c0cf1e79c4682c5bd Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Thu, 13 Jan 2022 13:27:45 -0800 Subject: [PATCH] rpc: check RPC service functions more carefully (#7587) Require that RPC functions take a context as their first argument, and return an error as either their only result, or the second of two results. This does not change how functions are dispatched, but will make it a little easier to make more invasive changes in the near future. --- rpc/jsonrpc/server/parse_test.go | 4 +- rpc/jsonrpc/server/rpc_func.go | 98 ++++++++++++++++++++++---------- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/rpc/jsonrpc/server/parse_test.go b/rpc/jsonrpc/server/parse_test.go index c661cc3d4..6533a5d44 100644 --- a/rpc/jsonrpc/server/parse_test.go +++ b/rpc/jsonrpc/server/parse_test.go @@ -134,7 +134,7 @@ func TestParseJSONArray(t *testing.T) { } func TestParseJSONRPC(t *testing.T) { - demo := func(ctx context.Context, height int, name string) {} + demo := func(ctx context.Context, height int, name string) error { return nil } call := NewRPCFunc(demo, "height", "name") cases := []struct { @@ -171,7 +171,7 @@ func TestParseJSONRPC(t *testing.T) { } func TestParseURI(t *testing.T) { - demo := func(ctx context.Context, height int, name string) {} + demo := func(ctx context.Context, height int, name string) error { return nil } call := NewRPCFunc(demo, "height", "name") cases := []struct { diff --git a/rpc/jsonrpc/server/rpc_func.go b/rpc/jsonrpc/server/rpc_func.go index 354315567..4ccd045b9 100644 --- a/rpc/jsonrpc/server/rpc_func.go +++ b/rpc/jsonrpc/server/rpc_func.go @@ -1,6 +1,9 @@ package server import ( + "context" + "errors" + "fmt" "net/http" "reflect" @@ -23,7 +26,7 @@ func RegisterRPCFuncs(mux *http.ServeMux, funcMap map[string]*RPCFunc, logger lo // Function introspection -// RPCFunc contains the introspected type information for a function +// RPCFunc contains the introspected type information for a function. type RPCFunc struct { f reflect.Value // underlying rpc function args []reflect.Type // type of each function arg @@ -32,47 +35,80 @@ type RPCFunc struct { ws bool // websocket only } -// NewRPCFunc wraps a function for introspection. -// f is the function, args are comma separated argument names +// NewRPCFunc constructs an RPCFunc for f, which must be a function whose type +// signature matches one of these schemes: +// +// func(context.Context, T1, T2, ...) error +// func(context.Context, T1, T2, ...) (R, error) +// +// for arbitrary types T_i and R. The number of argNames must exactly match the +// number of non-context arguments to f. Otherwise, NewRPCFunc panics. +// +// The parameter names given are used to map JSON object keys to the +// corresonding parameter of the function. The names do not need to match the +// declared names, but must match what the client sends in a request. func NewRPCFunc(f interface{}, argNames ...string) *RPCFunc { - return newRPCFunc(f, argNames, false) + rf, err := newRPCFunc(f, argNames) + if err != nil { + panic("invalid RPC function: " + err.Error()) + } + return rf } -// NewWSRPCFunc wraps a function for introspection and use in the websockets. +// NewWSRPCFunc behaves as NewRPCFunc, but marks the resulting function for use +// via websocket. func NewWSRPCFunc(f interface{}, argNames ...string) *RPCFunc { - return newRPCFunc(f, argNames, true) + rf := NewRPCFunc(f, argNames...) + rf.ws = true + return rf } -func newRPCFunc(f interface{}, argNames []string, wsOnly bool) *RPCFunc { - return &RPCFunc{ - f: reflect.ValueOf(f), - args: funcArgTypes(f), - returns: funcReturnTypes(f), - argNames: argNames, - ws: wsOnly, +var ( + ctxType = reflect.TypeOf((*context.Context)(nil)).Elem() + errType = reflect.TypeOf((*error)(nil)).Elem() +) + +// newRPCFunc constructs an RPCFunc for f. See the comment at NewRPCFunc. +func newRPCFunc(f interface{}, argNames []string) (*RPCFunc, error) { + if f == nil { + return nil, errors.New("nil function") } -} -// return a function's argument types -func funcArgTypes(f interface{}) []reflect.Type { - t := reflect.TypeOf(f) - n := t.NumIn() - typez := make([]reflect.Type, n) - for i := 0; i < n; i++ { - typez[i] = t.In(i) + // Check the type and signature of f. + fv := reflect.ValueOf(f) + if fv.Kind() != reflect.Func { + return nil, errors.New("not a function") } - return typez -} -// return a function's return types -func funcReturnTypes(f interface{}) []reflect.Type { - t := reflect.TypeOf(f) - n := t.NumOut() - typez := make([]reflect.Type, n) - for i := 0; i < n; i++ { - typez[i] = t.Out(i) + ft := fv.Type() + if np := ft.NumIn(); np == 0 { + return nil, errors.New("wrong number of parameters") + } else if ft.In(0) != ctxType { + return nil, errors.New("first parameter is not context.Context") + } else if np-1 != len(argNames) { + return nil, fmt.Errorf("have %d names for %d parameters", len(argNames), np-1) } - return typez + + if no := ft.NumOut(); no < 1 || no > 2 { + return nil, errors.New("wrong number of results") + } else if ft.Out(no-1) != errType { + return nil, errors.New("last result is not error") + } + + args := make([]reflect.Type, ft.NumIn()) + for i := 0; i < ft.NumIn(); i++ { + args[i] = ft.In(i) + } + outs := make([]reflect.Type, ft.NumOut()) + for i := 0; i < ft.NumOut(); i++ { + outs[i] = ft.Out(i) + } + return &RPCFunc{ + f: fv, + args: args, + returns: outs, + argNames: argNames, + }, nil } //-------------------------------------------------------------