diff --git a/priv_key.go b/priv_key.go index 0c6bd2ae7..e6e7ac036 100644 --- a/priv_key.go +++ b/priv_key.go @@ -13,13 +13,27 @@ import ( func PrivKeyFromBytes(privKeyBytes []byte) (privKey PrivKey, err error) { err = wire.ReadBinaryBytes(privKeyBytes, &privKey) + if err == nil { + // add support for a ValidateKey method on PrivKeys + // to make sure they load correctly + val, ok := privKey.Unwrap().(validatable) + if ok { + err = val.ValidateKey() + } + } return } +// validatable is an optional interface for keys that want to +// check integrity +type validatable interface { + ValidateKey() error +} + //---------------------------------------- // DO NOT USE THIS INTERFACE. -// You probably want to use PubKey +// You probably want to use PrivKey // +gen wrapper:"PrivKey,Impl[PrivKeyEd25519,PrivKeySecp256k1],ed25519,secp256k1" type PrivKeyInner interface { AssertIsPrivKeyInner() diff --git a/priv_key_test.go b/priv_key_test.go new file mode 100644 index 000000000..154df5593 --- /dev/null +++ b/priv_key_test.go @@ -0,0 +1,65 @@ +package crypto + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + wire "github.com/tendermint/go-wire" +) + +type BadKey struct { + PrivKeyEd25519 +} + +// Wrap fulfils interface for PrivKey struct +func (pk BadKey) Wrap() PrivKey { + return PrivKey{pk} +} + +func (pk BadKey) Bytes() []byte { + return wire.BinaryBytes(pk.Wrap()) +} + +func (pk BadKey) ValidateKey() error { + return fmt.Errorf("fuggly key") +} + +func init() { + PrivKeyMapper. + RegisterImplementation(BadKey{}, "bad", 0x66) +} + +func TestReadPrivKey(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + // garbage in, garbage out + garbage := []byte("hjgewugfbiewgofwgewr") + _, err := PrivKeyFromBytes(garbage) + require.Error(err) + + edKey := GenPrivKeyEd25519() + badKey := BadKey{edKey} + + cases := []struct { + key PrivKey + valid bool + }{ + {edKey.Wrap(), true}, + {badKey.Wrap(), false}, + } + + for i, tc := range cases { + data := tc.key.Bytes() + key, err := PrivKeyFromBytes(data) + if tc.valid { + assert.NoError(err, "%d", i) + assert.Equal(tc.key, key, "%d", i) + } else { + assert.Error(err, "%d: %#v", i, key) + } + } + +}