From 4a1b714ca4decaa3fdc3b08c855220ad09aa8121 Mon Sep 17 00:00:00 2001 From: Ethan Frey Date: Wed, 3 May 2017 16:42:30 +0200 Subject: [PATCH] All tests pass without go-wire json ptr madness --- rpc/lib/server/handlers.go | 106 +++++++------- rpc/lib/server/wire.go | 285 ------------------------------------- 2 files changed, 53 insertions(+), 338 deletions(-) delete mode 100644 rpc/lib/server/wire.go diff --git a/rpc/lib/server/handlers.go b/rpc/lib/server/handlers.go index 321b72d6d..588af2765 100644 --- a/rpc/lib/server/handlers.go +++ b/rpc/lib/server/handlers.go @@ -15,6 +15,7 @@ import ( "github.com/gorilla/websocket" "github.com/pkg/errors" //wire "github.com/tendermint/go-wire" + types "github.com/tendermint/tendermint/rpc/lib/types" cmn "github.com/tendermint/tmlibs/common" events "github.com/tendermint/tmlibs/events" @@ -140,6 +141,45 @@ func makeJSONRPCHandler(funcMap map[string]*RPCFunc) http.HandlerFunc { } } +func mapParamsToArgs(rpcFunc *RPCFunc, params map[string]*json.RawMessage, argsOffset int) ([]reflect.Value, error) { + values := make([]reflect.Value, len(rpcFunc.argNames)) + for i, argName := range rpcFunc.argNames { + argType := rpcFunc.args[i+argsOffset] + + if p, ok := params[argName]; ok && len(*p) > 0 { + val := reflect.New(argType) + err := json.Unmarshal(*p, val.Interface()) + if err != nil { + return nil, err + } + values[i] = val.Elem() + } else { // use default for that type + values[i] = reflect.Zero(argType) + } + } + + return values, nil +} + +func arrayParamsToArgs(rpcFunc *RPCFunc, params []*json.RawMessage, argsOffset int) ([]reflect.Value, error) { + if len(rpcFunc.argNames) != len(params) { + return nil, errors.Errorf("Expected %v parameters (%v), got %v (%v)", + len(rpcFunc.argNames), rpcFunc.argNames, len(params), params) + } + + values := make([]reflect.Value, len(params)) + for i, p := range params { + argType := rpcFunc.args[i+argsOffset] + val := reflect.New(argType) + err := json.Unmarshal(*p, val.Interface()) + if err != nil { + return nil, err + } + values[i] = val.Elem() + } + return values, nil +} + // raw is unparsed json (from json.RawMessage). It either has // and array or a map behind it, let's parse this all without resorting to wire... // @@ -148,51 +188,22 @@ func makeJSONRPCHandler(funcMap map[string]*RPCFunc) http.HandlerFunc { // rpcFunc.args = [rpctypes.WSRPCContext string] // rpcFunc.argNames = ["arg"] func jsonParamsToArgs(rpcFunc *RPCFunc, raw []byte, argsOffset int) ([]reflect.Value, error) { - values := make([]reflect.Value, len(rpcFunc.argNames)) - - // right now, this is the same as before, but the whole parsing is in one function... - var paramsI interface{} - err := json.Unmarshal(raw, ¶msI) - if err != nil { - return nil, err + // first, try to get the map.. + var m map[string]*json.RawMessage + err := json.Unmarshal(raw, &m) + if err == nil { + return mapParamsToArgs(rpcFunc, m, argsOffset) } - switch params := paramsI.(type) { - - case map[string]interface{}: - for i, argName := range rpcFunc.argNames { - argType := rpcFunc.args[i+argsOffset] - - // decode param if provided - if param, ok := params[argName]; ok && "" != param { - v, err := _jsonObjectToArg(argType, param) - if err != nil { - return nil, err - } - values[i] = v - } else { // use default for that type - values[i] = reflect.Zero(argType) - } - } - case []interface{}: - if len(rpcFunc.argNames) != len(params) { - return nil, errors.New(fmt.Sprintf("Expected %v parameters (%v), got %v (%v)", - len(rpcFunc.argNames), rpcFunc.argNames, len(params), params)) - } - values := make([]reflect.Value, len(params)) - for i, p := range params { - ty := rpcFunc.args[i+argsOffset] - v, err := _jsonObjectToArg(ty, p) - if err != nil { - return nil, err - } - values[i] = v - } - return values, nil - default: - return nil, fmt.Errorf("Unknown type for JSON params %v. Expected map[string]interface{} or []interface{}", reflect.TypeOf(paramsI)) + // otherwise, try an array + var a []*json.RawMessage + err = json.Unmarshal(raw, &a) + if err == nil { + return arrayParamsToArgs(rpcFunc, a, argsOffset) } - return values, nil + + // otherwise, bad format, we cannot parse + return nil, errors.Errorf("Unknown type for JSON params: %v. Expected map or array", err) } // Convert a []interface{} OR a map[string]interface{} to properly typed values @@ -209,17 +220,6 @@ func jsonParamsToArgsWS(rpcFunc *RPCFunc, params *json.RawMessage, wsCtx types.W return append([]reflect.Value{reflect.ValueOf(wsCtx)}, values...), nil } -func _jsonObjectToArg(ty reflect.Type, object interface{}) (reflect.Value, error) { - var err error - v := reflect.New(ty) - readJSONObjectPtr(v.Interface(), object, &err) - if err != nil { - return v, err - } - v = v.Elem() - return v, nil -} - // rpc.json //----------------------------------------------------------------------------- // rpc.http diff --git a/rpc/lib/server/wire.go b/rpc/lib/server/wire.go deleted file mode 100644 index 4ed9771a1..000000000 --- a/rpc/lib/server/wire.go +++ /dev/null @@ -1,285 +0,0 @@ -package rpcserver - -import ( - "encoding/base64" - "encoding/hex" - "reflect" - "time" - - "github.com/pkg/errors" - - "github.com/tendermint/go-wire" - "github.com/tendermint/go-wire/data" - cmn "github.com/tendermint/tmlibs/common" -) - -// NOTE: This code is copied from go-wire in order to allow byte arrays to be treated -// differently whether they are data.Bytes (hex) or not (base64). We can't build this -// into wire directly without refactoring since we'd need to import data but data imports wire - -var ( - timeType = wire.GetTypeFromStructDeclaration(struct{ time.Time }{}) -) - -func readJSONObjectPtr(o interface{}, object interface{}, err *error) interface{} { - rv, rt := reflect.ValueOf(o), reflect.TypeOf(o) - if rv.Kind() == reflect.Ptr { - readReflectJSON(rv.Elem(), rt.Elem(), wire.Options{}, object, err) - } else { - cmn.PanicSanity("ReadJSON(Object)Ptr expects o to be a pointer") - } - return o -} - -func readByteJSON(o interface{}) (typeByte byte, rest interface{}, err error) { - oSlice, ok := o.([]interface{}) - if !ok { - err = errors.New(cmn.Fmt("Expected type [Byte,?] but got type %v", reflect.TypeOf(o))) - return - } - if len(oSlice) != 2 { - err = errors.New(cmn.Fmt("Expected [Byte,?] len 2 but got len %v", len(oSlice))) - return - } - typeByte_, ok := oSlice[0].(float64) - typeByte = byte(typeByte_) - rest = oSlice[1] - return -} - -// Contract: Caller must ensure that rt is supported -// (e.g. is recursively composed of supported native types, and structs and slices.) -// rv and rt refer to the object we're unmarhsaling into, whereas o is the result of naiive json unmarshal (map[string]interface{}) -func readReflectJSON(rv reflect.Value, rt reflect.Type, opts wire.Options, o interface{}, err *error) { - - // Get typeInfo - typeInfo := wire.GetTypeInfo(rt) - - if rt.Kind() == reflect.Interface { - if !typeInfo.IsRegisteredInterface { - // There's no way we can read such a thing. - *err = errors.New(cmn.Fmt("Cannot read unregistered interface type %v", rt)) - return - } - if o == nil { - return // nil - } - typeByte, rest, err_ := readByteJSON(o) - if err_ != nil { - *err = err_ - return - } - crt, ok := typeInfo.ByteToType[typeByte] - if !ok { - *err = errors.New(cmn.Fmt("Byte %X not registered for interface %v", typeByte, rt)) - return - } - if crt.Kind() == reflect.Ptr { - crt = crt.Elem() - crv := reflect.New(crt) - readReflectJSON(crv.Elem(), crt, opts, rest, err) - rv.Set(crv) // NOTE: orig rv is ignored. - } else { - crv := reflect.New(crt).Elem() - readReflectJSON(crv, crt, opts, rest, err) - rv.Set(crv) // NOTE: orig rv is ignored. - } - return - } - - if rt.Kind() == reflect.Ptr { - if o == nil { - return // nil - } - // Create new struct if rv is nil. - if rv.IsNil() { - newRv := reflect.New(rt.Elem()) - rv.Set(newRv) - rv = newRv - } - // Dereference pointer - rv, rt = rv.Elem(), rt.Elem() - typeInfo = wire.GetTypeInfo(rt) - // continue... - } - - switch rt.Kind() { - case reflect.Array: - elemRt := rt.Elem() - length := rt.Len() - if elemRt.Kind() == reflect.Uint8 { - // Special case: Bytearrays - oString, ok := o.(string) - if !ok { - *err = errors.New(cmn.Fmt("Expected string but got type %v", reflect.TypeOf(o))) - return - } - - // if its data.Bytes, use hex; else use base64 - dbty := reflect.TypeOf(data.Bytes{}) - var buf []byte - var err_ error - if rt == dbty { - buf, err_ = hex.DecodeString(oString) - } else { - buf, err_ = base64.StdEncoding.DecodeString(oString) - } - if err_ != nil { - *err = err_ - return - } - if len(buf) != length { - *err = errors.New(cmn.Fmt("Expected bytearray of length %v but got %v", length, len(buf))) - return - } - reflect.Copy(rv, reflect.ValueOf(buf)) - } else { - oSlice, ok := o.([]interface{}) - if !ok { - *err = errors.New(cmn.Fmt("Expected array of %v but got type %v", rt, reflect.TypeOf(o))) - return - } - if len(oSlice) != length { - *err = errors.New(cmn.Fmt("Expected array of length %v but got %v", length, len(oSlice))) - return - } - for i := 0; i < length; i++ { - elemRv := rv.Index(i) - readReflectJSON(elemRv, elemRt, opts, oSlice[i], err) - } - } - - case reflect.Slice: - elemRt := rt.Elem() - if elemRt.Kind() == reflect.Uint8 { - // Special case: Byteslices - oString, ok := o.(string) - if !ok { - *err = errors.New(cmn.Fmt("Expected string but got type %v", reflect.TypeOf(o))) - return - } - // if its data.Bytes, use hex; else use base64 - dbty := reflect.TypeOf(data.Bytes{}) - var buf []byte - var err_ error - if rt == dbty { - buf, err_ = hex.DecodeString(oString) - } else { - buf, err_ = base64.StdEncoding.DecodeString(oString) - } - if err_ != nil { - *err = err_ - return - } - rv.Set(reflect.ValueOf(buf)) - } else { - // Read length - oSlice, ok := o.([]interface{}) - if !ok { - *err = errors.New(cmn.Fmt("Expected array of %v but got type %v", rt, reflect.TypeOf(o))) - return - } - length := len(oSlice) - sliceRv := reflect.MakeSlice(rt, length, length) - // Read elems - for i := 0; i < length; i++ { - elemRv := sliceRv.Index(i) - readReflectJSON(elemRv, elemRt, opts, oSlice[i], err) - } - rv.Set(sliceRv) - } - - case reflect.Struct: - if rt == timeType { - // Special case: time.Time - str, ok := o.(string) - if !ok { - *err = errors.New(cmn.Fmt("Expected string but got type %v", reflect.TypeOf(o))) - return - } - // try three ways, seconds, milliseconds, or microseconds... - t, err_ := time.Parse(time.RFC3339Nano, str) - if err_ != nil { - *err = err_ - return - } - rv.Set(reflect.ValueOf(t)) - } else { - if typeInfo.Unwrap { - f := typeInfo.Fields[0] - fieldIdx, fieldType, opts := f.Index, f.Type, f.Options - fieldRv := rv.Field(fieldIdx) - readReflectJSON(fieldRv, fieldType, opts, o, err) - } else { - oMap, ok := o.(map[string]interface{}) - if !ok { - *err = errors.New(cmn.Fmt("Expected map but got type %v", reflect.TypeOf(o))) - return - } - // TODO: ensure that all fields are set? - // TODO: disallow unknown oMap fields? - for _, fieldInfo := range typeInfo.Fields { - f := fieldInfo - fieldIdx, fieldType, opts := f.Index, f.Type, f.Options - value, ok := oMap[opts.JSONName] - if !ok { - continue // Skip missing fields. - } - fieldRv := rv.Field(fieldIdx) - readReflectJSON(fieldRv, fieldType, opts, value, err) - } - } - } - - case reflect.String: - str, ok := o.(string) - if !ok { - *err = errors.New(cmn.Fmt("Expected string but got type %v", reflect.TypeOf(o))) - return - } - rv.SetString(str) - - case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int: - num, ok := o.(float64) - if !ok { - *err = errors.New(cmn.Fmt("Expected numeric but got type %v", reflect.TypeOf(o))) - return - } - rv.SetInt(int64(num)) - - case reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Uint: - num, ok := o.(float64) - if !ok { - *err = errors.New(cmn.Fmt("Expected numeric but got type %v", reflect.TypeOf(o))) - return - } - if num < 0 { - *err = errors.New(cmn.Fmt("Expected unsigned numeric but got %v", num)) - return - } - rv.SetUint(uint64(num)) - - case reflect.Float64, reflect.Float32: - if !opts.Unsafe { - *err = errors.New("Wire float* support requires `wire:\"unsafe\"`") - return - } - num, ok := o.(float64) - if !ok { - *err = errors.New(cmn.Fmt("Expected numeric but got type %v", reflect.TypeOf(o))) - return - } - rv.SetFloat(num) - - case reflect.Bool: - bl, ok := o.(bool) - if !ok { - *err = errors.New(cmn.Fmt("Expected boolean but got type %v", reflect.TypeOf(o))) - return - } - rv.SetBool(bl) - - default: - cmn.PanicSanity(cmn.Fmt("Unknown field type %v", rt.Kind())) - } -}