You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

278 lines
7.3 KiB

  1. package json
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "reflect"
  8. )
  9. // Unmarshal unmarshals JSON into the given value, using Amino-compatible JSON encoding (strings
  10. // for 64-bit numbers, and type wrappers for registered types).
  11. func Unmarshal(bz []byte, v interface{}) error {
  12. return decode(bz, v)
  13. }
  14. func decode(bz []byte, v interface{}) error {
  15. if len(bz) == 0 {
  16. return errors.New("cannot decode empty bytes")
  17. }
  18. rv := reflect.ValueOf(v)
  19. if rv.Kind() != reflect.Ptr {
  20. return errors.New("must decode into a pointer")
  21. }
  22. rv = rv.Elem()
  23. // If this is a registered type, defer to interface decoder regardless of whether the input is
  24. // an interface or a bare value. This retains Amino's behavior, but is inconsistent with
  25. // behavior in structs where an interface field will get the type wrapper while a bare value
  26. // field will not.
  27. if typeRegistry.name(rv.Type()) != "" {
  28. return decodeReflectInterface(bz, rv)
  29. }
  30. return decodeReflect(bz, rv)
  31. }
  32. func decodeReflect(bz []byte, rv reflect.Value) error {
  33. if !rv.CanAddr() {
  34. return errors.New("value is not addressable")
  35. }
  36. // Handle null for slices, interfaces, and pointers
  37. if bytes.Equal(bz, []byte("null")) {
  38. rv.Set(reflect.Zero(rv.Type()))
  39. return nil
  40. }
  41. // Dereference-and-construct pointers, to handle nested pointers.
  42. for rv.Kind() == reflect.Ptr {
  43. if rv.IsNil() {
  44. rv.Set(reflect.New(rv.Type().Elem()))
  45. }
  46. rv = rv.Elem()
  47. }
  48. // Times must be UTC and end with Z
  49. if rv.Type() == timeType {
  50. switch {
  51. case len(bz) < 2 || bz[0] != '"' || bz[len(bz)-1] != '"':
  52. return fmt.Errorf("JSON time must be an RFC3339 string, but got %q", bz)
  53. case bz[len(bz)-2] != 'Z':
  54. return fmt.Errorf("JSON time must be UTC and end with 'Z', but got %q", bz)
  55. }
  56. }
  57. // If value implements json.Umarshaler, call it.
  58. if rv.Addr().Type().Implements(jsonUnmarshalerType) {
  59. return rv.Addr().Interface().(json.Unmarshaler).UnmarshalJSON(bz)
  60. }
  61. switch rv.Type().Kind() {
  62. // Decode complex types recursively.
  63. case reflect.Slice, reflect.Array:
  64. return decodeReflectList(bz, rv)
  65. case reflect.Map:
  66. return decodeReflectMap(bz, rv)
  67. case reflect.Struct:
  68. return decodeReflectStruct(bz, rv)
  69. case reflect.Interface:
  70. return decodeReflectInterface(bz, rv)
  71. // For 64-bit integers, unwrap expected string and defer to stdlib for integer decoding.
  72. case reflect.Int64, reflect.Int, reflect.Uint64, reflect.Uint:
  73. if bz[0] != '"' || bz[len(bz)-1] != '"' {
  74. return fmt.Errorf("invalid 64-bit integer encoding %q, expected string", string(bz))
  75. }
  76. bz = bz[1 : len(bz)-1]
  77. fallthrough
  78. // Anything else we defer to the stdlib.
  79. default:
  80. return decodeStdlib(bz, rv)
  81. }
  82. }
  83. func decodeReflectList(bz []byte, rv reflect.Value) error {
  84. if !rv.CanAddr() {
  85. return errors.New("list value is not addressable")
  86. }
  87. switch rv.Type().Elem().Kind() {
  88. // Decode base64-encoded bytes using stdlib decoder, via byte slice for arrays.
  89. case reflect.Uint8:
  90. if rv.Type().Kind() == reflect.Array {
  91. var buf []byte
  92. if err := json.Unmarshal(bz, &buf); err != nil {
  93. return err
  94. }
  95. if len(buf) != rv.Len() {
  96. return fmt.Errorf("got %v bytes, expected %v", len(buf), rv.Len())
  97. }
  98. reflect.Copy(rv, reflect.ValueOf(buf))
  99. } else if err := decodeStdlib(bz, rv); err != nil {
  100. return err
  101. }
  102. // Decode anything else into a raw JSON slice, and decode values recursively.
  103. default:
  104. var rawSlice []json.RawMessage
  105. if err := json.Unmarshal(bz, &rawSlice); err != nil {
  106. return err
  107. }
  108. if rv.Type().Kind() == reflect.Slice {
  109. rv.Set(reflect.MakeSlice(reflect.SliceOf(rv.Type().Elem()), len(rawSlice), len(rawSlice)))
  110. }
  111. if rv.Len() != len(rawSlice) { // arrays of wrong size
  112. return fmt.Errorf("got list of %v elements, expected %v", len(rawSlice), rv.Len())
  113. }
  114. for i, bz := range rawSlice {
  115. if err := decodeReflect(bz, rv.Index(i)); err != nil {
  116. return err
  117. }
  118. }
  119. }
  120. // Replace empty slices with nil slices, for Amino compatibility
  121. if rv.Type().Kind() == reflect.Slice && rv.Len() == 0 {
  122. rv.Set(reflect.Zero(rv.Type()))
  123. }
  124. return nil
  125. }
  126. func decodeReflectMap(bz []byte, rv reflect.Value) error {
  127. if !rv.CanAddr() {
  128. return errors.New("map value is not addressable")
  129. }
  130. // Decode into a raw JSON map, using string keys.
  131. rawMap := make(map[string]json.RawMessage)
  132. if err := json.Unmarshal(bz, &rawMap); err != nil {
  133. return err
  134. }
  135. if rv.Type().Key().Kind() != reflect.String {
  136. return fmt.Errorf("map keys must be strings, got %v", rv.Type().Key().String())
  137. }
  138. // Recursively decode values.
  139. rv.Set(reflect.MakeMapWithSize(rv.Type(), len(rawMap)))
  140. for key, bz := range rawMap {
  141. value := reflect.New(rv.Type().Elem()).Elem()
  142. if err := decodeReflect(bz, value); err != nil {
  143. return err
  144. }
  145. rv.SetMapIndex(reflect.ValueOf(key), value)
  146. }
  147. return nil
  148. }
  149. func decodeReflectStruct(bz []byte, rv reflect.Value) error {
  150. if !rv.CanAddr() {
  151. return errors.New("struct value is not addressable")
  152. }
  153. sInfo := makeStructInfo(rv.Type())
  154. // Decode raw JSON values into a string-keyed map.
  155. rawMap := make(map[string]json.RawMessage)
  156. if err := json.Unmarshal(bz, &rawMap); err != nil {
  157. return err
  158. }
  159. for i, fInfo := range sInfo.fields {
  160. if !fInfo.hidden {
  161. frv := rv.Field(i)
  162. bz := rawMap[fInfo.jsonName]
  163. if len(bz) > 0 {
  164. if err := decodeReflect(bz, frv); err != nil {
  165. return err
  166. }
  167. } else if !fInfo.omitEmpty {
  168. frv.Set(reflect.Zero(frv.Type()))
  169. }
  170. }
  171. }
  172. return nil
  173. }
  174. func decodeReflectInterface(bz []byte, rv reflect.Value) error {
  175. if !rv.CanAddr() {
  176. return errors.New("interface value not addressable")
  177. }
  178. // Decode the interface wrapper.
  179. wrapper := interfaceWrapper{}
  180. if err := json.Unmarshal(bz, &wrapper); err != nil {
  181. return err
  182. }
  183. if wrapper.Type == "" {
  184. return errors.New("interface type cannot be empty")
  185. }
  186. if len(wrapper.Value) == 0 {
  187. return errors.New("interface value cannot be empty")
  188. }
  189. // Dereference-and-construct pointers, to handle nested pointers.
  190. for rv.Kind() == reflect.Ptr {
  191. if rv.IsNil() {
  192. rv.Set(reflect.New(rv.Type().Elem()))
  193. }
  194. rv = rv.Elem()
  195. }
  196. // Look up the interface type, and construct a concrete value.
  197. rt, returnPtr := typeRegistry.lookup(wrapper.Type)
  198. if rt == nil {
  199. return fmt.Errorf("unknown type %q", wrapper.Type)
  200. }
  201. cptr := reflect.New(rt)
  202. crv := cptr.Elem()
  203. if err := decodeReflect(wrapper.Value, crv); err != nil {
  204. return err
  205. }
  206. // This makes sure interface implementations with pointer receivers (e.g. func (c *Car)) are
  207. // constructed as pointers behind the interface. The types must be registered as pointers with
  208. // RegisterType().
  209. if rv.Type().Kind() == reflect.Interface && returnPtr {
  210. if !cptr.Type().AssignableTo(rv.Type()) {
  211. return fmt.Errorf("invalid type %q for this value", wrapper.Type)
  212. }
  213. rv.Set(cptr)
  214. } else {
  215. if !crv.Type().AssignableTo(rv.Type()) {
  216. return fmt.Errorf("invalid type %q for this value", wrapper.Type)
  217. }
  218. rv.Set(crv)
  219. }
  220. return nil
  221. }
  222. func decodeStdlib(bz []byte, rv reflect.Value) error {
  223. if !rv.CanAddr() && rv.Kind() != reflect.Ptr {
  224. return errors.New("value must be addressable or pointer")
  225. }
  226. // Make sure we are unmarshaling into a pointer.
  227. target := rv
  228. if rv.Kind() != reflect.Ptr {
  229. target = reflect.New(rv.Type())
  230. }
  231. if err := json.Unmarshal(bz, target.Interface()); err != nil {
  232. return err
  233. }
  234. rv.Set(target.Elem())
  235. return nil
  236. }
  237. type interfaceWrapper struct {
  238. Type string `json:"type"`
  239. Value json.RawMessage `json:"value"`
  240. }