diff --git a/common/bytes.go b/common/bytes.go new file mode 100644 index 000000000..1ec880c25 --- /dev/null +++ b/common/bytes.go @@ -0,0 +1,53 @@ +package common + +import ( + "encoding/hex" + "fmt" + "strings" +) + +// The main purpose of HexBytes is to enable HEX-encoding for json/encoding. +type HexBytes []byte + +// Marshal needed for protobuf compatibility +func (bz HexBytes) Marshal() ([]byte, error) { + return bz, nil +} + +// Unmarshal needed for protobuf compatibility +func (bz *HexBytes) Unmarshal(data []byte) error { + *bz = data + return nil +} + +// This is the point of Bytes. +func (bz HexBytes) MarshalJSON() ([]byte, error) { + s := strings.ToUpper(hex.EncodeToString(bz)) + jbz := make([]byte, len(s)+2) + jbz[0] = '"' + copy(jbz[1:], []byte(s)) + jbz[1] = '"' + return jbz, nil +} + +// This is the point of Bytes. +func (bz *HexBytes) UnmarshalJSON(data []byte) error { + if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' { + return fmt.Errorf("Invalid hex string: %s", data) + } + bz2, err := hex.DecodeString(string(data[1 : len(data)-1])) + if err != nil { + return err + } + *bz = bz2 + return nil +} + +// Allow it to fulfill various interfaces in light-client, etc... +func (bz HexBytes) Bytes() []byte { + return bz +} + +func (bz HexBytes) String() string { + return strings.ToUpper(hex.EncodeToString(bz)) +} diff --git a/common/bytes_test.go b/common/bytes_test.go new file mode 100644 index 000000000..3e693b239 --- /dev/null +++ b/common/bytes_test.go @@ -0,0 +1,65 @@ +package common + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +// This is a trivial test for protobuf compatibility. +func TestMarshal(t *testing.T) { + bz := []byte("hello world") + dataB := HexBytes(bz) + bz2, err := dataB.Marshal() + assert.Nil(t, err) + assert.Equal(t, bz, bz2) + + var dataB2 HexBytes + err = (&dataB2).Unmarshal(bz) + assert.Nil(t, err) + assert.Equal(t, dataB, dataB2) +} + +// Test that the hex encoding works. +func TestJSONMarshal(t *testing.T) { + + type TestStruct struct { + B1 []byte + B2 HexBytes + } + + cases := []struct { + input []byte + expected string + }{ + {[]byte(``), `{"B1":"","B2":""}`}, + {[]byte(``), `{"B1":"","B2":""}`}, + {[]byte(``), `{"B1":"","B2":""}`}, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("Case %d", i), func(t *testing.T) { + ts := TestStruct{B1: tc.input, B2: tc.input} + + // Test that it marshals correctly to JSON. + jsonBytes, err := json.Marshal(ts) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, string(jsonBytes), tc.expected) + + // TODO do fuzz testing to ensure that unmarshal fails + + // Test that unmarshaling works correctly. + ts2 := TestStruct{} + err = json.Unmarshal(jsonBytes, &ts2) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, ts2.B1, tc.input) + assert.Equal(t, ts2.B2, HexBytes(tc.input)) + }) + } +}