Browse Source

privval: remove panics in privval implementation (#7475)

pull/7477/head
Sam Kleinman 2 years ago
committed by GitHub
parent
commit
1630d1cf3e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 172 additions and 137 deletions
  1. +3
    -1
      cmd/tendermint/commands/init.go
  2. +6
    -2
      cmd/tendermint/commands/reset_priv_validator.go
  3. +3
    -5
      internal/consensus/common_test.go
  4. +2
    -2
      internal/consensus/replay_test.go
  5. +75
    -56
      privval/file.go
  6. +13
    -10
      privval/file_test.go
  7. +20
    -26
      privval/grpc/client_test.go
  8. +23
    -12
      privval/secret_connection.go
  9. +4
    -6
      privval/signer_listener_endpoint_test.go
  10. +13
    -3
      privval/socket_dialers_test.go
  11. +2
    -3
      privval/socket_listeners_test.go
  12. +0
    -9
      privval/utils.go
  13. +8
    -2
      test/e2e/runner/setup.go

+ 3
- 1
cmd/tendermint/commands/init.go View File

@ -65,7 +65,9 @@ func initFilesWithConfig(ctx context.Context, config *cfg.Config) error {
if err != nil { if err != nil {
return err return err
} }
pv.Save()
if err := pv.Save(); err != nil {
return err
}
logger.Info("Generated private validator", "keyFile", privValKeyFile, logger.Info("Generated private validator", "keyFile", privValKeyFile,
"stateFile", privValStateFile) "stateFile", privValStateFile)
} }


+ 6
- 2
cmd/tendermint/commands/reset_priv_validator.go View File

@ -68,7 +68,9 @@ func resetFilePV(privValKeyFile, privValStateFile string, logger log.Logger) err
if err != nil { if err != nil {
return err return err
} }
pv.Reset()
if err := pv.Reset(); err != nil {
return err
}
logger.Info("Reset private validator file to genesis state", "keyFile", privValKeyFile, logger.Info("Reset private validator file to genesis state", "keyFile", privValKeyFile,
"stateFile", privValStateFile) "stateFile", privValStateFile)
} else { } else {
@ -76,7 +78,9 @@ func resetFilePV(privValKeyFile, privValStateFile string, logger log.Logger) err
if err != nil { if err != nil {
return err return err
} }
pv.Save()
if err := pv.Save(); err != nil {
return err
}
logger.Info("Generated private validator file", "keyFile", privValKeyFile, logger.Info("Generated private validator file", "keyFile", privValKeyFile,
"stateFile", privValStateFile) "stateFile", privValStateFile)
} }


+ 3
- 5
internal/consensus/common_test.go View File

@ -487,15 +487,13 @@ func newStateWithConfigAndBlockStore(
return cs return cs
} }
func loadPrivValidator(cfg *config.Config) *privval.FilePV {
func loadPrivValidator(t *testing.T, cfg *config.Config) *privval.FilePV {
privValidatorKeyFile := cfg.PrivValidator.KeyFile() privValidatorKeyFile := cfg.PrivValidator.KeyFile()
ensureDir(filepath.Dir(privValidatorKeyFile), 0700) ensureDir(filepath.Dir(privValidatorKeyFile), 0700)
privValidatorStateFile := cfg.PrivValidator.StateFile() privValidatorStateFile := cfg.PrivValidator.StateFile()
privValidator, err := privval.LoadOrGenFilePV(privValidatorKeyFile, privValidatorStateFile) privValidator, err := privval.LoadOrGenFilePV(privValidatorKeyFile, privValidatorStateFile)
if err != nil {
panic(err)
}
privValidator.Reset()
require.NoError(t, err)
require.NoError(t, privValidator.Reset())
return privValidator return privValidator
} }


+ 2
- 2
internal/consensus/replay_test.go View File

