diff --git a/binary/binary.go b/binary/binary.go index 890ac76aa..1a5656937 100644 --- a/binary/binary.go +++ b/binary/binary.go @@ -9,7 +9,7 @@ import ( func ReadBinary(o interface{}, r io.Reader, n *int64, err *error) interface{} { rv, rt := reflect.ValueOf(o), reflect.TypeOf(o) if rv.Kind() == reflect.Ptr { - readReflectBinary(rv.Elem(), rt.Elem(), Options{}, r, n, err) + readReflectBinary(rv, rt, Options{}, r, n, err) return o } else { ptrRv := reflect.New(rt) @@ -18,12 +18,19 @@ func ReadBinary(o interface{}, r io.Reader, n *int64, err *error) interface{} { } } +func ReadBinaryPtr(o interface{}, r io.Reader, n *int64, err *error) interface{} { + rv, rt := reflect.ValueOf(o), reflect.TypeOf(o) + if rv.Kind() == reflect.Ptr { + readReflectBinary(rv.Elem(), rt.Elem(), Options{}, r, n, err) + return o + } else { + panic("ReadBinaryPtr expects o to be a pointer") + } +} + func WriteBinary(o interface{}, w io.Writer, n *int64, err *error) { rv := reflect.ValueOf(o) rt := reflect.TypeOf(o) - if rv.Kind() == reflect.Ptr { - rv, rt = rv.Elem(), rt.Elem() - } writeReflectBinary(rv, rt, Options{}, w, n, err) } diff --git a/binary/reflect.go b/binary/reflect.go index 6eac5a46e..b497cf5f0 100644 --- a/binary/reflect.go +++ b/binary/reflect.go @@ -226,6 +226,9 @@ func readReflectBinary(rv reflect.Value, rt reflect.Type, opts Options, r io.Rea typeInfo = GetTypeInfo(rt) if typeInfo.Byte != 0x00 { r = NewPrefixedReader([]byte{typeByte}, r) + } else if typeByte != 0x01 { + *err = errors.New(Fmt("Unexpected type byte %X for ptr of untyped thing", typeByte)) + return } // continue... } @@ -250,7 +253,7 @@ func readReflectBinary(rv reflect.Value, rt reflect.Type, opts Options, r io.Rea } else { var sliceRv reflect.Value // Read length - length := int(ReadUvarint(r, n, err)) + length := ReadVarint(r, n, err) log.Debug(Fmt("Read length: %v", length)) sliceRv = reflect.MakeSlice(rt, 0, 0) // read one ReflectSliceChunk at a time and append @@ -322,7 +325,7 @@ func readReflectBinary(rv reflect.Value, rt reflect.Type, opts Options, r io.Rea case reflect.Uint64: if opts.Varint { - num := ReadUvarint(r, n, err) + num := ReadVarint(r, n, err) log.Debug(Fmt("Read num: %v", num)) rv.SetUint(uint64(num)) } else { @@ -347,7 +350,7 @@ func readReflectBinary(rv reflect.Value, rt reflect.Type, opts Options, r io.Rea rv.SetUint(uint64(num)) case reflect.Uint: - num := ReadUvarint(r, n, err) + num := ReadVarint(r, n, err) log.Debug(Fmt("Read num: %v", num)) rv.SetUint(uint64(num)) diff --git a/binary/reflect_test.go b/binary/reflect_test.go index 3fd3f765a..8d1b03861 100644 --- a/binary/reflect_test.go +++ b/binary/reflect_test.go @@ -72,13 +72,14 @@ func TestAnimalInterface(t *testing.T) { ptr := reflect.New(rte).Interface() fmt.Printf("ptr: %v", ptr) - // Make a binary byteslice that represents a snake. - snakeBytes := BinaryBytes(Snake([]byte("snake"))) + // Make a binary byteslice that represents a *snake. + foo = Snake([]byte("snake")) + snakeBytes := BinaryBytes(foo) snakeReader := bytes.NewReader(snakeBytes) // Now you can read it. n, err := new(int64), new(error) - it := *ReadBinary(ptr, snakeReader, n, err).(*Animal) + it := ReadBinary(foo, snakeReader, n, err).(Animal) fmt.Println(it, reflect.TypeOf(it)) } @@ -374,7 +375,7 @@ func TestBinary(t *testing.T) { // Read onto a pointer n, err = new(int64), new(error) - res = ReadBinary(instancePtr, bytes.NewReader(data), n, err) + res = ReadBinaryPtr(instancePtr, bytes.NewReader(data), n, err) if *err != nil { t.Fatalf("Failed to read into instance: %v", *err) } diff --git a/binary/string.go b/binary/string.go index c7ebd7f77..fb0bfc7d8 100644 --- a/binary/string.go +++ b/binary/string.go @@ -5,16 +5,16 @@ import "io" // String func WriteString(s string, w io.Writer, n *int64, err *error) { - WriteUvarint(uint(len(s)), w, n, err) + WriteVarint(len(s), w, n, err) WriteTo([]byte(s), w, n, err) } func ReadString(r io.Reader, n *int64, err *error) string { - length := ReadUvarint(r, n, err) + length := ReadVarint(r, n, err) if *err != nil { return "" } - buf := make([]byte, int(length)) + buf := make([]byte, length) ReadFull(buf, r, n, err) return string(buf) } diff --git a/p2p/connection.go b/p2p/connection.go index 856f40f30..a7bda8a98 100644 --- a/p2p/connection.go +++ b/p2p/connection.go @@ -403,7 +403,7 @@ FOR_LOOP: log.Debug("Receive Pong") case packetTypeMsg: pkt, n, err := msgPacket{}, new(int64), new(error) - binary.ReadBinary(&pkt, c.bufReader, n, err) + binary.ReadBinaryPtr(&pkt, c.bufReader, n, err) c.recvMonitor.Update(int(*n)) if *err != nil { if atomic.LoadUint32(&c.stopped) != 1 {