This PR is related to #3107 and a continuation of #3351
It is important to emphasise that in the privval original design, client/server and listening/dialing roles are inverted and do not follow a conventional interaction.
Given two hosts A and B:
Host A is listener/client
Host B is dialer/server (contains the secret key)
When A requires a signature, it needs to wait for B to dial in before it can issue a request.
A only accepts a single connection and any failure leads to dropping the connection and waiting for B to reconnect.
The original rationale behind this design was based on security.
Host B only allows outbound connections to a list of whitelisted hosts.
It is not possible to reach B unless B dials in. There are no listening/open ports in B.
This PR results in the following changes:
Refactors ping/heartbeat to avoid previously existing race conditions.
Separates transport (dialer/listener) from signing (client/server) concerns to simplify workflow.
Unifies and abstracts away the differences between unix and tcp sockets.
A single signer endpoint implementation unifies connection handling code (read/write/close/connection obj)
The signer request handler (server side) is customizable to increase testability.
Updates and extends unit tests
A high level overview of the classes is as follows:
Transport (endpoints): The following classes take care of establishing a connection
SignerDialerEndpoint
SignerListeningEndpoint
SignerEndpoint groups common functionality (read/write/timeouts/etc.)
Signing (client/server): The following classes take care of exchanging request/responses
SignerClient
SignerServer
This PR also closes #3601
Commits:
* refactoring - work in progress
* reworking unit tests
* Encapsulating and fixing unit tests
* Improve tests
* Clean up
* Fix/improve unit tests
* clean up tests
* Improving service endpoint
* fixing unit test
* fix linter issues
* avoid invalid cache values (improve later?)
* complete implementation
* wip
* improved connection loop
* Improve reconnections + fixing unit tests
* addressing comments
* small formatting changes
* clean up
* Update node/node.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_client.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_client_test.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* check during initialization
* dropping connecting when writing fails
* removing break
* use t.log instead
* unifying and using cmn.GetFreePort()
* review fixes
* reordering and unifying drop connection
* closing instead of signalling
* refactored service loop
* removed superfluous brackets
* GetPubKey can return errors
* Revert "GetPubKey can return errors"
This reverts commit 68c06f19b4
.
* adding entry to changelog
* Update CHANGELOG_PENDING.md
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_client.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_dialer_endpoint.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_dialer_endpoint.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_dialer_endpoint.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_dialer_endpoint.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* Update privval/signer_listener_endpoint_test.go
Co-Authored-By: jleni <juan.leni@zondax.ch>
* updating node.go
* review fixes
* fixes linter
* fixing unit test
* small fixes in comments
* addressing review comments
* addressing review comments 2
* reverting suggestion
* Update privval/signer_client_test.go
Co-Authored-By: Anton Kaliaev <anton.kalyaev@gmail.com>
* Update privval/signer_client_test.go
Co-Authored-By: Anton Kaliaev <anton.kalyaev@gmail.com>
* Update privval/signer_listener_endpoint_test.go
Co-Authored-By: Anton Kaliaev <anton.kalyaev@gmail.com>
* do not expose brokenSignerDialerEndpoint
* clean up logging
* unifying methods
shorten test time
signer also drops
* reenabling pings
* improving testability + unit test
* fixing go fmt + unit test
* remove unused code
* Addressing review comments
* simplifying connection workflow
* fix linter/go import issue
* using base service quit
* updating comment
* Simplifying design + adjusting names
* fixing linter issues
* refactoring test harness + fixes
* Addressing review comments
* cleaning up
* adding additional error check
pull/3912/head
@ -1,61 +1,64 @@ | |||||
package privval | package privval | ||||
import ( | import ( | ||||
amino "github.com/tendermint/go-amino" | |||||
"github.com/tendermint/go-amino" | |||||
"github.com/tendermint/tendermint/crypto" | "github.com/tendermint/tendermint/crypto" | ||||
"github.com/tendermint/tendermint/types" | "github.com/tendermint/tendermint/types" | ||||
) | ) | ||||
// RemoteSignerMsg is sent between SignerServiceEndpoint and the SignerServiceEndpoint client. | |||||
type RemoteSignerMsg interface{} | |||||
// SignerMessage is sent between Signer Clients and Servers. | |||||
type SignerMessage interface{} | |||||
func RegisterRemoteSignerMsg(cdc *amino.Codec) { | func RegisterRemoteSignerMsg(cdc *amino.Codec) { | ||||
cdc.RegisterInterface((*RemoteSignerMsg)(nil), nil) | |||||
cdc.RegisterInterface((*SignerMessage)(nil), nil) | |||||
cdc.RegisterConcrete(&PubKeyRequest{}, "tendermint/remotesigner/PubKeyRequest", nil) | cdc.RegisterConcrete(&PubKeyRequest{}, "tendermint/remotesigner/PubKeyRequest", nil) | ||||
cdc.RegisterConcrete(&PubKeyResponse{}, "tendermint/remotesigner/PubKeyResponse", nil) | cdc.RegisterConcrete(&PubKeyResponse{}, "tendermint/remotesigner/PubKeyResponse", nil) | ||||
cdc.RegisterConcrete(&SignVoteRequest{}, "tendermint/remotesigner/SignVoteRequest", nil) | cdc.RegisterConcrete(&SignVoteRequest{}, "tendermint/remotesigner/SignVoteRequest", nil) | ||||
cdc.RegisterConcrete(&SignedVoteResponse{}, "tendermint/remotesigner/SignedVoteResponse", nil) | cdc.RegisterConcrete(&SignedVoteResponse{}, "tendermint/remotesigner/SignedVoteResponse", nil) | ||||
cdc.RegisterConcrete(&SignProposalRequest{}, "tendermint/remotesigner/SignProposalRequest", nil) | cdc.RegisterConcrete(&SignProposalRequest{}, "tendermint/remotesigner/SignProposalRequest", nil) | ||||
cdc.RegisterConcrete(&SignedProposalResponse{}, "tendermint/remotesigner/SignedProposalResponse", nil) | cdc.RegisterConcrete(&SignedProposalResponse{}, "tendermint/remotesigner/SignedProposalResponse", nil) | ||||
cdc.RegisterConcrete(&PingRequest{}, "tendermint/remotesigner/PingRequest", nil) | cdc.RegisterConcrete(&PingRequest{}, "tendermint/remotesigner/PingRequest", nil) | ||||
cdc.RegisterConcrete(&PingResponse{}, "tendermint/remotesigner/PingResponse", nil) | cdc.RegisterConcrete(&PingResponse{}, "tendermint/remotesigner/PingResponse", nil) | ||||
} | } | ||||
// TODO: Add ChainIDRequest | |||||
// PubKeyRequest requests the consensus public key from the remote signer. | // PubKeyRequest requests the consensus public key from the remote signer. | ||||
type PubKeyRequest struct{} | type PubKeyRequest struct{} | ||||
// PubKeyResponse is a PrivValidatorSocket message containing the public key. | |||||
// PubKeyResponse is a response message containing the public key. | |||||
type PubKeyResponse struct { | type PubKeyResponse struct { | ||||
PubKey crypto.PubKey | PubKey crypto.PubKey | ||||
Error *RemoteSignerError | Error *RemoteSignerError | ||||
} | } | ||||
// SignVoteRequest is a PrivValidatorSocket message containing a vote. | |||||
// SignVoteRequest is a request to sign a vote | |||||
type SignVoteRequest struct { | type SignVoteRequest struct { | ||||
Vote *types.Vote | Vote *types.Vote | ||||
} | } | ||||
// SignedVoteResponse is a PrivValidatorSocket message containing a signed vote along with a potenial error message. | |||||
// SignedVoteResponse is a response containing a signed vote or an error | |||||
type SignedVoteResponse struct { | type SignedVoteResponse struct { | ||||
Vote *types.Vote | Vote *types.Vote | ||||
Error *RemoteSignerError | Error *RemoteSignerError | ||||
} | } | ||||
// SignProposalRequest is a PrivValidatorSocket message containing a Proposal. | |||||
// SignProposalRequest is a request to sign a proposal | |||||
type SignProposalRequest struct { | type SignProposalRequest struct { | ||||
Proposal *types.Proposal | Proposal *types.Proposal | ||||
} | } | ||||
// SignedProposalResponse is a PrivValidatorSocket message containing a proposal response | |||||
// SignedProposalResponse is response containing a signed proposal or an error | |||||
type SignedProposalResponse struct { | type SignedProposalResponse struct { | ||||
Proposal *types.Proposal | Proposal *types.Proposal | ||||
Error *RemoteSignerError | Error *RemoteSignerError | ||||
} | } | ||||
// PingRequest is a PrivValidatorSocket message to keep the connection alive. | |||||
// PingRequest is a request to confirm that the connection is alive. | |||||
type PingRequest struct { | type PingRequest struct { | ||||
} | } | ||||
// PingRequest is a PrivValidatorSocket response to keep the connection alive. | |||||
// PingResponse is a response to confirm that the connection is alive. | |||||
type PingResponse struct { | type PingResponse struct { | ||||
} | } |
@ -0,0 +1,131 @@ | |||||
package privval | |||||
import ( | |||||
"time" | |||||
"github.com/pkg/errors" | |||||
"github.com/tendermint/tendermint/crypto" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
// SignerClient implements PrivValidator. | |||||
// Handles remote validator connections that provide signing services | |||||
type SignerClient struct { | |||||
endpoint *SignerListenerEndpoint | |||||
} | |||||
var _ types.PrivValidator = (*SignerClient)(nil) | |||||
// NewSignerClient returns an instance of SignerClient. | |||||
// it will start the endpoint (if not already started) | |||||
func NewSignerClient(endpoint *SignerListenerEndpoint) (*SignerClient, error) { | |||||
if !endpoint.IsRunning() { | |||||
if err := endpoint.Start(); err != nil { | |||||
return nil, errors.Wrap(err, "failed to start listener endpoint") | |||||
} | |||||
} | |||||
return &SignerClient{endpoint: endpoint}, nil | |||||
} | |||||
// Close closes the underlying connection | |||||
func (sc *SignerClient) Close() error { | |||||
return sc.endpoint.Close() | |||||
} | |||||
// IsConnected indicates with the signer is connected to a remote signing service | |||||
func (sc *SignerClient) IsConnected() bool { | |||||
return sc.endpoint.IsConnected() | |||||
} | |||||
// WaitForConnection waits maxWait for a connection or returns a timeout error | |||||
func (sc *SignerClient) WaitForConnection(maxWait time.Duration) error { | |||||
return sc.endpoint.WaitForConnection(maxWait) | |||||
} | |||||
//-------------------------------------------------------- | |||||
// Implement PrivValidator | |||||
// Ping sends a ping request to the remote signer | |||||
func (sc *SignerClient) Ping() error { | |||||
response, err := sc.endpoint.SendRequest(&PingRequest{}) | |||||
if err != nil { | |||||
sc.endpoint.Logger.Error("SignerClient::Ping", "err", err) | |||||
return nil | |||||
} | |||||
_, ok := response.(*PingResponse) | |||||
if !ok { | |||||
sc.endpoint.Logger.Error("SignerClient::Ping", "err", "response != PingResponse") | |||||
return err | |||||
} | |||||
return nil | |||||
} | |||||
// GetPubKey retrieves a public key from a remote signer | |||||
func (sc *SignerClient) GetPubKey() crypto.PubKey { | |||||
response, err := sc.endpoint.SendRequest(&PubKeyRequest{}) | |||||
if err != nil { | |||||
sc.endpoint.Logger.Error("SignerClient::GetPubKey", "err", err) | |||||
return nil | |||||
} | |||||
pubKeyResp, ok := response.(*PubKeyResponse) | |||||
if !ok { | |||||
sc.endpoint.Logger.Error("SignerClient::GetPubKey", "err", "response != PubKeyResponse") | |||||
return nil | |||||
} | |||||
if pubKeyResp.Error != nil { | |||||
sc.endpoint.Logger.Error("failed to get private validator's public key", "err", pubKeyResp.Error) | |||||
return nil | |||||
} | |||||
return pubKeyResp.PubKey | |||||
} | |||||
// SignVote requests a remote signer to sign a vote | |||||
func (sc *SignerClient) SignVote(chainID string, vote *types.Vote) error { | |||||
response, err := sc.endpoint.SendRequest(&SignVoteRequest{Vote: vote}) | |||||
if err != nil { | |||||
sc.endpoint.Logger.Error("SignerClient::SignVote", "err", err) | |||||
return err | |||||
} | |||||
resp, ok := response.(*SignedVoteResponse) | |||||
if !ok { | |||||
sc.endpoint.Logger.Error("SignerClient::GetPubKey", "err", "response != SignedVoteResponse") | |||||
return ErrUnexpectedResponse | |||||
} | |||||
if resp.Error != nil { | |||||
return resp.Error | |||||
} | |||||
*vote = *resp.Vote | |||||
return nil | |||||
} | |||||
// SignProposal requests a remote signer to sign a proposal | |||||
func (sc *SignerClient) SignProposal(chainID string, proposal *types.Proposal) error { | |||||
response, err := sc.endpoint.SendRequest(&SignProposalRequest{Proposal: proposal}) | |||||
if err != nil { | |||||
sc.endpoint.Logger.Error("SignerClient::SignProposal", "err", err) | |||||
return err | |||||
} | |||||
resp, ok := response.(*SignedProposalResponse) | |||||
if !ok { | |||||
sc.endpoint.Logger.Error("SignerClient::SignProposal", "err", "response != SignedProposalResponse") | |||||
return ErrUnexpectedResponse | |||||
} | |||||
if resp.Error != nil { | |||||
return resp.Error | |||||
} | |||||
*proposal = *resp.Proposal | |||||
return nil | |||||
} |
@ -0,0 +1,257 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
"github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
type signerTestCase struct { | |||||
chainID string | |||||
mockPV types.PrivValidator | |||||
signerClient *SignerClient | |||||
signerServer *SignerServer | |||||
} | |||||
func getSignerTestCases(t *testing.T) []signerTestCase { | |||||
testCases := make([]signerTestCase, 0) | |||||
// Get test cases for each possible dialer (DialTCP / DialUnix / etc) | |||||
for _, dtc := range getDialerTestCases(t) { | |||||
chainID := common.RandStr(12) | |||||
mockPV := types.NewMockPV() | |||||
// get a pair of signer listener, signer dialer endpoints | |||||
sl, sd := getMockEndpoints(t, dtc.addr, dtc.dialer) | |||||
sc, err := NewSignerClient(sl) | |||||
require.NoError(t, err) | |||||
ss := NewSignerServer(sd, chainID, mockPV) | |||||
err = ss.Start() | |||||
require.NoError(t, err) | |||||
tc := signerTestCase{ | |||||
chainID: chainID, | |||||
mockPV: mockPV, | |||||
signerClient: sc, | |||||
signerServer: ss, | |||||
} | |||||
testCases = append(testCases, tc) | |||||
} | |||||
return testCases | |||||
} | |||||
func TestSignerClose(t *testing.T) { | |||||
for _, tc := range getSignerTestCases(t) { | |||||
err := tc.signerClient.Close() | |||||
assert.NoError(t, err) | |||||
err = tc.signerServer.Stop() | |||||
assert.NoError(t, err) | |||||
} | |||||
} | |||||
func TestSignerPing(t *testing.T) { | |||||
for _, tc := range getSignerTestCases(t) { | |||||
defer tc.signerServer.Stop() | |||||
defer tc.signerClient.Close() | |||||
err := tc.signerClient.Ping() | |||||
assert.NoError(t, err) | |||||
} | |||||
} | |||||
func TestSignerGetPubKey(t *testing.T) { | |||||
for _, tc := range getSignerTestCases(t) { | |||||
defer tc.signerServer.Stop() | |||||
defer tc.signerClient.Close() | |||||
pubKey := tc.signerClient.GetPubKey() | |||||
expectedPubKey := tc.mockPV.GetPubKey() | |||||
assert.Equal(t, expectedPubKey, pubKey) | |||||
addr := tc.signerClient.GetPubKey().Address() | |||||
expectedAddr := tc.mockPV.GetPubKey().Address() | |||||
assert.Equal(t, expectedAddr, addr) | |||||
} | |||||
} | |||||
func TestSignerProposal(t *testing.T) { | |||||
for _, tc := range getSignerTestCases(t) { | |||||
ts := time.Now() | |||||
want := &types.Proposal{Timestamp: ts} | |||||
have := &types.Proposal{Timestamp: ts} | |||||
defer tc.signerServer.Stop() | |||||
defer tc.signerClient.Close() | |||||
require.NoError(t, tc.mockPV.SignProposal(tc.chainID, want)) | |||||
require.NoError(t, tc.signerClient.SignProposal(tc.chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
} | |||||
} | |||||
func TestSignerVote(t *testing.T) { | |||||
for _, tc := range getSignerTestCases(t) { | |||||
ts := time.Now() | |||||
want := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||||
have := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||||
defer tc.signerServer.Stop() | |||||
defer tc.signerClient.Close() | |||||
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want)) | |||||
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
} | |||||
} | |||||
func TestSignerVoteResetDeadline(t *testing.T) { | |||||
for _, tc := range getSignerTestCases(t) { | |||||
ts := time.Now() | |||||
want := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||||
have := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||||
defer tc.signerServer.Stop() | |||||
defer tc.signerClient.Close() | |||||
time.Sleep(testTimeoutReadWrite2o3) | |||||
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want)) | |||||
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
// TODO(jleni): Clarify what is actually being tested | |||||
// This would exceed the deadline if it was not extended by the previous message | |||||
time.Sleep(testTimeoutReadWrite2o3) | |||||
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want)) | |||||
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
} | |||||
} | |||||
func TestSignerVoteKeepAlive(t *testing.T) { | |||||
for _, tc := range getSignerTestCases(t) { | |||||
ts := time.Now() | |||||
want := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||||
have := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||||
defer tc.signerServer.Stop() | |||||
defer tc.signerClient.Close() | |||||
// Check that even if the client does not request a | |||||
// signature for a long time. The service is still available | |||||
// in this particular case, we use the dialer logger to ensure that | |||||
// test messages are properly interleaved in the test logs | |||||
tc.signerServer.Logger.Debug("TEST: Forced Wait -------------------------------------------------") | |||||
time.Sleep(testTimeoutReadWrite * 3) | |||||
tc.signerServer.Logger.Debug("TEST: Forced Wait DONE---------------------------------------------") | |||||
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want)) | |||||
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
} | |||||
} | |||||
func TestSignerSignProposalErrors(t *testing.T) { | |||||
for _, tc := range getSignerTestCases(t) { | |||||
// Replace service with a mock that always fails | |||||
tc.signerServer.privVal = types.NewErroringMockPV() | |||||
tc.mockPV = types.NewErroringMockPV() | |||||
defer tc.signerServer.Stop() | |||||
defer tc.signerClient.Close() | |||||
ts := time.Now() | |||||
proposal := &types.Proposal{Timestamp: ts} | |||||
err := tc.signerClient.SignProposal(tc.chainID, proposal) | |||||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||||
err = tc.mockPV.SignProposal(tc.chainID, proposal) | |||||
require.Error(t, err) | |||||
err = tc.signerClient.SignProposal(tc.chainID, proposal) | |||||
require.Error(t, err) | |||||
} | |||||
} | |||||
func TestSignerSignVoteErrors(t *testing.T) { | |||||
for _, tc := range getSignerTestCases(t) { | |||||
ts := time.Now() | |||||
vote := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||||
// Replace signer service privval with one that always fails | |||||
tc.signerServer.privVal = types.NewErroringMockPV() | |||||
tc.mockPV = types.NewErroringMockPV() | |||||
defer tc.signerServer.Stop() | |||||
defer tc.signerClient.Close() | |||||
err := tc.signerClient.SignVote(tc.chainID, vote) | |||||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||||
err = tc.mockPV.SignVote(tc.chainID, vote) | |||||
require.Error(t, err) | |||||
err = tc.signerClient.SignVote(tc.chainID, vote) | |||||
require.Error(t, err) | |||||
} | |||||
} | |||||
func brokenHandler(privVal types.PrivValidator, request SignerMessage, chainID string) (SignerMessage, error) { | |||||
var res SignerMessage | |||||
var err error | |||||
switch r := request.(type) { | |||||
// This is broken and will answer most requests with a pubkey response | |||||
case *PubKeyRequest: | |||||
res = &PubKeyResponse{nil, nil} | |||||
case *SignVoteRequest: | |||||
res = &PubKeyResponse{nil, nil} | |||||
case *SignProposalRequest: | |||||
res = &PubKeyResponse{nil, nil} | |||||
case *PingRequest: | |||||
err, res = nil, &PingResponse{} | |||||
default: | |||||
err = fmt.Errorf("unknown msg: %v", r) | |||||
} | |||||
return res, err | |||||
} | |||||
func TestSignerUnexpectedResponse(t *testing.T) { | |||||
for _, tc := range getSignerTestCases(t) { | |||||
tc.signerServer.privVal = types.NewMockPV() | |||||
tc.mockPV = types.NewMockPV() | |||||
tc.signerServer.SetRequestHandler(brokenHandler) | |||||
defer tc.signerServer.Stop() | |||||
defer tc.signerClient.Close() | |||||
ts := time.Now() | |||||
want := &types.Vote{Timestamp: ts, Type: types.PrecommitType} | |||||
e := tc.signerClient.SignVote(tc.chainID, want) | |||||
assert.EqualError(t, e, "received unexpected response") | |||||
} | |||||
} |
@ -0,0 +1,84 @@ | |||||
package privval | |||||
import ( | |||||
"time" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
) | |||||
const ( | |||||
defaultMaxDialRetries = 10 | |||||
defaultRetryWaitMilliseconds = 100 | |||||
) | |||||
// SignerServiceEndpointOption sets an optional parameter on the SignerDialerEndpoint. | |||||
type SignerServiceEndpointOption func(*SignerDialerEndpoint) | |||||
// SignerDialerEndpointTimeoutReadWrite sets the read and write timeout for connections | |||||
// from external signing processes. | |||||
func SignerDialerEndpointTimeoutReadWrite(timeout time.Duration) SignerServiceEndpointOption { | |||||
return func(ss *SignerDialerEndpoint) { ss.timeoutReadWrite = timeout } | |||||
} | |||||
// SignerDialerEndpointConnRetries sets the amount of attempted retries to acceptNewConnection. | |||||
func SignerDialerEndpointConnRetries(retries int) SignerServiceEndpointOption { | |||||
return func(ss *SignerDialerEndpoint) { ss.maxConnRetries = retries } | |||||
} | |||||
// SignerDialerEndpoint dials using its dialer and responds to any | |||||
// signature requests using its privVal. | |||||
type SignerDialerEndpoint struct { | |||||
signerEndpoint | |||||
dialer SocketDialer | |||||
retryWait time.Duration | |||||
maxConnRetries int | |||||
} | |||||
// NewSignerDialerEndpoint returns a SignerDialerEndpoint that will dial using the given | |||||
// dialer and respond to any signature requests over the connection | |||||
// using the given privVal. | |||||
func NewSignerDialerEndpoint( | |||||
logger log.Logger, | |||||
dialer SocketDialer, | |||||
) *SignerDialerEndpoint { | |||||
sd := &SignerDialerEndpoint{ | |||||
dialer: dialer, | |||||
retryWait: defaultRetryWaitMilliseconds * time.Millisecond, | |||||
maxConnRetries: defaultMaxDialRetries, | |||||
} | |||||
sd.BaseService = *cmn.NewBaseService(logger, "SignerDialerEndpoint", sd) | |||||
sd.signerEndpoint.timeoutReadWrite = defaultTimeoutReadWriteSeconds * time.Second | |||||
return sd | |||||
} | |||||
func (sd *SignerDialerEndpoint) ensureConnection() error { | |||||
if sd.IsConnected() { | |||||
return nil | |||||
} | |||||
retries := 0 | |||||
for retries < sd.maxConnRetries { | |||||
conn, err := sd.dialer() | |||||
if err != nil { | |||||
retries++ | |||||
sd.Logger.Debug("SignerDialer: Reconnection failed", "retries", retries, "max", sd.maxConnRetries, "err", err) | |||||
// Wait between retries | |||||
time.Sleep(sd.retryWait) | |||||
} else { | |||||
sd.SetConnection(conn) | |||||
sd.Logger.Debug("SignerDialer: Connection Ready") | |||||
return nil | |||||
} | |||||
} | |||||
sd.Logger.Debug("SignerDialer: Max retries exceeded", "retries", retries, "max", sd.maxConnRetries) | |||||
return ErrNoConnection | |||||
} |
@ -0,0 +1,156 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
"net" | |||||
"sync" | |||||
"time" | |||||
"github.com/pkg/errors" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
) | |||||
const ( | |||||
defaultTimeoutReadWriteSeconds = 3 | |||||
) | |||||
type signerEndpoint struct { | |||||
cmn.BaseService | |||||
connMtx sync.Mutex | |||||
conn net.Conn | |||||
timeoutReadWrite time.Duration | |||||
} | |||||
// Close closes the underlying net.Conn. | |||||
func (se *signerEndpoint) Close() error { | |||||
se.DropConnection() | |||||
return nil | |||||
} | |||||
// IsConnected indicates if there is an active connection | |||||
func (se *signerEndpoint) IsConnected() bool { | |||||
se.connMtx.Lock() | |||||
defer se.connMtx.Unlock() | |||||
return se.isConnected() | |||||
} | |||||
// TryGetConnection retrieves a connection if it is already available | |||||
func (se *signerEndpoint) GetAvailableConnection(connectionAvailableCh chan net.Conn) bool { | |||||
se.connMtx.Lock() | |||||
defer se.connMtx.Unlock() | |||||
// Is there a connection ready? | |||||
select { | |||||
case se.conn = <-connectionAvailableCh: | |||||
return true | |||||
default: | |||||
} | |||||
return false | |||||
} | |||||
// TryGetConnection retrieves a connection if it is already available | |||||
func (se *signerEndpoint) WaitConnection(connectionAvailableCh chan net.Conn, maxWait time.Duration) error { | |||||
se.connMtx.Lock() | |||||
defer se.connMtx.Unlock() | |||||
select { | |||||
case se.conn = <-connectionAvailableCh: | |||||
case <-time.After(maxWait): | |||||
return ErrConnectionTimeout | |||||
} | |||||
return nil | |||||
} | |||||
// SetConnection replaces the current connection object | |||||
func (se *signerEndpoint) SetConnection(newConnection net.Conn) { | |||||
se.connMtx.Lock() | |||||
defer se.connMtx.Unlock() | |||||
se.conn = newConnection | |||||
} | |||||
// IsConnected indicates if there is an active connection | |||||
func (se *signerEndpoint) DropConnection() { | |||||
se.connMtx.Lock() | |||||
defer se.connMtx.Unlock() | |||||
se.dropConnection() | |||||
} | |||||
// ReadMessage reads a message from the endpoint | |||||
func (se *signerEndpoint) ReadMessage() (msg SignerMessage, err error) { | |||||
se.connMtx.Lock() | |||||
defer se.connMtx.Unlock() | |||||
if !se.isConnected() { | |||||
return nil, fmt.Errorf("endpoint is not connected") | |||||
} | |||||
// Reset read deadline | |||||
deadline := time.Now().Add(se.timeoutReadWrite) | |||||
err = se.conn.SetReadDeadline(deadline) | |||||
if err != nil { | |||||
return | |||||
} | |||||
const maxRemoteSignerMsgSize = 1024 * 10 | |||||
_, err = cdc.UnmarshalBinaryLengthPrefixedReader(se.conn, &msg, maxRemoteSignerMsgSize) | |||||
if _, ok := err.(timeoutError); ok { | |||||
if err != nil { | |||||
err = errors.Wrap(ErrReadTimeout, err.Error()) | |||||
} else { | |||||
err = errors.Wrap(ErrReadTimeout, "Empty error") | |||||
} | |||||
se.Logger.Debug("Dropping [read]", "obj", se) | |||||
se.dropConnection() | |||||
} | |||||
return | |||||
} | |||||
// WriteMessage writes a message from the endpoint | |||||
func (se *signerEndpoint) WriteMessage(msg SignerMessage) (err error) { | |||||
se.connMtx.Lock() | |||||
defer se.connMtx.Unlock() | |||||
if !se.isConnected() { | |||||
return errors.Wrap(ErrNoConnection, "endpoint is not connected") | |||||
} | |||||
// Reset read deadline | |||||
deadline := time.Now().Add(se.timeoutReadWrite) | |||||
se.Logger.Debug("Write::Error Resetting deadline", "obj", se) | |||||
err = se.conn.SetWriteDeadline(deadline) | |||||
if err != nil { | |||||
return | |||||
} | |||||
_, err = cdc.MarshalBinaryLengthPrefixedWriter(se.conn, msg) | |||||
if _, ok := err.(timeoutError); ok { | |||||
if err != nil { | |||||
err = errors.Wrap(ErrWriteTimeout, err.Error()) | |||||
} else { | |||||
err = errors.Wrap(ErrWriteTimeout, "Empty error") | |||||
} | |||||
se.dropConnection() | |||||
} | |||||
return | |||||
} | |||||
func (se *signerEndpoint) isConnected() bool { | |||||
return se.conn != nil | |||||
} | |||||
func (se *signerEndpoint) dropConnection() { | |||||
if se.conn != nil { | |||||
if err := se.conn.Close(); err != nil { | |||||
se.Logger.Error("signerEndpoint::dropConnection", "err", err) | |||||
} | |||||
se.conn = nil | |||||
} | |||||
} |
@ -0,0 +1,198 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
"net" | |||||
"sync" | |||||
"time" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
) | |||||
// SignerValidatorEndpointOption sets an optional parameter on the SocketVal. | |||||
type SignerValidatorEndpointOption func(*SignerListenerEndpoint) | |||||
// SignerListenerEndpoint listens for an external process to dial in | |||||
// and keeps the connection alive by dropping and reconnecting | |||||
type SignerListenerEndpoint struct { | |||||
signerEndpoint | |||||
listener net.Listener | |||||
connectRequestCh chan struct{} | |||||
connectionAvailableCh chan net.Conn | |||||
timeoutAccept time.Duration | |||||
pingTimer *time.Ticker | |||||
instanceMtx sync.Mutex // Ensures instance public methods access, i.e. SendRequest | |||||
} | |||||
// NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint. | |||||
func NewSignerListenerEndpoint( | |||||
logger log.Logger, | |||||
listener net.Listener, | |||||
) *SignerListenerEndpoint { | |||||
sc := &SignerListenerEndpoint{ | |||||
listener: listener, | |||||
timeoutAccept: defaultTimeoutAcceptSeconds * time.Second, | |||||
} | |||||
sc.BaseService = *cmn.NewBaseService(logger, "SignerListenerEndpoint", sc) | |||||
sc.signerEndpoint.timeoutReadWrite = defaultTimeoutReadWriteSeconds * time.Second | |||||
return sc | |||||
} | |||||
// OnStart implements cmn.Service. | |||||
func (sl *SignerListenerEndpoint) OnStart() error { | |||||
sl.connectRequestCh = make(chan struct{}) | |||||
sl.connectionAvailableCh = make(chan net.Conn) | |||||
sl.pingTimer = time.NewTicker(defaultPingPeriodMilliseconds * time.Millisecond) | |||||
go sl.serviceLoop() | |||||
go sl.pingLoop() | |||||
sl.connectRequestCh <- struct{}{} | |||||
return nil | |||||
} | |||||
// OnStop implements cmn.Service | |||||
func (sl *SignerListenerEndpoint) OnStop() { | |||||
sl.instanceMtx.Lock() | |||||
defer sl.instanceMtx.Unlock() | |||||
_ = sl.Close() | |||||
// Stop listening | |||||
if sl.listener != nil { | |||||
if err := sl.listener.Close(); err != nil { | |||||
sl.Logger.Error("Closing Listener", "err", err) | |||||
sl.listener = nil | |||||
} | |||||
} | |||||
sl.pingTimer.Stop() | |||||
} | |||||
// WaitForConnection waits maxWait for a connection or returns a timeout error | |||||
func (sl *SignerListenerEndpoint) WaitForConnection(maxWait time.Duration) error { | |||||
sl.instanceMtx.Lock() | |||||
defer sl.instanceMtx.Unlock() | |||||
return sl.ensureConnection(maxWait) | |||||
} | |||||
// SendRequest ensures there is a connection, sends a request and waits for a response | |||||
func (sl *SignerListenerEndpoint) SendRequest(request SignerMessage) (SignerMessage, error) { | |||||
sl.instanceMtx.Lock() | |||||
defer sl.instanceMtx.Unlock() | |||||
err := sl.ensureConnection(sl.timeoutAccept) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
err = sl.WriteMessage(request) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
res, err := sl.ReadMessage() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return res, nil | |||||
} | |||||
func (sl *SignerListenerEndpoint) ensureConnection(maxWait time.Duration) error { | |||||
if sl.IsConnected() { | |||||
return nil | |||||
} | |||||
// Is there a connection ready? then use it | |||||
if sl.GetAvailableConnection(sl.connectionAvailableCh) { | |||||
return nil | |||||
} | |||||
// block until connected or timeout | |||||
sl.triggerConnect() | |||||
err := sl.WaitConnection(sl.connectionAvailableCh, maxWait) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
return nil | |||||
} | |||||
func (sl *SignerListenerEndpoint) acceptNewConnection() (net.Conn, error) { | |||||
if !sl.IsRunning() || sl.listener == nil { | |||||
return nil, fmt.Errorf("endpoint is closing") | |||||
} | |||||
// wait for a new conn | |||||
sl.Logger.Info("SignerListener: Listening for new connection") | |||||
conn, err := sl.listener.Accept() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return conn, nil | |||||
} | |||||
func (sl *SignerListenerEndpoint) triggerConnect() { | |||||
select { | |||||
case sl.connectRequestCh <- struct{}{}: | |||||
default: | |||||
} | |||||
} | |||||
func (sl *SignerListenerEndpoint) triggerReconnect() { | |||||
sl.DropConnection() | |||||
sl.triggerConnect() | |||||
} | |||||
func (sl *SignerListenerEndpoint) serviceLoop() { | |||||
for { | |||||
select { | |||||
case <-sl.connectRequestCh: | |||||
{ | |||||
conn, err := sl.acceptNewConnection() | |||||
if err == nil { | |||||
sl.Logger.Info("SignerListener: Connected") | |||||
// We have a good connection, wait for someone that needs one otherwise cancellation | |||||
select { | |||||
case sl.connectionAvailableCh <- conn: | |||||
case <-sl.Quit(): | |||||
return | |||||
} | |||||
} | |||||
select { | |||||
case sl.connectRequestCh <- struct{}{}: | |||||
default: | |||||
} | |||||
} | |||||
case <-sl.Quit(): | |||||
return | |||||
} | |||||
} | |||||
} | |||||
func (sl *SignerListenerEndpoint) pingLoop() { | |||||
for { | |||||
select { | |||||
case <-sl.pingTimer.C: | |||||
{ | |||||
_, err := sl.SendRequest(&PingRequest{}) | |||||
if err != nil { | |||||
sl.Logger.Error("SignerListener: Ping timeout") | |||||
sl.triggerReconnect() | |||||
} | |||||
} | |||||
case <-sl.Quit(): | |||||
return | |||||
} | |||||
} | |||||
} |
@ -0,0 +1,198 @@ | |||||
package privval | |||||
import ( | |||||
"net" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
"github.com/tendermint/tendermint/crypto/ed25519" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
var ( | |||||
testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second | |||||
testTimeoutReadWrite = 100 * time.Millisecond | |||||
testTimeoutReadWrite2o3 = 60 * time.Millisecond // 2/3 of the other one | |||||
) | |||||
type dialerTestCase struct { | |||||
addr string | |||||
dialer SocketDialer | |||||
} | |||||
// TestSignerRemoteRetryTCPOnly will test connection retry attempts over TCP. We | |||||
// don't need this for Unix sockets because the OS instantly knows the state of | |||||
// both ends of the socket connection. This basically causes the | |||||
// SignerDialerEndpoint.dialer() call inside SignerDialerEndpoint.acceptNewConnection() to return | |||||
// successfully immediately, putting an instant stop to any retry attempts. | |||||
func TestSignerRemoteRetryTCPOnly(t *testing.T) { | |||||
var ( | |||||
attemptCh = make(chan int) | |||||
retries = 10 | |||||
) | |||||
ln, err := net.Listen("tcp", "127.0.0.1:0") | |||||
require.NoError(t, err) | |||||
// Continuously Accept connection and close {attempts} times | |||||
go func(ln net.Listener, attemptCh chan<- int) { | |||||
attempts := 0 | |||||
for { | |||||
conn, err := ln.Accept() | |||||
require.NoError(t, err) | |||||
err = conn.Close() | |||||
require.NoError(t, err) | |||||
attempts++ | |||||
if attempts == retries { | |||||
attemptCh <- attempts | |||||
break | |||||
} | |||||
} | |||||
}(ln, attemptCh) | |||||
dialerEndpoint := NewSignerDialerEndpoint( | |||||
log.TestingLogger(), | |||||
DialTCPFn(ln.Addr().String(), testTimeoutReadWrite, ed25519.GenPrivKey()), | |||||
) | |||||
SignerDialerEndpointTimeoutReadWrite(time.Millisecond)(dialerEndpoint) | |||||
SignerDialerEndpointConnRetries(retries)(dialerEndpoint) | |||||
chainId := cmn.RandStr(12) | |||||
mockPV := types.NewMockPV() | |||||
signerServer := NewSignerServer(dialerEndpoint, chainId, mockPV) | |||||
err = signerServer.Start() | |||||
require.NoError(t, err) | |||||
defer signerServer.Stop() | |||||
select { | |||||
case attempts := <-attemptCh: | |||||
assert.Equal(t, retries, attempts) | |||||
case <-time.After(1500 * time.Millisecond): | |||||
t.Error("expected remote to observe connection attempts") | |||||
} | |||||
} | |||||
func TestRetryConnToRemoteSigner(t *testing.T) { | |||||
for _, tc := range getDialerTestCases(t) { | |||||
var ( | |||||
logger = log.TestingLogger() | |||||
chainID = cmn.RandStr(12) | |||||
mockPV = types.NewMockPV() | |||||
endpointIsOpenCh = make(chan struct{}) | |||||
thisConnTimeout = testTimeoutReadWrite | |||||
listenerEndpoint = newSignerListenerEndpoint(logger, tc.addr, thisConnTimeout) | |||||
) | |||||
dialerEndpoint := NewSignerDialerEndpoint( | |||||
logger, | |||||
tc.dialer, | |||||
) | |||||
SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint) | |||||
SignerDialerEndpointConnRetries(10)(dialerEndpoint) | |||||
signerServer := NewSignerServer(dialerEndpoint, chainID, mockPV) | |||||
startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh) | |||||
defer listenerEndpoint.Stop() | |||||
require.NoError(t, signerServer.Start()) | |||||
assert.True(t, signerServer.IsRunning()) | |||||
<-endpointIsOpenCh | |||||
signerServer.Stop() | |||||
dialerEndpoint2 := NewSignerDialerEndpoint( | |||||
logger, | |||||
tc.dialer, | |||||
) | |||||
signerServer2 := NewSignerServer(dialerEndpoint2, chainID, mockPV) | |||||
// let some pings pass | |||||
require.NoError(t, signerServer2.Start()) | |||||
assert.True(t, signerServer2.IsRunning()) | |||||
defer signerServer2.Stop() | |||||
// give the client some time to re-establish the conn to the remote signer | |||||
// should see sth like this in the logs: | |||||
// | |||||
// E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" | |||||
// I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal | |||||
time.Sleep(testTimeoutReadWrite * 2) | |||||
} | |||||
} | |||||
/////////////////////////////////// | |||||
func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint { | |||||
proto, address := cmn.ProtocolAndAddress(addr) | |||||
ln, err := net.Listen(proto, address) | |||||
logger.Info("SignerListener: Listening", "proto", proto, "address", address) | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
var listener net.Listener | |||||
if proto == "unix" { | |||||
unixLn := NewUnixListener(ln) | |||||
UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn) | |||||
UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn) | |||||
listener = unixLn | |||||
} else { | |||||
tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) | |||||
TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn) | |||||
TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn) | |||||
listener = tcpLn | |||||
} | |||||
return NewSignerListenerEndpoint(logger, listener) | |||||
} | |||||
func startListenerEndpointAsync(t *testing.T, sle *SignerListenerEndpoint, endpointIsOpenCh chan struct{}) { | |||||
go func(sle *SignerListenerEndpoint) { | |||||
require.NoError(t, sle.Start()) | |||||
assert.True(t, sle.IsRunning()) | |||||
close(endpointIsOpenCh) | |||||
}(sle) | |||||
} | |||||
func getMockEndpoints( | |||||
t *testing.T, | |||||
addr string, | |||||
socketDialer SocketDialer, | |||||
) (*SignerListenerEndpoint, *SignerDialerEndpoint) { | |||||
var ( | |||||
logger = log.TestingLogger() | |||||
endpointIsOpenCh = make(chan struct{}) | |||||
dialerEndpoint = NewSignerDialerEndpoint( | |||||
logger, | |||||
socketDialer, | |||||
) | |||||
listenerEndpoint = newSignerListenerEndpoint(logger, addr, testTimeoutReadWrite) | |||||
) | |||||
SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint) | |||||
SignerDialerEndpointConnRetries(1e6)(dialerEndpoint) | |||||
startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh) | |||||
require.NoError(t, dialerEndpoint.Start()) | |||||
assert.True(t, dialerEndpoint.IsRunning()) | |||||
<-endpointIsOpenCh | |||||
return listenerEndpoint, dialerEndpoint | |||||
} |
@ -1,192 +0,0 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
"io" | |||||
"net" | |||||
"github.com/pkg/errors" | |||||
"github.com/tendermint/tendermint/crypto" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
// SignerRemote implements PrivValidator. | |||||
// It uses a net.Conn to request signatures from an external process. | |||||
type SignerRemote struct { | |||||
conn net.Conn | |||||
// memoized | |||||
consensusPubKey crypto.PubKey | |||||
} | |||||
// Check that SignerRemote implements PrivValidator. | |||||
var _ types.PrivValidator = (*SignerRemote)(nil) | |||||
// NewSignerRemote returns an instance of SignerRemote. | |||||
func NewSignerRemote(conn net.Conn) (*SignerRemote, error) { | |||||
// retrieve and memoize the consensus public key once. | |||||
pubKey, err := getPubKey(conn) | |||||
if err != nil { | |||||
return nil, cmn.ErrorWrap(err, "error while retrieving public key for remote signer") | |||||
} | |||||
return &SignerRemote{ | |||||
conn: conn, | |||||
consensusPubKey: pubKey, | |||||
}, nil | |||||
} | |||||
// Close calls Close on the underlying net.Conn. | |||||
func (sc *SignerRemote) Close() error { | |||||
return sc.conn.Close() | |||||
} | |||||
// GetPubKey implements PrivValidator. | |||||
func (sc *SignerRemote) GetPubKey() crypto.PubKey { | |||||
return sc.consensusPubKey | |||||
} | |||||
// not thread-safe (only called on startup). | |||||
func getPubKey(conn net.Conn) (crypto.PubKey, error) { | |||||
err := writeMsg(conn, &PubKeyRequest{}) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
res, err := readMsg(conn) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
pubKeyResp, ok := res.(*PubKeyResponse) | |||||
if !ok { | |||||
return nil, errors.Wrap(ErrUnexpectedResponse, "response is not PubKeyResponse") | |||||
} | |||||
if pubKeyResp.Error != nil { | |||||
return nil, errors.Wrap(pubKeyResp.Error, "failed to get private validator's public key") | |||||
} | |||||
return pubKeyResp.PubKey, nil | |||||
} | |||||
// SignVote implements PrivValidator. | |||||
func (sc *SignerRemote) SignVote(chainID string, vote *types.Vote) error { | |||||
err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote}) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
res, err := readMsg(sc.conn) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
resp, ok := res.(*SignedVoteResponse) | |||||
if !ok { | |||||
return ErrUnexpectedResponse | |||||
} | |||||
if resp.Error != nil { | |||||
return resp.Error | |||||
} | |||||
*vote = *resp.Vote | |||||
return nil | |||||
} | |||||
// SignProposal implements PrivValidator. | |||||
func (sc *SignerRemote) SignProposal(chainID string, proposal *types.Proposal) error { | |||||
err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal}) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
res, err := readMsg(sc.conn) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
resp, ok := res.(*SignedProposalResponse) | |||||
if !ok { | |||||
return ErrUnexpectedResponse | |||||
} | |||||
if resp.Error != nil { | |||||
return resp.Error | |||||
} | |||||
*proposal = *resp.Proposal | |||||
return nil | |||||
} | |||||
// Ping is used to check connection health. | |||||
func (sc *SignerRemote) Ping() error { | |||||
err := writeMsg(sc.conn, &PingRequest{}) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
res, err := readMsg(sc.conn) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
_, ok := res.(*PingResponse) | |||||
if !ok { | |||||
return ErrUnexpectedResponse | |||||
} | |||||
return nil | |||||
} | |||||
func readMsg(r io.Reader) (msg RemoteSignerMsg, err error) { | |||||
const maxRemoteSignerMsgSize = 1024 * 10 | |||||
_, err = cdc.UnmarshalBinaryLengthPrefixedReader(r, &msg, maxRemoteSignerMsgSize) | |||||
if _, ok := err.(timeoutError); ok { | |||||
err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) | |||||
} | |||||
return | |||||
} | |||||
func writeMsg(w io.Writer, msg interface{}) (err error) { | |||||
_, err = cdc.MarshalBinaryLengthPrefixedWriter(w, msg) | |||||
if _, ok := err.(timeoutError); ok { | |||||
err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) | |||||
} | |||||
return | |||||
} | |||||
func handleRequest(req RemoteSignerMsg, chainID string, privVal types.PrivValidator) (RemoteSignerMsg, error) { | |||||
var res RemoteSignerMsg | |||||
var err error | |||||
switch r := req.(type) { | |||||
case *PubKeyRequest: | |||||
var p crypto.PubKey | |||||
p = privVal.GetPubKey() | |||||
res = &PubKeyResponse{p, nil} | |||||
case *SignVoteRequest: | |||||
err = privVal.SignVote(chainID, r.Vote) | |||||
if err != nil { | |||||
res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}} | |||||
} else { | |||||
res = &SignedVoteResponse{r.Vote, nil} | |||||
} | |||||
case *SignProposalRequest: | |||||
err = privVal.SignProposal(chainID, r.Proposal) | |||||
if err != nil { | |||||
res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}} | |||||
} else { | |||||
res = &SignedProposalResponse{r.Proposal, nil} | |||||
} | |||||
case *PingRequest: | |||||
res = &PingResponse{} | |||||
default: | |||||
err = fmt.Errorf("unknown msg: %v", r) | |||||
} | |||||
return res, err | |||||
} |
@ -1,68 +0,0 @@ | |||||
package privval | |||||
import ( | |||||
"net" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
"github.com/tendermint/tendermint/crypto/ed25519" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
// TestSignerRemoteRetryTCPOnly will test connection retry attempts over TCP. We | |||||
// don't need this for Unix sockets because the OS instantly knows the state of | |||||
// both ends of the socket connection. This basically causes the | |||||
// SignerServiceEndpoint.dialer() call inside SignerServiceEndpoint.connect() to return | |||||
// successfully immediately, putting an instant stop to any retry attempts. | |||||
func TestSignerRemoteRetryTCPOnly(t *testing.T) { | |||||
var ( | |||||
attemptCh = make(chan int) | |||||
retries = 2 | |||||
) | |||||
ln, err := net.Listen("tcp", "127.0.0.1:0") | |||||
require.NoError(t, err) | |||||
go func(ln net.Listener, attemptCh chan<- int) { | |||||
attempts := 0 | |||||
for { | |||||
conn, err := ln.Accept() | |||||
require.NoError(t, err) | |||||
err = conn.Close() | |||||
require.NoError(t, err) | |||||
attempts++ | |||||
if attempts == retries { | |||||
attemptCh <- attempts | |||||
break | |||||
} | |||||
} | |||||
}(ln, attemptCh) | |||||
serviceEndpoint := NewSignerServiceEndpoint( | |||||
log.TestingLogger(), | |||||
cmn.RandStr(12), | |||||
types.NewMockPV(), | |||||
DialTCPFn(ln.Addr().String(), testTimeoutReadWrite, ed25519.GenPrivKey()), | |||||
) | |||||
defer serviceEndpoint.Stop() | |||||
SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint) | |||||
SignerServiceEndpointConnRetries(retries)(serviceEndpoint) | |||||
assert.Equal(t, serviceEndpoint.Start(), ErrDialRetryMax) | |||||
select { | |||||
case attempts := <-attemptCh: | |||||
assert.Equal(t, retries, attempts) | |||||
case <-time.After(100 * time.Millisecond): | |||||
t.Error("expected remote to observe connection attempts") | |||||
} | |||||
} |
@ -0,0 +1,44 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
"github.com/tendermint/tendermint/crypto" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
func DefaultValidationRequestHandler(privVal types.PrivValidator, req SignerMessage, chainID string) (SignerMessage, error) { | |||||
var res SignerMessage | |||||
var err error | |||||
switch r := req.(type) { | |||||
case *PubKeyRequest: | |||||
var p crypto.PubKey | |||||
p = privVal.GetPubKey() | |||||
res = &PubKeyResponse{p, nil} | |||||
case *SignVoteRequest: | |||||
err = privVal.SignVote(chainID, r.Vote) | |||||
if err != nil { | |||||
res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}} | |||||
} else { | |||||
res = &SignedVoteResponse{r.Vote, nil} | |||||
} | |||||
case *SignProposalRequest: | |||||
err = privVal.SignProposal(chainID, r.Proposal) | |||||
if err != nil { | |||||
res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}} | |||||
} else { | |||||
res = &SignedProposalResponse{r.Proposal, nil} | |||||
} | |||||
case *PingRequest: | |||||
err, res = nil, &PingResponse{} | |||||
default: | |||||
err = fmt.Errorf("unknown msg: %v", r) | |||||
} | |||||
return res, err | |||||
} |
@ -0,0 +1,107 @@ | |||||
package privval | |||||
import ( | |||||
"io" | |||||
"sync" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
// ValidationRequestHandlerFunc handles different remoteSigner requests | |||||
type ValidationRequestHandlerFunc func( | |||||
privVal types.PrivValidator, | |||||
requestMessage SignerMessage, | |||||
chainID string) (SignerMessage, error) | |||||
type SignerServer struct { | |||||
cmn.BaseService | |||||
endpoint *SignerDialerEndpoint | |||||
chainID string | |||||
privVal types.PrivValidator | |||||
handlerMtx sync.Mutex | |||||
validationRequestHandler ValidationRequestHandlerFunc | |||||
} | |||||
func NewSignerServer(endpoint *SignerDialerEndpoint, chainID string, privVal types.PrivValidator) *SignerServer { | |||||
ss := &SignerServer{ | |||||
endpoint: endpoint, | |||||
chainID: chainID, | |||||
privVal: privVal, | |||||
validationRequestHandler: DefaultValidationRequestHandler, | |||||
} | |||||
ss.BaseService = *cmn.NewBaseService(endpoint.Logger, "SignerServer", ss) | |||||
return ss | |||||
} | |||||
// OnStart implements cmn.Service. | |||||
func (ss *SignerServer) OnStart() error { | |||||
go ss.serviceLoop() | |||||
return nil | |||||
} | |||||
// OnStop implements cmn.Service. | |||||
func (ss *SignerServer) OnStop() { | |||||
ss.endpoint.Logger.Debug("SignerServer: OnStop calling Close") | |||||
_ = ss.endpoint.Close() | |||||
} | |||||
// SetRequestHandler override the default function that is used to service requests | |||||
func (ss *SignerServer) SetRequestHandler(validationRequestHandler ValidationRequestHandlerFunc) { | |||||
ss.handlerMtx.Lock() | |||||
defer ss.handlerMtx.Unlock() | |||||
ss.validationRequestHandler = validationRequestHandler | |||||
} | |||||
func (ss *SignerServer) servicePendingRequest() { | |||||
if !ss.IsRunning() { | |||||
return // Ignore error from closing. | |||||
} | |||||
req, err := ss.endpoint.ReadMessage() | |||||
if err != nil { | |||||
if err != io.EOF { | |||||
ss.Logger.Error("SignerServer: HandleMessage", "err", err) | |||||
} | |||||
return | |||||
} | |||||
var res SignerMessage | |||||
{ | |||||
// limit the scope of the lock | |||||
ss.handlerMtx.Lock() | |||||
defer ss.handlerMtx.Unlock() | |||||
res, err = ss.validationRequestHandler(ss.privVal, req, ss.chainID) | |||||
if err != nil { | |||||
// only log the error; we'll reply with an error in res | |||||
ss.Logger.Error("SignerServer: handleMessage", "err", err) | |||||
} | |||||
} | |||||
if res != nil { | |||||
err = ss.endpoint.WriteMessage(res) | |||||
if err != nil { | |||||
ss.Logger.Error("SignerServer: writeMessage", "err", err) | |||||
} | |||||
} | |||||
} | |||||
func (ss *SignerServer) serviceLoop() { | |||||
for { | |||||
select { | |||||
default: | |||||
err := ss.endpoint.ensureConnection() | |||||
if err != nil { | |||||
return | |||||
} | |||||
ss.servicePendingRequest() | |||||
case <-ss.Quit(): | |||||
return | |||||
} | |||||
} | |||||
} |
@ -1,139 +0,0 @@ | |||||
package privval | |||||
import ( | |||||
"io" | |||||
"net" | |||||
"time" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
// SignerServiceEndpointOption sets an optional parameter on the SignerServiceEndpoint. | |||||
type SignerServiceEndpointOption func(*SignerServiceEndpoint) | |||||
// SignerServiceEndpointTimeoutReadWrite sets the read and write timeout for connections | |||||
// from external signing processes. | |||||
func SignerServiceEndpointTimeoutReadWrite(timeout time.Duration) SignerServiceEndpointOption { | |||||
return func(ss *SignerServiceEndpoint) { ss.timeoutReadWrite = timeout } | |||||
} | |||||
// SignerServiceEndpointConnRetries sets the amount of attempted retries to connect. | |||||
func SignerServiceEndpointConnRetries(retries int) SignerServiceEndpointOption { | |||||
return func(ss *SignerServiceEndpoint) { ss.connRetries = retries } | |||||
} | |||||
// SignerServiceEndpoint dials using its dialer and responds to any | |||||
// signature requests using its privVal. | |||||
type SignerServiceEndpoint struct { | |||||
cmn.BaseService | |||||
chainID string | |||||
timeoutReadWrite time.Duration | |||||
connRetries int | |||||
privVal types.PrivValidator | |||||
dialer SocketDialer | |||||
conn net.Conn | |||||
} | |||||
// NewSignerServiceEndpoint returns a SignerServiceEndpoint that will dial using the given | |||||
// dialer and respond to any signature requests over the connection | |||||
// using the given privVal. | |||||
func NewSignerServiceEndpoint( | |||||
logger log.Logger, | |||||
chainID string, | |||||
privVal types.PrivValidator, | |||||
dialer SocketDialer, | |||||
) *SignerServiceEndpoint { | |||||
se := &SignerServiceEndpoint{ | |||||
chainID: chainID, | |||||
timeoutReadWrite: time.Second * defaultTimeoutReadWriteSeconds, | |||||
connRetries: defaultMaxDialRetries, | |||||
privVal: privVal, | |||||
dialer: dialer, | |||||
} | |||||
se.BaseService = *cmn.NewBaseService(logger, "SignerServiceEndpoint", se) | |||||
return se | |||||
} | |||||
// OnStart implements cmn.Service. | |||||
func (se *SignerServiceEndpoint) OnStart() error { | |||||
conn, err := se.connect() | |||||
if err != nil { | |||||
se.Logger.Error("OnStart", "err", err) | |||||
return err | |||||
} | |||||
se.conn = conn | |||||
go se.handleConnection(conn) | |||||
return nil | |||||
} | |||||
// OnStop implements cmn.Service. | |||||
func (se *SignerServiceEndpoint) OnStop() { | |||||
if se.conn == nil { | |||||
return | |||||
} | |||||
if err := se.conn.Close(); err != nil { | |||||
se.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed")) | |||||
} | |||||
} | |||||
func (se *SignerServiceEndpoint) connect() (net.Conn, error) { | |||||
for retries := 0; retries < se.connRetries; retries++ { | |||||
// Don't sleep if it is the first retry. | |||||
if retries > 0 { | |||||
time.Sleep(se.timeoutReadWrite) | |||||
} | |||||
conn, err := se.dialer() | |||||
if err == nil { | |||||
return conn, nil | |||||
} | |||||
se.Logger.Error("dialing", "err", err) | |||||
} | |||||
return nil, ErrDialRetryMax | |||||
} | |||||
func (se *SignerServiceEndpoint) handleConnection(conn net.Conn) { | |||||
for { | |||||
if !se.IsRunning() { | |||||
return // Ignore error from listener closing. | |||||
} | |||||
// Reset the connection deadline | |||||
deadline := time.Now().Add(se.timeoutReadWrite) | |||||
err := conn.SetDeadline(deadline) | |||||
if err != nil { | |||||
return | |||||
} | |||||
req, err := readMsg(conn) | |||||
if err != nil { | |||||
if err != io.EOF { | |||||
se.Logger.Error("handleConnection readMsg", "err", err) | |||||
} | |||||
return | |||||
} | |||||
res, err := handleRequest(req, se.chainID, se.privVal) | |||||
if err != nil { | |||||
// only log the error; we'll reply with an error in res | |||||
se.Logger.Error("handleConnection handleRequest", "err", err) | |||||
} | |||||
err = writeMsg(conn, res) | |||||
if err != nil { | |||||
se.Logger.Error("handleConnection writeMsg", "err", err) | |||||
return | |||||
} | |||||
} | |||||
} |
@ -1,230 +0,0 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
"net" | |||||
"sync" | |||||
"time" | |||||
"github.com/tendermint/tendermint/crypto" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
const ( | |||||
defaultHeartbeatSeconds = 2 | |||||
defaultMaxDialRetries = 10 | |||||
) | |||||
var ( | |||||
heartbeatPeriod = time.Second * defaultHeartbeatSeconds | |||||
) | |||||
// SignerValidatorEndpointOption sets an optional parameter on the SocketVal. | |||||
type SignerValidatorEndpointOption func(*SignerValidatorEndpoint) | |||||
// SignerValidatorEndpointSetHeartbeat sets the period on which to check the liveness of the | |||||
// connected Signer connections. | |||||
func SignerValidatorEndpointSetHeartbeat(period time.Duration) SignerValidatorEndpointOption { | |||||
return func(sc *SignerValidatorEndpoint) { sc.heartbeatPeriod = period } | |||||
} | |||||
// SocketVal implements PrivValidator. | |||||
// It listens for an external process to dial in and uses | |||||
// the socket to request signatures. | |||||
type SignerValidatorEndpoint struct { | |||||
cmn.BaseService | |||||
listener net.Listener | |||||
// ping | |||||
cancelPingCh chan struct{} | |||||
pingTicker *time.Ticker | |||||
heartbeatPeriod time.Duration | |||||
// signer is mutable since it can be reset if the connection fails. | |||||
// failures are detected by a background ping routine. | |||||
// All messages are request/response, so we hold the mutex | |||||
// so only one request/response pair can happen at a time. | |||||
// Methods on the underlying net.Conn itself are already goroutine safe. | |||||
mtx sync.Mutex | |||||
// TODO: Signer should encapsulate and hide the endpoint completely. Invert the relation | |||||
signer *SignerRemote | |||||
} | |||||
// Check that SignerValidatorEndpoint implements PrivValidator. | |||||
var _ types.PrivValidator = (*SignerValidatorEndpoint)(nil) | |||||
// NewSignerValidatorEndpoint returns an instance of SignerValidatorEndpoint. | |||||
func NewSignerValidatorEndpoint(logger log.Logger, listener net.Listener) *SignerValidatorEndpoint { | |||||
sc := &SignerValidatorEndpoint{ | |||||
listener: listener, | |||||
heartbeatPeriod: heartbeatPeriod, | |||||
} | |||||
sc.BaseService = *cmn.NewBaseService(logger, "SignerValidatorEndpoint", sc) | |||||
return sc | |||||
} | |||||
//-------------------------------------------------------- | |||||
// Implement PrivValidator | |||||
// GetPubKey implements PrivValidator. | |||||
func (ve *SignerValidatorEndpoint) GetPubKey() crypto.PubKey { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
return ve.signer.GetPubKey() | |||||
} | |||||
// SignVote implements PrivValidator. | |||||
func (ve *SignerValidatorEndpoint) SignVote(chainID string, vote *types.Vote) error { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
return ve.signer.SignVote(chainID, vote) | |||||
} | |||||
// SignProposal implements PrivValidator. | |||||
func (ve *SignerValidatorEndpoint) SignProposal(chainID string, proposal *types.Proposal) error { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
return ve.signer.SignProposal(chainID, proposal) | |||||
} | |||||
//-------------------------------------------------------- | |||||
// More thread safe methods proxied to the signer | |||||
// Ping is used to check connection health. | |||||
func (ve *SignerValidatorEndpoint) Ping() error { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
return ve.signer.Ping() | |||||
} | |||||
// Close closes the underlying net.Conn. | |||||
func (ve *SignerValidatorEndpoint) Close() { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
if ve.signer != nil { | |||||
if err := ve.signer.Close(); err != nil { | |||||
ve.Logger.Error("OnStop", "err", err) | |||||
} | |||||
} | |||||
if ve.listener != nil { | |||||
if err := ve.listener.Close(); err != nil { | |||||
ve.Logger.Error("OnStop", "err", err) | |||||
} | |||||
} | |||||
} | |||||
//-------------------------------------------------------- | |||||
// Service start and stop | |||||
// OnStart implements cmn.Service. | |||||
func (ve *SignerValidatorEndpoint) OnStart() error { | |||||
if closed, err := ve.reset(); err != nil { | |||||
ve.Logger.Error("OnStart", "err", err) | |||||
return err | |||||
} else if closed { | |||||
return fmt.Errorf("listener is closed") | |||||
} | |||||
// Start a routine to keep the connection alive | |||||
ve.cancelPingCh = make(chan struct{}, 1) | |||||
ve.pingTicker = time.NewTicker(ve.heartbeatPeriod) | |||||
go func() { | |||||
for { | |||||
select { | |||||
case <-ve.pingTicker.C: | |||||
err := ve.Ping() | |||||
if err != nil { | |||||
ve.Logger.Error("Ping", "err", err) | |||||
if err == ErrUnexpectedResponse { | |||||
return | |||||
} | |||||
closed, err := ve.reset() | |||||
if err != nil { | |||||
ve.Logger.Error("Reconnecting to remote signer failed", "err", err) | |||||
continue | |||||
} | |||||
if closed { | |||||
ve.Logger.Info("listener is closing") | |||||
return | |||||
} | |||||
ve.Logger.Info("Re-created connection to remote signer", "impl", ve) | |||||
} | |||||
case <-ve.cancelPingCh: | |||||
ve.pingTicker.Stop() | |||||
return | |||||
} | |||||
} | |||||
}() | |||||
return nil | |||||
} | |||||
// OnStop implements cmn.Service. | |||||
func (ve *SignerValidatorEndpoint) OnStop() { | |||||
if ve.cancelPingCh != nil { | |||||
close(ve.cancelPingCh) | |||||
} | |||||
ve.Close() | |||||
} | |||||
//-------------------------------------------------------- | |||||
// Connection and signer management | |||||
// waits to accept and sets a new connection. | |||||
// connection is closed in OnStop. | |||||
// returns true if the listener is closed | |||||
// (ie. it returns a nil conn). | |||||
func (ve *SignerValidatorEndpoint) reset() (closed bool, err error) { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
// first check if the conn already exists and close it. | |||||
if ve.signer != nil { | |||||
if tmpErr := ve.signer.Close(); tmpErr != nil { | |||||
ve.Logger.Error("error closing socket val connection during reset", "err", tmpErr) | |||||
} | |||||
} | |||||
// wait for a new conn | |||||
conn, err := ve.acceptConnection() | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
// listener is closed | |||||
if conn == nil { | |||||
return true, nil | |||||
} | |||||
ve.signer, err = NewSignerRemote(conn) | |||||
if err != nil { | |||||
// failed to fetch the pubkey. close out the connection. | |||||
if tmpErr := conn.Close(); tmpErr != nil { | |||||
ve.Logger.Error("error closing connection", "err", tmpErr) | |||||
} | |||||
return false, err | |||||
} | |||||
return false, nil | |||||
} | |||||
// Attempt to accept a connection. | |||||
// Times out after the listener's timeoutAccept | |||||
func (ve *SignerValidatorEndpoint) acceptConnection() (net.Conn, error) { | |||||
conn, err := ve.listener.Accept() | |||||
if err != nil { | |||||
if !ve.IsRunning() { | |||||
return nil, nil // Ignore error from listener closing. | |||||
} | |||||
return nil, err | |||||
} | |||||
return conn, nil | |||||
} |
@ -1,506 +0,0 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
"net" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
"github.com/tendermint/tendermint/crypto/ed25519" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
var ( | |||||
testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second | |||||
testTimeoutReadWrite = 100 * time.Millisecond | |||||
testTimeoutReadWrite2o3 = 66 * time.Millisecond // 2/3 of the other one | |||||
testTimeoutHeartbeat = 10 * time.Millisecond | |||||
testTimeoutHeartbeat3o2 = 6 * time.Millisecond // 3/2 of the other one | |||||
) | |||||
type socketTestCase struct { | |||||
addr string | |||||
dialer SocketDialer | |||||
} | |||||
func socketTestCases(t *testing.T) []socketTestCase { | |||||
tcpAddr := fmt.Sprintf("tcp://%s", testFreeTCPAddr(t)) | |||||
unixFilePath, err := testUnixAddr() | |||||
require.NoError(t, err) | |||||
unixAddr := fmt.Sprintf("unix://%s", unixFilePath) | |||||
return []socketTestCase{ | |||||
{ | |||||
addr: tcpAddr, | |||||
dialer: DialTCPFn(tcpAddr, testTimeoutReadWrite, ed25519.GenPrivKey()), | |||||
}, | |||||
{ | |||||
addr: unixAddr, | |||||
dialer: DialUnixFn(unixFilePath), | |||||
}, | |||||
} | |||||
} | |||||
func TestSocketPVAddress(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
// Execute the test within a closure to ensure the deferred statements | |||||
// are called between each for loop iteration, for isolated test cases. | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
serviceAddr := serviceEndpoint.privVal.GetPubKey().Address() | |||||
validatorAddr := validatorEndpoint.GetPubKey().Address() | |||||
assert.Equal(t, serviceAddr, validatorAddr) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVPubKey(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
clientKey := validatorEndpoint.GetPubKey() | |||||
privvalPubKey := serviceEndpoint.privVal.GetPubKey() | |||||
assert.Equal(t, privvalPubKey, clientKey) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVProposal(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
privProposal = &types.Proposal{Timestamp: ts} | |||||
clientProposal = &types.Proposal{Timestamp: ts} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
require.NoError(t, serviceEndpoint.privVal.SignProposal(chainID, privProposal)) | |||||
require.NoError(t, validatorEndpoint.SignProposal(chainID, clientProposal)) | |||||
assert.Equal(t, privProposal.Signature, clientProposal.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVVote(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
want = &types.Vote{Timestamp: ts, Type: vType} | |||||
have = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVVoteResetDeadline(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
want = &types.Vote{Timestamp: ts, Type: vType} | |||||
have = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
time.Sleep(testTimeoutReadWrite2o3) | |||||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
// This would exceed the deadline if it was not extended by the previous message | |||||
time.Sleep(testTimeoutReadWrite2o3) | |||||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVVoteKeepalive(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
want = &types.Vote{Timestamp: ts, Type: vType} | |||||
have = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
time.Sleep(testTimeoutReadWrite * 2) | |||||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVDeadline(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
listenc = make(chan struct{}) | |||||
thisConnTimeout = 100 * time.Millisecond | |||||
validatorEndpoint = newSignerValidatorEndpoint(log.TestingLogger(), tc.addr, thisConnTimeout) | |||||
) | |||||
go func(sc *SignerValidatorEndpoint) { | |||||
defer close(listenc) | |||||
// Note: the TCP connection times out at the accept() phase, | |||||
// whereas the Unix domain sockets connection times out while | |||||
// attempting to fetch the remote signer's public key. | |||||
assert.True(t, IsConnTimeout(sc.Start())) | |||||
assert.False(t, sc.IsRunning()) | |||||
}(validatorEndpoint) | |||||
for { | |||||
_, err := cmn.Connect(tc.addr) | |||||
if err == nil { | |||||
break | |||||
} | |||||
} | |||||
<-listenc | |||||
}() | |||||
} | |||||
} | |||||
func TestRemoteSignVoteErrors(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewErroringMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
vote = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
err := validatorEndpoint.SignVote("", vote) | |||||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||||
err = serviceEndpoint.privVal.SignVote(chainID, vote) | |||||
require.Error(t, err) | |||||
err = validatorEndpoint.SignVote(chainID, vote) | |||||
require.Error(t, err) | |||||
}() | |||||
} | |||||
} | |||||
func TestRemoteSignProposalErrors(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewErroringMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
proposal = &types.Proposal{Timestamp: ts} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
err := validatorEndpoint.SignProposal("", proposal) | |||||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||||
err = serviceEndpoint.privVal.SignProposal(chainID, proposal) | |||||
require.Error(t, err) | |||||
err = validatorEndpoint.SignProposal(chainID, proposal) | |||||
require.Error(t, err) | |||||
}() | |||||
} | |||||
} | |||||
func TestErrUnexpectedResponse(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
logger = log.TestingLogger() | |||||
chainID = cmn.RandStr(12) | |||||
readyCh = make(chan struct{}) | |||||
errCh = make(chan error, 1) | |||||
serviceEndpoint = NewSignerServiceEndpoint( | |||||
logger, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.dialer, | |||||
) | |||||
validatorEndpoint = newSignerValidatorEndpoint( | |||||
logger, | |||||
tc.addr, | |||||
testTimeoutReadWrite) | |||||
) | |||||
testStartEndpoint(t, readyCh, validatorEndpoint) | |||||
defer validatorEndpoint.Stop() | |||||
SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint) | |||||
SignerServiceEndpointConnRetries(100)(serviceEndpoint) | |||||
// we do not want to Start() the remote signer here and instead use the connection to | |||||
// reply with intentionally wrong replies below: | |||||
rsConn, err := serviceEndpoint.connect() | |||||
require.NoError(t, err) | |||||
require.NotNil(t, rsConn) | |||||
defer rsConn.Close() | |||||
// send over public key to get the remote signer running: | |||||
go testReadWriteResponse(t, &PubKeyResponse{}, rsConn) | |||||
<-readyCh | |||||
// Proposal: | |||||
go func(errc chan error) { | |||||
errc <- validatorEndpoint.SignProposal(chainID, &types.Proposal{}) | |||||
}(errCh) | |||||
// read request and write wrong response: | |||||
go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn) | |||||
err = <-errCh | |||||
require.Error(t, err) | |||||
require.Equal(t, err, ErrUnexpectedResponse) | |||||
// Vote: | |||||
go func(errc chan error) { | |||||
errc <- validatorEndpoint.SignVote(chainID, &types.Vote{}) | |||||
}(errCh) | |||||
// read request and write wrong response: | |||||
go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn) | |||||
err = <-errCh | |||||
require.Error(t, err) | |||||
require.Equal(t, err, ErrUnexpectedResponse) | |||||
}() | |||||
} | |||||
} | |||||
func TestRetryConnToRemoteSigner(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
logger = log.TestingLogger() | |||||
chainID = cmn.RandStr(12) | |||||
readyCh = make(chan struct{}) | |||||
serviceEndpoint = NewSignerServiceEndpoint( | |||||
logger, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.dialer, | |||||
) | |||||
thisConnTimeout = testTimeoutReadWrite | |||||
validatorEndpoint = newSignerValidatorEndpoint(logger, tc.addr, thisConnTimeout) | |||||
) | |||||
// Ping every: | |||||
SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) | |||||
SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) | |||||
SignerServiceEndpointConnRetries(10)(serviceEndpoint) | |||||
testStartEndpoint(t, readyCh, validatorEndpoint) | |||||
defer validatorEndpoint.Stop() | |||||
require.NoError(t, serviceEndpoint.Start()) | |||||
assert.True(t, serviceEndpoint.IsRunning()) | |||||
<-readyCh | |||||
time.Sleep(testTimeoutHeartbeat * 2) | |||||
serviceEndpoint.Stop() | |||||
rs2 := NewSignerServiceEndpoint( | |||||
logger, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.dialer, | |||||
) | |||||
// let some pings pass | |||||
time.Sleep(testTimeoutHeartbeat3o2) | |||||
require.NoError(t, rs2.Start()) | |||||
assert.True(t, rs2.IsRunning()) | |||||
defer rs2.Stop() | |||||
// give the client some time to re-establish the conn to the remote signer | |||||
// should see sth like this in the logs: | |||||
// | |||||
// E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" | |||||
// I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal | |||||
time.Sleep(testTimeoutReadWrite * 2) | |||||
}() | |||||
} | |||||
} | |||||
func newSignerValidatorEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerValidatorEndpoint { | |||||
proto, address := cmn.ProtocolAndAddress(addr) | |||||
ln, err := net.Listen(proto, address) | |||||
logger.Info("Listening at", "proto", proto, "address", address) | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
var listener net.Listener | |||||
if proto == "unix" { | |||||
unixLn := NewUnixListener(ln) | |||||
UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn) | |||||
UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn) | |||||
listener = unixLn | |||||
} else { | |||||
tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) | |||||
TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn) | |||||
TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn) | |||||
listener = tcpLn | |||||
} | |||||
return NewSignerValidatorEndpoint(logger, listener) | |||||
} | |||||
func testSetupSocketPair( | |||||
t *testing.T, | |||||
chainID string, | |||||
privValidator types.PrivValidator, | |||||
addr string, | |||||
socketDialer SocketDialer, | |||||
) (*SignerValidatorEndpoint, *SignerServiceEndpoint) { | |||||
var ( | |||||
logger = log.TestingLogger() | |||||
privVal = privValidator | |||||
readyc = make(chan struct{}) | |||||
serviceEndpoint = NewSignerServiceEndpoint( | |||||
logger, | |||||
chainID, | |||||
privVal, | |||||
socketDialer, | |||||
) | |||||
thisConnTimeout = testTimeoutReadWrite | |||||
validatorEndpoint = newSignerValidatorEndpoint(logger, addr, thisConnTimeout) | |||||
) | |||||
SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) | |||||
SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) | |||||
SignerServiceEndpointConnRetries(1e6)(serviceEndpoint) | |||||
testStartEndpoint(t, readyc, validatorEndpoint) | |||||
require.NoError(t, serviceEndpoint.Start()) | |||||
assert.True(t, serviceEndpoint.IsRunning()) | |||||
<-readyc | |||||
return validatorEndpoint, serviceEndpoint | |||||
} | |||||
func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) { | |||||
_, err := readMsg(rsConn) | |||||
require.NoError(t, err) | |||||
err = writeMsg(rsConn, resp) | |||||
require.NoError(t, err) | |||||
} | |||||
func testStartEndpoint(t *testing.T, readyCh chan struct{}, sc *SignerValidatorEndpoint) { | |||||
go func(sc *SignerValidatorEndpoint) { | |||||
require.NoError(t, sc.Start()) | |||||
assert.True(t, sc.IsRunning()) | |||||
readyCh <- struct{}{} | |||||
}(sc) | |||||
} | |||||
// testFreeTCPAddr claims a free port so we don't block on listener being ready. | |||||
func testFreeTCPAddr(t *testing.T) string { | |||||
ln, err := net.Listen("tcp", "127.0.0.1:0") | |||||
require.NoError(t, err) | |||||
defer ln.Close() | |||||
return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port) | |||||
} |
@ -1,26 +1,49 @@ | |||||
package privval | package privval | ||||
import ( | import ( | ||||
"fmt" | |||||
"testing" | "testing" | ||||
"time" | "time" | ||||
"github.com/stretchr/testify/assert" | "github.com/stretchr/testify/assert" | ||||
"github.com/stretchr/testify/require" | |||||
"github.com/tendermint/tendermint/crypto/ed25519" | "github.com/tendermint/tendermint/crypto/ed25519" | ||||
cmn "github.com/tendermint/tendermint/libs/common" | cmn "github.com/tendermint/tendermint/libs/common" | ||||
) | ) | ||||
func getDialerTestCases(t *testing.T) []dialerTestCase { | |||||
tcpAddr := GetFreeLocalhostAddrPort() | |||||
unixFilePath, err := testUnixAddr() | |||||
require.NoError(t, err) | |||||
unixAddr := fmt.Sprintf("unix://%s", unixFilePath) | |||||
return []dialerTestCase{ | |||||
{ | |||||
addr: tcpAddr, | |||||
dialer: DialTCPFn(tcpAddr, testTimeoutReadWrite, ed25519.GenPrivKey()), | |||||
}, | |||||
{ | |||||
addr: unixAddr, | |||||
dialer: DialUnixFn(unixFilePath), | |||||
}, | |||||
} | |||||
} | |||||
func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) { | func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) { | ||||
// Generate a networking timeout | // Generate a networking timeout | ||||
dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) | |||||
tcpAddr := GetFreeLocalhostAddrPort() | |||||
dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey()) | |||||
_, err := dialer() | _, err := dialer() | ||||
assert.Error(t, err) | assert.Error(t, err) | ||||
assert.True(t, IsConnTimeout(err)) | assert.True(t, IsConnTimeout(err)) | ||||
} | } | ||||
func TestIsConnTimeoutForWrappedConnTimeouts(t *testing.T) { | func TestIsConnTimeoutForWrappedConnTimeouts(t *testing.T) { | ||||
dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) | |||||
tcpAddr := GetFreeLocalhostAddrPort() | |||||
dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey()) | |||||
_, err := dialer() | _, err := dialer() | ||||
assert.Error(t, err) | assert.Error(t, err) | ||||
err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) | |||||
err = cmn.ErrorWrap(ErrConnectionTimeout, err.Error()) | |||||
assert.True(t, IsConnTimeout(err)) | assert.True(t, IsConnTimeout(err)) | ||||
} | } |