This PR is related to #3107 and a continuation of #3351
It is important to emphasise that in the privval original design, client/server and listening/dialing roles are inverted and do not follow a conventional interaction.
Given two hosts A and B:
Host A is listener/client
Host B is dialer/server (contains the secret key)
When A requires a signature, it needs to wait for B to dial in before it can issue a request.
A only accepts a single connection and any failure leads to dropping the connection and waiting for B to reconnect.
The original rationale behind this design was based on security.
Host B only allows outbound connections to a list of whitelisted hosts.
It is not possible to reach B unless B dials in. There are no listening/open ports in B.
This PR results in the following changes:
Refactors ping/heartbeat to avoid previously existing race conditions.
Separates transport (dialer/listener) from signing (client/server) concerns to simplify workflow.
Unifies and abstracts away the differences between unix and tcp sockets.
A single signer endpoint implementation unifies connection handling code (read/write/close/connection obj)
The signer request handler (server side) is customizable to increase testability.
Updates and extends unit tests
A high level overview of the classes is as follows:
Transport (endpoints): The following classes take care of establishing a connection
SignerDialerEndpoint
SignerListeningEndpoint
SignerEndpoint groups common functionality (read/write/timeouts/etc.)
Signing (client/server): The following classes take care of exchanging request/responses
SignerClient
SignerServer
This PR also closes #3601
Commits:
* refactoring - work in progress
* reworking unit tests
* Encapsulating and fixing unit tests
* Improve tests
* Clean up
* Fix/improve unit tests
* clean up tests
* Improving service endpoint
* fixing unit test
* fix linter issues
* avoid invalid cache values (improve later?)
* complete implementation
* wip
* improved connection loop
* Improve reconnections + fixing unit tests
* addressing comments
* small formatting changes
* clean up
* Update node/node.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_client.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_client_test.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* check during initialization
* dropping connecting when writing fails
* removing break
* use t.log instead
* unifying and using cmn.GetFreePort()
* review fixes
* reordering and unifying drop connection
* closing instead of signalling
* refactored service loop
* removed superfluous brackets
* GetPubKey can return errors
* Revert "GetPubKey can return errors"
This reverts commit 68c06f19b4
.
* adding entry to changelog
* Update CHANGELOG_PENDING.md
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_client.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_dialer_endpoint.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_dialer_endpoint.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_dialer_endpoint.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_dialer_endpoint.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_listener_endpoint_test.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* updating node.go
* review fixes
* fixes linter
* fixing unit test
* small fixes in comments
* addressing review comments
* addressing review comments 2
* reverting suggestion
* Update privval/signer_client_test.go
Co-Authored-By: Anton Kaliaev <anton.kalyaev@gmail.com>
* Update privval/signer_client_test.go
Co-Authored-By: Anton Kaliaev <anton.kalyaev@gmail.com>
* Update privval/signer_listener_endpoint_test.go
Co-Authored-By: Anton Kaliaev <anton.kalyaev@gmail.com>
* do not expose brokenSignerDialerEndpoint
* clean up logging
* unifying methods
shorten test time
signer also drops
* reenabling pings
* improving testability + unit test
* fixing go fmt + unit test
* remove unused code
* Addressing review comments
* simplifying connection workflow
* fix linter/go import issue
* using base service quit
* updating comment
* Simplifying design + adjusting names
* fixing linter issues
* refactoring test harness + fixes
* Addressing review comments
* cleaning up
* adding additional error check
pull/3912/head
@ -1,61 +1,64 @@ | |||
package privval | |||
import ( | |||
amino "github.com/tendermint/go-amino" | |||
"github.com/tendermint/go-amino" | |||
"github.com/tendermint/tendermint/crypto" | |||
"github.com/tendermint/tendermint/types" | |||
) | |||
// RemoteSignerMsg is sent between SignerServiceEndpoint and the SignerServiceEndpoint client. | |||
type RemoteSignerMsg interface{} | |||
// SignerMessage is sent between Signer Clients and Servers. | |||
type SignerMessage interface{} | |||
func RegisterRemoteSignerMsg(cdc *amino.Codec) { | |||
cdc.RegisterInterface((*RemoteSignerMsg)(nil), nil) | |||
cdc.RegisterInterface((*SignerMessage)(nil), nil) | |||
cdc.RegisterConcrete(&PubKeyRequest{}, "tendermint/remotesigner/PubKeyRequest", nil) | |||
cdc.RegisterConcrete(&PubKeyResponse{}, "tendermint/remotesigner/PubKeyResponse", nil) | |||
cdc.RegisterConcrete(&SignVoteRequest{}, "tendermint/remotesigner/SignVoteRequest", nil) | |||
cdc.RegisterConcrete(&SignedVoteResponse{}, "tendermint/remotesigner/SignedVoteResponse", nil) | |||
cdc.RegisterConcrete(&SignProposalRequest{}, "tendermint/remotesigner/SignProposalRequest", nil) | |||
cdc.RegisterConcrete(&SignedProposalResponse{}, "tendermint/remotesigner/SignedProposalResponse", nil) | |||
cdc.RegisterConcrete(&PingRequest{}, "tendermint/remotesigner/PingRequest", nil) | |||
cdc.RegisterConcrete(&PingResponse{}, "tendermint/remotesigner/PingResponse", nil) | |||
} | |||
// TODO: Add ChainIDRequest | |||
// PubKeyRequest requests the consensus public key from the remote signer. | |||
type PubKeyRequest struct{} | |||
// PubKeyResponse is a PrivValidatorSocket message containing the public key. | |||
// PubKeyResponse is a response message containing the public key. | |||
type PubKeyResponse struct { | |||
PubKey crypto.PubKey | |||
Error *RemoteSignerError | |||
} | |||
// SignVoteRequest is a PrivValidatorSocket message containing a vote. | |||
// SignVoteRequest is a request to sign a vote | |||
type SignVoteRequest struct { | |||
Vote *types.Vote | |||
} | |||
// SignedVoteResponse is a PrivValidatorSocket message containing a signed vote along with a potenial error message. | |||
// SignedVoteResponse is a response containing a signed vote or an error | |||
type SignedVoteResponse struct { | |||
Vote *types.Vote | |||
Error *RemoteSignerError | |||
} | |||
// SignProposalRequest is a PrivValidatorSocket message containing a Proposal. | |||
// SignProposalRequest is a request to sign a proposal | |||
type SignProposalRequest struct { | |||
Proposal *types.Proposal | |||
} | |||
// SignedProposalResponse is a PrivValidatorSocket message containing a proposal response | |||
// SignedProposalResponse is response containing a signed proposal or an error | |||
type SignedProposalResponse struct { | |||
Proposal *types.Proposal | |||
Error *RemoteSignerError | |||
} | |||
// PingRequest is a PrivValidatorSocket message to keep the connection alive. | |||
// PingRequest is a request to confirm that the connection is alive. | |||
type PingRequest struct { | |||
} | |||
// PingRequest is a PrivValidatorSocket response to keep the connection alive. | |||
// PingResponse is a response to confirm that the connection is alive. | |||
type PingResponse struct { | |||
} |
@ -0,0 +1,131 @@ | |||
package privval | |||
import ( | |||
"time" | |||
"github.com/pkg/errors" | |||
"github.com/tendermint/tendermint/crypto" | |||
"github.com/tendermint/tendermint/types" | |||
) | |||
// SignerClient implements PrivValidator. | |||
// Handles remote validator connections that provide signing services | |||
type SignerClient struct { | |||
endpoint *SignerListenerEndpoint | |||
} | |||
var _ types.PrivValidator = (*SignerClient)(nil) | |||
// NewSignerClient returns an instance of SignerClient. | |||
// it will start the endpoint (if not already started) | |||
func NewSignerClient(endpoint *SignerListenerEndpoint) (*SignerClient, error) { | |||
if !endpoint.IsRunning() { | |||
if err := endpoint.Start(); err != nil { | |||
return nil, errors.Wrap(err, "failed to start listener endpoint") | |||
} | |||
} | |||
return &SignerClient{endpoint: endpoint}, nil | |||
} | |||
// Close closes the underlying connection | |||
func (sc *SignerClient) Close() error { | |||
return sc.endpoint.Close() | |||
} | |||
// IsConnected indicates with the signer is connected to a remote signing service | |||
func (sc *SignerClient) IsConnected() bool { | |||
return sc.endpoint.IsConnected() | |||
} | |||
// WaitForConnection waits maxWait for a connection or returns a timeout error | |||
func (sc *SignerClient) WaitForConnection(maxWait time.Duration) error { | |||
return sc.endpoint.WaitForConnection(maxWait) | |||
} | |||
//-------------------------------------------------------- | |||
// Implement PrivValidator | |||
// Ping sends a ping request to the remote signer | |||
func (sc *SignerClient) Ping() error { | |||
response, err := sc.endpoint.SendRequest(&PingRequest{}) | |||
if err != nil { | |||
sc.endpoint.Logger.Error("SignerClient::Ping", "err", err) | |||
return nil | |||
} | |||
_, ok := response.(*PingResponse) | |||
if !ok { | |||
sc.endpoint.Logger.Error("SignerClient::Ping", "err", "response != PingResponse") | |||
return err | |||
} | |||
return nil | |||
} | |||
// GetPubKey retrieves a public key from a remote signer | |||
func (sc *SignerClient) GetPubKey() crypto.PubKey { | |||
response, err := sc.endpoint.SendRequest(&PubKeyRequest{}) | |||
if err != nil { | |||
sc.endpoint.Logger.Error("SignerClient::GetPubKey", "err", err) | |||
return nil | |||
} | |||
pubKeyResp, ok := response.(*PubKeyResponse) | |||
if !ok { | |||
sc.endpoint.Logger.Error("SignerClient::GetPubKey", "err", "response != PubKeyResponse") | |||
return nil | |||
} | |||
if pubKeyResp.Error != nil { | |||
sc.endpoint.Logger.Error("failed to get private validator's public key", "err", pubKeyResp.Error) | |||
return nil | |||
} | |||
return pubKeyResp.PubKey | |||
} | |||
// SignVote requests a remote signer to sign a vote | |||
func (sc *SignerClient) SignVote(chainID string, vote *types.Vote) error { | |||
response, err := sc.endpoint.SendRequest(&SignVoteRequest{Vote: vote}) | |||
if err != nil { | |||
sc.endpoint.Logger.Error("SignerClient::SignVote", "err", err) | |||
return err | |||
} | |||
resp, ok := response.(*SignedVoteResponse) | |||
if !ok { | |||
sc.endpoint.Logger.Error("SignerClient::GetPubKey", "err", "response != SignedVoteResponse") | |||
return ErrUnexpectedResponse | |||
} | |||
if resp.Error != nil { | |||
return resp.Error | |||
} | |||
*vote = *resp.Vote | |||
return nil | |||
} | |||
// SignProposal requests a remote signer to sign a proposal | |||
func (sc *SignerClient) SignProposal(chainID string, proposal *types.Proposal) error { | |||
response, err := sc.endpoint.SendRequest(&SignProposalRequest{Proposal: proposal}) | |||
if err != nil { | |||
sc.endpoint.Logger.Error("SignerClient::SignProposal", "err", err) | |||
return err | |||
} | |||
resp, ok := response.(*SignedProposalResponse) | |||
if !ok { | |||
sc.endpoint.Logger.Error("SignerClient::SignProposal", "err", "response != SignedProposalResponse") | |||
return ErrUnexpectedResponse | |||
} | |||
if resp.Error != nil { | |||
return resp.Error | |||
} | |||
*proposal = *resp.Proposal | |||
return nil | |||
} |
@ -0,0 +1,257 @@ | |||
package privval | |||
import ( | |||
"fmt" | |||
"testing" | |||
"time" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
"github.com/tendermint/tendermint/libs/common" | |||
"github.com/tendermint/tendermint/types" | |||
) | |||
type signerTestCase struct { | |||
chainID string | |||
mockPV types.PrivValidator | |||
signerClient *SignerClient | |||
signerServer *SignerServer | |||
} | |||
func getSignerTestCases(t *testing.T) []signerTestCase { | |||
testCases := make([]signerTestCase, 0) | |||
// Get test cases for each possible dialer (DialTCP / DialUnix / etc) | |||
for _, dtc := range getDialerTestCases(t) { | |||
chainID := common.RandStr(12) | |||
mockPV := types.NewMockPV() | |||
// get a pair of signer listener, signer dialer endpoints | |||
sl, sd := getMockEndpoints(t, dtc.addr, dtc.dialer) | |||
sc, err := NewSignerClient(sl) | |||
require.NoError(t, err) | |||
ss := NewSignerServer(sd, chainID, mockPV) | |||
err = ss.Start() | |||
require.NoError(t, err) | |||
tc := signerTestCase{ | |||
chainID: chainID, | |||
mockPV: mockPV, | |||
signerClient: sc, | |||
signerServer: ss, | |||
} | |||
testCases = append(testCases, tc) | |||
} | |||
return testCases | |||
} | |||
func TestSignerClose(t *testing.T) { | |||
for _, tc := range getSignerTestCases(t) { | |||
err := tc.signerClient.Close() | |||
assert.NoError(t, err) | |||
err = tc.signerServer.Stop() | |||
assert.NoError(t, err) | |||
} | |||
} | |||
func TestSignerPing(t *testing.T) { | |||
for _, tc := range getSignerTestCases(t) { | |||
defer tc.signerServer.Stop() | |||
defer tc.signerClient.Close() | |||
err := tc.signerClient.Ping() | |||
assert.NoError(t, err) | |||
} | |||
} | |||
func TestSignerGetPubKey(t *testing.T) { | |||
for _, tc := range getSignerTestCases(t) { | |||
defer tc.signerServer.Stop() | |||
defer tc.signerClient.Close() | |||
pubKey := tc.signerClient.GetPubKey() | |||
expectedPubKey := tc.mockPV.GetPubKey() | |||
assert.Equal(t, expectedPubKey, pubKey) | |||
addr := tc.signerClient.GetPubKey().Address() | |||
expectedAddr := tc.mockPV.GetPubKey().Address() | |||
assert.Equal(t, expectedAddr, addr) | |||
} | |||
} | |||
func TestSignerProposal(t *testing.T) { | |||
for _, tc := range getSignerTestCases(t) { | |||
ts := time.Now() | |||
want := &types.Proposal{Timestamp: ts} | |||
have := &types.Proposal{Timestamp: ts} | |||
defer tc.signerServer.Stop() | |||
defer tc.signerClient.Close() | |||
require.NoError(t, tc.mockPV.SignProposal(tc.chainID, want)) | |||
require.NoError(t, tc.signerClient.SignProposal(tc.chainID, have)) | |||
assert.Equal(t, want.Signature, have.Signature) | |||
} | |||
} | |||
func TestSignerVote(t *testing.T) { | |||
for _, tc := range getSignerTestCases(t) { | |||
ts := time.Now() | |||
want := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||
have := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||
defer tc.signerServer.Stop() | |||
defer tc.signerClient.Close() | |||
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want)) | |||
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have)) | |||
assert.Equal(t, want.Signature, have.Signature) | |||
} | |||
} | |||
func TestSignerVoteResetDeadline(t *testing.T) { | |||
for _, tc := range getSignerTestCases(t) { | |||
ts := time.Now() | |||
want := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||
have := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||
defer tc.signerServer.Stop() | |||
defer tc.signerClient.Close() | |||
time.Sleep(testTimeoutReadWrite2o3) | |||
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want)) | |||
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have)) | |||
assert.Equal(t, want.Signature, have.Signature) | |||
// TODO(jleni): Clarify what is actually being tested | |||
// This would exceed the deadline if it was not extended by the previous message | |||
time.Sleep(testTimeoutReadWrite2o3) | |||
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want)) | |||
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have)) | |||
assert.Equal(t, want.Signature, have.Signature) | |||
} | |||
} | |||
func TestSignerVoteKeepAlive(t *testing.T) { | |||
for _, tc := range getSignerTestCases(t) { | |||
ts := time.Now() | |||
want := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||
have := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||
defer tc.signerServer.Stop() | |||
defer tc.signerClient.Close() | |||
// Check that even if the client does not request a | |||
// signature for a long time. The service is still available | |||
// in this particular case, we use the dialer logger to ensure that | |||
// test messages are properly interleaved in the test logs | |||
tc.signerServer.Logger.Debug("TEST: Forced Wait -------------------------------------------------") | |||
time.Sleep(testTimeoutReadWrite * 3) | |||
tc.signerServer.Logger.Debug("TEST: Forced Wait DONE---------------------------------------------") | |||
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want)) | |||
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have)) | |||
assert.Equal(t, want.Signature, have.Signature) | |||
} | |||
} | |||
func TestSignerSignProposalErrors(t *testing.T) { | |||
for _, tc := range getSignerTestCases(t) { | |||
// Replace service with a mock that always fails | |||
tc.signerServer.privVal = types.NewErroringMockPV() | |||
tc.mockPV = types.NewErroringMockPV() | |||
defer tc.signerServer.Stop() | |||
defer tc.signerClient.Close() | |||
ts := time.Now() | |||
proposal := &types.Proposal{Timestamp: ts} | |||
err := tc.signerClient.SignProposal(tc.chainID, proposal) | |||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||
err = tc.mockPV.SignProposal(tc.chainID, proposal) | |||
require.Error(t, err) | |||
err = tc.signerClient.SignProposal(tc.chainID, proposal) | |||
require.Error(t, err) | |||
} | |||
} | |||
func TestSignerSignVoteErrors(t *testing.T) { | |||
for _, tc := range getSignerTestCases(t) { | |||
ts := time.Now() | |||
vote := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||
// Replace signer service privval with one that always fails | |||
tc.signerServer.privVal = types.NewErroringMockPV() | |||
tc.mockPV = types.NewErroringMockPV() | |||
defer tc.signerServer.Stop() | |||
defer tc.signerClient.Close() | |||
err := tc.signerClient.SignVote(tc.chainID, vote) | |||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||
err = tc.mockPV.SignVote(tc.chainID, vote) | |||
require.Error(t, err) | |||
err = tc.signerClient.SignVote(tc.chainID, vote) | |||
require.Error(t, err) | |||
} | |||
} | |||
func brokenHandler(privVal types.PrivValidator, request SignerMessage, chainID string) (SignerMessage, error) { | |||
var res SignerMessage | |||
var err error | |||
switch r := request.(type) { | |||
// This is broken and will answer most requests with a pubkey response | |||
case *PubKeyRequest: | |||
res = &PubKeyResponse{nil, nil} | |||
case *SignVoteRequest: | |||
res = &PubKeyResponse{nil, nil} | |||
case *SignProposalRequest: | |||
res = &PubKeyResponse{nil, nil} | |||
case *PingRequest: | |||
err, res = nil, &PingResponse{} | |||
default: | |||
err = fmt.Errorf("unknown msg: %v", r) | |||
} | |||
return res, err | |||
} | |||
func TestSignerUnexpectedResponse(t *testing.T) { | |||
for _, tc := range getSignerTestCases(t) { | |||
tc.signerServer.privVal = types.NewMockPV() | |||
tc.mockPV = types.NewMockPV() | |||
tc.signerServer.SetRequestHandler(brokenHandler) | |||
defer tc.signerServer.Stop() | |||
defer tc.signerClient.Close() | |||
ts := time.Now() | |||
want := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||
e := tc.signerClient.SignVote(tc.chainID, want) | |||
assert.EqualError(t, e, "received unexpected response") | |||
} | |||
} |
@ -0,0 +1,84 @@ | |||
package privval | |||
import ( | |||
"time" | |||
cmn "github.com/tendermint/tendermint/libs/common" | |||
"github.com/tendermint/tendermint/libs/log" | |||
) | |||
const ( | |||
defaultMaxDialRetries = 10 | |||
defaultRetryWaitMilliseconds = 100 | |||
) | |||
// SignerServiceEndpointOption sets an optional parameter on the SignerDialerEndpoint. | |||
type SignerServiceEndpointOption func(*SignerDialerEndpoint) | |||
// SignerDialerEndpointTimeoutReadWrite sets the read and write timeout for connections | |||
// from external signing processes. | |||
func SignerDialerEndpointTimeoutReadWrite(timeout time.Duration) SignerServiceEndpointOption { | |||
return func(ss *SignerDialerEndpoint) { ss.timeoutReadWrite = timeout } | |||
} | |||
// SignerDialerEndpointConnRetries sets the amount of attempted retries to acceptNewConnection. | |||
func SignerDialerEndpointConnRetries(retries int) SignerServiceEndpointOption { | |||
return func(ss *SignerDialerEndpoint) { ss.maxConnRetries = retries } | |||
} | |||
// SignerDialerEndpoint dials using its dialer and responds to any | |||
// signature requests using its privVal. | |||
type SignerDialerEndpoint struct { | |||
signerEndpoint | |||
dialer SocketDialer | |||
retryWait time.Duration | |||
maxConnRetries int | |||
} | |||
// NewSignerDialerEndpoint returns a SignerDialerEndpoint that will dial using the given | |||
// dialer and respond to any signature requests over the connection | |||
// using the given privVal. | |||
func NewSignerDialerEndpoint( | |||
logger log.Logger, | |||
dialer SocketDialer, | |||
) *SignerDialerEndpoint { | |||
sd := &SignerDialerEndpoint{ | |||
dialer: dialer, | |||
retryWait: defaultRetryWaitMilliseconds * time.Millisecond, | |||
maxConnRetries: defaultMaxDialRetries, | |||
} | |||
sd.BaseService = *cmn.NewBaseService(logger, "SignerDialerEndpoint", sd) | |||
sd.signerEndpoint.timeoutReadWrite = defaultTimeoutReadWriteSeconds * time.Second | |||
return sd | |||
} | |||
func (sd *SignerDialerEndpoint) ensureConnection() error { | |||
if sd.IsConnected() { | |||
return nil | |||
} | |||
retries := 0 | |||
for retries < sd.maxConnRetries { | |||
conn, err := sd.dialer() | |||
if err != nil { | |||
retries++ | |||
sd.Logger.Debug("SignerDialer: Reconnection failed", "retries", retries, "max", sd.maxConnRetries, "err", err) | |||
// Wait between retries | |||
time.Sleep(sd.retryWait) | |||
} else { | |||
sd.SetConnection(conn) | |||
sd.Logger.Debug("SignerDialer: Connection Ready") | |||
return nil | |||
} | |||
} | |||
sd.Logger.Debug("SignerDialer: Max retries exceeded", "retries", retries, "max", sd.maxConnRetries) | |||
return ErrNoConnection | |||
} |
@ -0,0 +1,156 @@ | |||
package privval | |||
import ( | |||
"fmt" | |||
"net" | |||
"sync" | |||
"time" | |||
"github.com/pkg/errors" | |||
cmn "github.com/tendermint/tendermint/libs/common" | |||
) | |||
const ( | |||
defaultTimeoutReadWriteSeconds = 3 | |||
) | |||
type signerEndpoint struct { | |||
cmn.BaseService | |||
connMtx sync.Mutex | |||
conn net.Conn | |||
timeoutReadWrite time.Duration | |||
} | |||
// Close closes the underlying net.Conn. | |||
func (se *signerEndpoint) Close() error { | |||
se.DropConnection() | |||
return nil | |||
} | |||
// IsConnected indicates if there is an active connection | |||
func (se *signerEndpoint) IsConnected() bool { | |||
se.connMtx.Lock() | |||
defer se.connMtx.Unlock() | |||
return se.isConnected() | |||
} | |||
// TryGetConnection retrieves a connection if it is already available | |||
func (se *signerEndpoint) GetAvailableConnection(connectionAvailableCh chan net.Conn) bool { | |||
se.connMtx.Lock() | |||
defer se.connMtx.Unlock() | |||
// Is there a connection ready? | |||
select { | |||
case se.conn = <-connectionAvailableCh: | |||
return true | |||
default: | |||
} | |||
return false | |||
} | |||
// TryGetConnection retrieves a connection if it is already available | |||
func (se *signerEndpoint) WaitConnection(connectionAvailableCh chan net.Conn, maxWait time.Duration) error { | |||
se.connMtx.Lock() | |||
defer se.connMtx.Unlock() | |||
select { | |||
case se.conn = <-connectionAvailableCh: | |||
case <-time.After(maxWait): | |||
return ErrConnectionTimeout | |||
} | |||
return nil | |||
} | |||
// SetConnection replaces the current connection object | |||
func (se *signerEndpoint) SetConnection(newConnection net.Conn) { | |||
se.connMtx.Lock() | |||
defer se.connMtx.Unlock() | |||
se.conn = newConnection | |||
} | |||
// IsConnected indicates if there is an active connection | |||
func (se *signerEndpoint) DropConnection() { | |||
se.connMtx.Lock() | |||
defer se.connMtx.Unlock() | |||
se.dropConnection() | |||
} | |||
// ReadMessage reads a message from the endpoint | |||
func (se *signerEndpoint) ReadMessage() (msg SignerMessage, err error) { | |||
se.connMtx.Lock() | |||
defer se.connMtx.Unlock() | |||
if !se.isConnected() { | |||
return nil, fmt.Errorf("endpoint is not connected") | |||
} | |||
// Reset read deadline | |||
deadline := time.Now().Add(se.timeoutReadWrite) | |||
err = se.conn.SetReadDeadline(deadline) | |||
if err != nil { | |||
return | |||
} | |||
const maxRemoteSignerMsgSize = 1024 * 10 | |||
_, err = cdc.UnmarshalBinaryLengthPrefixedReader(se.conn, &msg, maxRemoteSignerMsgSize) | |||
if _, ok := err.(timeoutError); ok { | |||
if err != nil { | |||
err = errors.Wrap(ErrReadTimeout, err.Error()) | |||
} else { | |||
err = errors.Wrap(ErrReadTimeout, "Empty error") | |||
} | |||
se.Logger.Debug("Dropping [read]", "obj", se) | |||
se.dropConnection() | |||
} | |||
return | |||
} | |||
// WriteMessage writes a message from the endpoint | |||
func (se *signerEndpoint) WriteMessage(msg SignerMessage) (err error) { | |||
se.connMtx.Lock() | |||
defer se.connMtx.Unlock() | |||
if !se.isConnected() { | |||
return errors.Wrap(ErrNoConnection, "endpoint is not connected") | |||
} | |||
// Reset read deadline | |||
deadline := time.Now().Add(se.timeoutReadWrite) | |||
se.Logger.Debug("Write::Error Resetting deadline", "obj", se) | |||
err = se.conn.SetWriteDeadline(deadline) | |||
if err != nil { | |||
return | |||
} | |||
_, err = cdc.MarshalBinaryLengthPrefixedWriter(se.conn, msg) | |||
if _, ok := err.(timeoutError); ok { | |||
if err != nil { | |||
err = errors.Wrap(ErrWriteTimeout, err.Error()) | |||
} else { | |||
err = errors.Wrap(ErrWriteTimeout, "Empty error") | |||
} | |||
se.dropConnection() | |||
} | |||
return | |||
} | |||
func (se *signerEndpoint) isConnected() bool { | |||
return se.conn != nil | |||
} | |||
func (se *signerEndpoint) dropConnection() { | |||
if se.conn != nil { | |||
if err := se.conn.Close(); err != nil { | |||
se.Logger.Error("signerEndpoint::dropConnection", "err", err) | |||
} | |||
se.conn = nil | |||
} | |||
} |
@ -0,0 +1,198 @@ | |||
package privval | |||
import ( | |||
"fmt" | |||
"net" | |||
"sync" | |||
"time" | |||
cmn "github.com/tendermint/tendermint/libs/common" | |||
"github.com/tendermint/tendermint/libs/log" | |||
) | |||
// SignerValidatorEndpointOption sets an optional parameter on the SocketVal. | |||
type SignerValidatorEndpointOption func(*SignerListenerEndpoint) | |||
// SignerListenerEndpoint listens for an external process to dial in | |||
// and keeps the connection alive by dropping and reconnecting | |||
type SignerListenerEndpoint struct { | |||
signerEndpoint | |||
listener net.Listener | |||
connectRequestCh chan struct{} | |||
connectionAvailableCh chan net.Conn | |||
timeoutAccept time.Duration | |||
pingTimer *time.Ticker | |||
instanceMtx sync.Mutex // Ensures instance public methods access, i.e. SendRequest | |||
} | |||
// NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint. | |||
func NewSignerListenerEndpoint( | |||
logger log.Logger, | |||
listener net.Listener, | |||
) *SignerListenerEndpoint { | |||
sc := &SignerListenerEndpoint{ | |||
listener: listener, | |||
timeoutAccept: defaultTimeoutAcceptSeconds * time.Second, | |||
} | |||
sc.BaseService = *cmn.NewBaseService(logger, "SignerListenerEndpoint", sc) | |||
sc.signerEndpoint.timeoutReadWrite = defaultTimeoutReadWriteSeconds * time.Second | |||
return sc | |||
} | |||
// OnStart implements cmn.Service. | |||
func (sl *SignerListenerEndpoint) OnStart() error { | |||
sl.connectRequestCh = make(chan struct{}) | |||
sl.connectionAvailableCh = make(chan net.Conn) | |||
sl.pingTimer = time.NewTicker(defaultPingPeriodMilliseconds * time.Millisecond) | |||
go sl.serviceLoop() | |||
go sl.pingLoop() | |||
sl.connectRequestCh <- struct{}{} | |||
return nil | |||
} | |||
// OnStop implements cmn.Service | |||
func (sl *SignerListenerEndpoint) OnStop() { | |||
sl.instanceMtx.Lock() | |||
defer sl.instanceMtx.Unlock() | |||
_ = sl.Close() | |||
// Stop listening | |||
if sl.listener != nil { | |||
if err := sl.listener.Close(); err != nil { | |||
sl.Logger.Error("Closing Listener", "err", err) | |||
sl.listener = nil | |||
} | |||
} | |||
sl.pingTimer.Stop() | |||
} | |||
// WaitForConnection waits maxWait for a connection or returns a timeout error | |||
func (sl *SignerListenerEndpoint) WaitForConnection(maxWait time.Duration) error { | |||
sl.instanceMtx.Lock() | |||
defer sl.instanceMtx.Unlock() | |||
return sl.ensureConnection(maxWait) | |||
} | |||
// SendRequest ensures there is a connection, sends a request and waits for a response | |||
func (sl *SignerListenerEndpoint) SendRequest(request SignerMessage) (SignerMessage, error) { | |||
sl.instanceMtx.Lock() | |||
defer sl.instanceMtx.Unlock() | |||
err := sl.ensureConnection(sl.timeoutAccept) | |||
if err != nil { | |||
return nil, err | |||
} | |||
err = sl.WriteMessage(request) | |||
if err != nil { | |||
return nil, err | |||
} | |||
res, err := sl.ReadMessage() | |||
if err != nil { | |||
return nil, err | |||
} | |||
return res, nil | |||
} | |||
func (sl *SignerListenerEndpoint) ensureConnection(maxWait time.Duration) error { | |||
if sl.IsConnected() { | |||
return nil | |||
} | |||
// Is there a connection ready? then use it | |||
if sl.GetAvailableConnection(sl.connectionAvailableCh) { | |||
return nil | |||
} | |||
// block until connected or timeout | |||
sl.triggerConnect() | |||
err := sl.WaitConnection(sl.connectionAvailableCh, maxWait) | |||
if err != nil { | |||
return err | |||
} | |||
return nil | |||
} | |||
func (sl *SignerListenerEndpoint) acceptNewConnection() (net.Conn, error) { | |||
if !sl.IsRunning() || sl.listener == nil { | |||
return nil, fmt.Errorf("endpoint is closing") | |||
} | |||
// wait for a new conn | |||
sl.Logger.Info("SignerListener: Listening for new connection") | |||
conn, err := sl.listener.Accept() | |||
if err != nil { | |||
return nil, err | |||
} | |||
return conn, nil | |||
} | |||
func (sl *SignerListenerEndpoint) triggerConnect() { | |||
select { | |||
case sl.connectRequestCh <- struct{}{}: | |||
default: | |||
} | |||
} | |||
func (sl *SignerListenerEndpoint) triggerReconnect() { | |||
sl.DropConnection() | |||
sl.triggerConnect() | |||
} | |||
func (sl *SignerListenerEndpoint) serviceLoop() { | |||
for { | |||
select { | |||
case <-sl.connectRequestCh: | |||
{ | |||
conn, err := sl.acceptNewConnection() | |||
if err == nil { | |||
sl.Logger.Info("SignerListener: Connected") | |||
// We have a good connection, wait for someone that needs one otherwise cancellation | |||
select { | |||
case sl.connectionAvailableCh <- conn: | |||
case <-sl.Quit(): | |||
return | |||
} | |||
} | |||
select { | |||
case sl.connectRequestCh <- struct{}{}: | |||
default: | |||
} | |||
} | |||
case <-sl.Quit(): | |||
return | |||
} | |||
} | |||
} | |||
func (sl *SignerListenerEndpoint) pingLoop() { | |||
for { | |||
select { | |||
case <-sl.pingTimer.C: | |||
{ | |||
_, err := sl.SendRequest(&PingRequest{}) | |||
if err != nil { | |||
sl.Logger.Error("SignerListener: Ping timeout") | |||
sl.triggerReconnect() | |||
} | |||
} | |||
case <-sl.Quit(): | |||
return | |||
} | |||
} | |||
} |
@ -0,0 +1,198 @@ | |||
package privval | |||
import ( | |||
"net" | |||
"testing" | |||
"time" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
"github.com/tendermint/tendermint/crypto/ed25519" | |||
cmn "github.com/tendermint/tendermint/libs/common" | |||
"github.com/tendermint/tendermint/libs/log" | |||
"github.com/tendermint/tendermint/types" | |||
) | |||
var ( | |||
testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second | |||
testTimeoutReadWrite = 100 * time.Millisecond | |||
testTimeoutReadWrite2o3 = 60 * time.Millisecond // 2/3 of the other one | |||
) | |||
type dialerTestCase struct { | |||
addr string | |||
dialer SocketDialer | |||
} | |||
// TestSignerRemoteRetryTCPOnly will test connection retry attempts over TCP. We | |||
// don't need this for Unix sockets because the OS instantly knows the state of | |||
// both ends of the socket connection. This basically causes the | |||
// SignerDialerEndpoint.dialer() call inside SignerDialerEndpoint.acceptNewConnection() to return | |||
// successfully immediately, putting an instant stop to any retry attempts. | |||
func TestSignerRemoteRetryTCPOnly(t *testing.T) { | |||
var ( | |||
attemptCh = make(chan int) | |||
retries = 10 | |||
) | |||
ln, err := net.Listen("tcp", "127.0.0.1:0") | |||
require.NoError(t, err) | |||
// Continuously Accept connection and close {attempts} times | |||
go func(ln net.Listener, attemptCh chan<- int) { | |||
attempts := 0 | |||
for { | |||
conn, err := ln.Accept() | |||
require.NoError(t, err) | |||
err = conn.Close() | |||
require.NoError(t, err) | |||
attempts++ | |||
if attempts == retries { | |||
attemptCh <- attempts | |||
break | |||
} | |||
} | |||
}(ln, attemptCh) | |||
dialerEndpoint := NewSignerDialerEndpoint( | |||
log.TestingLogger(), | |||
DialTCPFn(ln.Addr().String(), testTimeoutReadWrite, ed25519.GenPrivKey()), | |||
) | |||
SignerDialerEndpointTimeoutReadWrite(time.Millisecond)(dialerEndpoint) | |||
SignerDialerEndpointConnRetries(retries)(dialerEndpoint) | |||
chainId := cmn.RandStr(12) | |||
mockPV := types.NewMockPV() | |||
signerServer := NewSignerServer(dialerEndpoint, chainId, mockPV) | |||
err = signerServer.Start() | |||
require.NoError(t, err) | |||
defer signerServer.Stop() | |||
select { | |||
case attempts := <-attemptCh: | |||
assert.Equal(t, retries, attempts) | |||
case <-time.After(1500 * time.Millisecond): | |||
t.Error("expected remote to observe connection attempts") | |||
} | |||
} | |||
func TestRetryConnToRemoteSigner(t *testing.T) { | |||
for _, tc := range getDialerTestCases(t) { | |||
var ( | |||
logger = log.TestingLogger() | |||
chainID = cmn.RandStr(12) | |||
mockPV = types.NewMockPV() | |||
endpointIsOpenCh = make(chan struct{}) | |||
thisConnTimeout = testTimeoutReadWrite | |||
listenerEndpoint = newSignerListenerEndpoint(logger, tc.addr, thisConnTimeout) | |||
) | |||
dialerEndpoint := NewSignerDialerEndpoint( | |||
logger, | |||
tc.dialer, | |||
) | |||
SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint) | |||
SignerDialerEndpointConnRetries(10)(dialerEndpoint) | |||
signerServer := NewSignerServer(dialerEndpoint, chainID, mockPV) | |||
startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh) | |||
defer listenerEndpoint.Stop() | |||
require.NoError(t, signerServer.Start()) | |||
assert.True(t, signerServer.IsRunning()) | |||
<-endpointIsOpenCh | |||
signerServer.Stop() | |||
dialerEndpoint2 := NewSignerDialerEndpoint( | |||
logger, | |||
tc.dialer, | |||
) | |||
signerServer2 := NewSignerServer(dialerEndpoint2, chainID, mockPV) | |||
// let some pings pass | |||
require.NoError(t, signerServer2.Start()) | |||
assert.True(t, signerServer2.IsRunning()) | |||
defer signerServer2.Stop() | |||
// give the client some time to re-establish the conn to the remote signer | |||
// should see sth like this in the logs: | |||
// | |||
// E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" | |||
// I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal | |||
time.Sleep(testTimeoutReadWrite * 2) | |||
} | |||
} | |||
/////////////////////////////////// | |||
func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint { | |||
proto, address := cmn.ProtocolAndAddress(addr) | |||
ln, err := net.Listen(proto, address) | |||
logger.Info("SignerListener: Listening", "proto", proto, "address", address) | |||
if err != nil { | |||
panic(err) | |||
} | |||
var listener net.Listener | |||
if proto == "unix" { | |||
unixLn := NewUnixListener(ln) | |||
UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn) | |||
UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn) | |||
listener = unixLn | |||
} else { | |||
tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) | |||
TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn) | |||
TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn) | |||
listener = tcpLn | |||
} | |||
return NewSignerListenerEndpoint(logger, listener) | |||
} | |||
func startListenerEndpointAsync(t *testing.T, sle *SignerListenerEndpoint, endpointIsOpenCh chan struct{}) { | |||
go func(sle *SignerListenerEndpoint) { | |||
require.NoError(t, sle.Start()) | |||
assert.True(t, sle.IsRunning()) | |||
close(endpointIsOpenCh) | |||
}(sle) | |||
} | |||
func getMockEndpoints( | |||
t *testing.T, | |||
addr string, | |||
socketDialer SocketDialer, | |||
) (*SignerListenerEndpoint, *SignerDialerEndpoint) { | |||
var ( | |||
logger = log.TestingLogger() | |||
endpointIsOpenCh = make(chan struct{}) | |||
dialerEndpoint = NewSignerDialerEndpoint( | |||
logger, | |||
socketDialer, | |||
) | |||
listenerEndpoint = newSignerListenerEndpoint(logger, addr, testTimeoutReadWrite) | |||
) | |||
SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint) | |||
SignerDialerEndpointConnRetries(1e6)(dialerEndpoint) | |||
startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh) | |||
require.NoError(t, dialerEndpoint.Start()) | |||
assert.True(t, dialerEndpoint.IsRunning()) | |||
<-endpointIsOpenCh | |||
return listenerEndpoint, dialerEndpoint | |||
} |
@ -1,192 +0,0 @@ | |||
package privval | |||
import ( | |||
"fmt" | |||
"io" | |||
"net" | |||
"github.com/pkg/errors" | |||
"github.com/tendermint/tendermint/crypto" | |||
cmn "github.com/tendermint/tendermint/libs/common" | |||
"github.com/tendermint/tendermint/types" | |||
) | |||
// SignerRemote implements PrivValidator. | |||
// It uses a net.Conn to request signatures from an external process. | |||
type SignerRemote struct { | |||
conn net.Conn | |||
// memoized | |||
consensusPubKey crypto.PubKey | |||
} | |||
// Check that SignerRemote implements PrivValidator. | |||
var _ types.PrivValidator = (*SignerRemote)(nil) | |||
// NewSignerRemote returns an instance of SignerRemote. | |||
func NewSignerRemote(conn net.Conn) (*SignerRemote, error) { | |||
// retrieve and memoize the consensus public key once. | |||
pubKey, err := getPubKey(conn) | |||
if err != nil { | |||
return nil, cmn.ErrorWrap(err, "error while retrieving public key for remote signer") | |||
} | |||
return &SignerRemote{ | |||
conn: conn, | |||
consensusPubKey: pubKey, | |||
}, nil | |||
} | |||
// Close calls Close on the underlying net.Conn. | |||
func (sc *SignerRemote) Close() error { | |||
return sc.conn.Close() | |||
} | |||
// GetPubKey implements PrivValidator. | |||
func (sc *SignerRemote) GetPubKey() crypto.PubKey { | |||
return sc.consensusPubKey | |||
} | |||
// not thread-safe (only called on startup). | |||
func getPubKey(conn net.Conn) (crypto.PubKey, error) { | |||
err := writeMsg(conn, &PubKeyRequest{}) | |||
if err != nil { | |||
return nil, err | |||
} | |||
res, err := readMsg(conn) | |||
if err != nil { | |||
return nil, err | |||
} | |||
pubKeyResp, ok := res.(*PubKeyResponse) | |||
if !ok { | |||
return nil, errors.Wrap(ErrUnexpectedResponse, "response is not PubKeyResponse") | |||
} | |||
if pubKeyResp.Error != nil { | |||
return nil, errors.Wrap(pubKeyResp.Error, "failed to get private validator's public key") | |||
} | |||
return pubKeyResp.PubKey, nil | |||
} | |||
// SignVote implements PrivValidator. | |||
func (sc *SignerRemote) SignVote(chainID string, vote *types.Vote) error { | |||
err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote}) | |||
if err != nil { | |||
return err | |||
} | |||
res, err := readMsg(sc.conn) | |||
if err != nil { | |||
return err | |||
} | |||
resp, ok := res.(*SignedVoteResponse) | |||
if !ok { | |||
return ErrUnexpectedResponse | |||
} | |||
if resp.Error != nil { | |||
return resp.Error | |||
} | |||
*vote = *resp.Vote | |||
return nil | |||
} | |||
// SignProposal implements PrivValidator. | |||
func (sc *SignerRemote) SignProposal(chainID string, proposal *types.Proposal) error { | |||
err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal}) | |||
if err != nil { | |||
return err | |||
} | |||
res, err := readMsg(sc.conn) | |||
if err != nil { | |||
return err | |||
} | |||
resp, ok := res.(*SignedProposalResponse) | |||
if !ok { | |||
return ErrUnexpectedResponse | |||
} | |||
if resp.Error != nil { | |||
return resp.Error | |||
} | |||
*proposal = *resp.Proposal | |||
return nil | |||
} | |||
// Ping is used to check connection health. | |||
func (sc *SignerRemote) Ping() error { | |||
err := writeMsg(sc.conn, &PingRequest{}) | |||
if err != nil { | |||
return err | |||
} | |||
res, err := readMsg(sc.conn) | |||
if err != nil { | |||
return err | |||
} | |||
_, ok := res.(*PingResponse) | |||
if !ok { | |||
return ErrUnexpectedResponse | |||
} | |||
return nil | |||
} | |||
func readMsg(r io.Reader) (msg RemoteSignerMsg, err error) { | |||
const maxRemoteSignerMsgSize = 1024 * 10 | |||
_, err = cdc.UnmarshalBinaryLengthPrefixedReader(r, &msg, maxRemoteSignerMsgSize) | |||
if _, ok := err.(timeoutError); ok { | |||
err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) | |||
} | |||
return | |||
} | |||
func writeMsg(w io.Writer, msg interface{}) (err error) { | |||
_, err = cdc.MarshalBinaryLengthPrefixedWriter(w, msg) | |||
if _, ok := err.(timeoutError); ok { | |||
err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) | |||
} | |||
return | |||
} | |||
func handleRequest(req RemoteSignerMsg, chainID string, privVal types.PrivValidator) (RemoteSignerMsg, error) { | |||
var res RemoteSignerMsg | |||
var err error | |||
switch r := req.(type) { | |||
case *PubKeyRequest: | |||
var p crypto.PubKey | |||
p = privVal.GetPubKey() | |||
res = &PubKeyResponse{p, nil} | |||
case *SignVoteRequest: | |||
err = privVal.SignVote(chainID, r.Vote) | |||
if err != nil { | |||
res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}} | |||
} else { | |||
res = &SignedVoteResponse{r.Vote, nil} | |||
} | |||
case *SignProposalRequest: | |||
err = privVal.SignProposal(chainID, r.Proposal) | |||
if err != nil { | |||
res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}} | |||
} else { | |||
res = &SignedProposalResponse{r.Proposal, nil} | |||
} | |||
case *PingRequest: | |||
res = &PingResponse{} | |||
default: | |||
err = fmt.Errorf("unknown msg: %v", r) | |||
} | |||
return res, err | |||
} |
@ -1,68 +0,0 @@ | |||
package privval | |||
import ( | |||
"net" | |||
"testing" | |||
"time" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
"github.com/tendermint/tendermint/crypto/ed25519" | |||
cmn "github.com/tendermint/tendermint/libs/common" | |||
"github.com/tendermint/tendermint/libs/log" | |||
"github.com/tendermint/tendermint/types" | |||
) | |||
// TestSignerRemoteRetryTCPOnly will test connection retry attempts over TCP. We | |||
// don't need this for Unix sockets because the OS instantly knows the state of | |||
// both ends of the socket connection. This basically causes the | |||
// SignerServiceEndpoint.dialer() call inside SignerServiceEndpoint.connect() to return | |||
// successfully immediately, putting an instant stop to any retry attempts. | |||
func TestSignerRemoteRetryTCPOnly(t *testing.T) { | |||
var ( | |||
attemptCh = make(chan int) | |||
retries = 2 | |||
) | |||
ln, err := net.Listen("tcp", "127.0.0.1:0") | |||
require.NoError(t, err) | |||
go func(ln net.Listener, attemptCh chan<- int) { | |||
attempts := 0 | |||
for { | |||
conn, err := ln.Accept() | |||
require.NoError(t, err) | |||
err = conn.Close() | |||
require.NoError(t, err) | |||
attempts++ | |||
if attempts == retries { | |||
attemptCh <- attempts | |||
break | |||
} | |||
} | |||
}(ln, attemptCh) | |||
serviceEndpoint := NewSignerServiceEndpoint( | |||
log.TestingLogger(), | |||
cmn.RandStr(12), | |||
types.NewMockPV(), | |||
DialTCPFn(ln.Addr().String(), testTimeoutReadWrite, ed25519.GenPrivKey()), | |||
) | |||
defer serviceEndpoint.Stop() | |||
SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint) | |||
SignerServiceEndpointConnRetries(retries)(serviceEndpoint) | |||
assert.Equal(t, serviceEndpoint.Start(), ErrDialRetryMax) | |||
select { | |||
case attempts := <-attemptCh: | |||
assert.Equal(t, retries, attempts) | |||
case <-time.After(100 * time.Millisecond): | |||
t.Error("expected remote to observe connection attempts") | |||
} | |||
} |
@ -0,0 +1,44 @@ | |||
package privval | |||
import ( | |||
"fmt" | |||
"github.com/tendermint/tendermint/crypto" | |||
"github.com/tendermint/tendermint/types" | |||
) | |||
func DefaultValidationRequestHandler(privVal types.PrivValidator, req SignerMessage, chainID string) (SignerMessage, error) { | |||
var res SignerMessage | |||
var err error | |||
switch r := req.(type) { | |||
case *PubKeyRequest: | |||
var p crypto.PubKey | |||
p = privVal.GetPubKey() | |||
res = &PubKeyResponse{p, nil} | |||
case *SignVoteRequest: | |||
err = privVal.SignVote(chainID, r.Vote) | |||
if err != nil { | |||
res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}} | |||
} else { | |||
res = &SignedVoteResponse{r.Vote, nil} | |||
} | |||
case *SignProposalRequest: | |||
err = privVal.SignProposal(chainID, r.Proposal) | |||
if err != nil { | |||
res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}} | |||
} else { | |||
res = &SignedProposalResponse{r.Proposal, nil} | |||
} | |||
case *PingRequest: | |||
err, res = nil, &PingResponse{} | |||
default: | |||
err = fmt.Errorf("unknown msg: %v", r) | |||
} | |||
return res, err | |||
} |
@ -0,0 +1,107 @@ | |||
package privval | |||
import ( | |||
"io" | |||
"sync" | |||
cmn "github.com/tendermint/tendermint/libs/common" | |||
"github.com/tendermint/tendermint/types" | |||
) | |||
// ValidationRequestHandlerFunc handles different remoteSigner requests | |||
type ValidationRequestHandlerFunc func( | |||
privVal types.PrivValidator, | |||
requestMessage SignerMessage, | |||
chainID string) (SignerMessage, error) | |||
type SignerServer struct { | |||
cmn.BaseService | |||
endpoint *SignerDialerEndpoint | |||
chainID string | |||
privVal types.PrivValidator | |||
handlerMtx sync.Mutex | |||
validationRequestHandler ValidationRequestHandlerFunc | |||
} | |||
func NewSignerServer(endpoint *SignerDialerEndpoint, chainID string, privVal types.PrivValidator) *SignerServer { | |||
ss := &SignerServer{ | |||
endpoint: endpoint, | |||
chainID: chainID, | |||
privVal: privVal, | |||
validationRequestHandler: DefaultValidationRequestHandler, | |||
} | |||
ss.BaseService = *cmn.NewBaseService(endpoint.Logger, "SignerServer", ss) | |||
return ss | |||
} | |||
// OnStart implements cmn.Service. | |||
func (ss *SignerServer) OnStart() error { | |||
go ss.serviceLoop() | |||
return nil | |||
} | |||
// OnStop implements cmn.Service. | |||
func (ss *SignerServer) OnStop() { | |||
ss.endpoint.Logger.Debug("SignerServer: OnStop calling Close") | |||
_ = ss.endpoint.Close() | |||
} | |||
// SetRequestHandler override the default function that is used to service requests | |||
func (ss *SignerServer) SetRequestHandler(validationRequestHandler ValidationRequestHandlerFunc) { | |||
ss.handlerMtx.Lock() | |||
defer ss.handlerMtx.Unlock() | |||
ss.validationRequestHandler = validationRequestHandler | |||
} | |||
func (ss *SignerServer) servicePendingRequest() { | |||
if !ss.IsRunning() { | |||
return // Ignore error from closing. | |||
} | |||
req, err := ss.endpoint.ReadMessage() | |||
if err != nil { | |||
if err != io.EOF { | |||
ss.Logger.Error("SignerServer: HandleMessage", "err", err) | |||
} | |||
return | |||
} | |||
var res SignerMessage | |||
{ | |||
// limit the scope of the lock | |||
ss.handlerMtx.Lock() | |||
defer ss.handlerMtx.Unlock() | |||
res, err = ss.validationRequestHandler(ss.privVal, req, ss.chainID) | |||
if err != nil { | |||
// only log the error; we'll reply with an error in res | |||
ss.Logger.Error("SignerServer: handleMessage", "err", err) | |||
} | |||
} | |||
if res != nil { | |||
err = ss.endpoint.WriteMessage(res) | |||
if err != nil { | |||
ss.Logger.Error("SignerServer: writeMessage", "err", err) | |||
} | |||
} | |||
} | |||
func (ss *SignerServer) serviceLoop() { | |||
for { | |||
select { | |||
default: | |||
err := ss.endpoint.ensureConnection() | |||
if err != nil { | |||
return | |||
} | |||
ss.servicePendingRequest() | |||
case <-ss.Quit(): | |||
return | |||
} | |||
} | |||
} |
@ -1,139 +0,0 @@ | |||
package privval | |||
import ( | |||
"io" | |||
"net" | |||
"time" | |||
cmn "github.com/tendermint/tendermint/libs/common" | |||
"github.com/tendermint/tendermint/libs/log" | |||
"github.com/tendermint/tendermint/types" | |||
) | |||
// SignerServiceEndpointOption sets an optional parameter on the SignerServiceEndpoint. | |||
type SignerServiceEndpointOption func(*SignerServiceEndpoint) | |||
// SignerServiceEndpointTimeoutReadWrite sets the read and write timeout for connections | |||
// from external signing processes. | |||
func SignerServiceEndpointTimeoutReadWrite(timeout time.Duration) SignerServiceEndpointOption { | |||
return func(ss *SignerServiceEndpoint) { ss.timeoutReadWrite = timeout } | |||
} | |||
// SignerServiceEndpointConnRetries sets the amount of attempted retries to connect. | |||
func SignerServiceEndpointConnRetries(retries int) SignerServiceEndpointOption { | |||
return func(ss *SignerServiceEndpoint) { ss.connRetries = retries } | |||
} | |||
// SignerServiceEndpoint dials using its dialer and responds to any | |||
// signature requests using its privVal. | |||
type SignerServiceEndpoint struct { | |||
cmn.BaseService | |||
chainID string | |||
timeoutReadWrite time.Duration | |||
connRetries int | |||
privVal types.PrivValidator | |||
dialer SocketDialer | |||
conn net.Conn | |||
} | |||
// NewSignerServiceEndpoint returns a SignerServiceEndpoint that will dial using the given | |||
// dialer and respond to any signature requests over the connection | |||
// using the given privVal. | |||
func NewSignerServiceEndpoint( | |||
logger log.Logger, | |||
chainID string, | |||
privVal types.PrivValidator, | |||
dialer SocketDialer, | |||
) *SignerServiceEndpoint { | |||
se := &SignerServiceEndpoint{ | |||
chainID: chainID, | |||
timeoutReadWrite: time.Second * defaultTimeoutReadWriteSeconds, | |||
connRetries: defaultMaxDialRetries, | |||
privVal: privVal, | |||
dialer: dialer, | |||
} | |||
se.BaseService = *cmn.NewBaseService(logger, "SignerServiceEndpoint", se) | |||
return se | |||
} | |||
// OnStart implements cmn.Service. | |||
func (se *SignerServiceEndpoint) OnStart() error { | |||
conn, err := se.connect() | |||
if err != nil { | |||
se.Logger.Error("OnStart", "err", err) | |||
return err | |||
} | |||
se.conn = conn | |||
go se.handleConnection(conn) | |||
return nil | |||
} | |||
// OnStop implements cmn.Service. | |||
func (se *SignerServiceEndpoint) OnStop() { | |||
if se.conn == nil { | |||
return | |||
} | |||
if err := se.conn.Close(); err != nil { | |||
se.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed")) | |||
} | |||
} | |||
func (se *SignerServiceEndpoint) connect() (net.Conn, error) { | |||
for retries := 0; retries < se.connRetries; retries++ { | |||
// Don't sleep if it is the first retry. | |||
if retries > 0 { | |||
time.Sleep(se.timeoutReadWrite) | |||
} | |||
conn, err := se.dialer() | |||
if err == nil { | |||
return conn, nil | |||
} | |||
se.Logger.Error("dialing", "err", err) | |||
} | |||
return nil, ErrDialRetryMax | |||
} | |||
func (se *SignerServiceEndpoint) handleConnection(conn net.Conn) { | |||
for { | |||
if !se.IsRunning() { | |||
return // Ignore error from listener closing. | |||
} | |||
// Reset the connection deadline | |||
deadline := time.Now().Add(se.timeoutReadWrite) | |||
err := conn.SetDeadline(deadline) | |||
if err != nil { | |||
return | |||
} | |||
req, err := readMsg(conn) | |||
if err != nil { | |||
if err != io.EOF { | |||
se.Logger.Error("handleConnection readMsg", "err", err) | |||
} | |||
return | |||
} | |||
res, err := handleRequest(req, se.chainID, se.privVal) | |||
if err != nil { | |||
// only log the error; we'll reply with an error in res | |||
se.Logger.Error("handleConnection handleRequest", "err", err) | |||
} | |||
err = writeMsg(conn, res) | |||
if err != nil { | |||
se.Logger.Error("handleConnection writeMsg", "err", err) | |||
return | |||
} | |||
} | |||
} |
@ -1,230 +0,0 @@ | |||
package privval | |||
import ( | |||
"fmt" | |||
"net" | |||
"sync" | |||
"time" | |||
"github.com/tendermint/tendermint/crypto" | |||
cmn "github.com/tendermint/tendermint/libs/common" | |||
"github.com/tendermint/tendermint/libs/log" | |||
"github.com/tendermint/tendermint/types" | |||
) | |||
const ( | |||
defaultHeartbeatSeconds = 2 | |||
defaultMaxDialRetries = 10 | |||
) | |||
var ( | |||
heartbeatPeriod = time.Second * defaultHeartbeatSeconds | |||
) | |||
// SignerValidatorEndpointOption sets an optional parameter on the SocketVal. | |||
type SignerValidatorEndpointOption func(*SignerValidatorEndpoint) | |||
// SignerValidatorEndpointSetHeartbeat sets the period on which to check the liveness of the | |||
// connected Signer connections. | |||
func SignerValidatorEndpointSetHeartbeat(period time.Duration) SignerValidatorEndpointOption { | |||
return func(sc *SignerValidatorEndpoint) { sc.heartbeatPeriod = period } | |||
} | |||
// SocketVal implements PrivValidator. | |||
// It listens for an external process to dial in and uses | |||
// the socket to request signatures. | |||
type SignerValidatorEndpoint struct { | |||
cmn.BaseService | |||
listener net.Listener | |||
// ping | |||
cancelPingCh chan struct{} | |||
pingTicker *time.Ticker | |||
heartbeatPeriod time.Duration | |||
// signer is mutable since it can be reset if the connection fails. | |||
// failures are detected by a background ping routine. | |||
// All messages are request/response, so we hold the mutex | |||
// so only one request/response pair can happen at a time. | |||
// Methods on the underlying net.Conn itself are already goroutine safe. | |||
mtx sync.Mutex | |||
// TODO: Signer should encapsulate and hide the endpoint completely. Invert the relation | |||
signer *SignerRemote | |||
} | |||
// Check that SignerValidatorEndpoint implements PrivValidator. | |||
var _ types.PrivValidator = (*SignerValidatorEndpoint)(nil) | |||
// NewSignerValidatorEndpoint returns an instance of SignerValidatorEndpoint. | |||
func NewSignerValidatorEndpoint(logger log.Logger, listener net.Listener) *SignerValidatorEndpoint { | |||
sc := &SignerValidatorEndpoint{ | |||
listener: listener, | |||
heartbeatPeriod: heartbeatPeriod, | |||
} | |||
sc.BaseService = *cmn.NewBaseService(logger, "SignerValidatorEndpoint", sc) | |||
return sc | |||
} | |||
//-------------------------------------------------------- | |||
// Implement PrivValidator | |||
// GetPubKey implements PrivValidator. | |||
func (ve *SignerValidatorEndpoint) GetPubKey() crypto.PubKey { | |||
ve.mtx.Lock() | |||
defer ve.mtx.Unlock() | |||
return ve.signer.GetPubKey() | |||
} | |||
// SignVote implements PrivValidator. | |||
func (ve *SignerValidatorEndpoint) SignVote(chainID string, vote *types.Vote) error { | |||
ve.mtx.Lock() | |||
defer ve.mtx.Unlock() | |||
return ve.signer.SignVote(chainID, vote) | |||
} | |||
// SignProposal implements PrivValidator. | |||
func (ve *SignerValidatorEndpoint) SignProposal(chainID string, proposal *types.Proposal) error { | |||
ve.mtx.Lock() | |||
defer ve.mtx.Unlock() | |||
return ve.signer.SignProposal(chainID, proposal) | |||
} | |||
//-------------------------------------------------------- | |||
// More thread safe methods proxied to the signer | |||
// Ping is used to check connection health. | |||
func (ve *SignerValidatorEndpoint) Ping() error { | |||
ve.mtx.Lock() | |||
defer ve.mtx.Unlock() | |||
return ve.signer.Ping() | |||
} | |||
// Close closes the underlying net.Conn. | |||
func (ve *SignerValidatorEndpoint) Close() { | |||
ve.mtx.Lock() | |||
defer ve.mtx.Unlock() | |||
if ve.signer != nil { | |||
if err := ve.signer.Close(); err != nil { | |||
ve.Logger.Error("OnStop", "err", err) | |||
} | |||
} | |||
if ve.listener != nil { | |||
if err := ve.listener.Close(); err != nil { | |||
ve.Logger.Error("OnStop", "err", err) | |||
} | |||
} | |||
} | |||
//-------------------------------------------------------- | |||
// Service start and stop | |||
// OnStart implements cmn.Service. | |||
func (ve *SignerValidatorEndpoint) OnStart() error { | |||
if closed, err := ve.reset(); err != nil { | |||
ve.Logger.Error("OnStart", "err", err) | |||
return err | |||
} else if closed { | |||
return fmt.Errorf("listener is closed") | |||
} | |||
// Start a routine to keep the connection alive | |||
ve.cancelPingCh = make(chan struct{}, 1) | |||
ve.pingTicker = time.NewTicker(ve.heartbeatPeriod) | |||
go func() { | |||
for { | |||
select { | |||
case <-ve.pingTicker.C: | |||
err := ve.Ping() | |||
if err != nil { | |||
ve.Logger.Error("Ping", "err", err) | |||
if err == ErrUnexpectedResponse { | |||
return | |||
} | |||
closed, err := ve.reset() | |||
if err != nil { | |||
ve.Logger.Error("Reconnecting to remote signer failed", "err", err) | |||
continue | |||
} | |||
if closed { | |||
ve.Logger.Info("listener is closing") | |||
return | |||
} | |||
ve.Logger.Info("Re-created connection to remote signer", "impl", ve) | |||
} | |||
case <-ve.cancelPingCh: | |||
ve.pingTicker.Stop() | |||
return | |||
} | |||
} | |||
}() | |||
return nil | |||
} | |||
// OnStop implements cmn.Service. | |||
func (ve *SignerValidatorEndpoint) OnStop() { | |||
if ve.cancelPingCh != nil { | |||
close(ve.cancelPingCh) | |||
} | |||
ve.Close() | |||
} | |||
//-------------------------------------------------------- | |||
// Connection and signer management | |||
// waits to accept and sets a new connection. | |||
// connection is closed in OnStop. | |||
// returns true if the listener is closed | |||
// (ie. it returns a nil conn). | |||
func (ve *SignerValidatorEndpoint) reset() (closed bool, err error) { | |||
ve.mtx.Lock() | |||
defer ve.mtx.Unlock() | |||
// first check if the conn already exists and close it. | |||
if ve.signer != nil { | |||
if tmpErr := ve.signer.Close(); tmpErr != nil { | |||
ve.Logger.Error("error closing socket val connection during reset", "err", tmpErr) | |||
} | |||
} | |||
// wait for a new conn | |||
conn, err := ve.acceptConnection() | |||
if err != nil { | |||
return false, err | |||
} | |||
// listener is closed | |||
if conn == nil { | |||
return true, nil | |||
} | |||
ve.signer, err = NewSignerRemote(conn) | |||
if err != nil { | |||
// failed to fetch the pubkey. close out the connection. | |||
if tmpErr := conn.Close(); tmpErr != nil { | |||
ve.Logger.Error("error closing connection", "err", tmpErr) | |||
} | |||
return false, err | |||
} | |||
return false, nil | |||
} | |||
// Attempt to accept a connection. | |||
// Times out after the listener's timeoutAccept | |||
func (ve *SignerValidatorEndpoint) acceptConnection() (net.Conn, error) { | |||
conn, err := ve.listener.Accept() | |||
if err != nil { | |||
if !ve.IsRunning() { | |||
return nil, nil // Ignore error from listener closing. | |||
} | |||
return nil, err | |||
} | |||
return conn, nil | |||
} |
@ -1,506 +0,0 @@ | |||
package privval | |||
import ( | |||
"fmt" | |||
"net" | |||
"testing" | |||
"time" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
"github.com/tendermint/tendermint/crypto/ed25519" | |||
cmn "github.com/tendermint/tendermint/libs/common" | |||
"github.com/tendermint/tendermint/libs/log" | |||
"github.com/tendermint/tendermint/types" | |||
) | |||
var ( | |||
testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second | |||
testTimeoutReadWrite = 100 * time.Millisecond | |||
testTimeoutReadWrite2o3 = 66 * time.Millisecond // 2/3 of the other one | |||
testTimeoutHeartbeat = 10 * time.Millisecond | |||
testTimeoutHeartbeat3o2 = 6 * time.Millisecond // 3/2 of the other one | |||
) | |||
type socketTestCase struct { | |||
addr string | |||
dialer SocketDialer | |||
} | |||
func socketTestCases(t *testing.T) []socketTestCase { | |||
tcpAddr := fmt.Sprintf("tcp://%s", testFreeTCPAddr(t)) | |||
unixFilePath, err := testUnixAddr() | |||
require.NoError(t, err) | |||
unixAddr := fmt.Sprintf("unix://%s", unixFilePath) | |||
return []socketTestCase{ | |||
{ | |||
addr: tcpAddr, | |||
dialer: DialTCPFn(tcpAddr, testTimeoutReadWrite, ed25519.GenPrivKey()), | |||
}, | |||
{ | |||
addr: unixAddr, | |||
dialer: DialUnixFn(unixFilePath), | |||
}, | |||
} | |||
} | |||
func TestSocketPVAddress(t *testing.T) { | |||
for _, tc := range socketTestCases(t) { | |||
// Execute the test within a closure to ensure the deferred statements | |||
// are called between each for loop iteration, for isolated test cases. | |||
func() { | |||
var ( | |||
chainID = cmn.RandStr(12) | |||
validatorEndpoint, serviceEndpoint = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) | |||
) | |||
defer validatorEndpoint.Stop() | |||
defer serviceEndpoint.Stop() | |||
serviceAddr := serviceEndpoint.privVal.GetPubKey().Address() | |||
validatorAddr := validatorEndpoint.GetPubKey().Address() | |||
assert.Equal(t, serviceAddr, validatorAddr) | |||
}() | |||
} | |||
} | |||
func TestSocketPVPubKey(t *testing.T) { | |||
for _, tc := range socketTestCases(t) { | |||
func() { | |||
var ( | |||
chainID = cmn.RandStr(12) | |||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||
t, | |||
chainID, | |||
types.NewMockPV(), | |||
tc.addr, | |||
tc.dialer) | |||
) | |||
defer validatorEndpoint.Stop() | |||
defer serviceEndpoint.Stop() | |||
clientKey := validatorEndpoint.GetPubKey() | |||
privvalPubKey := serviceEndpoint.privVal.GetPubKey() | |||
assert.Equal(t, privvalPubKey, clientKey) | |||
}() | |||
} | |||
} | |||
func TestSocketPVProposal(t *testing.T) { | |||
for _, tc := range socketTestCases(t) { | |||
func() { | |||
var ( | |||
chainID = cmn.RandStr(12) | |||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||
t, | |||
chainID, | |||
types.NewMockPV(), | |||
tc.addr, | |||
tc.dialer) | |||
ts = time.Now() | |||
privProposal = &types.Proposal{Timestamp: ts} | |||
clientProposal = &types.Proposal{Timestamp: ts} | |||
) | |||
defer validatorEndpoint.Stop() | |||
defer serviceEndpoint.Stop() | |||
require.NoError(t, serviceEndpoint.privVal.SignProposal(chainID, privProposal)) | |||
require.NoError(t, validatorEndpoint.SignProposal(chainID, clientProposal)) | |||
assert.Equal(t, privProposal.Signature, clientProposal.Signature) | |||
}() | |||
} | |||
} | |||
func TestSocketPVVote(t *testing.T) { | |||
for _, tc := range socketTestCases(t) { | |||
func() { | |||
var ( | |||
chainID = cmn.RandStr(12) | |||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||
t, | |||
chainID, | |||
types.NewMockPV(), | |||
tc.addr, | |||
tc.dialer) | |||
ts = time.Now() | |||
vType = types.PrecommitType | |||
want = &types.Vote{Timestamp: ts, Type: vType} | |||
have = &types.Vote{Timestamp: ts, Type: vType} | |||
) | |||
defer validatorEndpoint.Stop() | |||
defer serviceEndpoint.Stop() | |||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||
assert.Equal(t, want.Signature, have.Signature) | |||
}() | |||
} | |||
} | |||
func TestSocketPVVoteResetDeadline(t *testing.T) { | |||
for _, tc := range socketTestCases(t) { | |||
func() { | |||
var ( | |||
chainID = cmn.RandStr(12) | |||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||
t, | |||
chainID, | |||
types.NewMockPV(), | |||
tc.addr, | |||
tc.dialer) | |||
ts = time.Now() | |||
vType = types.PrecommitType | |||
want = &types.Vote{Timestamp: ts, Type: vType} | |||
have = &types.Vote{Timestamp: ts, Type: vType} | |||
) | |||
defer validatorEndpoint.Stop() | |||
defer serviceEndpoint.Stop() | |||
time.Sleep(testTimeoutReadWrite2o3) | |||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||
assert.Equal(t, want.Signature, have.Signature) | |||
// This would exceed the deadline if it was not extended by the previous message | |||
time.Sleep(testTimeoutReadWrite2o3) | |||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||
assert.Equal(t, want.Signature, have.Signature) | |||
}() | |||
} | |||
} | |||
func TestSocketPVVoteKeepalive(t *testing.T) { | |||
for _, tc := range socketTestCases(t) { | |||
func() { | |||
var ( | |||
chainID = cmn.RandStr(12) | |||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||
t, | |||
chainID, | |||
types.NewMockPV(), | |||
tc.addr, | |||
tc.dialer) | |||
ts = time.Now() | |||
vType = types.PrecommitType | |||
want = &types.Vote{Timestamp: ts, Type: vType} | |||
have = &types.Vote{Timestamp: ts, Type: vType} | |||
) | |||
defer validatorEndpoint.Stop() | |||
defer serviceEndpoint.Stop() | |||
time.Sleep(testTimeoutReadWrite * 2) | |||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||
assert.Equal(t, want.Signature, have.Signature) | |||
}() | |||
} | |||
} | |||
func TestSocketPVDeadline(t *testing.T) { | |||
for _, tc := range socketTestCases(t) { | |||
func() { | |||
var ( | |||
listenc = make(chan struct{}) | |||
thisConnTimeout = 100 * time.Millisecond | |||
validatorEndpoint = newSignerValidatorEndpoint(log.TestingLogger(), tc.addr, thisConnTimeout) | |||
) | |||
go func(sc *SignerValidatorEndpoint) { | |||
defer close(listenc) | |||
// Note: the TCP connection times out at the accept() phase, | |||
// whereas the Unix domain sockets connection times out while | |||
// attempting to fetch the remote signer's public key. | |||
assert.True(t, IsConnTimeout(sc.Start())) | |||
assert.False(t, sc.IsRunning()) | |||
}(validatorEndpoint) | |||
for { | |||
_, err := cmn.Connect(tc.addr) | |||
if err == nil { | |||
break | |||
} | |||
} | |||
<-listenc | |||
}() | |||
} | |||
} | |||
func TestRemoteSignVoteErrors(t *testing.T) { | |||
for _, tc := range socketTestCases(t) { | |||
func() { | |||
var ( | |||
chainID = cmn.RandStr(12) | |||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||
t, | |||
chainID, | |||
types.NewErroringMockPV(), | |||
tc.addr, | |||
tc.dialer) | |||
ts = time.Now() | |||
vType = types.PrecommitType | |||
vote = &types.Vote{Timestamp: ts, Type: vType} | |||
) | |||
defer validatorEndpoint.Stop() | |||
defer serviceEndpoint.Stop() | |||
err := validatorEndpoint.SignVote("", vote) | |||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||
err = serviceEndpoint.privVal.SignVote(chainID, vote) | |||
require.Error(t, err) | |||
err = validatorEndpoint.SignVote(chainID, vote) | |||
require.Error(t, err) | |||
}() | |||
} | |||
} | |||
func TestRemoteSignProposalErrors(t *testing.T) { | |||
for _, tc := range socketTestCases(t) { | |||
func() { | |||
var ( | |||
chainID = cmn.RandStr(12) | |||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||
t, | |||
chainID, | |||
types.NewErroringMockPV(), | |||
tc.addr, | |||
tc.dialer) | |||
ts = time.Now() | |||
proposal = &types.Proposal{Timestamp: ts} | |||
) | |||
defer validatorEndpoint.Stop() | |||
defer serviceEndpoint.Stop() | |||
err := validatorEndpoint.SignProposal("", proposal) | |||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||
err = serviceEndpoint.privVal.SignProposal(chainID, proposal) | |||
require.Error(t, err) | |||
err = validatorEndpoint.SignProposal(chainID, proposal) | |||
require.Error(t, err) | |||
}() | |||
} | |||
} | |||
func TestErrUnexpectedResponse(t *testing.T) { | |||
for _, tc := range socketTestCases(t) { | |||
func() { | |||
var ( | |||
logger = log.TestingLogger() | |||
chainID = cmn.RandStr(12) | |||
readyCh = make(chan struct{}) | |||
errCh = make(chan error, 1) | |||
serviceEndpoint = NewSignerServiceEndpoint( | |||
logger, | |||
chainID, | |||
types.NewMockPV(), | |||
tc.dialer, | |||
) | |||
validatorEndpoint = newSignerValidatorEndpoint( | |||
logger, | |||
tc.addr, | |||
testTimeoutReadWrite) | |||
) | |||
testStartEndpoint(t, readyCh, validatorEndpoint) | |||
defer validatorEndpoint.Stop() | |||
SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint) | |||
SignerServiceEndpointConnRetries(100)(serviceEndpoint) | |||
// we do not want to Start() the remote signer here and instead use the connection to | |||
// reply with intentionally wrong replies below: | |||
rsConn, err := serviceEndpoint.connect() | |||
require.NoError(t, err) | |||
require.NotNil(t, rsConn) | |||
defer rsConn.Close() | |||
// send over public key to get the remote signer running: | |||
go testReadWriteResponse(t, &PubKeyResponse{}, rsConn) | |||
<-readyCh | |||
// Proposal: | |||
go func(errc chan error) { | |||
errc <- validatorEndpoint.SignProposal(chainID, &types.Proposal{}) | |||
}(errCh) | |||
// read request and write wrong response: | |||
go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn) | |||
err = <-errCh | |||
require.Error(t, err) | |||
require.Equal(t, err, ErrUnexpectedResponse) | |||
// Vote: | |||
go func(errc chan error) { | |||
errc <- validatorEndpoint.SignVote(chainID, &types.Vote{}) | |||
}(errCh) | |||
// read request and write wrong response: | |||
go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn) | |||
err = <-errCh | |||
require.Error(t, err) | |||
require.Equal(t, err, ErrUnexpectedResponse) | |||
}() | |||
} | |||
} | |||
func TestRetryConnToRemoteSigner(t *testing.T) { | |||
for _, tc := range socketTestCases(t) { | |||
func() { | |||
var ( | |||
logger = log.TestingLogger() | |||
chainID = cmn.RandStr(12) | |||
readyCh = make(chan struct{}) | |||
serviceEndpoint = NewSignerServiceEndpoint( | |||
logger, | |||
chainID, | |||
types.NewMockPV(), | |||
tc.dialer, | |||
) | |||
thisConnTimeout = testTimeoutReadWrite | |||
validatorEndpoint = newSignerValidatorEndpoint(logger, tc.addr, thisConnTimeout) | |||
) | |||
// Ping every: | |||
SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) | |||
SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) | |||
SignerServiceEndpointConnRetries(10)(serviceEndpoint) | |||
testStartEndpoint(t, readyCh, validatorEndpoint) | |||
defer validatorEndpoint.Stop() | |||
require.NoError(t, serviceEndpoint.Start()) | |||
assert.True(t, serviceEndpoint.IsRunning()) | |||
<-readyCh | |||
time.Sleep(testTimeoutHeartbeat * 2) | |||
serviceEndpoint.Stop() | |||
rs2 := NewSignerServiceEndpoint( | |||
logger, | |||
chainID, | |||
types.NewMockPV(), | |||
tc.dialer, | |||
) | |||
// let some pings pass | |||
time.Sleep(testTimeoutHeartbeat3o2) | |||
require.NoError(t, rs2.Start()) | |||
assert.True(t, rs2.IsRunning()) | |||
defer rs2.Stop() | |||
// give the client some time to re-establish the conn to the remote signer | |||
// should see sth like this in the logs: | |||
// | |||
// E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" | |||
// I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal | |||
time.Sleep(testTimeoutReadWrite * 2) | |||
}() | |||
} | |||
} | |||
func newSignerValidatorEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerValidatorEndpoint { | |||
proto, address := cmn.ProtocolAndAddress(addr) | |||
ln, err := net.Listen(proto, address) | |||
logger.Info("Listening at", "proto", proto, "address", address) | |||
if err != nil { | |||
panic(err) | |||
} | |||
var listener net.Listener | |||
if proto == "unix" { | |||
unixLn := NewUnixListener(ln) | |||
UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn) | |||
UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn) | |||
listener = unixLn | |||
} else { | |||
tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) | |||
TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn) | |||
TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn) | |||
listener = tcpLn | |||
} | |||
return NewSignerValidatorEndpoint(logger, listener) | |||
} | |||
func testSetupSocketPair( | |||
t *testing.T, | |||
chainID string, | |||
privValidator types.PrivValidator, | |||
addr string, | |||
socketDialer SocketDialer, | |||
) (*SignerValidatorEndpoint, *SignerServiceEndpoint) { | |||
var ( | |||
logger = log.TestingLogger() | |||
privVal = privValidator | |||
readyc = make(chan struct{}) | |||
serviceEndpoint = NewSignerServiceEndpoint( | |||
logger, | |||
chainID, | |||
privVal, | |||
socketDialer, | |||
) | |||
thisConnTimeout = testTimeoutReadWrite | |||
validatorEndpoint = newSignerValidatorEndpoint(logger, addr, thisConnTimeout) | |||
) | |||
SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) | |||
SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) | |||
SignerServiceEndpointConnRetries(1e6)(serviceEndpoint) | |||
testStartEndpoint(t, readyc, validatorEndpoint) | |||
require.NoError(t, serviceEndpoint.Start()) | |||
assert.True(t, serviceEndpoint.IsRunning()) | |||
<-readyc | |||
return validatorEndpoint, serviceEndpoint | |||
} | |||
func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) { | |||
_, err := readMsg(rsConn) | |||
require.NoError(t, err) | |||
err = writeMsg(rsConn, resp) | |||
require.NoError(t, err) | |||
} | |||
func testStartEndpoint(t *testing.T, readyCh chan struct{}, sc *SignerValidatorEndpoint) { | |||
go func(sc *SignerValidatorEndpoint) { | |||
require.NoError(t, sc.Start()) | |||
assert.True(t, sc.IsRunning()) | |||
readyCh <- struct{}{} | |||
}(sc) | |||
} | |||
// testFreeTCPAddr claims a free port so we don't block on listener being ready. | |||
func testFreeTCPAddr(t *testing.T) string { | |||
ln, err := net.Listen("tcp", "127.0.0.1:0") | |||
require.NoError(t, err) | |||
defer ln.Close() | |||
return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port) | |||
} |
@ -1,26 +1,49 @@ | |||
package privval | |||
import ( | |||
"fmt" | |||
"testing" | |||
"time" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
"github.com/tendermint/tendermint/crypto/ed25519" | |||
cmn "github.com/tendermint/tendermint/libs/common" | |||
) | |||
func getDialerTestCases(t *testing.T) []dialerTestCase { | |||
tcpAddr := GetFreeLocalhostAddrPort() | |||
unixFilePath, err := testUnixAddr() | |||
require.NoError(t, err) | |||
unixAddr := fmt.Sprintf("unix://%s", unixFilePath) | |||
return []dialerTestCase{ | |||
{ | |||
addr: tcpAddr, | |||
dialer: DialTCPFn(tcpAddr, testTimeoutReadWrite, ed25519.GenPrivKey()), | |||
}, | |||
{ | |||
addr: unixAddr, | |||
dialer: DialUnixFn(unixFilePath), | |||
}, | |||
} | |||
} | |||
func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) { | |||
// Generate a networking timeout | |||
dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) | |||
tcpAddr := GetFreeLocalhostAddrPort() | |||
dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey()) | |||
_, err := dialer() | |||
assert.Error(t, err) | |||
assert.True(t, IsConnTimeout(err)) | |||
} | |||
func TestIsConnTimeoutForWrappedConnTimeouts(t *testing.T) { | |||
dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) | |||
tcpAddr := GetFreeLocalhostAddrPort() | |||
dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey()) | |||
_, err := dialer() | |||
assert.Error(t, err) | |||
err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) | |||
err = cmn.ErrorWrap(ErrConnectionTimeout, err.Error()) | |||
assert.True(t, IsConnTimeout(err)) | |||
} |