@ -60,7 +60,7 @@ func startNewStateAndWaitForBlock(ctx context.Context, t *testing.T, consensusRe
logger := log.TestingLogger() logger := log.TestingLogger()
state, err := sm.MakeGenesisStateFromFile(consensusReplayConfig.GenesisFile()) state, err := sm.MakeGenesisStateFromFile(consensusReplayConfig.GenesisFile())
require.NoError(t, err) require.NoError(t, err)
privValidator := loadPrivValidator(consensusReplayConfig)
privValidator := loadPrivValidator(t, consensusReplayConfig)
blockStore := store.NewBlockStore(dbm.NewMemDB()) blockStore := store.NewBlockStore(dbm.NewMemDB())
cs := newStateWithConfigAndBlockStore( cs := newStateWithConfigAndBlockStore(
ctx, ctx,
@ -165,7 +165,7 @@ LOOP:
blockStore := store.NewBlockStore(blockDB) blockStore := store.NewBlockStore(blockDB)
state, err := sm.MakeGenesisStateFromFile(consensusReplayConfig.GenesisFile()) state, err := sm.MakeGenesisStateFromFile(consensusReplayConfig.GenesisFile())
require.NoError(t, err) require.NoError(t, err)
privValidator := loadPrivValidator(consensusReplayConfig)
privValidator := loadPrivValidator(t, consensusReplayConfig)
cs := newStateWithConfigAndBlockStore( cs := newStateWithConfigAndBlockStore(
rctx, rctx,
logger, logger,


+ 75
- 56
privval/file.go View File

@ -32,14 +32,14 @@ const (
) )
// A vote is either stepPrevote or stepPrecommit. // A vote is either stepPrevote or stepPrecommit.
func voteToStep(vote *tmproto.Vote) int8 {
func voteToStep(vote *tmproto.Vote) (int8, error) {
switch vote.Type { switch vote.Type {
case tmproto.PrevoteType: case tmproto.PrevoteType:
return stepPrevote
return stepPrevote, nil
case tmproto.PrecommitType: case tmproto.PrecommitType:
return stepPrecommit
return stepPrecommit, nil
default: default:
panic(fmt.Sprintf("Unknown vote type: %v", vote.Type))
return 0, fmt.Errorf("unknown vote type: %v", vote.Type)
} }
} }
@ -55,21 +55,17 @@ type FilePVKey struct {
} }
// Save persists the FilePVKey to its filePath. // Save persists the FilePVKey to its filePath.
func (pvKey FilePVKey) Save() {
func (pvKey FilePVKey) Save() error {
outFile := pvKey.filePath outFile := pvKey.filePath
if outFile == "" { if outFile == "" {
panic("cannot save PrivValidator key: filePath not set")
return errors.New("cannot save PrivValidator key: filePath not set")
} }
jsonBytes, err := tmjson.MarshalIndent(pvKey, "", " ") jsonBytes, err := tmjson.MarshalIndent(pvKey, "", " ")
if err != nil { if err != nil {
panic(err)
}
err = tempfile.WriteFileAtomic(outFile, jsonBytes, 0600)
if err != nil {
panic(err)
return err
} }
return tempfile.WriteFileAtomic(outFile, jsonBytes, 0600)
} }
//------------------------------------------------------------------------------- //-------------------------------------------------------------------------------
@ -127,19 +123,16 @@ func (lss *FilePVLastSignState) CheckHRS(height int64, round int32, step int8) (
} }
// Save persists the FilePvLastSignState to its filePath. // Save persists the FilePvLastSignState to its filePath.
func (lss *FilePVLastSignState) Save() {
func (lss *FilePVLastSignState) Save() error {
outFile := lss.filePath outFile := lss.filePath
if outFile == "" { if outFile == "" {
panic("cannot save FilePVLastSignState: filePath not set")
return errors.New("cannot save FilePVLastSignState: filePath not set")
} }
jsonBytes, err := tmjson.MarshalIndent(lss, "", " ") jsonBytes, err := tmjson.MarshalIndent(lss, "", " ")
if err != nil { if err != nil {
panic(err)
}
err = tempfile.WriteFileAtomic(outFile, jsonBytes, 0600)
if err != nil {
panic(err)
return err
} }
return tempfile.WriteFileAtomic(outFile, jsonBytes, 0600)
} }
//------------------------------------------------------------------------------- //-------------------------------------------------------------------------------
@ -239,17 +232,23 @@ func loadFilePV(keyFilePath, stateFilePath string, loadState bool) (*FilePV, err
// LoadOrGenFilePV loads a FilePV from the given filePaths // LoadOrGenFilePV loads a FilePV from the given filePaths
// or else generates a new one and saves it to the filePaths. // or else generates a new one and saves it to the filePaths.
func LoadOrGenFilePV(keyFilePath, stateFilePath string) (*FilePV, error) { func LoadOrGenFilePV(keyFilePath, stateFilePath string) (*FilePV, error) {
var (
pv *FilePV
err error
)
if tmos.FileExists(keyFilePath) { if tmos.FileExists(keyFilePath) {
pv, err = LoadFilePV(keyFilePath, stateFilePath)
} else {
pv, err = GenFilePV(keyFilePath, stateFilePath, "")
pv.Save()
pv, err := LoadFilePV(keyFilePath, stateFilePath)
if err != nil {
return nil, err
}
return pv, nil
}
pv, err := GenFilePV(keyFilePath, stateFilePath, "")
if err != nil {
return nil, err
}
if err := pv.Save(); err != nil {
return nil, err
} }
return pv, err
return pv, nil
} }
// GetAddress returns the address of the validator. // GetAddress returns the address of the validator.
@ -283,21 +282,23 @@ func (pv *FilePV) SignProposal(ctx context.Context, chainID string, proposal *tm
} }
// Save persists the FilePV to disk. // Save persists the FilePV to disk.
func (pv *FilePV) Save() {
pv.Key.Save()
pv.LastSignState.Save()
func (pv *FilePV) Save() error {
if err := pv.Key.Save(); err != nil {
return err
}
return pv.LastSignState.Save()
} }
// Reset resets all fields in the FilePV. // Reset resets all fields in the FilePV.
// NOTE: Unsafe! // NOTE: Unsafe!
func (pv *FilePV) Reset() {
func (pv *FilePV) Reset() error {
var sig []byte var sig []byte
pv.LastSignState.Height = 0 pv.LastSignState.Height = 0
pv.LastSignState.Round = 0 pv.LastSignState.Round = 0
pv.LastSignState.Step = 0 pv.LastSignState.Step = 0
pv.LastSignState.Signature = sig pv.LastSignState.Signature = sig
pv.LastSignState.SignBytes = nil pv.LastSignState.SignBytes = nil
pv.Save()
return pv.Save()
} }
// String returns a string representation of the FilePV. // String returns a string representation of the FilePV.
@ -317,8 +318,13 @@ func (pv *FilePV) String() string {
// It may need to set the timestamp as well if the vote is otherwise the same as // It may need to set the timestamp as well if the vote is otherwise the same as
// a previously signed vote (ie. we crashed after signing but before the vote hit the WAL). // a previously signed vote (ie. we crashed after signing but before the vote hit the WAL).
func (pv *FilePV) signVote(chainID string, vote *tmproto.Vote) error { func (pv *FilePV) signVote(chainID string, vote *tmproto.Vote) error {
height, round, step := vote.Height, vote.Round, voteToStep(vote)
step, err := voteToStep(vote)
if err != nil {
return err
}
height := vote.Height
round := vote.Round
lss := pv.LastSignState lss := pv.LastSignState
sameHRS, err := lss.CheckHRS(height, round, step) sameHRS, err := lss.CheckHRS(height, round, step)
@ -336,13 +342,19 @@ func (pv *FilePV) signVote(chainID string, vote *tmproto.Vote) error {
if sameHRS { if sameHRS {
if bytes.Equal(signBytes, lss.SignBytes) { if bytes.Equal(signBytes, lss.SignBytes) {
vote.Signature = lss.Signature vote.Signature = lss.Signature
} else if timestamp, ok := checkVotesOnlyDifferByTimestamp(lss.SignBytes, signBytes); ok {
} else {
timestamp, ok, err := checkVotesOnlyDifferByTimestamp(lss.SignBytes, signBytes)
if err != nil {
return err
}
if !ok {
return errors.New("conflicting data")
}
vote.Timestamp = timestamp vote.Timestamp = timestamp
vote.Signature = lss.Signature vote.Signature = lss.Signature
} else {
err = fmt.Errorf("conflicting data")
return nil
} }
return err
} }
// It passed the checks. Sign the vote // It passed the checks. Sign the vote
@ -350,7 +362,9 @@ func (pv *FilePV) signVote(chainID string, vote *tmproto.Vote) error {
if err != nil { if err != nil {
return err return err
} }
pv.saveSigned(height, round, step, signBytes, sig)
if err := pv.saveSigned(height, round, step, signBytes, sig); err != nil {
return err
}
vote.Signature = sig vote.Signature = sig
return nil return nil
} }
@ -378,13 +392,18 @@ func (pv *FilePV) signProposal(chainID string, proposal *tmproto.Proposal) error
if sameHRS { if sameHRS {
if bytes.Equal(signBytes, lss.SignBytes) { if bytes.Equal(signBytes, lss.SignBytes) {
proposal.Signature = lss.Signature proposal.Signature = lss.Signature
} else if timestamp, ok := checkProposalsOnlyDifferByTimestamp(lss.SignBytes, signBytes); ok {
} else {
timestamp, ok, err := checkProposalsOnlyDifferByTimestamp(lss.SignBytes, signBytes)
if err != nil {
return err
}
if !ok {
return errors.New("conflicting data")
}
proposal.Timestamp = timestamp proposal.Timestamp = timestamp
proposal.Signature = lss.Signature proposal.Signature = lss.Signature
} else {
err = fmt.Errorf("conflicting data")
return nil
} }
return err
} }
// It passed the checks. Sign the proposal // It passed the checks. Sign the proposal
@ -392,34 +411,34 @@ func (pv *FilePV) signProposal(chainID string, proposal *tmproto.Proposal) error
if err != nil { if err != nil {
return err return err
} }
pv.saveSigned(height, round, step, signBytes, sig)
if err := pv.saveSigned(height, round, step, signBytes, sig); err != nil {
return err
}
proposal.Signature = sig proposal.Signature = sig
return nil return nil
} }
// Persist height/round/step and signature // Persist height/round/step and signature
func (pv *FilePV) saveSigned(height int64, round int32, step int8,
signBytes []byte, sig []byte) {
func (pv *FilePV) saveSigned(height int64, round int32, step int8, signBytes []byte, sig []byte) error {
pv.LastSignState.Height = height pv.LastSignState.Height = height
pv.LastSignState.Round = round pv.LastSignState.Round = round
pv.LastSignState.Step = step pv.LastSignState.Step = step
pv.LastSignState.Signature = sig pv.LastSignState.Signature = sig
pv.LastSignState.SignBytes = signBytes pv.LastSignState.SignBytes = signBytes
pv.LastSignState.Save()
return pv.LastSignState.Save()
} }
//----------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------
// returns the timestamp from the lastSignBytes. // returns the timestamp from the lastSignBytes.
// returns true if the only difference in the votes is their timestamp. // returns true if the only difference in the votes is their timestamp.
func checkVotesOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool) {
func checkVotesOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool, error) {
var lastVote, newVote tmproto.CanonicalVote var lastVote, newVote tmproto.CanonicalVote
if err := protoio.UnmarshalDelimited(lastSignBytes, &lastVote); err != nil { if err := protoio.UnmarshalDelimited(lastSignBytes, &lastVote); err != nil {
panic(fmt.Sprintf("LastSignBytes cannot be unmarshalled into vote: %v", err))
return time.Time{}, false, fmt.Errorf("LastSignBytes cannot be unmarshalled into vote: %v", err)
} }
if err := protoio.UnmarshalDelimited(newSignBytes, &newVote); err != nil { if err := protoio.UnmarshalDelimited(newSignBytes, &newVote); err != nil {
panic(fmt.Sprintf("signBytes cannot be unmarshalled into vote: %v", err))
return time.Time{}, false, fmt.Errorf("signBytes cannot be unmarshalled into vote: %v", err)
} }
lastTime := lastVote.Timestamp lastTime := lastVote.Timestamp
@ -428,18 +447,18 @@ func checkVotesOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.T
lastVote.Timestamp = now lastVote.Timestamp = now
newVote.Timestamp = now newVote.Timestamp = now
return lastTime, proto.Equal(&newVote, &lastVote)
return lastTime, proto.Equal(&newVote, &lastVote), nil
} }
// returns the timestamp from the lastSignBytes. // returns the timestamp from the lastSignBytes.
// returns true if the only difference in the proposals is their timestamp // returns true if the only difference in the proposals is their timestamp
func checkProposalsOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool) {
func checkProposalsOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool, error) {
var lastProposal, newProposal tmproto.CanonicalProposal var lastProposal, newProposal tmproto.CanonicalProposal
if err := protoio.UnmarshalDelimited(lastSignBytes, &lastProposal); err != nil { if err := protoio.UnmarshalDelimited(lastSignBytes, &lastProposal); err != nil {
panic(fmt.Sprintf("LastSignBytes cannot be unmarshalled into proposal: %v", err))
return time.Time{}, false, fmt.Errorf("LastSignBytes cannot be unmarshalled into proposal: %v", err)
} }
if err := protoio.UnmarshalDelimited(newSignBytes, &newProposal); err != nil { if err := protoio.UnmarshalDelimited(newSignBytes, &newProposal); err != nil {
panic(fmt.Sprintf("signBytes cannot be unmarshalled into proposal: %v", err))
return time.Time{}, false, fmt.Errorf("signBytes cannot be unmarshalled into proposal: %v", err)
} }
lastTime := lastProposal.Timestamp lastTime := lastProposal.Timestamp
@ -448,5 +467,5 @@ func checkProposalsOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (ti
lastProposal.Timestamp = now lastProposal.Timestamp = now
newProposal.Timestamp = now newProposal.Timestamp = now
return lastTime, proto.Equal(&newProposal, &lastProposal)
return lastTime, proto.Equal(&newProposal, &lastProposal), nil
} }

