diff --git a/cmd/priv_val_server/main.go b/cmd/priv_val_server/main.go index 6d5406924..c86bced81 100644 --- a/cmd/priv_val_server/main.go +++ b/cmd/priv_val_server/main.go @@ -35,7 +35,7 @@ func main() { pv := privval.LoadFilePV(*privValKeyPath, *privValStatePath) - var dialer privval.Dialer + var dialer privval.SocketDialer protocol, address := cmn.ProtocolAndAddress(*addr) switch protocol { case "unix": @@ -48,7 +48,7 @@ func main() { os.Exit(1) } - rs := privval.NewRemoteSigner(logger, *chainID, pv, dialer) + rs := privval.NewSignerServiceEndpoint(logger, *chainID, pv, dialer) err := rs.Start() if err != nil { panic(err) diff --git a/node/node.go b/node/node.go index e5ddd09cc..2b803502f 100644 --- a/node/node.go +++ b/node/node.go @@ -914,7 +914,7 @@ func createAndStartPrivValidatorSocketClient( ) } - pvsc := privval.NewSocketVal(logger.With("module", "privval"), listener) + pvsc := privval.NewSignerValidatorEndpoint(logger.With("module", "privval"), listener) if err := pvsc.Start(); err != nil { return nil, errors.Wrap(err, "failed to start private validator") } diff --git a/node/node_test.go b/node/node_test.go index 0fce0dd96..ebc3f2102 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -132,13 +132,13 @@ func TestNodeSetPrivValTCP(t *testing.T) { config.BaseConfig.PrivValidatorListenAddr = addr dialer := privval.DialTCPFn(addr, 100*time.Millisecond, ed25519.GenPrivKey()) - pvsc := privval.NewRemoteSigner( + pvsc := privval.NewSignerServiceEndpoint( log.TestingLogger(), config.ChainID(), types.NewMockPV(), dialer, ) - privval.RemoteSignerConnDeadline(100 * time.Millisecond)(pvsc) + privval.SignerServiceEndpointTimeoutReadWrite(100 * time.Millisecond)(pvsc) go func() { err := pvsc.Start() @@ -150,7 +150,7 @@ func TestNodeSetPrivValTCP(t *testing.T) { n, err := DefaultNewNode(config, log.TestingLogger()) require.NoError(t, err) - assert.IsType(t, &privval.SocketVal{}, n.PrivValidator()) + assert.IsType(t, &privval.SignerValidatorEndpoint{}, n.PrivValidator()) } // address without a protocol must result in error @@ -174,13 +174,13 @@ func TestNodeSetPrivValIPC(t *testing.T) { config.BaseConfig.PrivValidatorListenAddr = "unix://" + tmpfile dialer := privval.DialUnixFn(tmpfile) - pvsc := privval.NewRemoteSigner( + pvsc := privval.NewSignerServiceEndpoint( log.TestingLogger(), config.ChainID(), types.NewMockPV(), dialer, ) - privval.RemoteSignerConnDeadline(100 * time.Millisecond)(pvsc) + privval.SignerServiceEndpointTimeoutReadWrite(100 * time.Millisecond)(pvsc) go func() { err := pvsc.Start() @@ -190,7 +190,7 @@ func TestNodeSetPrivValIPC(t *testing.T) { n, err := DefaultNewNode(config, log.TestingLogger()) require.NoError(t, err) - assert.IsType(t, &privval.SocketVal{}, n.PrivValidator()) + assert.IsType(t, &privval.SignerValidatorEndpoint{}, n.PrivValidator()) } diff --git a/privval/client.go b/privval/client.go deleted file mode 100644 index 11151fee3..000000000 --- a/privval/client.go +++ /dev/null @@ -1,240 +0,0 @@ -package privval - -import ( - "errors" - "fmt" - "net" - "sync" - "time" - - "github.com/tendermint/tendermint/crypto" - cmn "github.com/tendermint/tendermint/libs/common" - "github.com/tendermint/tendermint/libs/log" - "github.com/tendermint/tendermint/types" -) - -const ( - defaultConnHeartBeatSeconds = 2 - defaultDialRetries = 10 -) - -// Socket errors. -var ( - ErrUnexpectedResponse = errors.New("received unexpected response") -) - -var ( - connHeartbeat = time.Second * defaultConnHeartBeatSeconds -) - -// SocketValOption sets an optional parameter on the SocketVal. -type SocketValOption func(*SocketVal) - -// SocketValHeartbeat sets the period on which to check the liveness of the -// connected Signer connections. -func SocketValHeartbeat(period time.Duration) SocketValOption { - return func(sc *SocketVal) { sc.connHeartbeat = period } -} - -// SocketVal implements PrivValidator. -// It listens for an external process to dial in and uses -// the socket to request signatures. -type SocketVal struct { - cmn.BaseService - - listener net.Listener - - // ping - cancelPing chan struct{} - pingTicker *time.Ticker - connHeartbeat time.Duration - - // signer is mutable since it can be - // reset if the connection fails. - // failures are detected by a background - // ping routine. - // All messages are request/response, so we hold the mutex - // so only one request/response pair can happen at a time. - // Methods on the underlying net.Conn itself - // are already gorountine safe. - mtx sync.Mutex - signer *RemoteSignerClient -} - -// Check that SocketVal implements PrivValidator. -var _ types.PrivValidator = (*SocketVal)(nil) - -// NewSocketVal returns an instance of SocketVal. -func NewSocketVal( - logger log.Logger, - listener net.Listener, -) *SocketVal { - sc := &SocketVal{ - listener: listener, - connHeartbeat: connHeartbeat, - } - - sc.BaseService = *cmn.NewBaseService(logger, "SocketVal", sc) - - return sc -} - -//-------------------------------------------------------- -// Implement PrivValidator - -// GetPubKey implements PrivValidator. -func (sc *SocketVal) GetPubKey() crypto.PubKey { - sc.mtx.Lock() - defer sc.mtx.Unlock() - return sc.signer.GetPubKey() -} - -// SignVote implements PrivValidator. -func (sc *SocketVal) SignVote(chainID string, vote *types.Vote) error { - sc.mtx.Lock() - defer sc.mtx.Unlock() - return sc.signer.SignVote(chainID, vote) -} - -// SignProposal implements PrivValidator. -func (sc *SocketVal) SignProposal(chainID string, proposal *types.Proposal) error { - sc.mtx.Lock() - defer sc.mtx.Unlock() - return sc.signer.SignProposal(chainID, proposal) -} - -//-------------------------------------------------------- -// More thread safe methods proxied to the signer - -// Ping is used to check connection health. -func (sc *SocketVal) Ping() error { - sc.mtx.Lock() - defer sc.mtx.Unlock() - return sc.signer.Ping() -} - -// Close closes the underlying net.Conn. -func (sc *SocketVal) Close() { - sc.mtx.Lock() - defer sc.mtx.Unlock() - if sc.signer != nil { - if err := sc.signer.Close(); err != nil { - sc.Logger.Error("OnStop", "err", err) - } - } - - if sc.listener != nil { - if err := sc.listener.Close(); err != nil { - sc.Logger.Error("OnStop", "err", err) - } - } -} - -//-------------------------------------------------------- -// Service start and stop - -// OnStart implements cmn.Service. -func (sc *SocketVal) OnStart() error { - if closed, err := sc.reset(); err != nil { - sc.Logger.Error("OnStart", "err", err) - return err - } else if closed { - return fmt.Errorf("listener is closed") - } - - // Start a routine to keep the connection alive - sc.cancelPing = make(chan struct{}, 1) - sc.pingTicker = time.NewTicker(sc.connHeartbeat) - go func() { - for { - select { - case <-sc.pingTicker.C: - err := sc.Ping() - if err != nil { - sc.Logger.Error("Ping", "err", err) - if err == ErrUnexpectedResponse { - return - } - - closed, err := sc.reset() - if err != nil { - sc.Logger.Error("Reconnecting to remote signer failed", "err", err) - continue - } - if closed { - sc.Logger.Info("listener is closing") - return - } - - sc.Logger.Info("Re-created connection to remote signer", "impl", sc) - } - case <-sc.cancelPing: - sc.pingTicker.Stop() - return - } - } - }() - - return nil -} - -// OnStop implements cmn.Service. -func (sc *SocketVal) OnStop() { - if sc.cancelPing != nil { - close(sc.cancelPing) - } - sc.Close() -} - -//-------------------------------------------------------- -// Connection and signer management - -// waits to accept and sets a new connection. -// connection is closed in OnStop. -// returns true if the listener is closed -// (ie. it returns a nil conn). -func (sc *SocketVal) reset() (closed bool, err error) { - sc.mtx.Lock() - defer sc.mtx.Unlock() - - // first check if the conn already exists and close it. - if sc.signer != nil { - if err := sc.signer.Close(); err != nil { - sc.Logger.Error("error closing socket val connection during reset", "err", err) - } - } - - // wait for a new conn - conn, err := sc.acceptConnection() - if err != nil { - return false, err - } - - // listener is closed - if conn == nil { - return true, nil - } - - sc.signer, err = NewRemoteSignerClient(conn) - if err != nil { - // failed to fetch the pubkey. close out the connection. - if err := conn.Close(); err != nil { - sc.Logger.Error("error closing connection", "err", err) - } - return false, err - } - return false, nil -} - -// Attempt to accept a connection. -// Times out after the listener's acceptDeadline -func (sc *SocketVal) acceptConnection() (net.Conn, error) { - conn, err := sc.listener.Accept() - if err != nil { - if !sc.IsRunning() { - return nil, nil // Ignore error from listener closing. - } - return nil, err - } - return conn, nil -} diff --git a/privval/client_test.go b/privval/client_test.go deleted file mode 100644 index 1aea58cf0..000000000 --- a/privval/client_test.go +++ /dev/null @@ -1,461 +0,0 @@ -package privval - -import ( - "fmt" - "net" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/tendermint/tendermint/crypto/ed25519" - cmn "github.com/tendermint/tendermint/libs/common" - "github.com/tendermint/tendermint/libs/log" - - "github.com/tendermint/tendermint/types" -) - -var ( - testAcceptDeadline = defaultAcceptDeadlineSeconds * time.Second - - testConnDeadline = 100 * time.Millisecond - testConnDeadline2o3 = 66 * time.Millisecond // 2/3 of the other one - - testHeartbeatTimeout = 10 * time.Millisecond - testHeartbeatTimeout3o2 = 6 * time.Millisecond // 3/2 of the other one -) - -type socketTestCase struct { - addr string - dialer Dialer -} - -func socketTestCases(t *testing.T) []socketTestCase { - tcpAddr := fmt.Sprintf("tcp://%s", testFreeTCPAddr(t)) - unixFilePath, err := testUnixAddr() - require.NoError(t, err) - unixAddr := fmt.Sprintf("unix://%s", unixFilePath) - return []socketTestCase{ - { - addr: tcpAddr, - dialer: DialTCPFn(tcpAddr, testConnDeadline, ed25519.GenPrivKey()), - }, - { - addr: unixAddr, - dialer: DialUnixFn(unixFilePath), - }, - } -} - -func TestSocketPVAddress(t *testing.T) { - for _, tc := range socketTestCases(t) { - // Execute the test within a closure to ensure the deferred statements - // are called between each for loop iteration, for isolated test cases. - func() { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) - ) - defer sc.Stop() - defer rs.Stop() - - serverAddr := rs.privVal.GetPubKey().Address() - clientAddr := sc.GetPubKey().Address() - - assert.Equal(t, serverAddr, clientAddr) - }() - } -} - -func TestSocketPVPubKey(t *testing.T) { - for _, tc := range socketTestCases(t) { - func() { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) - ) - defer sc.Stop() - defer rs.Stop() - - clientKey := sc.GetPubKey() - - privvalPubKey := rs.privVal.GetPubKey() - - assert.Equal(t, privvalPubKey, clientKey) - }() - } -} - -func TestSocketPVProposal(t *testing.T) { - for _, tc := range socketTestCases(t) { - func() { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) - - ts = time.Now() - privProposal = &types.Proposal{Timestamp: ts} - clientProposal = &types.Proposal{Timestamp: ts} - ) - defer sc.Stop() - defer rs.Stop() - - require.NoError(t, rs.privVal.SignProposal(chainID, privProposal)) - require.NoError(t, sc.SignProposal(chainID, clientProposal)) - assert.Equal(t, privProposal.Signature, clientProposal.Signature) - }() - } -} - -func TestSocketPVVote(t *testing.T) { - for _, tc := range socketTestCases(t) { - func() { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) - - ts = time.Now() - vType = types.PrecommitType - want = &types.Vote{Timestamp: ts, Type: vType} - have = &types.Vote{Timestamp: ts, Type: vType} - ) - defer sc.Stop() - defer rs.Stop() - - require.NoError(t, rs.privVal.SignVote(chainID, want)) - require.NoError(t, sc.SignVote(chainID, have)) - assert.Equal(t, want.Signature, have.Signature) - }() - } -} - -func TestSocketPVVoteResetDeadline(t *testing.T) { - for _, tc := range socketTestCases(t) { - func() { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) - - ts = time.Now() - vType = types.PrecommitType - want = &types.Vote{Timestamp: ts, Type: vType} - have = &types.Vote{Timestamp: ts, Type: vType} - ) - defer sc.Stop() - defer rs.Stop() - - time.Sleep(testConnDeadline2o3) - - require.NoError(t, rs.privVal.SignVote(chainID, want)) - require.NoError(t, sc.SignVote(chainID, have)) - assert.Equal(t, want.Signature, have.Signature) - - // This would exceed the deadline if it was not extended by the previous message - time.Sleep(testConnDeadline2o3) - - require.NoError(t, rs.privVal.SignVote(chainID, want)) - require.NoError(t, sc.SignVote(chainID, have)) - assert.Equal(t, want.Signature, have.Signature) - }() - } -} - -func TestSocketPVVoteKeepalive(t *testing.T) { - for _, tc := range socketTestCases(t) { - func() { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) - - ts = time.Now() - vType = types.PrecommitType - want = &types.Vote{Timestamp: ts, Type: vType} - have = &types.Vote{Timestamp: ts, Type: vType} - ) - defer sc.Stop() - defer rs.Stop() - - time.Sleep(testConnDeadline * 2) - - require.NoError(t, rs.privVal.SignVote(chainID, want)) - require.NoError(t, sc.SignVote(chainID, have)) - assert.Equal(t, want.Signature, have.Signature) - }() - } -} - -func TestSocketPVDeadline(t *testing.T) { - for _, tc := range socketTestCases(t) { - func() { - var ( - listenc = make(chan struct{}) - thisConnTimeout = 100 * time.Millisecond - sc = newSocketVal(log.TestingLogger(), tc.addr, thisConnTimeout) - ) - - go func(sc *SocketVal) { - defer close(listenc) - - // Note: the TCP connection times out at the accept() phase, - // whereas the Unix domain sockets connection times out while - // attempting to fetch the remote signer's public key. - assert.True(t, IsConnTimeout(sc.Start())) - - assert.False(t, sc.IsRunning()) - }(sc) - - for { - _, err := cmn.Connect(tc.addr) - if err == nil { - break - } - } - - <-listenc - }() - } -} - -func TestRemoteSignVoteErrors(t *testing.T) { - for _, tc := range socketTestCases(t) { - func() { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewErroringMockPV(), tc.addr, tc.dialer) - - ts = time.Now() - vType = types.PrecommitType - vote = &types.Vote{Timestamp: ts, Type: vType} - ) - defer sc.Stop() - defer rs.Stop() - - err := sc.SignVote("", vote) - require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) - - err = rs.privVal.SignVote(chainID, vote) - require.Error(t, err) - err = sc.SignVote(chainID, vote) - require.Error(t, err) - }() - } -} - -func TestRemoteSignProposalErrors(t *testing.T) { - for _, tc := range socketTestCases(t) { - func() { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewErroringMockPV(), tc.addr, tc.dialer) - - ts = time.Now() - proposal = &types.Proposal{Timestamp: ts} - ) - defer sc.Stop() - defer rs.Stop() - - err := sc.SignProposal("", proposal) - require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) - - err = rs.privVal.SignProposal(chainID, proposal) - require.Error(t, err) - - err = sc.SignProposal(chainID, proposal) - require.Error(t, err) - }() - } -} - -func TestErrUnexpectedResponse(t *testing.T) { - for _, tc := range socketTestCases(t) { - func() { - var ( - logger = log.TestingLogger() - chainID = cmn.RandStr(12) - readyc = make(chan struct{}) - errc = make(chan error, 1) - - rs = NewRemoteSigner( - logger, - chainID, - types.NewMockPV(), - tc.dialer, - ) - sc = newSocketVal(logger, tc.addr, testConnDeadline) - ) - - testStartSocketPV(t, readyc, sc) - defer sc.Stop() - RemoteSignerConnDeadline(time.Millisecond)(rs) - RemoteSignerConnRetries(100)(rs) - // we do not want to Start() the remote signer here and instead use the connection to - // reply with intentionally wrong replies below: - rsConn, err := rs.connect() - defer rsConn.Close() - require.NoError(t, err) - require.NotNil(t, rsConn) - // send over public key to get the remote signer running: - go testReadWriteResponse(t, &PubKeyResponse{}, rsConn) - <-readyc - - // Proposal: - go func(errc chan error) { - errc <- sc.SignProposal(chainID, &types.Proposal{}) - }(errc) - // read request and write wrong response: - go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn) - err = <-errc - require.Error(t, err) - require.Equal(t, err, ErrUnexpectedResponse) - - // Vote: - go func(errc chan error) { - errc <- sc.SignVote(chainID, &types.Vote{}) - }(errc) - // read request and write wrong response: - go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn) - err = <-errc - require.Error(t, err) - require.Equal(t, err, ErrUnexpectedResponse) - }() - } -} - -func TestRetryConnToRemoteSigner(t *testing.T) { - for _, tc := range socketTestCases(t) { - func() { - var ( - logger = log.TestingLogger() - chainID = cmn.RandStr(12) - readyc = make(chan struct{}) - - rs = NewRemoteSigner( - logger, - chainID, - types.NewMockPV(), - tc.dialer, - ) - thisConnTimeout = testConnDeadline - sc = newSocketVal(logger, tc.addr, thisConnTimeout) - ) - // Ping every: - SocketValHeartbeat(testHeartbeatTimeout)(sc) - - RemoteSignerConnDeadline(testConnDeadline)(rs) - RemoteSignerConnRetries(10)(rs) - - testStartSocketPV(t, readyc, sc) - defer sc.Stop() - require.NoError(t, rs.Start()) - assert.True(t, rs.IsRunning()) - - <-readyc - time.Sleep(testHeartbeatTimeout * 2) - - rs.Stop() - rs2 := NewRemoteSigner( - logger, - chainID, - types.NewMockPV(), - tc.dialer, - ) - // let some pings pass - time.Sleep(testHeartbeatTimeout3o2) - require.NoError(t, rs2.Start()) - assert.True(t, rs2.IsRunning()) - defer rs2.Stop() - - // give the client some time to re-establish the conn to the remote signer - // should see sth like this in the logs: - // - // E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" - // I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal - time.Sleep(testConnDeadline * 2) - }() - } -} - -func newSocketVal(logger log.Logger, addr string, connDeadline time.Duration) *SocketVal { - proto, address := cmn.ProtocolAndAddress(addr) - ln, err := net.Listen(proto, address) - logger.Info("Listening at", "proto", proto, "address", address) - if err != nil { - panic(err) - } - var svln net.Listener - if proto == "unix" { - unixLn := NewUnixListener(ln) - UnixListenerAcceptDeadline(testAcceptDeadline)(unixLn) - UnixListenerConnDeadline(connDeadline)(unixLn) - svln = unixLn - } else { - tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) - TCPListenerAcceptDeadline(testAcceptDeadline)(tcpLn) - TCPListenerConnDeadline(connDeadline)(tcpLn) - svln = tcpLn - } - return NewSocketVal(logger, svln) -} - -func testSetupSocketPair( - t *testing.T, - chainID string, - privValidator types.PrivValidator, - addr string, - dialer Dialer, -) (*SocketVal, *RemoteSigner) { - var ( - logger = log.TestingLogger() - privVal = privValidator - readyc = make(chan struct{}) - rs = NewRemoteSigner( - logger, - chainID, - privVal, - dialer, - ) - - thisConnTimeout = testConnDeadline - sc = newSocketVal(logger, addr, thisConnTimeout) - ) - - SocketValHeartbeat(testHeartbeatTimeout)(sc) - RemoteSignerConnDeadline(testConnDeadline)(rs) - RemoteSignerConnRetries(1e6)(rs) - - testStartSocketPV(t, readyc, sc) - - require.NoError(t, rs.Start()) - assert.True(t, rs.IsRunning()) - - <-readyc - - return sc, rs -} - -func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) { - _, err := readMsg(rsConn) - require.NoError(t, err) - - err = writeMsg(rsConn, resp) - require.NoError(t, err) -} - -func testStartSocketPV(t *testing.T, readyc chan struct{}, sc *SocketVal) { - go func(sc *SocketVal) { - require.NoError(t, sc.Start()) - assert.True(t, sc.IsRunning()) - - readyc <- struct{}{} - }(sc) -} - -// testFreeTCPAddr claims a free port so we don't block on listener being ready. -func testFreeTCPAddr(t *testing.T) string { - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - defer ln.Close() - - return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port) -} diff --git a/privval/doc.go b/privval/doc.go index ed378c190..80869a6a7 100644 --- a/privval/doc.go +++ b/privval/doc.go @@ -6,16 +6,16 @@ FilePV FilePV is the simplest implementation and developer default. It uses one file for the private key and another to store state. -SocketVal +SignerValidatorEndpoint -SocketVal establishes a connection to an external process, like a Key Management Server (KMS), using a socket. -SocketVal listens for the external KMS process to dial in. -SocketVal takes a listener, which determines the type of connection +SignerValidatorEndpoint establishes a connection to an external process, like a Key Management Server (KMS), using a socket. +SignerValidatorEndpoint listens for the external KMS process to dial in. +SignerValidatorEndpoint takes a listener, which determines the type of connection (ie. encrypted over tcp, or unencrypted over unix). -RemoteSigner +SignerServiceEndpoint -RemoteSigner is a simple wrapper around a net.Conn. It's used by both IPCVal and TCPVal. +SignerServiceEndpoint is a simple wrapper around a net.Conn. It's used by both IPCVal and TCPVal. */ package privval diff --git a/privval/errors.go b/privval/errors.go new file mode 100644 index 000000000..75fb25fc6 --- /dev/null +++ b/privval/errors.go @@ -0,0 +1,22 @@ +package privval + +import ( + "fmt" +) + +// Socket errors. +var ( + ErrUnexpectedResponse = fmt.Errorf("received unexpected response") + ErrConnTimeout = fmt.Errorf("remote signer timed out") +) + +// RemoteSignerError allows (remote) validators to include meaningful error descriptions in their reply. +type RemoteSignerError struct { + // TODO(ismail): create an enum of known errors + Code int + Description string +} + +func (e *RemoteSignerError) Error() string { + return fmt.Sprintf("signerServiceEndpoint returned error #%d: %s", e.Code, e.Description) +} diff --git a/privval/file.go b/privval/file.go index 8eb38e806..1cb88f7c0 100644 --- a/privval/file.go +++ b/privval/file.go @@ -49,7 +49,7 @@ type FilePVKey struct { func (pvKey FilePVKey) Save() { outFile := pvKey.filePath if outFile == "" { - panic("Cannot save PrivValidator key: filePath not set") + panic("cannot save PrivValidator key: filePath not set") } jsonBytes, err := cdc.MarshalJSONIndent(pvKey, "", " ") @@ -86,17 +86,17 @@ type FilePVLastSignState struct { func (lss *FilePVLastSignState) CheckHRS(height int64, round int, step int8) (bool, error) { if lss.Height > height { - return false, fmt.Errorf("Height regression. Got %v, last height %v", height, lss.Height) + return false, fmt.Errorf("height regression. Got %v, last height %v", height, lss.Height) } if lss.Height == height { if lss.Round > round { - return false, fmt.Errorf("Round regression at height %v. Got %v, last round %v", height, round, lss.Round) + return false, fmt.Errorf("round regression at height %v. Got %v, last round %v", height, round, lss.Round) } if lss.Round == round { if lss.Step > step { - return false, fmt.Errorf("Step regression at height %v round %v. Got %v, last step %v", height, round, step, lss.Step) + return false, fmt.Errorf("step regression at height %v round %v. Got %v, last step %v", height, round, step, lss.Step) } else if lss.Step == step { if lss.SignBytes != nil { if lss.Signature == nil { @@ -104,7 +104,7 @@ func (lss *FilePVLastSignState) CheckHRS(height int64, round int, step int8) (bo } return true, nil } - return false, errors.New("No SignBytes found") + return false, errors.New("no SignBytes found") } } } @@ -115,7 +115,7 @@ func (lss *FilePVLastSignState) CheckHRS(height int64, round int, step int8) (bo func (lss *FilePVLastSignState) Save() { outFile := lss.filePath if outFile == "" { - panic("Cannot save FilePVLastSignState: filePath not set") + panic("cannot save FilePVLastSignState: filePath not set") } jsonBytes, err := cdc.MarshalJSONIndent(lss, "", " ") if err != nil { @@ -237,7 +237,7 @@ func (pv *FilePV) GetPubKey() crypto.PubKey { // chainID. Implements PrivValidator. func (pv *FilePV) SignVote(chainID string, vote *types.Vote) error { if err := pv.signVote(chainID, vote); err != nil { - return fmt.Errorf("Error signing vote: %v", err) + return fmt.Errorf("error signing vote: %v", err) } return nil } @@ -246,7 +246,7 @@ func (pv *FilePV) SignVote(chainID string, vote *types.Vote) error { // the chainID. Implements PrivValidator. func (pv *FilePV) SignProposal(chainID string, proposal *types.Proposal) error { if err := pv.signProposal(chainID, proposal); err != nil { - return fmt.Errorf("Error signing proposal: %v", err) + return fmt.Errorf("error signing proposal: %v", err) } return nil } @@ -303,7 +303,7 @@ func (pv *FilePV) signVote(chainID string, vote *types.Vote) error { vote.Timestamp = timestamp vote.Signature = lss.Signature } else { - err = fmt.Errorf("Conflicting data") + err = fmt.Errorf("conflicting data") } return err } @@ -345,7 +345,7 @@ func (pv *FilePV) signProposal(chainID string, proposal *types.Proposal) error { proposal.Timestamp = timestamp proposal.Signature = lss.Signature } else { - err = fmt.Errorf("Conflicting data") + err = fmt.Errorf("conflicting data") } return err } diff --git a/privval/old_file.go b/privval/file_deprecated.go similarity index 98% rename from privval/old_file.go rename to privval/file_deprecated.go index ec72c1834..d010de763 100644 --- a/privval/old_file.go +++ b/privval/file_deprecated.go @@ -10,6 +10,7 @@ import ( ) // OldFilePV is the old version of the FilePV, pre v0.28.0. +// Deprecated: Use FilePV instead. type OldFilePV struct { Address types.Address `json:"address"` PubKey crypto.PubKey `json:"pub_key"` diff --git a/privval/old_file_test.go b/privval/file_deprecated_test.go similarity index 100% rename from privval/old_file_test.go rename to privval/file_deprecated_test.go diff --git a/privval/messages.go b/privval/messages.go new file mode 100644 index 000000000..6774a2795 --- /dev/null +++ b/privval/messages.go @@ -0,0 +1,61 @@ +package privval + +import ( + amino "github.com/tendermint/go-amino" + "github.com/tendermint/tendermint/crypto" + "github.com/tendermint/tendermint/types" +) + +// RemoteSignerMsg is sent between SignerServiceEndpoint and the SignerServiceEndpoint client. +type RemoteSignerMsg interface{} + +func RegisterRemoteSignerMsg(cdc *amino.Codec) { + cdc.RegisterInterface((*RemoteSignerMsg)(nil), nil) + cdc.RegisterConcrete(&PubKeyRequest{}, "tendermint/remotesigner/PubKeyRequest", nil) + cdc.RegisterConcrete(&PubKeyResponse{}, "tendermint/remotesigner/PubKeyResponse", nil) + cdc.RegisterConcrete(&SignVoteRequest{}, "tendermint/remotesigner/SignVoteRequest", nil) + cdc.RegisterConcrete(&SignedVoteResponse{}, "tendermint/remotesigner/SignedVoteResponse", nil) + cdc.RegisterConcrete(&SignProposalRequest{}, "tendermint/remotesigner/SignProposalRequest", nil) + cdc.RegisterConcrete(&SignedProposalResponse{}, "tendermint/remotesigner/SignedProposalResponse", nil) + cdc.RegisterConcrete(&PingRequest{}, "tendermint/remotesigner/PingRequest", nil) + cdc.RegisterConcrete(&PingResponse{}, "tendermint/remotesigner/PingResponse", nil) +} + +// PubKeyRequest requests the consensus public key from the remote signer. +type PubKeyRequest struct{} + +// PubKeyResponse is a PrivValidatorSocket message containing the public key. +type PubKeyResponse struct { + PubKey crypto.PubKey + Error *RemoteSignerError +} + +// SignVoteRequest is a PrivValidatorSocket message containing a vote. +type SignVoteRequest struct { + Vote *types.Vote +} + +// SignedVoteResponse is a PrivValidatorSocket message containing a signed vote along with a potenial error message. +type SignedVoteResponse struct { + Vote *types.Vote + Error *RemoteSignerError +} + +// SignProposalRequest is a PrivValidatorSocket message containing a Proposal. +type SignProposalRequest struct { + Proposal *types.Proposal +} + +// SignedProposalResponse is a PrivValidatorSocket message containing a proposal response +type SignedProposalResponse struct { + Proposal *types.Proposal + Error *RemoteSignerError +} + +// PingRequest is a PrivValidatorSocket message to keep the connection alive. +type PingRequest struct { +} + +// PingRequest is a PrivValidatorSocket response to keep the connection alive. +type PingResponse struct { +} diff --git a/privval/remote_signer_test.go b/privval/remote_signer_test.go deleted file mode 100644 index cb2a600db..000000000 --- a/privval/remote_signer_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package privval - -import ( - "net" - "testing" - "time" - - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tendermint/tendermint/crypto/ed25519" - cmn "github.com/tendermint/tendermint/libs/common" - "github.com/tendermint/tendermint/libs/log" - "github.com/tendermint/tendermint/types" -) - -// TestRemoteSignerRetryTCPOnly will test connection retry attempts over TCP. We -// don't need this for Unix sockets because the OS instantly knows the state of -// both ends of the socket connection. This basically causes the -// RemoteSigner.dialer() call inside RemoteSigner.connect() to return -// successfully immediately, putting an instant stop to any retry attempts. -func TestRemoteSignerRetryTCPOnly(t *testing.T) { - var ( - attemptc = make(chan int) - retries = 2 - ) - - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - go func(ln net.Listener, attemptc chan<- int) { - attempts := 0 - - for { - conn, err := ln.Accept() - require.NoError(t, err) - - err = conn.Close() - require.NoError(t, err) - - attempts++ - - if attempts == retries { - attemptc <- attempts - break - } - } - }(ln, attemptc) - - rs := NewRemoteSigner( - log.TestingLogger(), - cmn.RandStr(12), - types.NewMockPV(), - DialTCPFn(ln.Addr().String(), testConnDeadline, ed25519.GenPrivKey()), - ) - defer rs.Stop() - - RemoteSignerConnDeadline(time.Millisecond)(rs) - RemoteSignerConnRetries(retries)(rs) - - assert.Equal(t, rs.Start(), ErrDialRetryMax) - - select { - case attempts := <-attemptc: - assert.Equal(t, retries, attempts) - case <-time.After(100 * time.Millisecond): - t.Error("expected remote to observe connection attempts") - } -} - -func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) { - // Generate a networking timeout - dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) - _, err := dialer() - assert.Error(t, err) - assert.True(t, IsConnTimeout(err)) -} - -func TestIsConnTimeoutForWrappedConnTimeouts(t *testing.T) { - dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) - _, err := dialer() - assert.Error(t, err) - err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) - assert.True(t, IsConnTimeout(err)) -} - -func TestIsConnTimeoutForNonTimeoutErrors(t *testing.T) { - assert.False(t, IsConnTimeout(cmn.ErrorWrap(ErrDialRetryMax, "max retries exceeded"))) - assert.False(t, IsConnTimeout(errors.New("completely irrelevant error"))) -} diff --git a/privval/server.go b/privval/server.go deleted file mode 100644 index cce659525..000000000 --- a/privval/server.go +++ /dev/null @@ -1,168 +0,0 @@ -package privval - -import ( - "io" - "net" - "time" - - "github.com/pkg/errors" - "github.com/tendermint/tendermint/crypto/ed25519" - cmn "github.com/tendermint/tendermint/libs/common" - "github.com/tendermint/tendermint/libs/log" - p2pconn "github.com/tendermint/tendermint/p2p/conn" - "github.com/tendermint/tendermint/types" -) - -// Socket errors. -var ( - ErrDialRetryMax = errors.New("dialed maximum retries") -) - -// RemoteSignerOption sets an optional parameter on the RemoteSigner. -type RemoteSignerOption func(*RemoteSigner) - -// RemoteSignerConnDeadline sets the read and write deadline for connections -// from external signing processes. -func RemoteSignerConnDeadline(deadline time.Duration) RemoteSignerOption { - return func(ss *RemoteSigner) { ss.connDeadline = deadline } -} - -// RemoteSignerConnRetries sets the amount of attempted retries to connect. -func RemoteSignerConnRetries(retries int) RemoteSignerOption { - return func(ss *RemoteSigner) { ss.connRetries = retries } -} - -// RemoteSigner dials using its dialer and responds to any -// signature requests using its privVal. -type RemoteSigner struct { - cmn.BaseService - - chainID string - connDeadline time.Duration - connRetries int - privVal types.PrivValidator - - dialer Dialer - conn net.Conn -} - -// Dialer dials a remote address and returns a net.Conn or an error. -type Dialer func() (net.Conn, error) - -// DialTCPFn dials the given tcp addr, using the given connTimeout and privKey for the -// authenticated encryption handshake. -func DialTCPFn(addr string, connTimeout time.Duration, privKey ed25519.PrivKeyEd25519) Dialer { - return func() (net.Conn, error) { - conn, err := cmn.Connect(addr) - if err == nil { - err = conn.SetDeadline(time.Now().Add(connTimeout)) - } - if err == nil { - conn, err = p2pconn.MakeSecretConnection(conn, privKey) - } - return conn, err - } -} - -// DialUnixFn dials the given unix socket. -func DialUnixFn(addr string) Dialer { - return func() (net.Conn, error) { - unixAddr := &net.UnixAddr{Name: addr, Net: "unix"} - return net.DialUnix("unix", nil, unixAddr) - } -} - -// NewRemoteSigner return a RemoteSigner that will dial using the given -// dialer and respond to any signature requests over the connection -// using the given privVal. -func NewRemoteSigner( - logger log.Logger, - chainID string, - privVal types.PrivValidator, - dialer Dialer, -) *RemoteSigner { - rs := &RemoteSigner{ - chainID: chainID, - connDeadline: time.Second * defaultConnDeadlineSeconds, - connRetries: defaultDialRetries, - privVal: privVal, - dialer: dialer, - } - - rs.BaseService = *cmn.NewBaseService(logger, "RemoteSigner", rs) - return rs -} - -// OnStart implements cmn.Service. -func (rs *RemoteSigner) OnStart() error { - conn, err := rs.connect() - if err != nil { - rs.Logger.Error("OnStart", "err", err) - return err - } - rs.conn = conn - - go rs.handleConnection(conn) - - return nil -} - -// OnStop implements cmn.Service. -func (rs *RemoteSigner) OnStop() { - if rs.conn == nil { - return - } - - if err := rs.conn.Close(); err != nil { - rs.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed")) - } -} - -func (rs *RemoteSigner) connect() (net.Conn, error) { - for retries := rs.connRetries; retries > 0; retries-- { - // Don't sleep if it is the first retry. - if retries != rs.connRetries { - time.Sleep(rs.connDeadline) - } - conn, err := rs.dialer() - if err != nil { - rs.Logger.Error("dialing", "err", err) - continue - } - return conn, nil - } - - return nil, ErrDialRetryMax -} - -func (rs *RemoteSigner) handleConnection(conn net.Conn) { - for { - if !rs.IsRunning() { - return // Ignore error from listener closing. - } - - // Reset the connection deadline - conn.SetDeadline(time.Now().Add(rs.connDeadline)) - - req, err := readMsg(conn) - if err != nil { - if err != io.EOF { - rs.Logger.Error("handleConnection readMsg", "err", err) - } - return - } - - res, err := handleRequest(req, rs.chainID, rs.privVal) - - if err != nil { - // only log the error; we'll reply with an error in res - rs.Logger.Error("handleConnection handleRequest", "err", err) - } - - err = writeMsg(conn, res) - if err != nil { - rs.Logger.Error("handleConnection writeMsg", "err", err) - return - } - } -} diff --git a/privval/remote_signer.go b/privval/signer_remote.go similarity index 50% rename from privval/remote_signer.go rename to privval/signer_remote.go index a5b8cac64..53b0cb773 100644 --- a/privval/remote_signer.go +++ b/privval/signer_remote.go @@ -7,51 +7,44 @@ import ( "github.com/pkg/errors" - amino "github.com/tendermint/go-amino" "github.com/tendermint/tendermint/crypto" cmn "github.com/tendermint/tendermint/libs/common" "github.com/tendermint/tendermint/types" ) -// Socket errors. -var ( - ErrConnTimeout = errors.New("remote signer timed out") -) - -// RemoteSignerClient implements PrivValidator. -// It uses a net.Conn to request signatures -// from an external process. -type RemoteSignerClient struct { +// SignerRemote implements PrivValidator. +// It uses a net.Conn to request signatures from an external process. +type SignerRemote struct { conn net.Conn // memoized consensusPubKey crypto.PubKey } -// Check that RemoteSignerClient implements PrivValidator. -var _ types.PrivValidator = (*RemoteSignerClient)(nil) +// Check that SignerRemote implements PrivValidator. +var _ types.PrivValidator = (*SignerRemote)(nil) -// NewRemoteSignerClient returns an instance of RemoteSignerClient. -func NewRemoteSignerClient(conn net.Conn) (*RemoteSignerClient, error) { +// NewSignerRemote returns an instance of SignerRemote. +func NewSignerRemote(conn net.Conn) (*SignerRemote, error) { // retrieve and memoize the consensus public key once. pubKey, err := getPubKey(conn) if err != nil { return nil, cmn.ErrorWrap(err, "error while retrieving public key for remote signer") } - return &RemoteSignerClient{ + return &SignerRemote{ conn: conn, consensusPubKey: pubKey, }, nil } // Close calls Close on the underlying net.Conn. -func (sc *RemoteSignerClient) Close() error { +func (sc *SignerRemote) Close() error { return sc.conn.Close() } // GetPubKey implements PrivValidator. -func (sc *RemoteSignerClient) GetPubKey() crypto.PubKey { +func (sc *SignerRemote) GetPubKey() crypto.PubKey { return sc.consensusPubKey } @@ -66,6 +59,7 @@ func getPubKey(conn net.Conn) (crypto.PubKey, error) { if err != nil { return nil, err } + pubKeyResp, ok := res.(*PubKeyResponse) if !ok { return nil, errors.Wrap(ErrUnexpectedResponse, "response is not PubKeyResponse") @@ -79,7 +73,7 @@ func getPubKey(conn net.Conn) (crypto.PubKey, error) { } // SignVote implements PrivValidator. -func (sc *RemoteSignerClient) SignVote(chainID string, vote *types.Vote) error { +func (sc *SignerRemote) SignVote(chainID string, vote *types.Vote) error { err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote}) if err != nil { return err @@ -103,10 +97,7 @@ func (sc *RemoteSignerClient) SignVote(chainID string, vote *types.Vote) error { } // SignProposal implements PrivValidator. -func (sc *RemoteSignerClient) SignProposal( - chainID string, - proposal *types.Proposal, -) error { +func (sc *SignerRemote) SignProposal(chainID string, proposal *types.Proposal) error { err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal}) if err != nil { return err @@ -129,7 +120,7 @@ func (sc *RemoteSignerClient) SignProposal( } // Ping is used to check connection health. -func (sc *RemoteSignerClient) Ping() error { +func (sc *SignerRemote) Ping() error { err := writeMsg(sc.conn, &PingRequest{}) if err != nil { return err @@ -147,69 +138,6 @@ func (sc *RemoteSignerClient) Ping() error { return nil } -// RemoteSignerMsg is sent between RemoteSigner and the RemoteSigner client. -type RemoteSignerMsg interface{} - -func RegisterRemoteSignerMsg(cdc *amino.Codec) { - cdc.RegisterInterface((*RemoteSignerMsg)(nil), nil) - cdc.RegisterConcrete(&PubKeyRequest{}, "tendermint/remotesigner/PubKeyRequest", nil) - cdc.RegisterConcrete(&PubKeyResponse{}, "tendermint/remotesigner/PubKeyResponse", nil) - cdc.RegisterConcrete(&SignVoteRequest{}, "tendermint/remotesigner/SignVoteRequest", nil) - cdc.RegisterConcrete(&SignedVoteResponse{}, "tendermint/remotesigner/SignedVoteResponse", nil) - cdc.RegisterConcrete(&SignProposalRequest{}, "tendermint/remotesigner/SignProposalRequest", nil) - cdc.RegisterConcrete(&SignedProposalResponse{}, "tendermint/remotesigner/SignedProposalResponse", nil) - cdc.RegisterConcrete(&PingRequest{}, "tendermint/remotesigner/PingRequest", nil) - cdc.RegisterConcrete(&PingResponse{}, "tendermint/remotesigner/PingResponse", nil) -} - -// PubKeyRequest requests the consensus public key from the remote signer. -type PubKeyRequest struct{} - -// PubKeyResponse is a PrivValidatorSocket message containing the public key. -type PubKeyResponse struct { - PubKey crypto.PubKey - Error *RemoteSignerError -} - -// SignVoteRequest is a PrivValidatorSocket message containing a vote. -type SignVoteRequest struct { - Vote *types.Vote -} - -// SignedVoteResponse is a PrivValidatorSocket message containing a signed vote along with a potenial error message. -type SignedVoteResponse struct { - Vote *types.Vote - Error *RemoteSignerError -} - -// SignProposalRequest is a PrivValidatorSocket message containing a Proposal. -type SignProposalRequest struct { - Proposal *types.Proposal -} - -type SignedProposalResponse struct { - Proposal *types.Proposal - Error *RemoteSignerError -} - -// PingRequest is a PrivValidatorSocket message to keep the connection alive. -type PingRequest struct { -} - -type PingResponse struct { -} - -// RemoteSignerError allows (remote) validators to include meaningful error descriptions in their reply. -type RemoteSignerError struct { - // TODO(ismail): create an enum of known errors - Code int - Description string -} - -func (e *RemoteSignerError) Error() string { - return fmt.Sprintf("RemoteSigner returned error #%d: %s", e.Code, e.Description) -} - func readMsg(r io.Reader) (msg RemoteSignerMsg, err error) { const maxRemoteSignerMsgSize = 1024 * 10 _, err = cdc.UnmarshalBinaryLengthPrefixedReader(r, &msg, maxRemoteSignerMsgSize) @@ -236,6 +164,7 @@ func handleRequest(req RemoteSignerMsg, chainID string, privVal types.PrivValida var p crypto.PubKey p = privVal.GetPubKey() res = &PubKeyResponse{p, nil} + case *SignVoteRequest: err = privVal.SignVote(chainID, r.Vote) if err != nil { @@ -243,6 +172,7 @@ func handleRequest(req RemoteSignerMsg, chainID string, privVal types.PrivValida } else { res = &SignedVoteResponse{r.Vote, nil} } + case *SignProposalRequest: err = privVal.SignProposal(chainID, r.Proposal) if err != nil { @@ -250,26 +180,13 @@ func handleRequest(req RemoteSignerMsg, chainID string, privVal types.PrivValida } else { res = &SignedProposalResponse{r.Proposal, nil} } + case *PingRequest: res = &PingResponse{} + default: err = fmt.Errorf("unknown msg: %v", r) } return res, err } - -// IsConnTimeout returns a boolean indicating whether the error is known to -// report that a connection timeout occurred. This detects both fundamental -// network timeouts, as well as ErrConnTimeout errors. -func IsConnTimeout(err error) bool { - if cmnErr, ok := err.(cmn.Error); ok { - if cmnErr.Data() == ErrConnTimeout { - return true - } - } - if _, ok := err.(timeoutError); ok { - return true - } - return false -} diff --git a/privval/signer_remote_test.go b/privval/signer_remote_test.go new file mode 100644 index 000000000..28230b803 --- /dev/null +++ b/privval/signer_remote_test.go @@ -0,0 +1,68 @@ +package privval + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/crypto/ed25519" + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/types" +) + +// TestSignerRemoteRetryTCPOnly will test connection retry attempts over TCP. We +// don't need this for Unix sockets because the OS instantly knows the state of +// both ends of the socket connection. This basically causes the +// SignerServiceEndpoint.dialer() call inside SignerServiceEndpoint.connect() to return +// successfully immediately, putting an instant stop to any retry attempts. +func TestSignerRemoteRetryTCPOnly(t *testing.T) { + var ( + attemptCh = make(chan int) + retries = 2 + ) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + go func(ln net.Listener, attemptCh chan<- int) { + attempts := 0 + + for { + conn, err := ln.Accept() + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + + attempts++ + + if attempts == retries { + attemptCh <- attempts + break + } + } + }(ln, attemptCh) + + serviceEndpoint := NewSignerServiceEndpoint( + log.TestingLogger(), + cmn.RandStr(12), + types.NewMockPV(), + DialTCPFn(ln.Addr().String(), testTimeoutReadWrite, ed25519.GenPrivKey()), + ) + defer serviceEndpoint.Stop() + + SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint) + SignerServiceEndpointConnRetries(retries)(serviceEndpoint) + + assert.Equal(t, serviceEndpoint.Start(), ErrDialRetryMax) + + select { + case attempts := <-attemptCh: + assert.Equal(t, retries, attempts) + case <-time.After(100 * time.Millisecond): + t.Error("expected remote to observe connection attempts") + } +} diff --git a/privval/signer_service_endpoint.go b/privval/signer_service_endpoint.go new file mode 100644 index 000000000..1b37d5fc6 --- /dev/null +++ b/privval/signer_service_endpoint.go @@ -0,0 +1,139 @@ +package privval + +import ( + "io" + "net" + "time" + + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/types" +) + +// SignerServiceEndpointOption sets an optional parameter on the SignerServiceEndpoint. +type SignerServiceEndpointOption func(*SignerServiceEndpoint) + +// SignerServiceEndpointTimeoutReadWrite sets the read and write timeout for connections +// from external signing processes. +func SignerServiceEndpointTimeoutReadWrite(timeout time.Duration) SignerServiceEndpointOption { + return func(ss *SignerServiceEndpoint) { ss.timeoutReadWrite = timeout } +} + +// SignerServiceEndpointConnRetries sets the amount of attempted retries to connect. +func SignerServiceEndpointConnRetries(retries int) SignerServiceEndpointOption { + return func(ss *SignerServiceEndpoint) { ss.connRetries = retries } +} + +// SignerServiceEndpoint dials using its dialer and responds to any +// signature requests using its privVal. +type SignerServiceEndpoint struct { + cmn.BaseService + + chainID string + timeoutReadWrite time.Duration + connRetries int + privVal types.PrivValidator + + dialer SocketDialer + conn net.Conn +} + +// NewSignerServiceEndpoint returns a SignerServiceEndpoint that will dial using the given +// dialer and respond to any signature requests over the connection +// using the given privVal. +func NewSignerServiceEndpoint( + logger log.Logger, + chainID string, + privVal types.PrivValidator, + dialer SocketDialer, +) *SignerServiceEndpoint { + se := &SignerServiceEndpoint{ + chainID: chainID, + timeoutReadWrite: time.Second * defaultTimeoutReadWriteSeconds, + connRetries: defaultMaxDialRetries, + privVal: privVal, + dialer: dialer, + } + + se.BaseService = *cmn.NewBaseService(logger, "SignerServiceEndpoint", se) + return se +} + +// OnStart implements cmn.Service. +func (se *SignerServiceEndpoint) OnStart() error { + conn, err := se.connect() + if err != nil { + se.Logger.Error("OnStart", "err", err) + return err + } + + se.conn = conn + go se.handleConnection(conn) + + return nil +} + +// OnStop implements cmn.Service. +func (se *SignerServiceEndpoint) OnStop() { + if se.conn == nil { + return + } + + if err := se.conn.Close(); err != nil { + se.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed")) + } +} + +func (se *SignerServiceEndpoint) connect() (net.Conn, error) { + for retries := 0; retries < se.connRetries; retries++ { + // Don't sleep if it is the first retry. + if retries > 0 { + time.Sleep(se.timeoutReadWrite) + } + + conn, err := se.dialer() + if err == nil { + return conn, nil + } + + se.Logger.Error("dialing", "err", err) + } + + return nil, ErrDialRetryMax +} + +func (se *SignerServiceEndpoint) handleConnection(conn net.Conn) { + for { + if !se.IsRunning() { + return // Ignore error from listener closing. + } + + // Reset the connection deadline + deadline := time.Now().Add(se.timeoutReadWrite) + err := conn.SetDeadline(deadline) + if err != nil { + return + } + + req, err := readMsg(conn) + if err != nil { + if err != io.EOF { + se.Logger.Error("handleConnection readMsg", "err", err) + } + return + } + + res, err := handleRequest(req, se.chainID, se.privVal) + + if err != nil { + // only log the error; we'll reply with an error in res + se.Logger.Error("handleConnection handleRequest", "err", err) + } + + err = writeMsg(conn, res) + if err != nil { + se.Logger.Error("handleConnection writeMsg", "err", err) + return + } + } +} diff --git a/privval/signer_validator_endpoint.go b/privval/signer_validator_endpoint.go new file mode 100644 index 000000000..6dc7f99d5 --- /dev/null +++ b/privval/signer_validator_endpoint.go @@ -0,0 +1,230 @@ +package privval + +import ( + "fmt" + "net" + "sync" + "time" + + "github.com/tendermint/tendermint/crypto" + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/types" +) + +const ( + defaultHeartbeatSeconds = 2 + defaultMaxDialRetries = 10 +) + +var ( + heartbeatPeriod = time.Second * defaultHeartbeatSeconds +) + +// SignerValidatorEndpointOption sets an optional parameter on the SocketVal. +type SignerValidatorEndpointOption func(*SignerValidatorEndpoint) + +// SignerValidatorEndpointSetHeartbeat sets the period on which to check the liveness of the +// connected Signer connections. +func SignerValidatorEndpointSetHeartbeat(period time.Duration) SignerValidatorEndpointOption { + return func(sc *SignerValidatorEndpoint) { sc.heartbeatPeriod = period } +} + +// SocketVal implements PrivValidator. +// It listens for an external process to dial in and uses +// the socket to request signatures. +type SignerValidatorEndpoint struct { + cmn.BaseService + + listener net.Listener + + // ping + cancelPingCh chan struct{} + pingTicker *time.Ticker + heartbeatPeriod time.Duration + + // signer is mutable since it can be reset if the connection fails. + // failures are detected by a background ping routine. + // All messages are request/response, so we hold the mutex + // so only one request/response pair can happen at a time. + // Methods on the underlying net.Conn itself are already goroutine safe. + mtx sync.Mutex + + // TODO: Signer should encapsulate and hide the endpoint completely. Invert the relation + signer *SignerRemote +} + +// Check that SignerValidatorEndpoint implements PrivValidator. +var _ types.PrivValidator = (*SignerValidatorEndpoint)(nil) + +// NewSignerValidatorEndpoint returns an instance of SignerValidatorEndpoint. +func NewSignerValidatorEndpoint(logger log.Logger, listener net.Listener) *SignerValidatorEndpoint { + sc := &SignerValidatorEndpoint{ + listener: listener, + heartbeatPeriod: heartbeatPeriod, + } + + sc.BaseService = *cmn.NewBaseService(logger, "SignerValidatorEndpoint", sc) + + return sc +} + +//-------------------------------------------------------- +// Implement PrivValidator + +// GetPubKey implements PrivValidator. +func (ve *SignerValidatorEndpoint) GetPubKey() crypto.PubKey { + ve.mtx.Lock() + defer ve.mtx.Unlock() + return ve.signer.GetPubKey() +} + +// SignVote implements PrivValidator. +func (ve *SignerValidatorEndpoint) SignVote(chainID string, vote *types.Vote) error { + ve.mtx.Lock() + defer ve.mtx.Unlock() + return ve.signer.SignVote(chainID, vote) +} + +// SignProposal implements PrivValidator. +func (ve *SignerValidatorEndpoint) SignProposal(chainID string, proposal *types.Proposal) error { + ve.mtx.Lock() + defer ve.mtx.Unlock() + return ve.signer.SignProposal(chainID, proposal) +} + +//-------------------------------------------------------- +// More thread safe methods proxied to the signer + +// Ping is used to check connection health. +func (ve *SignerValidatorEndpoint) Ping() error { + ve.mtx.Lock() + defer ve.mtx.Unlock() + return ve.signer.Ping() +} + +// Close closes the underlying net.Conn. +func (ve *SignerValidatorEndpoint) Close() { + ve.mtx.Lock() + defer ve.mtx.Unlock() + if ve.signer != nil { + if err := ve.signer.Close(); err != nil { + ve.Logger.Error("OnStop", "err", err) + } + } + + if ve.listener != nil { + if err := ve.listener.Close(); err != nil { + ve.Logger.Error("OnStop", "err", err) + } + } +} + +//-------------------------------------------------------- +// Service start and stop + +// OnStart implements cmn.Service. +func (ve *SignerValidatorEndpoint) OnStart() error { + if closed, err := ve.reset(); err != nil { + ve.Logger.Error("OnStart", "err", err) + return err + } else if closed { + return fmt.Errorf("listener is closed") + } + + // Start a routine to keep the connection alive + ve.cancelPingCh = make(chan struct{}, 1) + ve.pingTicker = time.NewTicker(ve.heartbeatPeriod) + go func() { + for { + select { + case <-ve.pingTicker.C: + err := ve.Ping() + if err != nil { + ve.Logger.Error("Ping", "err", err) + if err == ErrUnexpectedResponse { + return + } + + closed, err := ve.reset() + if err != nil { + ve.Logger.Error("Reconnecting to remote signer failed", "err", err) + continue + } + if closed { + ve.Logger.Info("listener is closing") + return + } + + ve.Logger.Info("Re-created connection to remote signer", "impl", ve) + } + case <-ve.cancelPingCh: + ve.pingTicker.Stop() + return + } + } + }() + + return nil +} + +// OnStop implements cmn.Service. +func (ve *SignerValidatorEndpoint) OnStop() { + if ve.cancelPingCh != nil { + close(ve.cancelPingCh) + } + ve.Close() +} + +//-------------------------------------------------------- +// Connection and signer management + +// waits to accept and sets a new connection. +// connection is closed in OnStop. +// returns true if the listener is closed +// (ie. it returns a nil conn). +func (ve *SignerValidatorEndpoint) reset() (closed bool, err error) { + ve.mtx.Lock() + defer ve.mtx.Unlock() + + // first check if the conn already exists and close it. + if ve.signer != nil { + if tmpErr := ve.signer.Close(); tmpErr != nil { + ve.Logger.Error("error closing socket val connection during reset", "err", tmpErr) + } + } + + // wait for a new conn + conn, err := ve.acceptConnection() + if err != nil { + return false, err + } + + // listener is closed + if conn == nil { + return true, nil + } + + ve.signer, err = NewSignerRemote(conn) + if err != nil { + // failed to fetch the pubkey. close out the connection. + if tmpErr := conn.Close(); tmpErr != nil { + ve.Logger.Error("error closing connection", "err", tmpErr) + } + return false, err + } + return false, nil +} + +// Attempt to accept a connection. +// Times out after the listener's timeoutAccept +func (ve *SignerValidatorEndpoint) acceptConnection() (net.Conn, error) { + conn, err := ve.listener.Accept() + if err != nil { + if !ve.IsRunning() { + return nil, nil // Ignore error from listener closing. + } + return nil, err + } + return conn, nil +} diff --git a/privval/signer_validator_endpoint_test.go b/privval/signer_validator_endpoint_test.go new file mode 100644 index 000000000..bf4c29930 --- /dev/null +++ b/privval/signer_validator_endpoint_test.go @@ -0,0 +1,505 @@ +package privval + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/tendermint/tendermint/crypto/ed25519" + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + + "github.com/tendermint/tendermint/types" +) + +var ( + testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second + + testTimeoutReadWrite = 100 * time.Millisecond + testTimeoutReadWrite2o3 = 66 * time.Millisecond // 2/3 of the other one + + testTimeoutHeartbeat = 10 * time.Millisecond + testTimeoutHeartbeat3o2 = 6 * time.Millisecond // 3/2 of the other one +) + +type socketTestCase struct { + addr string + dialer SocketDialer +} + +func socketTestCases(t *testing.T) []socketTestCase { + tcpAddr := fmt.Sprintf("tcp://%s", testFreeTCPAddr(t)) + unixFilePath, err := testUnixAddr() + require.NoError(t, err) + unixAddr := fmt.Sprintf("unix://%s", unixFilePath) + return []socketTestCase{ + { + addr: tcpAddr, + dialer: DialTCPFn(tcpAddr, testTimeoutReadWrite, ed25519.GenPrivKey()), + }, + { + addr: unixAddr, + dialer: DialUnixFn(unixFilePath), + }, + } +} + +func TestSocketPVAddress(t *testing.T) { + for _, tc := range socketTestCases(t) { + // Execute the test within a closure to ensure the deferred statements + // are called between each for loop iteration, for isolated test cases. + func() { + var ( + chainID = cmn.RandStr(12) + validatorEndpoint, serviceEndpoint = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) + ) + defer validatorEndpoint.Stop() + defer serviceEndpoint.Stop() + + serviceAddr := serviceEndpoint.privVal.GetPubKey().Address() + validatorAddr := validatorEndpoint.GetPubKey().Address() + + assert.Equal(t, serviceAddr, validatorAddr) + }() + } +} + +func TestSocketPVPubKey(t *testing.T) { + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + validatorEndpoint, serviceEndpoint = testSetupSocketPair( + t, + chainID, + types.NewMockPV(), + tc.addr, + tc.dialer) + ) + defer validatorEndpoint.Stop() + defer serviceEndpoint.Stop() + + clientKey := validatorEndpoint.GetPubKey() + privvalPubKey := serviceEndpoint.privVal.GetPubKey() + + assert.Equal(t, privvalPubKey, clientKey) + }() + } +} + +func TestSocketPVProposal(t *testing.T) { + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + validatorEndpoint, serviceEndpoint = testSetupSocketPair( + t, + chainID, + types.NewMockPV(), + tc.addr, + tc.dialer) + + ts = time.Now() + privProposal = &types.Proposal{Timestamp: ts} + clientProposal = &types.Proposal{Timestamp: ts} + ) + defer validatorEndpoint.Stop() + defer serviceEndpoint.Stop() + + require.NoError(t, serviceEndpoint.privVal.SignProposal(chainID, privProposal)) + require.NoError(t, validatorEndpoint.SignProposal(chainID, clientProposal)) + + assert.Equal(t, privProposal.Signature, clientProposal.Signature) + }() + } +} + +func TestSocketPVVote(t *testing.T) { + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + validatorEndpoint, serviceEndpoint = testSetupSocketPair( + t, + chainID, + types.NewMockPV(), + tc.addr, + tc.dialer) + + ts = time.Now() + vType = types.PrecommitType + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer validatorEndpoint.Stop() + defer serviceEndpoint.Stop() + + require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) + require.NoError(t, validatorEndpoint.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) + }() + } +} + +func TestSocketPVVoteResetDeadline(t *testing.T) { + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + validatorEndpoint, serviceEndpoint = testSetupSocketPair( + t, + chainID, + types.NewMockPV(), + tc.addr, + tc.dialer) + + ts = time.Now() + vType = types.PrecommitType + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer validatorEndpoint.Stop() + defer serviceEndpoint.Stop() + + time.Sleep(testTimeoutReadWrite2o3) + + require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) + require.NoError(t, validatorEndpoint.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) + + // This would exceed the deadline if it was not extended by the previous message + time.Sleep(testTimeoutReadWrite2o3) + + require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) + require.NoError(t, validatorEndpoint.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) + }() + } +} + +func TestSocketPVVoteKeepalive(t *testing.T) { + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + validatorEndpoint, serviceEndpoint = testSetupSocketPair( + t, + chainID, + types.NewMockPV(), + tc.addr, + tc.dialer) + + ts = time.Now() + vType = types.PrecommitType + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer validatorEndpoint.Stop() + defer serviceEndpoint.Stop() + + time.Sleep(testTimeoutReadWrite * 2) + + require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) + require.NoError(t, validatorEndpoint.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) + }() + } +} + +func TestSocketPVDeadline(t *testing.T) { + for _, tc := range socketTestCases(t) { + func() { + var ( + listenc = make(chan struct{}) + thisConnTimeout = 100 * time.Millisecond + validatorEndpoint = newSignerValidatorEndpoint(log.TestingLogger(), tc.addr, thisConnTimeout) + ) + + go func(sc *SignerValidatorEndpoint) { + defer close(listenc) + + // Note: the TCP connection times out at the accept() phase, + // whereas the Unix domain sockets connection times out while + // attempting to fetch the remote signer's public key. + assert.True(t, IsConnTimeout(sc.Start())) + + assert.False(t, sc.IsRunning()) + }(validatorEndpoint) + + for { + _, err := cmn.Connect(tc.addr) + if err == nil { + break + } + } + + <-listenc + }() + } +} + +func TestRemoteSignVoteErrors(t *testing.T) { + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + validatorEndpoint, serviceEndpoint = testSetupSocketPair( + t, + chainID, + types.NewErroringMockPV(), + tc.addr, + tc.dialer) + + ts = time.Now() + vType = types.PrecommitType + vote = &types.Vote{Timestamp: ts, Type: vType} + ) + defer validatorEndpoint.Stop() + defer serviceEndpoint.Stop() + + err := validatorEndpoint.SignVote("", vote) + require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) + + err = serviceEndpoint.privVal.SignVote(chainID, vote) + require.Error(t, err) + err = validatorEndpoint.SignVote(chainID, vote) + require.Error(t, err) + }() + } +} + +func TestRemoteSignProposalErrors(t *testing.T) { + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + validatorEndpoint, serviceEndpoint = testSetupSocketPair( + t, + chainID, + types.NewErroringMockPV(), + tc.addr, + tc.dialer) + + ts = time.Now() + proposal = &types.Proposal{Timestamp: ts} + ) + defer validatorEndpoint.Stop() + defer serviceEndpoint.Stop() + + err := validatorEndpoint.SignProposal("", proposal) + require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) + + err = serviceEndpoint.privVal.SignProposal(chainID, proposal) + require.Error(t, err) + + err = validatorEndpoint.SignProposal(chainID, proposal) + require.Error(t, err) + }() + } +} + +func TestErrUnexpectedResponse(t *testing.T) { + for _, tc := range socketTestCases(t) { + func() { + var ( + logger = log.TestingLogger() + chainID = cmn.RandStr(12) + readyCh = make(chan struct{}) + errCh = make(chan error, 1) + + serviceEndpoint = NewSignerServiceEndpoint( + logger, + chainID, + types.NewMockPV(), + tc.dialer, + ) + + validatorEndpoint = newSignerValidatorEndpoint( + logger, + tc.addr, + testTimeoutReadWrite) + ) + + testStartEndpoint(t, readyCh, validatorEndpoint) + defer validatorEndpoint.Stop() + SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint) + SignerServiceEndpointConnRetries(100)(serviceEndpoint) + // we do not want to Start() the remote signer here and instead use the connection to + // reply with intentionally wrong replies below: + rsConn, err := serviceEndpoint.connect() + defer rsConn.Close() + require.NoError(t, err) + require.NotNil(t, rsConn) + // send over public key to get the remote signer running: + go testReadWriteResponse(t, &PubKeyResponse{}, rsConn) + <-readyCh + + // Proposal: + go func(errc chan error) { + errc <- validatorEndpoint.SignProposal(chainID, &types.Proposal{}) + }(errCh) + + // read request and write wrong response: + go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn) + err = <-errCh + require.Error(t, err) + require.Equal(t, err, ErrUnexpectedResponse) + + // Vote: + go func(errc chan error) { + errc <- validatorEndpoint.SignVote(chainID, &types.Vote{}) + }(errCh) + // read request and write wrong response: + go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn) + err = <-errCh + require.Error(t, err) + require.Equal(t, err, ErrUnexpectedResponse) + }() + } +} + +func TestRetryConnToRemoteSigner(t *testing.T) { + for _, tc := range socketTestCases(t) { + func() { + var ( + logger = log.TestingLogger() + chainID = cmn.RandStr(12) + readyCh = make(chan struct{}) + + serviceEndpoint = NewSignerServiceEndpoint( + logger, + chainID, + types.NewMockPV(), + tc.dialer, + ) + thisConnTimeout = testTimeoutReadWrite + validatorEndpoint = newSignerValidatorEndpoint(logger, tc.addr, thisConnTimeout) + ) + // Ping every: + SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) + + SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) + SignerServiceEndpointConnRetries(10)(serviceEndpoint) + + testStartEndpoint(t, readyCh, validatorEndpoint) + defer validatorEndpoint.Stop() + require.NoError(t, serviceEndpoint.Start()) + assert.True(t, serviceEndpoint.IsRunning()) + + <-readyCh + time.Sleep(testTimeoutHeartbeat * 2) + + serviceEndpoint.Stop() + rs2 := NewSignerServiceEndpoint( + logger, + chainID, + types.NewMockPV(), + tc.dialer, + ) + // let some pings pass + time.Sleep(testTimeoutHeartbeat3o2) + require.NoError(t, rs2.Start()) + assert.True(t, rs2.IsRunning()) + defer rs2.Stop() + + // give the client some time to re-establish the conn to the remote signer + // should see sth like this in the logs: + // + // E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" + // I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal + time.Sleep(testTimeoutReadWrite * 2) + }() + } +} + +func newSignerValidatorEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerValidatorEndpoint { + proto, address := cmn.ProtocolAndAddress(addr) + + ln, err := net.Listen(proto, address) + logger.Info("Listening at", "proto", proto, "address", address) + if err != nil { + panic(err) + } + + var listener net.Listener + + if proto == "unix" { + unixLn := NewUnixListener(ln) + UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn) + UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn) + listener = unixLn + } else { + tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) + TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn) + TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn) + listener = tcpLn + } + + return NewSignerValidatorEndpoint(logger, listener) +} + +func testSetupSocketPair( + t *testing.T, + chainID string, + privValidator types.PrivValidator, + addr string, + socketDialer SocketDialer, +) (*SignerValidatorEndpoint, *SignerServiceEndpoint) { + var ( + logger = log.TestingLogger() + privVal = privValidator + readyc = make(chan struct{}) + serviceEndpoint = NewSignerServiceEndpoint( + logger, + chainID, + privVal, + socketDialer, + ) + + thisConnTimeout = testTimeoutReadWrite + validatorEndpoint = newSignerValidatorEndpoint(logger, addr, thisConnTimeout) + ) + + SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) + SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) + SignerServiceEndpointConnRetries(1e6)(serviceEndpoint) + + testStartEndpoint(t, readyc, validatorEndpoint) + + require.NoError(t, serviceEndpoint.Start()) + assert.True(t, serviceEndpoint.IsRunning()) + + <-readyc + + return validatorEndpoint, serviceEndpoint +} + +func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) { + _, err := readMsg(rsConn) + require.NoError(t, err) + + err = writeMsg(rsConn, resp) + require.NoError(t, err) +} + +func testStartEndpoint(t *testing.T, readyCh chan struct{}, sc *SignerValidatorEndpoint) { + go func(sc *SignerValidatorEndpoint) { + require.NoError(t, sc.Start()) + assert.True(t, sc.IsRunning()) + + readyCh <- struct{}{} + }(sc) +} + +// testFreeTCPAddr claims a free port so we don't block on listener being ready. +func testFreeTCPAddr(t *testing.T) string { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port) +} diff --git a/privval/socket_dialers.go b/privval/socket_dialers.go new file mode 100644 index 000000000..c92a1c8cc --- /dev/null +++ b/privval/socket_dialers.go @@ -0,0 +1,43 @@ +package privval + +import ( + "net" + "time" + + "github.com/pkg/errors" + "github.com/tendermint/tendermint/crypto/ed25519" + cmn "github.com/tendermint/tendermint/libs/common" + p2pconn "github.com/tendermint/tendermint/p2p/conn" +) + +// Socket errors. +var ( + ErrDialRetryMax = errors.New("dialed maximum retries") +) + +// SocketDialer dials a remote address and returns a net.Conn or an error. +type SocketDialer func() (net.Conn, error) + +// DialTCPFn dials the given tcp addr, using the given timeoutReadWrite and +// privKey for the authenticated encryption handshake. +func DialTCPFn(addr string, timeoutReadWrite time.Duration, privKey ed25519.PrivKeyEd25519) SocketDialer { + return func() (net.Conn, error) { + conn, err := cmn.Connect(addr) + if err == nil { + deadline := time.Now().Add(timeoutReadWrite) + err = conn.SetDeadline(deadline) + } + if err == nil { + conn, err = p2pconn.MakeSecretConnection(conn, privKey) + } + return conn, err + } +} + +// DialUnixFn dials the given unix socket. +func DialUnixFn(addr string) SocketDialer { + return func() (net.Conn, error) { + unixAddr := &net.UnixAddr{Name: addr, Net: "unix"} + return net.DialUnix("unix", nil, unixAddr) + } +} diff --git a/privval/socket_dialers_test.go b/privval/socket_dialers_test.go new file mode 100644 index 000000000..9d5d5cc2b --- /dev/null +++ b/privval/socket_dialers_test.go @@ -0,0 +1,26 @@ +package privval + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/tendermint/tendermint/crypto/ed25519" + cmn "github.com/tendermint/tendermint/libs/common" +) + +func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) { + // Generate a networking timeout + dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) + _, err := dialer() + assert.Error(t, err) + assert.True(t, IsConnTimeout(err)) +} + +func TestIsConnTimeoutForWrappedConnTimeouts(t *testing.T) { + dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) + _, err := dialer() + assert.Error(t, err) + err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) + assert.True(t, IsConnTimeout(err)) +} diff --git a/privval/socket.go b/privval/socket_listeners.go similarity index 60% rename from privval/socket.go rename to privval/socket_listeners.go index bd9cd9209..7c8835791 100644 --- a/privval/socket.go +++ b/privval/socket_listeners.go @@ -9,8 +9,8 @@ import ( ) const ( - defaultAcceptDeadlineSeconds = 3 - defaultConnDeadlineSeconds = 3 + defaultTimeoutAcceptSeconds = 3 + defaultTimeoutReadWriteSeconds = 3 ) // timeoutError can be used to check if an error returned from the netp package @@ -25,16 +25,16 @@ type timeoutError interface { // TCPListenerOption sets an optional parameter on the tcpListener. type TCPListenerOption func(*tcpListener) -// TCPListenerAcceptDeadline sets the deadline for the listener. -// A zero time value disables the deadline. -func TCPListenerAcceptDeadline(deadline time.Duration) TCPListenerOption { - return func(tl *tcpListener) { tl.acceptDeadline = deadline } +// TCPListenerTimeoutAccept sets the timeout for the listener. +// A zero time value disables the timeout. +func TCPListenerTimeoutAccept(timeout time.Duration) TCPListenerOption { + return func(tl *tcpListener) { tl.timeoutAccept = timeout } } -// TCPListenerConnDeadline sets the read and write deadline for connections +// TCPListenerTimeoutReadWrite sets the read and write timeout for connections // from external signing processes. -func TCPListenerConnDeadline(deadline time.Duration) TCPListenerOption { - return func(tl *tcpListener) { tl.connDeadline = deadline } +func TCPListenerTimeoutReadWrite(timeout time.Duration) TCPListenerOption { + return func(tl *tcpListener) { tl.timeoutReadWrite = timeout } } // tcpListener implements net.Listener. @@ -47,24 +47,25 @@ type tcpListener struct { secretConnKey ed25519.PrivKeyEd25519 - acceptDeadline time.Duration - connDeadline time.Duration + timeoutAccept time.Duration + timeoutReadWrite time.Duration } // NewTCPListener returns a listener that accepts authenticated encrypted connections // using the given secretConnKey and the default timeout values. func NewTCPListener(ln net.Listener, secretConnKey ed25519.PrivKeyEd25519) *tcpListener { return &tcpListener{ - TCPListener: ln.(*net.TCPListener), - secretConnKey: secretConnKey, - acceptDeadline: time.Second * defaultAcceptDeadlineSeconds, - connDeadline: time.Second * defaultConnDeadlineSeconds, + TCPListener: ln.(*net.TCPListener), + secretConnKey: secretConnKey, + timeoutAccept: time.Second * defaultTimeoutAcceptSeconds, + timeoutReadWrite: time.Second * defaultTimeoutReadWriteSeconds, } } // Accept implements net.Listener. func (ln *tcpListener) Accept() (net.Conn, error) { - err := ln.SetDeadline(time.Now().Add(ln.acceptDeadline)) + deadline := time.Now().Add(ln.timeoutAccept) + err := ln.SetDeadline(deadline) if err != nil { return nil, err } @@ -75,7 +76,7 @@ func (ln *tcpListener) Accept() (net.Conn, error) { } // Wrap the conn in our timeout and encryption wrappers - timeoutConn := newTimeoutConn(tc, ln.connDeadline) + timeoutConn := newTimeoutConn(tc, ln.timeoutReadWrite) secretConn, err := p2pconn.MakeSecretConnection(timeoutConn, ln.secretConnKey) if err != nil { return nil, err @@ -92,16 +93,16 @@ var _ net.Listener = (*unixListener)(nil) type UnixListenerOption func(*unixListener) -// UnixListenerAcceptDeadline sets the deadline for the listener. -// A zero time value disables the deadline. -func UnixListenerAcceptDeadline(deadline time.Duration) UnixListenerOption { - return func(ul *unixListener) { ul.acceptDeadline = deadline } +// UnixListenerTimeoutAccept sets the timeout for the listener. +// A zero time value disables the timeout. +func UnixListenerTimeoutAccept(timeout time.Duration) UnixListenerOption { + return func(ul *unixListener) { ul.timeoutAccept = timeout } } -// UnixListenerConnDeadline sets the read and write deadline for connections +// UnixListenerTimeoutReadWrite sets the read and write timeout for connections // from external signing processes. -func UnixListenerConnDeadline(deadline time.Duration) UnixListenerOption { - return func(ul *unixListener) { ul.connDeadline = deadline } +func UnixListenerTimeoutReadWrite(timeout time.Duration) UnixListenerOption { + return func(ul *unixListener) { ul.timeoutReadWrite = timeout } } // unixListener wraps a *net.UnixListener to standardise protocol timeouts @@ -109,23 +110,24 @@ func UnixListenerConnDeadline(deadline time.Duration) UnixListenerOption { type unixListener struct { *net.UnixListener - acceptDeadline time.Duration - connDeadline time.Duration + timeoutAccept time.Duration + timeoutReadWrite time.Duration } // NewUnixListener returns a listener that accepts unencrypted connections // using the default timeout values. func NewUnixListener(ln net.Listener) *unixListener { return &unixListener{ - UnixListener: ln.(*net.UnixListener), - acceptDeadline: time.Second * defaultAcceptDeadlineSeconds, - connDeadline: time.Second * defaultConnDeadlineSeconds, + UnixListener: ln.(*net.UnixListener), + timeoutAccept: time.Second * defaultTimeoutAcceptSeconds, + timeoutReadWrite: time.Second * defaultTimeoutReadWriteSeconds, } } // Accept implements net.Listener. func (ln *unixListener) Accept() (net.Conn, error) { - err := ln.SetDeadline(time.Now().Add(ln.acceptDeadline)) + deadline := time.Now().Add(ln.timeoutAccept) + err := ln.SetDeadline(deadline) if err != nil { return nil, err } @@ -136,7 +138,7 @@ func (ln *unixListener) Accept() (net.Conn, error) { } // Wrap the conn in our timeout wrapper - conn := newTimeoutConn(tc, ln.connDeadline) + conn := newTimeoutConn(tc, ln.timeoutReadWrite) // TODO: wrap in something that authenticates // with a MAC - https://github.com/tendermint/tendermint/issues/3099 @@ -153,24 +155,25 @@ var _ net.Conn = (*timeoutConn)(nil) // timeoutConn wraps a net.Conn to standardise protocol timeouts / deadline resets. type timeoutConn struct { net.Conn - - connDeadline time.Duration + timeout time.Duration } // newTimeoutConn returns an instance of timeoutConn. -func newTimeoutConn( - conn net.Conn, - connDeadline time.Duration) *timeoutConn { +func newTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn { return &timeoutConn{ conn, - connDeadline, + timeout, } } // Read implements net.Conn. func (c timeoutConn) Read(b []byte) (n int, err error) { // Reset deadline - c.Conn.SetReadDeadline(time.Now().Add(c.connDeadline)) + deadline := time.Now().Add(c.timeout) + err = c.Conn.SetReadDeadline(deadline) + if err != nil { + return + } return c.Conn.Read(b) } @@ -178,7 +181,11 @@ func (c timeoutConn) Read(b []byte) (n int, err error) { // Write implements net.Conn. func (c timeoutConn) Write(b []byte) (n int, err error) { // Reset deadline - c.Conn.SetWriteDeadline(time.Now().Add(c.connDeadline)) + deadline := time.Now().Add(c.timeout) + err = c.Conn.SetWriteDeadline(deadline) + if err != nil { + return + } return c.Conn.Write(b) } diff --git a/privval/socket_test.go b/privval/socket_listeners_test.go similarity index 71% rename from privval/socket_test.go rename to privval/socket_listeners_test.go index 7f7bbd892..498ef79c0 100644 --- a/privval/socket_test.go +++ b/privval/socket_listeners_test.go @@ -23,7 +23,7 @@ func newPrivKey() ed25519.PrivKeyEd25519 { type listenerTestCase struct { description string // For test reporting purposes. listener net.Listener - dialer Dialer + dialer SocketDialer } // testUnixAddr will attempt to obtain a platform-independent temporary file @@ -39,23 +39,23 @@ func testUnixAddr() (string, error) { return addr, nil } -func tcpListenerTestCase(t *testing.T, acceptDeadline, connectDeadline time.Duration) listenerTestCase { +func tcpListenerTestCase(t *testing.T, timeoutAccept, timeoutReadWrite time.Duration) listenerTestCase { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } tcpLn := NewTCPListener(ln, newPrivKey()) - TCPListenerAcceptDeadline(acceptDeadline)(tcpLn) - TCPListenerConnDeadline(connectDeadline)(tcpLn) + TCPListenerTimeoutAccept(timeoutAccept)(tcpLn) + TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn) return listenerTestCase{ description: "TCP", listener: tcpLn, - dialer: DialTCPFn(ln.Addr().String(), testConnDeadline, newPrivKey()), + dialer: DialTCPFn(ln.Addr().String(), testTimeoutReadWrite, newPrivKey()), } } -func unixListenerTestCase(t *testing.T, acceptDeadline, connectDeadline time.Duration) listenerTestCase { +func unixListenerTestCase(t *testing.T, timeoutAccept, timeoutReadWrite time.Duration) listenerTestCase { addr, err := testUnixAddr() if err != nil { t.Fatal(err) @@ -66,8 +66,8 @@ func unixListenerTestCase(t *testing.T, acceptDeadline, connectDeadline time.Dur } unixLn := NewUnixListener(ln) - UnixListenerAcceptDeadline(acceptDeadline)(unixLn) - UnixListenerConnDeadline(connectDeadline)(unixLn) + UnixListenerTimeoutAccept(timeoutAccept)(unixLn) + UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn) return listenerTestCase{ description: "Unix", listener: unixLn, @@ -75,14 +75,14 @@ func unixListenerTestCase(t *testing.T, acceptDeadline, connectDeadline time.Dur } } -func listenerTestCases(t *testing.T, acceptDeadline, connectDeadline time.Duration) []listenerTestCase { +func listenerTestCases(t *testing.T, timeoutAccept, timeoutReadWrite time.Duration) []listenerTestCase { return []listenerTestCase{ - tcpListenerTestCase(t, acceptDeadline, connectDeadline), - unixListenerTestCase(t, acceptDeadline, connectDeadline), + tcpListenerTestCase(t, timeoutAccept, timeoutReadWrite), + unixListenerTestCase(t, timeoutAccept, timeoutReadWrite), } } -func TestListenerAcceptDeadlines(t *testing.T) { +func TestListenerTimeoutAccept(t *testing.T) { for _, tc := range listenerTestCases(t, time.Millisecond, time.Second) { _, err := tc.listener.Accept() opErr, ok := err.(*net.OpError) @@ -96,9 +96,9 @@ func TestListenerAcceptDeadlines(t *testing.T) { } } -func TestListenerConnectDeadlines(t *testing.T) { +func TestListenerTimeoutReadWrite(t *testing.T) { for _, tc := range listenerTestCases(t, time.Second, time.Millisecond) { - go func(dialer Dialer) { + go func(dialer SocketDialer) { _, err := dialer() if err != nil { panic(err) diff --git a/privval/utils.go b/privval/utils.go new file mode 100644 index 000000000..d8837bdf0 --- /dev/null +++ b/privval/utils.go @@ -0,0 +1,20 @@ +package privval + +import ( + cmn "github.com/tendermint/tendermint/libs/common" +) + +// IsConnTimeout returns a boolean indicating whether the error is known to +// report that a connection timeout occurred. This detects both fundamental +// network timeouts, as well as ErrConnTimeout errors. +func IsConnTimeout(err error) bool { + if cmnErr, ok := err.(cmn.Error); ok { + if cmnErr.Data() == ErrConnTimeout { + return true + } + } + if _, ok := err.(timeoutError); ok { + return true + } + return false +} diff --git a/privval/utils_test.go b/privval/utils_test.go new file mode 100644 index 000000000..23f6f6a3b --- /dev/null +++ b/privval/utils_test.go @@ -0,0 +1,14 @@ +package privval + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + cmn "github.com/tendermint/tendermint/libs/common" +) + +func TestIsConnTimeoutForNonTimeoutErrors(t *testing.T) { + assert.False(t, IsConnTimeout(cmn.ErrorWrap(ErrDialRetryMax, "max retries exceeded"))) + assert.False(t, IsConnTimeout(fmt.Errorf("completely irrelevant error"))) +} diff --git a/tools/tm-signer-harness/internal/test_harness.go b/tools/tm-signer-harness/internal/test_harness.go index b961f2384..005489133 100644 --- a/tools/tm-signer-harness/internal/test_harness.go +++ b/tools/tm-signer-harness/internal/test_harness.go @@ -49,7 +49,7 @@ var _ error = (*TestHarnessError)(nil) // with this version of Tendermint. type TestHarness struct { addr string - spv *privval.SocketVal + spv *privval.SignerValidatorEndpoint fpv *privval.FilePV chainID string acceptRetries int @@ -314,7 +314,7 @@ func (th *TestHarness) Shutdown(err error) { // newTestHarnessSocketVal creates our client instance which we will use for // testing. -func newTestHarnessSocketVal(logger log.Logger, cfg TestHarnessConfig) (*privval.SocketVal, error) { +func newTestHarnessSocketVal(logger log.Logger, cfg TestHarnessConfig) (*privval.SignerValidatorEndpoint, error) { proto, addr := cmn.ProtocolAndAddress(cfg.BindAddr) if proto == "unix" { // make sure the socket doesn't exist - if so, try to delete it @@ -334,20 +334,20 @@ func newTestHarnessSocketVal(logger log.Logger, cfg TestHarnessConfig) (*privval switch proto { case "unix": unixLn := privval.NewUnixListener(ln) - privval.UnixListenerAcceptDeadline(cfg.AcceptDeadline)(unixLn) - privval.UnixListenerConnDeadline(cfg.ConnDeadline)(unixLn) + privval.UnixListenerTimeoutAccept(cfg.AcceptDeadline)(unixLn) + privval.UnixListenerTimeoutReadWrite(cfg.ConnDeadline)(unixLn) svln = unixLn case "tcp": tcpLn := privval.NewTCPListener(ln, cfg.SecretConnKey) - privval.TCPListenerAcceptDeadline(cfg.AcceptDeadline)(tcpLn) - privval.TCPListenerConnDeadline(cfg.ConnDeadline)(tcpLn) + privval.TCPListenerTimeoutAccept(cfg.AcceptDeadline)(tcpLn) + privval.TCPListenerTimeoutReadWrite(cfg.ConnDeadline)(tcpLn) logger.Info("Resolved TCP address for listener", "addr", tcpLn.Addr()) svln = tcpLn default: logger.Error("Unsupported protocol (must be unix:// or tcp://)", "proto", proto) return nil, newTestHarnessError(ErrInvalidParameters, nil, fmt.Sprintf("Unsupported protocol: %s", proto)) } - return privval.NewSocketVal(logger, svln), nil + return privval.NewSignerValidatorEndpoint(logger, svln), nil } func newTestHarnessError(code int, err error, info string) *TestHarnessError { diff --git a/tools/tm-signer-harness/internal/test_harness_test.go b/tools/tm-signer-harness/internal/test_harness_test.go index 804aca45e..adb818b0b 100644 --- a/tools/tm-signer-harness/internal/test_harness_test.go +++ b/tools/tm-signer-harness/internal/test_harness_test.go @@ -84,7 +84,7 @@ func TestRemoteSignerTestHarnessMaxAcceptRetriesReached(t *testing.T) { func TestRemoteSignerTestHarnessSuccessfulRun(t *testing.T) { harnessTest( t, - func(th *TestHarness) *privval.RemoteSigner { + func(th *TestHarness) *privval.SignerServiceEndpoint { return newMockRemoteSigner(t, th, th.fpv.Key.PrivKey, false, false) }, NoError, @@ -94,7 +94,7 @@ func TestRemoteSignerTestHarnessSuccessfulRun(t *testing.T) { func TestRemoteSignerPublicKeyCheckFailed(t *testing.T) { harnessTest( t, - func(th *TestHarness) *privval.RemoteSigner { + func(th *TestHarness) *privval.SignerServiceEndpoint { return newMockRemoteSigner(t, th, ed25519.GenPrivKey(), false, false) }, ErrTestPublicKeyFailed, @@ -104,7 +104,7 @@ func TestRemoteSignerPublicKeyCheckFailed(t *testing.T) { func TestRemoteSignerProposalSigningFailed(t *testing.T) { harnessTest( t, - func(th *TestHarness) *privval.RemoteSigner { + func(th *TestHarness) *privval.SignerServiceEndpoint { return newMockRemoteSigner(t, th, th.fpv.Key.PrivKey, true, false) }, ErrTestSignProposalFailed, @@ -114,15 +114,15 @@ func TestRemoteSignerProposalSigningFailed(t *testing.T) { func TestRemoteSignerVoteSigningFailed(t *testing.T) { harnessTest( t, - func(th *TestHarness) *privval.RemoteSigner { + func(th *TestHarness) *privval.SignerServiceEndpoint { return newMockRemoteSigner(t, th, th.fpv.Key.PrivKey, false, true) }, ErrTestSignVoteFailed, ) } -func newMockRemoteSigner(t *testing.T, th *TestHarness, privKey crypto.PrivKey, breakProposalSigning bool, breakVoteSigning bool) *privval.RemoteSigner { - return privval.NewRemoteSigner( +func newMockRemoteSigner(t *testing.T, th *TestHarness, privKey crypto.PrivKey, breakProposalSigning bool, breakVoteSigning bool) *privval.SignerServiceEndpoint { + return privval.NewSignerServiceEndpoint( th.logger, th.chainID, types.NewMockPVWithParams(privKey, breakProposalSigning, breakVoteSigning), @@ -135,7 +135,7 @@ func newMockRemoteSigner(t *testing.T, th *TestHarness, privKey crypto.PrivKey, } // For running relatively standard tests. -func harnessTest(t *testing.T, rsMaker func(th *TestHarness) *privval.RemoteSigner, expectedExitCode int) { +func harnessTest(t *testing.T, rsMaker func(th *TestHarness) *privval.SignerServiceEndpoint, expectedExitCode int) { cfg := makeConfig(t, 100, 3) defer cleanup(cfg)