You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

278 lines
6.3 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. package binary
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "reflect"
  7. )
  8. type TypeInfo struct {
  9. Type reflect.Type // The type
  10. Encoder Encoder // Optional custom encoder function
  11. Decoder Decoder // Optional custom decoder function
  12. HasTypeByte bool
  13. TypeByte byte
  14. }
  15. // If a type implements TypeByte, the byte is included
  16. // as the first byte for encoding. This is used to encode
  17. // interfaces/union types. In this case the decoding should
  18. // be done manually with a switch statement, and so the
  19. // reflection-based decoder provided here does not expect this
  20. // prefix byte.
  21. // See the reactor implementations for use-cases.
  22. type HasTypeByte interface {
  23. TypeByte() byte
  24. }
  25. var typeInfos = map[reflect.Type]*TypeInfo{}
  26. func RegisterType(info *TypeInfo) *TypeInfo {
  27. // Register the type info
  28. typeInfos[info.Type] = info
  29. // Also register the underlying struct's info, if info.Type is a pointer.
  30. // Or, if info.Type is not a pointer, register the pointer.
  31. if info.Type.Kind() == reflect.Ptr {
  32. rt := info.Type.Elem()
  33. typeInfos[rt] = info
  34. } else {
  35. ptrRt := reflect.PtrTo(info.Type)
  36. typeInfos[ptrRt] = info
  37. }
  38. // See if the type implements HasTypeByte
  39. if info.Type.Implements(reflect.TypeOf((*HasTypeByte)(nil)).Elem()) {
  40. zero := reflect.Zero(info.Type)
  41. typeByte := zero.Interface().(HasTypeByte).TypeByte()
  42. if info.HasTypeByte && info.TypeByte != typeByte {
  43. panic(fmt.Sprintf("Type %v expected TypeByte of %X", info.Type, typeByte))
  44. }
  45. info.HasTypeByte = true
  46. info.TypeByte = typeByte
  47. }
  48. return info
  49. }
  50. func readReflect(rv reflect.Value, rt reflect.Type, r io.Reader, n *int64, err *error) {
  51. // First, create a new struct if rv is nil pointer.
  52. if rt.Kind() == reflect.Ptr && rv.IsNil() {
  53. newRv := reflect.New(rt.Elem())
  54. rv.Set(newRv)
  55. rv = newRv
  56. }
  57. // Dereference pointer
  58. // Still addressable, thus settable!
  59. if rv.Kind() == reflect.Ptr {
  60. rv, rt = rv.Elem(), rt.Elem()
  61. }
  62. // Get typeInfo
  63. typeInfo := typeInfos[rt]
  64. if typeInfo == nil {
  65. typeInfo = RegisterType(&TypeInfo{Type: rt})
  66. }
  67. // Custom decoder
  68. if typeInfo.Decoder != nil {
  69. decoded := typeInfo.Decoder(r, n, err)
  70. decodedRv := reflect.Indirect(reflect.ValueOf(decoded))
  71. rv.Set(decodedRv)
  72. return
  73. }
  74. // Read TypeByte prefix
  75. if typeInfo.HasTypeByte {
  76. typeByte := ReadByte(r, n, err)
  77. if typeByte != typeInfo.TypeByte {
  78. *err = errors.New(fmt.Sprintf("Expected TypeByte of %X but got %X", typeInfo.TypeByte, typeByte))
  79. return
  80. }
  81. }
  82. switch rt.Kind() {
  83. case reflect.Slice:
  84. elemRt := rt.Elem()
  85. if elemRt.Kind() == reflect.Uint8 {
  86. // Special case: Byteslices
  87. byteslice := ReadByteSlice(r, n, err)
  88. rv.Set(reflect.ValueOf(byteslice))
  89. } else {
  90. // Read length
  91. length := int(ReadUvarint(r, n, err))
  92. sliceRv := reflect.MakeSlice(rt, length, length)
  93. // Read elems
  94. for i := 0; i < length; i++ {
  95. elemRv := sliceRv.Index(i)
  96. readReflect(elemRv, elemRt, r, n, err)
  97. }
  98. rv.Set(sliceRv)
  99. }
  100. case reflect.Struct:
  101. numFields := rt.NumField()
  102. for i := 0; i < numFields; i++ {
  103. field := rt.Field(i)
  104. if field.PkgPath != "" {
  105. continue
  106. }
  107. fieldRv := rv.Field(i)
  108. readReflect(fieldRv, field.Type, r, n, err)
  109. }
  110. case reflect.String:
  111. str := ReadString(r, n, err)
  112. rv.SetString(str)
  113. case reflect.Int64:
  114. num := ReadUint64(r, n, err)
  115. rv.SetInt(int64(num))
  116. case reflect.Int32:
  117. num := ReadUint32(r, n, err)
  118. rv.SetInt(int64(num))
  119. case reflect.Int16:
  120. num := ReadUint16(r, n, err)
  121. rv.SetInt(int64(num))
  122. case reflect.Int8:
  123. num := ReadUint8(r, n, err)
  124. rv.SetInt(int64(num))
  125. case reflect.Int:
  126. num := ReadUvarint(r, n, err)
  127. rv.SetInt(int64(num))
  128. case reflect.Uint64:
  129. num := ReadUint64(r, n, err)
  130. rv.SetUint(uint64(num))
  131. case reflect.Uint32:
  132. num := ReadUint32(r, n, err)
  133. rv.SetUint(uint64(num))
  134. case reflect.Uint16:
  135. num := ReadUint16(r, n, err)
  136. rv.SetUint(uint64(num))
  137. case reflect.Uint8:
  138. num := ReadUint8(r, n, err)
  139. rv.SetUint(uint64(num))
  140. case reflect.Uint:
  141. num := ReadUvarint(r, n, err)
  142. rv.SetUint(uint64(num))
  143. default:
  144. panic(fmt.Sprintf("Unknown field type %v", rt.Kind()))
  145. }
  146. }
  147. func writeReflect(rv reflect.Value, rt reflect.Type, w io.Writer, n *int64, err *error) {
  148. // Get typeInfo
  149. typeInfo := typeInfos[rt]
  150. if typeInfo == nil {
  151. typeInfo = RegisterType(&TypeInfo{Type: rt})
  152. }
  153. // Custom encoder, say for an interface type rt.
  154. if typeInfo.Encoder != nil {
  155. typeInfo.Encoder(rv.Interface(), w, n, err)
  156. return
  157. }
  158. // Dereference pointer or interface
  159. if rt.Kind() == reflect.Ptr {
  160. rt = rt.Elem()
  161. rv = rv.Elem()
  162. // RegisterType registers the ptr type,
  163. // so typeInfo is already for the ptr.
  164. } else if rt.Kind() == reflect.Interface {
  165. rv = rv.Elem()
  166. rt = rv.Type()
  167. typeInfo = typeInfos[rt]
  168. // If interface type, get typeInfo of underlying type.
  169. if typeInfo == nil {
  170. typeInfo = RegisterType(&TypeInfo{Type: rt})
  171. }
  172. }
  173. // Write TypeByte prefix
  174. if typeInfo.HasTypeByte {
  175. WriteByte(typeInfo.TypeByte, w, n, err)
  176. }
  177. switch rt.Kind() {
  178. case reflect.Slice:
  179. elemRt := rt.Elem()
  180. if elemRt.Kind() == reflect.Uint8 {
  181. // Special case: Byteslices
  182. byteslice := rv.Interface().([]byte)
  183. WriteByteSlice(byteslice, w, n, err)
  184. } else {
  185. // Write length
  186. length := rv.Len()
  187. WriteUvarint(uint(length), w, n, err)
  188. // Write elems
  189. for i := 0; i < length; i++ {
  190. elemRv := rv.Index(i)
  191. writeReflect(elemRv, elemRt, w, n, err)
  192. }
  193. }
  194. case reflect.Struct:
  195. numFields := rt.NumField()
  196. for i := 0; i < numFields; i++ {
  197. field := rt.Field(i)
  198. if field.PkgPath != "" {
  199. continue
  200. }
  201. fieldRv := rv.Field(i)
  202. writeReflect(fieldRv, field.Type, w, n, err)
  203. }
  204. case reflect.String:
  205. WriteString(rv.String(), w, n, err)
  206. case reflect.Int64:
  207. WriteInt64(rv.Int(), w, n, err)
  208. case reflect.Int32:
  209. WriteInt32(int32(rv.Int()), w, n, err)
  210. case reflect.Int16:
  211. WriteInt16(int16(rv.Int()), w, n, err)
  212. case reflect.Int8:
  213. WriteInt8(int8(rv.Int()), w, n, err)
  214. case reflect.Int:
  215. WriteVarint(int(rv.Int()), w, n, err)
  216. case reflect.Uint64:
  217. WriteUint64(rv.Uint(), w, n, err)
  218. case reflect.Uint32:
  219. WriteUint32(uint32(rv.Uint()), w, n, err)
  220. case reflect.Uint16:
  221. WriteUint16(uint16(rv.Uint()), w, n, err)
  222. case reflect.Uint8:
  223. WriteUint8(uint8(rv.Uint()), w, n, err)
  224. case reflect.Uint:
  225. WriteUvarint(uint(rv.Uint()), w, n, err)
  226. default:
  227. panic(fmt.Sprintf("Unknown field type %v", rt.Kind()))
  228. }
  229. }