+ 13
- 10
privval/file_test.go View File

@ -33,7 +33,7 @@ func TestGenLoadValidator(t *testing.T) {
height := int64(100) height := int64(100)
privVal.LastSignState.Height = height privVal.LastSignState.Height = height
privVal.Save()
require.NoError(t, privVal.Save())
addr := privVal.GetAddress() addr := privVal.GetAddress()
privVal, err = LoadFilePV(tempKeyFile.Name(), tempStateFile.Name()) privVal, err = LoadFilePV(tempKeyFile.Name(), tempStateFile.Name())
@ -68,7 +68,7 @@ func TestResetValidator(t *testing.T) {
assert.NotEqual(t, privVal.LastSignState, emptyState) assert.NotEqual(t, privVal.LastSignState, emptyState)
// priv val after AcceptNewConnection is same as empty // priv val after AcceptNewConnection is same as empty
privVal.Reset()
require.NoError(t, privVal.Reset())
assert.Equal(t, privVal.LastSignState, emptyState) assert.Equal(t, privVal.LastSignState, emptyState)
} }
@ -267,6 +267,9 @@ func TestSignProposal(t *testing.T) {
} }
func TestDifferByTimestamp(t *testing.T) { func TestDifferByTimestamp(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tempKeyFile, err := os.CreateTemp("", "priv_validator_key_") tempKeyFile, err := os.CreateTemp("", "priv_validator_key_")
require.Nil(t, err) require.Nil(t, err)
tempStateFile, err := os.CreateTemp("", "priv_validator_state_") tempStateFile, err := os.CreateTemp("", "priv_validator_state_")
@ -283,8 +286,8 @@ func TestDifferByTimestamp(t *testing.T) {
{ {
proposal := newProposal(height, round, block1) proposal := newProposal(height, round, block1)
pb := proposal.ToProto() pb := proposal.ToProto()
err := privVal.SignProposal(context.Background(), chainID, pb)
assert.NoError(t, err, "expected no error signing proposal")
err := privVal.SignProposal(ctx, chainID, pb)
require.NoError(t, err, "expected no error signing proposal")
signBytes := types.ProposalSignBytes(chainID, pb) signBytes := types.ProposalSignBytes(chainID, pb)
sig := proposal.Signature sig := proposal.Signature
@ -294,8 +297,8 @@ func TestDifferByTimestamp(t *testing.T) {
pb.Timestamp = pb.Timestamp.Add(time.Millisecond) pb.Timestamp = pb.Timestamp.Add(time.Millisecond)
var emptySig []byte var emptySig []byte
proposal.Signature = emptySig proposal.Signature = emptySig
err = privVal.SignProposal(context.Background(), "mychainid", pb)
assert.NoError(t, err, "expected no error on signing same proposal")
err = privVal.SignProposal(ctx, "mychainid", pb)
require.NoError(t, err, "expected no error on signing same proposal")
assert.Equal(t, timeStamp, pb.Timestamp) assert.Equal(t, timeStamp, pb.Timestamp)
assert.Equal(t, signBytes, types.ProposalSignBytes(chainID, pb)) assert.Equal(t, signBytes, types.ProposalSignBytes(chainID, pb))
@ -308,8 +311,8 @@ func TestDifferByTimestamp(t *testing.T) {
blockID := types.BlockID{Hash: randbytes, PartSetHeader: types.PartSetHeader{}} blockID := types.BlockID{Hash: randbytes, PartSetHeader: types.PartSetHeader{}}
vote := newVote(privVal.Key.Address, 0, height, round, voteType, blockID) vote := newVote(privVal.Key.Address, 0, height, round, voteType, blockID)
v := vote.ToProto() v := vote.ToProto()
err := privVal.SignVote(context.Background(), "mychainid", v)
assert.NoError(t, err, "expected no error signing vote")
err := privVal.SignVote(ctx, "mychainid", v)
require.NoError(t, err, "expected no error signing vote")
signBytes := types.VoteSignBytes(chainID, v) signBytes := types.VoteSignBytes(chainID, v)
sig := v.Signature sig := v.Signature
@ -319,8 +322,8 @@ func TestDifferByTimestamp(t *testing.T) {
v.Timestamp = v.Timestamp.Add(time.Millisecond) v.Timestamp = v.Timestamp.Add(time.Millisecond)
var emptySig []byte var emptySig []byte
v.Signature = emptySig v.Signature = emptySig
err = privVal.SignVote(context.Background(), "mychainid", v)
assert.NoError(t, err, "expected no error on signing same vote")
err = privVal.SignVote(ctx, "mychainid", v)
require.NoError(t, err, "expected no error on signing same vote")
assert.Equal(t, timeStamp, v.Timestamp) assert.Equal(t, timeStamp, v.Timestamp)
assert.Equal(t, signBytes, types.VoteSignBytes(chainID, v)) assert.Equal(t, signBytes, types.VoteSignBytes(chainID, v))


+ 20
- 26
privval/grpc/client_test.go View File

@ -24,7 +24,7 @@ import (
const chainID = "chain-id" const chainID = "chain-id"
func dialer(pv types.PrivValidator, logger log.Logger) (*grpc.Server, func(context.Context, string) (net.Conn, error)) {
func dialer(t *testing.T, pv types.PrivValidator, logger log.Logger) (*grpc.Server, func(context.Context, string) (net.Conn, error)) {
listener := bufconn.Listen(1024 * 1024) listener := bufconn.Listen(1024 * 1024)
server := grpc.NewServer() server := grpc.NewServer()
@ -33,11 +33,7 @@ func dialer(pv types.PrivValidator, logger log.Logger) (*grpc.Server, func(conte
privvalproto.RegisterPrivValidatorAPIServer(server, s) privvalproto.RegisterPrivValidatorAPIServer(server, s)
go func() {
if err := server.Serve(listener); err != nil {
panic(err)
}
}()
go func() { require.NoError(t, server.Serve(listener)) }()
return server, func(context.Context, string) (net.Conn, error) { return server, func(context.Context, string) (net.Conn, error) {
return listener.Dial() return listener.Dial()
@ -46,44 +42,43 @@ func dialer(pv types.PrivValidator, logger log.Logger) (*grpc.Server, func(conte
func TestSignerClient_GetPubKey(t *testing.T) { func TestSignerClient_GetPubKey(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mockPV := types.NewMockPV() mockPV := types.NewMockPV()
logger := log.TestingLogger() logger := log.TestingLogger()
srv, dialer := dialer(mockPV, logger)
srv, dialer := dialer(t, mockPV, logger)
defer srv.Stop() defer srv.Stop()
conn, err := grpc.DialContext(ctx, "", conn, err := grpc.DialContext(ctx, "",
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(dialer), grpc.WithContextDialer(dialer),
) )
if err != nil {
panic(err)
}
require.NoError(t, err)
defer conn.Close() defer conn.Close()
client, err := tmgrpc.NewSignerClient(conn, chainID, logger) client, err := tmgrpc.NewSignerClient(conn, chainID, logger)
require.NoError(t, err) require.NoError(t, err)
pk, err := client.GetPubKey(context.Background())
pk, err := client.GetPubKey(ctx)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, mockPV.PrivKey.PubKey(), pk) assert.Equal(t, mockPV.PrivKey.PubKey(), pk)
} }
func TestSignerClient_SignVote(t *testing.T) { func TestSignerClient_SignVote(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := context.Background()
mockPV := types.NewMockPV() mockPV := types.NewMockPV()
logger := log.TestingLogger() logger := log.TestingLogger()
srv, dialer := dialer(mockPV, logger)
srv, dialer := dialer(t, mockPV, logger)
defer srv.Stop() defer srv.Stop()
conn, err := grpc.DialContext(ctx, "", conn, err := grpc.DialContext(ctx, "",
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(dialer), grpc.WithContextDialer(dialer),
) )
if err != nil {
panic(err)
}
require.NoError(t, err)
defer conn.Close() defer conn.Close()
client, err := tmgrpc.NewSignerClient(conn, chainID, logger) client, err := tmgrpc.NewSignerClient(conn, chainID, logger)
@ -115,31 +110,30 @@ func TestSignerClient_SignVote(t *testing.T) {
pbHave := have.ToProto() pbHave := have.ToProto()
err = client.SignVote(context.Background(), chainID, pbHave)
err = client.SignVote(ctx, chainID, pbHave)
require.NoError(t, err) require.NoError(t, err)
pbWant := want.ToProto() pbWant := want.ToProto()
require.NoError(t, mockPV.SignVote(context.Background(), chainID, pbWant))
require.NoError(t, mockPV.SignVote(ctx, chainID, pbWant))
assert.Equal(t, pbWant.Signature, pbHave.Signature) assert.Equal(t, pbWant.Signature, pbHave.Signature)
} }
func TestSignerClient_SignProposal(t *testing.T) { func TestSignerClient_SignProposal(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := context.Background()
mockPV := types.NewMockPV() mockPV := types.NewMockPV()
logger := log.TestingLogger() logger := log.TestingLogger()
srv, dialer := dialer(mockPV, logger)
srv, dialer := dialer(t, mockPV, logger)
defer srv.Stop() defer srv.Stop()
conn, err := grpc.DialContext(ctx, "", conn, err := grpc.DialContext(ctx, "",
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(dialer), grpc.WithContextDialer(dialer),
) )
if err != nil {
panic(err)
}
require.NoError(t, err)
defer conn.Close() defer conn.Close()
client, err := tmgrpc.NewSignerClient(conn, chainID, logger) client, err := tmgrpc.NewSignerClient(conn, chainID, logger)
@ -167,12 +161,12 @@ func TestSignerClient_SignProposal(t *testing.T) {
pbHave := have.ToProto() pbHave := have.ToProto()
err = client.SignProposal(context.Background(), chainID, pbHave)
err = client.SignProposal(ctx, chainID, pbHave)
require.NoError(t, err) require.NoError(t, err)
pbWant := want.ToProto() pbWant := want.ToProto()
require.NoError(t, mockPV.SignProposal(context.Background(), chainID, pbWant))
require.NoError(t, mockPV.SignProposal(ctx, chainID, pbWant))
assert.Equal(t, pbWant.Signature, pbHave.Signature) assert.Equal(t, pbWant.Signature, pbHave.Signature)
} }

+ 23
- 12
privval/secret_connection.go View File

@ -99,7 +99,10 @@ func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKey) (*
) )
// Generate ephemeral keys for perfect forward secrecy. // Generate ephemeral keys for perfect forward secrecy.
locEphPub, locEphPriv := genEphKeys()
locEphPub, locEphPriv, err := genEphKeys()
if err != nil {
return nil, err
}
// Write local ephemeral pubkey and receive one too. // Write local ephemeral pubkey and receive one too.
// NOTE: every 32-byte string is accepted as a Curve25519 public key (see // NOTE: every 32-byte string is accepted as a Curve25519 public key (see
@ -132,7 +135,10 @@ func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKey) (*
// Generate the secret used for receiving, sending, challenge via HKDF-SHA2 // Generate the secret used for receiving, sending, challenge via HKDF-SHA2
// on the transcript state (which itself also uses HKDF-SHA2 to derive a key // on the transcript state (which itself also uses HKDF-SHA2 to derive a key
// from the dhSecret). // from the dhSecret).
recvSecret, sendSecret := deriveSecrets(dhSecret, locIsLeast)
recvSecret, sendSecret, err := deriveSecrets(dhSecret, locIsLeast)
if err != nil {
return nil, err
}
const challengeSize = 32 const challengeSize = 32
var challenge [challengeSize]byte var challenge [challengeSize]byte
@ -214,7 +220,10 @@ func (sc *SecretConnection) Write(data []byte) (n int, err error) {
// encrypt the frame // encrypt the frame
sc.sendAead.Seal(sealedFrame[:0], sc.sendNonce[:], frame, nil) sc.sendAead.Seal(sealedFrame[:0], sc.sendNonce[:], frame, nil)
incrNonce(sc.sendNonce)
if err := incrNonce(sc.sendNonce); err != nil {
return err
}
// end encryption // end encryption
_, err = sc.conn.Write(sealedFrame) _, err = sc.conn.Write(sealedFrame)
@ -258,7 +267,9 @@ func (sc *SecretConnection) Read(data []byte) (n int, err error) {
if err != nil { if err != nil {
return n, fmt.Errorf("failed to decrypt SecretConnection: %w", err) return n, fmt.Errorf("failed to decrypt SecretConnection: %w", err)
} }
incrNonce(sc.recvNonce)
if err = incrNonce(sc.recvNonce); err != nil {
return
}
// end decryption // end decryption
// copy checkLength worth into data, // copy checkLength worth into data,
@ -288,14 +299,13 @@ func (sc *SecretConnection) SetWriteDeadline(t time.Time) error {
return sc.conn.(net.Conn).SetWriteDeadline(t) return sc.conn.(net.Conn).SetWriteDeadline(t)
} }
func genEphKeys() (ephPub, ephPriv *[32]byte) {
var err error
func genEphKeys() (ephPub, ephPriv *[32]byte, err error) {
// TODO: Probably not a problem but ask Tony: different from the rust implementation (uses x25519-dalek), // TODO: Probably not a problem but ask Tony: different from the rust implementation (uses x25519-dalek),
// we do not "clamp" the private key scalar: // we do not "clamp" the private key scalar:
// see: https://github.com/dalek-cryptography/x25519-dalek/blob/34676d336049df2bba763cc076a75e47ae1f170f/src/x25519.rs#L56-L74 // see: https://github.com/dalek-cryptography/x25519-dalek/blob/34676d336049df2bba763cc076a75e47ae1f170f/src/x25519.rs#L56-L74
ephPub, ephPriv, err = box.GenerateKey(crand.Reader) ephPub, ephPriv, err = box.GenerateKey(crand.Reader)
if err != nil { if err != nil {
panic("Could not generate ephemeral key-pair")
return
} }
return return
} }
@ -339,14 +349,14 @@ func shareEphPubKey(conn io.ReadWriter, locEphPub *[32]byte) (remEphPub *[32]byt
func deriveSecrets( func deriveSecrets(
dhSecret *[32]byte, dhSecret *[32]byte,
locIsLeast bool, locIsLeast bool,
) (recvSecret, sendSecret *[aeadKeySize]byte) {
) (recvSecret, sendSecret *[aeadKeySize]byte, err error) {
hash := sha256.New hash := sha256.New
hkdf := hkdf.New(hash, dhSecret[:], nil, secretConnKeyAndChallengeGen) hkdf := hkdf.New(hash, dhSecret[:], nil, secretConnKeyAndChallengeGen)
// get enough data for 2 aead keys, and a 32 byte challenge // get enough data for 2 aead keys, and a 32 byte challenge
res := new([2*aeadKeySize + 32]byte) res := new([2*aeadKeySize + 32]byte)
_, err := io.ReadFull(hkdf, res[:])
_, err = io.ReadFull(hkdf, res[:])
if err != nil { if err != nil {
panic(err)
return nil, nil, err
} }
recvSecret = new([aeadKeySize]byte) recvSecret = new([aeadKeySize]byte)
@ -454,13 +464,14 @@ func shareAuthSignature(sc io.ReadWriter, pubKey crypto.PubKey, signature []byte
// Due to chacha20poly1305 expecting a 12 byte nonce we do not use the first four // Due to chacha20poly1305 expecting a 12 byte nonce we do not use the first four
// bytes. We only increment a 64 bit unsigned int in the remaining 8 bytes // bytes. We only increment a 64 bit unsigned int in the remaining 8 bytes
// (little-endian in nonce[4:]). // (little-endian in nonce[4:]).
func incrNonce(nonce *[aeadNonceSize]byte) {
func incrNonce(nonce *[aeadNonceSize]byte) error {
counter := binary.LittleEndian.Uint64(nonce[4:]) counter := binary.LittleEndian.Uint64(nonce[4:])
if counter == math.MaxUint64 { if counter == math.MaxUint64 {
// Terminates the session and makes sure the nonce would not re-used. // Terminates the session and makes sure the nonce would not re-used.
// See https://github.com/tendermint/tendermint/issues/3531 // See https://github.com/tendermint/tendermint/issues/3531
panic("can't increase nonce without overflow")
return errors.New("can't increase nonce without overflow")
} }
counter++ counter++
binary.LittleEndian.PutUint64(nonce[4:], counter) binary.LittleEndian.PutUint64(nonce[4:], counter)
return nil
} }

+ 4
- 6
privval/signer_listener_endpoint_test.go View File

@ -98,7 +98,7 @@ func TestRetryConnToRemoteSigner(t *testing.T) {
mockPV = types.NewMockPV() mockPV = types.NewMockPV()
endpointIsOpenCh = make(chan struct{}) endpointIsOpenCh = make(chan struct{})
thisConnTimeout = testTimeoutReadWrite thisConnTimeout = testTimeoutReadWrite
listenerEndpoint = newSignerListenerEndpoint(logger, tc.addr, thisConnTimeout)
listenerEndpoint = newSignerListenerEndpoint(t, logger, tc.addr, thisConnTimeout)
) )
dialerEndpoint := NewSignerDialerEndpoint( dialerEndpoint := NewSignerDialerEndpoint(
@ -138,14 +138,12 @@ func TestRetryConnToRemoteSigner(t *testing.T) {
} }
} }
func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint {
func newSignerListenerEndpoint(t *testing.T, logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint {
proto, address := tmnet.ProtocolAndAddress(addr) proto, address := tmnet.ProtocolAndAddress(addr)
ln, err := net.Listen(proto, address) ln, err := net.Listen(proto, address)
logger.Info("SignerListener: Listening", "proto", proto, "address", address) logger.Info("SignerListener: Listening", "proto", proto, "address", address)
if err != nil {
panic(err)
}
require.NoError(t, err)
var listener net.Listener var listener net.Listener
@ -199,7 +197,7 @@ func getMockEndpoints(
socketDialer, socketDialer,
) )
listenerEndpoint = newSignerListenerEndpoint(logger, addr, testTimeoutReadWrite)
listenerEndpoint = newSignerListenerEndpoint(t, logger, addr, testTimeoutReadWrite)
) )
SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint) SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint)


+ 13
- 3
privval/socket_dialers_test.go View File

@ -9,10 +9,20 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/crypto/ed25519"
tmnet "github.com/tendermint/tendermint/libs/net"
) )
// getFreeLocalhostAddrPort returns a free localhost:port address
func getFreeLocalhostAddrPort(t *testing.T) string {
t.Helper()
port, err := tmnet.GetFreePort()
require.NoError(t, err)
return fmt.Sprintf("127.0.0.1:%d", port)
}
func getDialerTestCases(t *testing.T) []dialerTestCase { func getDialerTestCases(t *testing.T) []dialerTestCase {
tcpAddr := GetFreeLocalhostAddrPort()
tcpAddr := getFreeLocalhostAddrPort(t)
unixFilePath, err := testUnixAddr() unixFilePath, err := testUnixAddr()
require.NoError(t, err) require.NoError(t, err)
unixAddr := fmt.Sprintf("unix://%s", unixFilePath) unixAddr := fmt.Sprintf("unix://%s", unixFilePath)
@ -31,7 +41,7 @@ func getDialerTestCases(t *testing.T) []dialerTestCase {
func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) { func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) {
// Generate a networking timeout // Generate a networking timeout
tcpAddr := GetFreeLocalhostAddrPort()
tcpAddr := getFreeLocalhostAddrPort(t)
dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey()) dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey())
_, err := dialer() _, err := dialer()
assert.Error(t, err) assert.Error(t, err)
@ -39,7 +49,7 @@ func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) {
} }
func TestIsConnTimeoutForWrappedConnTimeouts(t *testing.T) { func TestIsConnTimeoutForWrappedConnTimeouts(t *testing.T) {
tcpAddr := GetFreeLocalhostAddrPort()
tcpAddr := getFreeLocalhostAddrPort(t)
dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey()) dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey())
_, err := dialer() _, err := dialer()
assert.Error(t, err) assert.Error(t, err)


+ 2
- 3
privval/socket_listeners_test.go View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/crypto/ed25519"
) )
@ -107,9 +108,7 @@ func TestListenerTimeoutReadWrite(t *testing.T) {
for _, tc := range listenerTestCases(t, timeoutAccept, timeoutReadWrite) { for _, tc := range listenerTestCases(t, timeoutAccept, timeoutReadWrite) {
go func(dialer SocketDialer) { go func(dialer SocketDialer) {
_, err := dialer() _, err := dialer()
if err != nil {
panic(err)
}
require.NoError(t, err)
}(tc.dialer) }(tc.dialer)
c, err := tc.listener.Accept() c, err := tc.listener.Accept()


+ 0
- 9
privval/utils.go View File

@ -51,12 +51,3 @@ func NewSignerListener(listenAddr string, logger log.Logger) (*SignerListenerEnd
return pve, nil return pve, nil
} }
// GetFreeLocalhostAddrPort returns a free localhost:port address
func GetFreeLocalhostAddrPort() string {
port, err := tmnet.GetFreePort()
if err != nil {
panic(err)
}
return fmt.Sprintf("127.0.0.1:%d", port)
}

+ 8
- 2
test/e2e/runner/setup.go View File

@ -111,17 +111,23 @@ func Setup(testnet *e2e.Testnet) error {
return err return err
} }
(privval.NewFilePV(node.PrivvalKey,
err = (privval.NewFilePV(node.PrivvalKey,
filepath.Join(nodeDir, PrivvalKeyFile), filepath.Join(nodeDir, PrivvalKeyFile),
filepath.Join(nodeDir, PrivvalStateFile), filepath.Join(nodeDir, PrivvalStateFile),
)).Save() )).Save()
if err != nil {
return err
}
// Set up a dummy validator. Tendermint requires a file PV even when not used, so we // Set up a dummy validator. Tendermint requires a file PV even when not used, so we
// give it a dummy such that it will fail if it actually tries to use it. // give it a dummy such that it will fail if it actually tries to use it.
(privval.NewFilePV(ed25519.GenPrivKey(),
err = (privval.NewFilePV(ed25519.GenPrivKey(),
filepath.Join(nodeDir, PrivvalDummyKeyFile), filepath.Join(nodeDir, PrivvalDummyKeyFile),
filepath.Join(nodeDir, PrivvalDummyStateFile), filepath.Join(nodeDir, PrivvalDummyStateFile),
)).Save() )).Save()
if err != nil {
return err
}
} }
return nil return nil


Loading…
Cancel
Save