diff --git a/priv_key.go b/priv_key.go index e6e7ac036..11dcb686f 100644 --- a/priv_key.go +++ b/priv_key.go @@ -1,7 +1,7 @@ package crypto import ( - "bytes" + "crypto/subtle" secp256k1 "github.com/btcsuite/btcd/btcec" "github.com/tendermint/ed25519" @@ -69,9 +69,11 @@ func (privKey PrivKeyEd25519) PubKey() PubKey { return PubKeyEd25519(pubBytes).Wrap() } +// Equals - you probably don't need to use this. +// Runs in constant time based on length of the keys. func (privKey PrivKeyEd25519) Equals(other PrivKey) bool { if otherEd, ok := other.Unwrap().(PrivKeyEd25519); ok { - return bytes.Equal(privKey[:], otherEd[:]) + return subtle.ConstantTimeCompare(privKey[:], otherEd[:]) == 1 } else { return false } @@ -156,9 +158,11 @@ func (privKey PrivKeySecp256k1) PubKey() PubKey { return pub.Wrap() } +// Equals - you probably don't need to use this. +// Runs in constant time based on length of the keys. func (privKey PrivKeySecp256k1) Equals(other PrivKey) bool { if otherSecp, ok := other.Unwrap().(PrivKeySecp256k1); ok { - return bytes.Equal(privKey[:], otherSecp[:]) + return subtle.ConstantTimeCompare(privKey[:], otherSecp[:]) == 1 } else { return false } diff --git a/signature.go b/signature.go index d2ea45132..cd40331cf 100644 --- a/signature.go +++ b/signature.go @@ -87,8 +87,8 @@ func (sig SignatureSecp256k1) IsZero() bool { return len(sig) == 0 } func (sig SignatureSecp256k1) String() string { return fmt.Sprintf("/%X.../", Fingerprint(sig[:])) } func (sig SignatureSecp256k1) Equals(other Signature) bool { - if otherEd, ok := other.Unwrap().(SignatureSecp256k1); ok { - return bytes.Equal(sig[:], otherEd[:]) + if otherSecp, ok := other.Unwrap().(SignatureSecp256k1); ok { + return bytes.Equal(sig[:], otherSecp[:]) } else { return false } diff --git a/signature_test.go b/signature_test.go index 5e9f06723..4801e5fef 100644 --- a/signature_test.go +++ b/signature_test.go @@ -141,3 +141,27 @@ func TestWrapping(t *testing.T) { } } + +func TestPrivKeyEquality(t *testing.T) { + { + privKey := GenPrivKeySecp256k1().Wrap() + privKey2 := GenPrivKeySecp256k1().Wrap() + assert.False(t, privKey.Equals(privKey2)) + assert.False(t, privKey2.Equals(privKey)) + + privKeyCopy := privKey // copy + assert.True(t, privKey.Equals(privKeyCopy)) + assert.True(t, privKeyCopy.Equals(privKey)) + } + + { + privKey := GenPrivKeyEd25519().Wrap() + privKey2 := GenPrivKeyEd25519().Wrap() + assert.False(t, privKey.Equals(privKey2)) + assert.False(t, privKey2.Equals(privKey)) + + privKeyCopy := privKey // copy + assert.True(t, privKey.Equals(privKeyCopy)) + assert.True(t, privKeyCopy.Equals(privKey)) + } +}