|
|
- package binary
-
- import (
- "errors"
- "fmt"
- "io"
- "reflect"
- )
-
- type TypeInfo struct {
- Type reflect.Type // The type
- Encoder Encoder // Optional custom encoder function
- Decoder Decoder // Optional custom decoder function
-
- HasTypeByte bool
- TypeByte byte
- }
-
- // If a type implements TypeByte, the byte is included
- // as the first byte for encoding. This is used to encode
- // interfaces/union types. In this case the decoding should
- // be done manually with a switch statement, and so the
- // reflection-based decoder provided here does not expect this
- // prefix byte.
- // See the reactor implementations for use-cases.
- type HasTypeByte interface {
- TypeByte() byte
- }
-
- var typeInfos = map[reflect.Type]*TypeInfo{}
-
- func RegisterType(info *TypeInfo) *TypeInfo {
-
- // Register the type info
- typeInfos[info.Type] = info
-
- // Also register the underlying struct's info, if info.Type is a pointer.
- // Or, if info.Type is not a pointer, register the pointer.
- if info.Type.Kind() == reflect.Ptr {
- rt := info.Type.Elem()
- typeInfos[rt] = info
- } else {
- ptrRt := reflect.PtrTo(info.Type)
- typeInfos[ptrRt] = info
- }
-
- // See if the type implements HasTypeByte
- if info.Type.Implements(reflect.TypeOf((*HasTypeByte)(nil)).Elem()) {
- zero := reflect.Zero(info.Type)
- typeByte := zero.Interface().(HasTypeByte).TypeByte()
- if info.HasTypeByte && info.TypeByte != typeByte {
- panic(fmt.Sprintf("Type %v expected TypeByte of %X", info.Type, typeByte))
- }
- info.HasTypeByte = true
- info.TypeByte = typeByte
- }
-
- return info
- }
-
- func readReflect(rv reflect.Value, rt reflect.Type, r io.Reader, n *int64, err *error) {
-
- // First, create a new struct if rv is nil pointer.
- if rt.Kind() == reflect.Ptr && rv.IsNil() {
- newRv := reflect.New(rt.Elem())
- rv.Set(newRv)
- rv = newRv
- }
-
- // Dereference pointer
- // Still addressable, thus settable!
- if rv.Kind() == reflect.Ptr {
- rv, rt = rv.Elem(), rt.Elem()
- }
-
- // Get typeInfo
- typeInfo := typeInfos[rt]
- if typeInfo == nil {
- typeInfo = RegisterType(&TypeInfo{Type: rt})
- }
-
- // Custom decoder
- if typeInfo.Decoder != nil {
- decoded := typeInfo.Decoder(r, n, err)
- decodedRv := reflect.Indirect(reflect.ValueOf(decoded))
- rv.Set(decodedRv)
- return
- }
-
- // Read TypeByte prefix
- if typeInfo.HasTypeByte {
- typeByte := ReadByte(r, n, err)
- if typeByte != typeInfo.TypeByte {
- *err = errors.New(fmt.Sprintf("Expected TypeByte of %X but got %X", typeInfo.TypeByte, typeByte))
- return
- }
- }
-
- switch rt.Kind() {
- case reflect.Slice:
- elemRt := rt.Elem()
- if elemRt.Kind() == reflect.Uint8 {
- // Special case: Byteslices
- byteslice := ReadByteSlice(r, n, err)
- rv.Set(reflect.ValueOf(byteslice))
- } else {
- // Read length
- length := int(ReadUvarint(r, n, err))
- sliceRv := reflect.MakeSlice(rt, length, length)
- // Read elems
- for i := 0; i < length; i++ {
- elemRv := sliceRv.Index(i)
- readReflect(elemRv, elemRt, r, n, err)
- }
- rv.Set(sliceRv)
- }
-
- case reflect.Struct:
- numFields := rt.NumField()
- for i := 0; i < numFields; i++ {
- field := rt.Field(i)
- if field.PkgPath != "" {
- continue
- }
- fieldRv := rv.Field(i)
- readReflect(fieldRv, field.Type, r, n, err)
- }
-
- case reflect.String:
- str := ReadString(r, n, err)
- rv.SetString(str)
-
- case reflect.Int64:
- num := ReadUint64(r, n, err)
- rv.SetInt(int64(num))
-
- case reflect.Int32:
- num := ReadUint32(r, n, err)
- rv.SetInt(int64(num))
-
- case reflect.Int16:
- num := ReadUint16(r, n, err)
- rv.SetInt(int64(num))
-
- case reflect.Int8:
- num := ReadUint8(r, n, err)
- rv.SetInt(int64(num))
-
- case reflect.Int:
- num := ReadUvarint(r, n, err)
- rv.SetInt(int64(num))
-
- case reflect.Uint64:
- num := ReadUint64(r, n, err)
- rv.SetUint(uint64(num))
-
- case reflect.Uint32:
- num := ReadUint32(r, n, err)
- rv.SetUint(uint64(num))
-
- case reflect.Uint16:
- num := ReadUint16(r, n, err)
- rv.SetUint(uint64(num))
-
- case reflect.Uint8:
- num := ReadUint8(r, n, err)
- rv.SetUint(uint64(num))
-
- case reflect.Uint:
- num := ReadUvarint(r, n, err)
- rv.SetUint(uint64(num))
-
- default:
- panic(fmt.Sprintf("Unknown field type %v", rt.Kind()))
- }
- }
-
- func writeReflect(rv reflect.Value, rt reflect.Type, w io.Writer, n *int64, err *error) {
-
- // Get typeInfo
- typeInfo := typeInfos[rt]
- if typeInfo == nil {
- typeInfo = RegisterType(&TypeInfo{Type: rt})
- }
-
- // Custom encoder, say for an interface type rt.
- if typeInfo.Encoder != nil {
- typeInfo.Encoder(rv.Interface(), w, n, err)
- return
- }
-
- // Dereference pointer or interface
- if rt.Kind() == reflect.Ptr {
- rt = rt.Elem()
- rv = rv.Elem()
- // RegisterType registers the ptr type,
- // so typeInfo is already for the ptr.
- } else if rt.Kind() == reflect.Interface {
- rv = rv.Elem()
- rt = rv.Type()
- typeInfo = typeInfos[rt]
- // If interface type, get typeInfo of underlying type.
- if typeInfo == nil {
- typeInfo = RegisterType(&TypeInfo{Type: rt})
- }
- }
-
- // Write TypeByte prefix
- if typeInfo.HasTypeByte {
- WriteByte(typeInfo.TypeByte, w, n, err)
- }
-
- switch rt.Kind() {
- case reflect.Slice:
- elemRt := rt.Elem()
- if elemRt.Kind() == reflect.Uint8 {
- // Special case: Byteslices
- byteslice := rv.Interface().([]byte)
- WriteByteSlice(byteslice, w, n, err)
- } else {
- // Write length
- length := rv.Len()
- WriteUvarint(uint(length), w, n, err)
- // Write elems
- for i := 0; i < length; i++ {
- elemRv := rv.Index(i)
- writeReflect(elemRv, elemRt, w, n, err)
- }
- }
-
- case reflect.Struct:
- numFields := rt.NumField()
- for i := 0; i < numFields; i++ {
- field := rt.Field(i)
- if field.PkgPath != "" {
- continue
- }
- fieldRv := rv.Field(i)
- writeReflect(fieldRv, field.Type, w, n, err)
- }
-
- case reflect.String:
- WriteString(rv.String(), w, n, err)
-
- case reflect.Int64:
- WriteInt64(rv.Int(), w, n, err)
-
- case reflect.Int32:
- WriteInt32(int32(rv.Int()), w, n, err)
-
- case reflect.Int16:
- WriteInt16(int16(rv.Int()), w, n, err)
-
- case reflect.Int8:
- WriteInt8(int8(rv.Int()), w, n, err)
-
- case reflect.Int:
- WriteVarint(int(rv.Int()), w, n, err)
-
- case reflect.Uint64:
- WriteUint64(rv.Uint(), w, n, err)
-
- case reflect.Uint32:
- WriteUint32(uint32(rv.Uint()), w, n, err)
-
- case reflect.Uint16:
- WriteUint16(uint16(rv.Uint()), w, n, err)
-
- case reflect.Uint8:
- WriteUint8(uint8(rv.Uint()), w, n, err)
-
- case reflect.Uint:
- WriteUvarint(uint(rv.Uint()), w, n, err)
-
- default:
- panic(fmt.Sprintf("Unknown field type %v", rt.Kind()))
- }
- }
|