diff --git a/rpc/jsonrpc/server/parse_test.go b/rpc/jsonrpc/server/parse_test.go index c661cc3d4..6533a5d44 100644 --- a/rpc/jsonrpc/server/parse_test.go +++ b/rpc/jsonrpc/server/parse_test.go @@ -134,7 +134,7 @@ func TestParseJSONArray(t *testing.T) { } func TestParseJSONRPC(t *testing.T) { - demo := func(ctx context.Context, height int, name string) {} + demo := func(ctx context.Context, height int, name string) error { return nil } call := NewRPCFunc(demo, "height", "name") cases := []struct { @@ -171,7 +171,7 @@ func TestParseJSONRPC(t *testing.T) { } func TestParseURI(t *testing.T) { - demo := func(ctx context.Context, height int, name string) {} + demo := func(ctx context.Context, height int, name string) error { return nil } call := NewRPCFunc(demo, "height", "name") cases := []struct { diff --git a/rpc/jsonrpc/server/rpc_func.go b/rpc/jsonrpc/server/rpc_func.go index 354315567..4ccd045b9 100644 --- a/rpc/jsonrpc/server/rpc_func.go +++ b/rpc/jsonrpc/server/rpc_func.go @@ -1,6 +1,9 @@ package server import ( + "context" + "errors" + "fmt" "net/http" "reflect" @@ -23,7 +26,7 @@ func RegisterRPCFuncs(mux *http.ServeMux, funcMap map[string]*RPCFunc, logger lo // Function introspection -// RPCFunc contains the introspected type information for a function +// RPCFunc contains the introspected type information for a function. type RPCFunc struct { f reflect.Value // underlying rpc function args []reflect.Type // type of each function arg @@ -32,47 +35,80 @@ type RPCFunc struct { ws bool // websocket only } -// NewRPCFunc wraps a function for introspection. -// f is the function, args are comma separated argument names +// NewRPCFunc constructs an RPCFunc for f, which must be a function whose type +// signature matches one of these schemes: +// +// func(context.Context, T1, T2, ...) error +// func(context.Context, T1, T2, ...) (R, error) +// +// for arbitrary types T_i and R. The number of argNames must exactly match the +// number of non-context arguments to f. Otherwise, NewRPCFunc panics. +// +// The parameter names given are used to map JSON object keys to the +// corresonding parameter of the function. The names do not need to match the +// declared names, but must match what the client sends in a request. func NewRPCFunc(f interface{}, argNames ...string) *RPCFunc { - return newRPCFunc(f, argNames, false) + rf, err := newRPCFunc(f, argNames) + if err != nil { + panic("invalid RPC function: " + err.Error()) + } + return rf } -// NewWSRPCFunc wraps a function for introspection and use in the websockets. +// NewWSRPCFunc behaves as NewRPCFunc, but marks the resulting function for use +// via websocket. func NewWSRPCFunc(f interface{}, argNames ...string) *RPCFunc { - return newRPCFunc(f, argNames, true) + rf := NewRPCFunc(f, argNames...) + rf.ws = true + return rf } -func newRPCFunc(f interface{}, argNames []string, wsOnly bool) *RPCFunc { - return &RPCFunc{ - f: reflect.ValueOf(f), - args: funcArgTypes(f), - returns: funcReturnTypes(f), - argNames: argNames, - ws: wsOnly, +var ( + ctxType = reflect.TypeOf((*context.Context)(nil)).Elem() + errType = reflect.TypeOf((*error)(nil)).Elem() +) + +// newRPCFunc constructs an RPCFunc for f. See the comment at NewRPCFunc. +func newRPCFunc(f interface{}, argNames []string) (*RPCFunc, error) { + if f == nil { + return nil, errors.New("nil function") } -} -// return a function's argument types -func funcArgTypes(f interface{}) []reflect.Type { - t := reflect.TypeOf(f) - n := t.NumIn() - typez := make([]reflect.Type, n) - for i := 0; i < n; i++ { - typez[i] = t.In(i) + // Check the type and signature of f. + fv := reflect.ValueOf(f) + if fv.Kind() != reflect.Func { + return nil, errors.New("not a function") } - return typez -} -// return a function's return types -func funcReturnTypes(f interface{}) []reflect.Type { - t := reflect.TypeOf(f) - n := t.NumOut() - typez := make([]reflect.Type, n) - for i := 0; i < n; i++ { - typez[i] = t.Out(i) + ft := fv.Type() + if np := ft.NumIn(); np == 0 { + return nil, errors.New("wrong number of parameters") + } else if ft.In(0) != ctxType { + return nil, errors.New("first parameter is not context.Context") + } else if np-1 != len(argNames) { + return nil, fmt.Errorf("have %d names for %d parameters", len(argNames), np-1) } - return typez + + if no := ft.NumOut(); no < 1 || no > 2 { + return nil, errors.New("wrong number of results") + } else if ft.Out(no-1) != errType { + return nil, errors.New("last result is not error") + } + + args := make([]reflect.Type, ft.NumIn()) + for i := 0; i < ft.NumIn(); i++ { + args[i] = ft.In(i) + } + outs := make([]reflect.Type, ft.NumOut()) + for i := 0; i < ft.NumOut(); i++ { + outs[i] = ft.Out(i) + } + return &RPCFunc{ + f: fv, + args: args, + returns: outs, + argNames: argNames, + }, nil } //-------------------------------------------------------------