This issue is related to #3107 This is a first renaming/refactoring step before reworking and removing heartbeats. As discussed with @Liamsi , we preferred to go for a couple of independent and separate PRs to simplify review work. The changes: Help to clarify the relation between the validator and remote signer endpoints Differentiate between timeouts and deadlines Prepare to encapsulate networking related code behind RemoteSigner in the next PR My intention is to separate and encapsulate the "network related" code from the actual signer. SignerRemote ---(uses/contains)--> SignerValidatorEndpoint <--(connects to)--> SignerServiceEndpoint ---> SignerService (future.. not here yet but would like to decouple too) All reconnection/heartbeat/whatever code goes in the endpoints. Signer[Remote/Service] do not need to know about that. I agree Endpoint may not be the perfect name. I tried to find something "Go-ish" enough. It is a common name in go-kit, kubernetes, etc. Right now: SignerValidatorEndpoint: handles the listener contains SignerRemote Implements the PrivValidator interface connects and sets a connection object in a contained SignerRemote delegates PrivValidator some calls to SignerRemote which in turn uses the conn object that was set externally SignerRemote: Implements the PrivValidator interface read/writes from a connection object directly handles heartbeats SignerServiceEndpoint: Does most things in a single place delegates to a PrivValidator IIRC. * cleanup * Refactoring step 1 * Refactoring step 2 * move messages to another file * mark for future work / next steps * mark deprecated classes in docs * Fix linter problems * additional linter fixespull/3358/head v0.31.0-dev0
@ -1,240 +0,0 @@ | |||||
package privval | |||||
import ( | |||||
"errors" | |||||
"fmt" | |||||
"net" | |||||
"sync" | |||||
"time" | |||||
"github.com/tendermint/tendermint/crypto" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
const ( | |||||
defaultConnHeartBeatSeconds = 2 | |||||
defaultDialRetries = 10 | |||||
) | |||||
// Socket errors. | |||||
var ( | |||||
ErrUnexpectedResponse = errors.New("received unexpected response") | |||||
) | |||||
var ( | |||||
connHeartbeat = time.Second * defaultConnHeartBeatSeconds | |||||
) | |||||
// SocketValOption sets an optional parameter on the SocketVal. | |||||
type SocketValOption func(*SocketVal) | |||||
// SocketValHeartbeat sets the period on which to check the liveness of the | |||||
// connected Signer connections. | |||||
func SocketValHeartbeat(period time.Duration) SocketValOption { | |||||
return func(sc *SocketVal) { sc.connHeartbeat = period } | |||||
} | |||||
// SocketVal implements PrivValidator. | |||||
// It listens for an external process to dial in and uses | |||||
// the socket to request signatures. | |||||
type SocketVal struct { | |||||
cmn.BaseService | |||||
listener net.Listener | |||||
// ping | |||||
cancelPing chan struct{} | |||||
pingTicker *time.Ticker | |||||
connHeartbeat time.Duration | |||||
// signer is mutable since it can be | |||||
// reset if the connection fails. | |||||
// failures are detected by a background | |||||
// ping routine. | |||||
// All messages are request/response, so we hold the mutex | |||||
// so only one request/response pair can happen at a time. | |||||
// Methods on the underlying net.Conn itself | |||||
// are already gorountine safe. | |||||
mtx sync.Mutex | |||||
signer *RemoteSignerClient | |||||
} | |||||
// Check that SocketVal implements PrivValidator. | |||||
var _ types.PrivValidator = (*SocketVal)(nil) | |||||
// NewSocketVal returns an instance of SocketVal. | |||||
func NewSocketVal( | |||||
logger log.Logger, | |||||
listener net.Listener, | |||||
) *SocketVal { | |||||
sc := &SocketVal{ | |||||
listener: listener, | |||||
connHeartbeat: connHeartbeat, | |||||
} | |||||
sc.BaseService = *cmn.NewBaseService(logger, "SocketVal", sc) | |||||
return sc | |||||
} | |||||
//-------------------------------------------------------- | |||||
// Implement PrivValidator | |||||
// GetPubKey implements PrivValidator. | |||||
func (sc *SocketVal) GetPubKey() crypto.PubKey { | |||||
sc.mtx.Lock() | |||||
defer sc.mtx.Unlock() | |||||
return sc.signer.GetPubKey() | |||||
} | |||||
// SignVote implements PrivValidator. | |||||
func (sc *SocketVal) SignVote(chainID string, vote *types.Vote) error { | |||||
sc.mtx.Lock() | |||||
defer sc.mtx.Unlock() | |||||
return sc.signer.SignVote(chainID, vote) | |||||
} | |||||
// SignProposal implements PrivValidator. | |||||
func (sc *SocketVal) SignProposal(chainID string, proposal *types.Proposal) error { | |||||
sc.mtx.Lock() | |||||
defer sc.mtx.Unlock() | |||||
return sc.signer.SignProposal(chainID, proposal) | |||||
} | |||||
//-------------------------------------------------------- | |||||
// More thread safe methods proxied to the signer | |||||
// Ping is used to check connection health. | |||||
func (sc *SocketVal) Ping() error { | |||||
sc.mtx.Lock() | |||||
defer sc.mtx.Unlock() | |||||
return sc.signer.Ping() | |||||
} | |||||
// Close closes the underlying net.Conn. | |||||
func (sc *SocketVal) Close() { | |||||
sc.mtx.Lock() | |||||
defer sc.mtx.Unlock() | |||||
if sc.signer != nil { | |||||
if err := sc.signer.Close(); err != nil { | |||||
sc.Logger.Error("OnStop", "err", err) | |||||
} | |||||
} | |||||
if sc.listener != nil { | |||||
if err := sc.listener.Close(); err != nil { | |||||
sc.Logger.Error("OnStop", "err", err) | |||||
} | |||||
} | |||||
} | |||||
//-------------------------------------------------------- | |||||
// Service start and stop | |||||
// OnStart implements cmn.Service. | |||||
func (sc *SocketVal) OnStart() error { | |||||
if closed, err := sc.reset(); err != nil { | |||||
sc.Logger.Error("OnStart", "err", err) | |||||
return err | |||||
} else if closed { | |||||
return fmt.Errorf("listener is closed") | |||||
} | |||||
// Start a routine to keep the connection alive | |||||
sc.cancelPing = make(chan struct{}, 1) | |||||
sc.pingTicker = time.NewTicker(sc.connHeartbeat) | |||||
go func() { | |||||
for { | |||||
select { | |||||
case <-sc.pingTicker.C: | |||||
err := sc.Ping() | |||||
if err != nil { | |||||
sc.Logger.Error("Ping", "err", err) | |||||
if err == ErrUnexpectedResponse { | |||||
return | |||||
} | |||||
closed, err := sc.reset() | |||||
if err != nil { | |||||
sc.Logger.Error("Reconnecting to remote signer failed", "err", err) | |||||
continue | |||||
} | |||||
if closed { | |||||
sc.Logger.Info("listener is closing") | |||||
return | |||||
} | |||||
sc.Logger.Info("Re-created connection to remote signer", "impl", sc) | |||||
} | |||||
case <-sc.cancelPing: | |||||
sc.pingTicker.Stop() | |||||
return | |||||
} | |||||
} | |||||
}() | |||||
return nil | |||||
} | |||||
// OnStop implements cmn.Service. | |||||
func (sc *SocketVal) OnStop() { | |||||
if sc.cancelPing != nil { | |||||
close(sc.cancelPing) | |||||
} | |||||
sc.Close() | |||||
} | |||||
//-------------------------------------------------------- | |||||
// Connection and signer management | |||||
// waits to accept and sets a new connection. | |||||
// connection is closed in OnStop. | |||||
// returns true if the listener is closed | |||||
// (ie. it returns a nil conn). | |||||
func (sc *SocketVal) reset() (closed bool, err error) { | |||||
sc.mtx.Lock() | |||||
defer sc.mtx.Unlock() | |||||
// first check if the conn already exists and close it. | |||||
if sc.signer != nil { | |||||
if err := sc.signer.Close(); err != nil { | |||||
sc.Logger.Error("error closing socket val connection during reset", "err", err) | |||||
} | |||||
} | |||||
// wait for a new conn | |||||
conn, err := sc.acceptConnection() | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
// listener is closed | |||||
if conn == nil { | |||||
return true, nil | |||||
} | |||||
sc.signer, err = NewRemoteSignerClient(conn) | |||||
if err != nil { | |||||
// failed to fetch the pubkey. close out the connection. | |||||
if err := conn.Close(); err != nil { | |||||
sc.Logger.Error("error closing connection", "err", err) | |||||
} | |||||
return false, err | |||||
} | |||||
return false, nil | |||||
} | |||||
// Attempt to accept a connection. | |||||
// Times out after the listener's acceptDeadline | |||||
func (sc *SocketVal) acceptConnection() (net.Conn, error) { | |||||
conn, err := sc.listener.Accept() | |||||
if err != nil { | |||||
if !sc.IsRunning() { | |||||
return nil, nil // Ignore error from listener closing. | |||||
} | |||||
return nil, err | |||||
} | |||||
return conn, nil | |||||
} |
@ -1,461 +0,0 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
"net" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
"github.com/tendermint/tendermint/crypto/ed25519" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
var ( | |||||
testAcceptDeadline = defaultAcceptDeadlineSeconds * time.Second | |||||
testConnDeadline = 100 * time.Millisecond | |||||
testConnDeadline2o3 = 66 * time.Millisecond // 2/3 of the other one | |||||
testHeartbeatTimeout = 10 * time.Millisecond | |||||
testHeartbeatTimeout3o2 = 6 * time.Millisecond // 3/2 of the other one | |||||
) | |||||
type socketTestCase struct { | |||||
addr string | |||||
dialer Dialer | |||||
} | |||||
func socketTestCases(t *testing.T) []socketTestCase { | |||||
tcpAddr := fmt.Sprintf("tcp://%s", testFreeTCPAddr(t)) | |||||
unixFilePath, err := testUnixAddr() | |||||
require.NoError(t, err) | |||||
unixAddr := fmt.Sprintf("unix://%s", unixFilePath) | |||||
return []socketTestCase{ | |||||
{ | |||||
addr: tcpAddr, | |||||
dialer: DialTCPFn(tcpAddr, testConnDeadline, ed25519.GenPrivKey()), | |||||
}, | |||||
{ | |||||
addr: unixAddr, | |||||
dialer: DialUnixFn(unixFilePath), | |||||
}, | |||||
} | |||||
} | |||||
func TestSocketPVAddress(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
// Execute the test within a closure to ensure the deferred statements | |||||
// are called between each for loop iteration, for isolated test cases. | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) | |||||
) | |||||
defer sc.Stop() | |||||
defer rs.Stop() | |||||
serverAddr := rs.privVal.GetPubKey().Address() | |||||
clientAddr := sc.GetPubKey().Address() | |||||
assert.Equal(t, serverAddr, clientAddr) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVPubKey(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) | |||||
) | |||||
defer sc.Stop() | |||||
defer rs.Stop() | |||||
clientKey := sc.GetPubKey() | |||||
privvalPubKey := rs.privVal.GetPubKey() | |||||
assert.Equal(t, privvalPubKey, clientKey) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVProposal(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) | |||||
ts = time.Now() | |||||
privProposal = &types.Proposal{Timestamp: ts} | |||||
clientProposal = &types.Proposal{Timestamp: ts} | |||||
) | |||||
defer sc.Stop() | |||||
defer rs.Stop() | |||||
require.NoError(t, rs.privVal.SignProposal(chainID, privProposal)) | |||||
require.NoError(t, sc.SignProposal(chainID, clientProposal)) | |||||
assert.Equal(t, privProposal.Signature, clientProposal.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVVote(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
want = &types.Vote{Timestamp: ts, Type: vType} | |||||
have = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer sc.Stop() | |||||
defer rs.Stop() | |||||
require.NoError(t, rs.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, sc.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVVoteResetDeadline(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
want = &types.Vote{Timestamp: ts, Type: vType} | |||||
have = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer sc.Stop() | |||||
defer rs.Stop() | |||||
time.Sleep(testConnDeadline2o3) | |||||
require.NoError(t, rs.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, sc.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
// This would exceed the deadline if it was not extended by the previous message | |||||
time.Sleep(testConnDeadline2o3) | |||||
require.NoError(t, rs.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, sc.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVVoteKeepalive(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
want = &types.Vote{Timestamp: ts, Type: vType} | |||||
have = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer sc.Stop() | |||||
defer rs.Stop() | |||||
time.Sleep(testConnDeadline * 2) | |||||
require.NoError(t, rs.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, sc.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVDeadline(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
listenc = make(chan struct{}) | |||||
thisConnTimeout = 100 * time.Millisecond | |||||
sc = newSocketVal(log.TestingLogger(), tc.addr, thisConnTimeout) | |||||
) | |||||
go func(sc *SocketVal) { | |||||
defer close(listenc) | |||||
// Note: the TCP connection times out at the accept() phase, | |||||
// whereas the Unix domain sockets connection times out while | |||||
// attempting to fetch the remote signer's public key. | |||||
assert.True(t, IsConnTimeout(sc.Start())) | |||||
assert.False(t, sc.IsRunning()) | |||||
}(sc) | |||||
for { | |||||
_, err := cmn.Connect(tc.addr) | |||||
if err == nil { | |||||
break | |||||
} | |||||
} | |||||
<-listenc | |||||
}() | |||||
} | |||||
} | |||||
func TestRemoteSignVoteErrors(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
sc, rs = testSetupSocketPair(t, chainID, types.NewErroringMockPV(), tc.addr, tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
vote = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer sc.Stop() | |||||
defer rs.Stop() | |||||
err := sc.SignVote("", vote) | |||||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||||
err = rs.privVal.SignVote(chainID, vote) | |||||
require.Error(t, err) | |||||
err = sc.SignVote(chainID, vote) | |||||
require.Error(t, err) | |||||
}() | |||||
} | |||||
} | |||||
func TestRemoteSignProposalErrors(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
sc, rs = testSetupSocketPair(t, chainID, types.NewErroringMockPV(), tc.addr, tc.dialer) | |||||
ts = time.Now() | |||||
proposal = &types.Proposal{Timestamp: ts} | |||||
) | |||||
defer sc.Stop() | |||||
defer rs.Stop() | |||||
err := sc.SignProposal("", proposal) | |||||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||||
err = rs.privVal.SignProposal(chainID, proposal) | |||||
require.Error(t, err) | |||||
err = sc.SignProposal(chainID, proposal) | |||||
require.Error(t, err) | |||||
}() | |||||
} | |||||
} | |||||
func TestErrUnexpectedResponse(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
logger = log.TestingLogger() | |||||
chainID = cmn.RandStr(12) | |||||
readyc = make(chan struct{}) | |||||
errc = make(chan error, 1) | |||||
rs = NewRemoteSigner( | |||||
logger, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.dialer, | |||||
) | |||||
sc = newSocketVal(logger, tc.addr, testConnDeadline) | |||||
) | |||||
testStartSocketPV(t, readyc, sc) | |||||
defer sc.Stop() | |||||
RemoteSignerConnDeadline(time.Millisecond)(rs) | |||||
RemoteSignerConnRetries(100)(rs) | |||||
// we do not want to Start() the remote signer here and instead use the connection to | |||||
// reply with intentionally wrong replies below: | |||||
rsConn, err := rs.connect() | |||||
defer rsConn.Close() | |||||
require.NoError(t, err) | |||||
require.NotNil(t, rsConn) | |||||
// send over public key to get the remote signer running: | |||||
go testReadWriteResponse(t, &PubKeyResponse{}, rsConn) | |||||
<-readyc | |||||
// Proposal: | |||||
go func(errc chan error) { | |||||
errc <- sc.SignProposal(chainID, &types.Proposal{}) | |||||
}(errc) | |||||
// read request and write wrong response: | |||||
go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn) | |||||
err = <-errc | |||||
require.Error(t, err) | |||||
require.Equal(t, err, ErrUnexpectedResponse) | |||||
// Vote: | |||||
go func(errc chan error) { | |||||
errc <- sc.SignVote(chainID, &types.Vote{}) | |||||
}(errc) | |||||
// read request and write wrong response: | |||||
go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn) | |||||
err = <-errc | |||||
require.Error(t, err) | |||||
require.Equal(t, err, ErrUnexpectedResponse) | |||||
}() | |||||
} | |||||
} | |||||
func TestRetryConnToRemoteSigner(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
logger = log.TestingLogger() | |||||
chainID = cmn.RandStr(12) | |||||
readyc = make(chan struct{}) | |||||
rs = NewRemoteSigner( | |||||
logger, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.dialer, | |||||
) | |||||
thisConnTimeout = testConnDeadline | |||||
sc = newSocketVal(logger, tc.addr, thisConnTimeout) | |||||
) | |||||
// Ping every: | |||||
SocketValHeartbeat(testHeartbeatTimeout)(sc) | |||||
RemoteSignerConnDeadline(testConnDeadline)(rs) | |||||
RemoteSignerConnRetries(10)(rs) | |||||
testStartSocketPV(t, readyc, sc) | |||||
defer sc.Stop() | |||||
require.NoError(t, rs.Start()) | |||||
assert.True(t, rs.IsRunning()) | |||||
<-readyc | |||||
time.Sleep(testHeartbeatTimeout * 2) | |||||
rs.Stop() | |||||
rs2 := NewRemoteSigner( | |||||
logger, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.dialer, | |||||
) | |||||
// let some pings pass | |||||
time.Sleep(testHeartbeatTimeout3o2) | |||||
require.NoError(t, rs2.Start()) | |||||
assert.True(t, rs2.IsRunning()) | |||||
defer rs2.Stop() | |||||
// give the client some time to re-establish the conn to the remote signer | |||||
// should see sth like this in the logs: | |||||
// | |||||
// E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" | |||||
// I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal | |||||
time.Sleep(testConnDeadline * 2) | |||||
}() | |||||
} | |||||
} | |||||
func newSocketVal(logger log.Logger, addr string, connDeadline time.Duration) *SocketVal { | |||||
proto, address := cmn.ProtocolAndAddress(addr) | |||||
ln, err := net.Listen(proto, address) | |||||
logger.Info("Listening at", "proto", proto, "address", address) | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
var svln net.Listener | |||||
if proto == "unix" { | |||||
unixLn := NewUnixListener(ln) | |||||
UnixListenerAcceptDeadline(testAcceptDeadline)(unixLn) | |||||
UnixListenerConnDeadline(connDeadline)(unixLn) | |||||
svln = unixLn | |||||
} else { | |||||
tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) | |||||
TCPListenerAcceptDeadline(testAcceptDeadline)(tcpLn) | |||||
TCPListenerConnDeadline(connDeadline)(tcpLn) | |||||
svln = tcpLn | |||||
} | |||||
return NewSocketVal(logger, svln) | |||||
} | |||||
func testSetupSocketPair( | |||||
t *testing.T, | |||||
chainID string, | |||||
privValidator types.PrivValidator, | |||||
addr string, | |||||
dialer Dialer, | |||||
) (*SocketVal, *RemoteSigner) { | |||||
var ( | |||||
logger = log.TestingLogger() | |||||
privVal = privValidator | |||||
readyc = make(chan struct{}) | |||||
rs = NewRemoteSigner( | |||||
logger, | |||||
chainID, | |||||
privVal, | |||||
dialer, | |||||
) | |||||
thisConnTimeout = testConnDeadline | |||||
sc = newSocketVal(logger, addr, thisConnTimeout) | |||||
) | |||||
SocketValHeartbeat(testHeartbeatTimeout)(sc) | |||||
RemoteSignerConnDeadline(testConnDeadline)(rs) | |||||
RemoteSignerConnRetries(1e6)(rs) | |||||
testStartSocketPV(t, readyc, sc) | |||||
require.NoError(t, rs.Start()) | |||||
assert.True(t, rs.IsRunning()) | |||||
<-readyc | |||||
return sc, rs | |||||
} | |||||
func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) { | |||||
_, err := readMsg(rsConn) | |||||
require.NoError(t, err) | |||||
err = writeMsg(rsConn, resp) | |||||
require.NoError(t, err) | |||||
} | |||||
func testStartSocketPV(t *testing.T, readyc chan struct{}, sc *SocketVal) { | |||||
go func(sc *SocketVal) { | |||||
require.NoError(t, sc.Start()) | |||||
assert.True(t, sc.IsRunning()) | |||||
readyc <- struct{}{} | |||||
}(sc) | |||||
} | |||||
// testFreeTCPAddr claims a free port so we don't block on listener being ready. | |||||
func testFreeTCPAddr(t *testing.T) string { | |||||
ln, err := net.Listen("tcp", "127.0.0.1:0") | |||||
require.NoError(t, err) | |||||
defer ln.Close() | |||||
return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port) | |||||
} |
@ -0,0 +1,22 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
) | |||||
// Socket errors. | |||||
var ( | |||||
ErrUnexpectedResponse = fmt.Errorf("received unexpected response") | |||||
ErrConnTimeout = fmt.Errorf("remote signer timed out") | |||||
) | |||||
// RemoteSignerError allows (remote) validators to include meaningful error descriptions in their reply. | |||||
type RemoteSignerError struct { | |||||
// TODO(ismail): create an enum of known errors | |||||
Code int | |||||
Description string | |||||
} | |||||
func (e *RemoteSignerError) Error() string { | |||||
return fmt.Sprintf("signerServiceEndpoint returned error #%d: %s", e.Code, e.Description) | |||||
} |
@ -0,0 +1,61 @@ | |||||
package privval | |||||
import ( | |||||
amino "github.com/tendermint/go-amino" | |||||
"github.com/tendermint/tendermint/crypto" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
// RemoteSignerMsg is sent between SignerServiceEndpoint and the SignerServiceEndpoint client. | |||||
type RemoteSignerMsg interface{} | |||||
func RegisterRemoteSignerMsg(cdc *amino.Codec) { | |||||
cdc.RegisterInterface((*RemoteSignerMsg)(nil), nil) | |||||
cdc.RegisterConcrete(&PubKeyRequest{}, "tendermint/remotesigner/PubKeyRequest", nil) | |||||
cdc.RegisterConcrete(&PubKeyResponse{}, "tendermint/remotesigner/PubKeyResponse", nil) | |||||
cdc.RegisterConcrete(&SignVoteRequest{}, "tendermint/remotesigner/SignVoteRequest", nil) | |||||
cdc.RegisterConcrete(&SignedVoteResponse{}, "tendermint/remotesigner/SignedVoteResponse", nil) | |||||
cdc.RegisterConcrete(&SignProposalRequest{}, "tendermint/remotesigner/SignProposalRequest", nil) | |||||
cdc.RegisterConcrete(&SignedProposalResponse{}, "tendermint/remotesigner/SignedProposalResponse", nil) | |||||
cdc.RegisterConcrete(&PingRequest{}, "tendermint/remotesigner/PingRequest", nil) | |||||
cdc.RegisterConcrete(&PingResponse{}, "tendermint/remotesigner/PingResponse", nil) | |||||
} | |||||
// PubKeyRequest requests the consensus public key from the remote signer. | |||||
type PubKeyRequest struct{} | |||||
// PubKeyResponse is a PrivValidatorSocket message containing the public key. | |||||
type PubKeyResponse struct { | |||||
PubKey crypto.PubKey | |||||
Error *RemoteSignerError | |||||
} | |||||
// SignVoteRequest is a PrivValidatorSocket message containing a vote. | |||||
type SignVoteRequest struct { | |||||
Vote *types.Vote | |||||
} | |||||
// SignedVoteResponse is a PrivValidatorSocket message containing a signed vote along with a potenial error message. | |||||
type SignedVoteResponse struct { | |||||
Vote *types.Vote | |||||
Error *RemoteSignerError | |||||
} | |||||
// SignProposalRequest is a PrivValidatorSocket message containing a Proposal. | |||||
type SignProposalRequest struct { | |||||
Proposal *types.Proposal | |||||
} | |||||
// SignedProposalResponse is a PrivValidatorSocket message containing a proposal response | |||||
type SignedProposalResponse struct { | |||||
Proposal *types.Proposal | |||||
Error *RemoteSignerError | |||||
} | |||||
// PingRequest is a PrivValidatorSocket message to keep the connection alive. | |||||
type PingRequest struct { | |||||
} | |||||
// PingRequest is a PrivValidatorSocket response to keep the connection alive. | |||||
type PingResponse struct { | |||||
} |
@ -1,90 +0,0 @@ | |||||
package privval | |||||
import ( | |||||
"net" | |||||
"testing" | |||||
"time" | |||||
"github.com/pkg/errors" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
"github.com/tendermint/tendermint/crypto/ed25519" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
// TestRemoteSignerRetryTCPOnly will test connection retry attempts over TCP. We | |||||
// don't need this for Unix sockets because the OS instantly knows the state of | |||||
// both ends of the socket connection. This basically causes the | |||||
// RemoteSigner.dialer() call inside RemoteSigner.connect() to return | |||||
// successfully immediately, putting an instant stop to any retry attempts. | |||||
func TestRemoteSignerRetryTCPOnly(t *testing.T) { | |||||
var ( | |||||
attemptc = make(chan int) | |||||
retries = 2 | |||||
) | |||||
ln, err := net.Listen("tcp", "127.0.0.1:0") | |||||
require.NoError(t, err) | |||||
go func(ln net.Listener, attemptc chan<- int) { | |||||
attempts := 0 | |||||
for { | |||||
conn, err := ln.Accept() | |||||
require.NoError(t, err) | |||||
err = conn.Close() | |||||
require.NoError(t, err) | |||||
attempts++ | |||||
if attempts == retries { | |||||
attemptc <- attempts | |||||
break | |||||
} | |||||
} | |||||
}(ln, attemptc) | |||||
rs := NewRemoteSigner( | |||||
log.TestingLogger(), | |||||
cmn.RandStr(12), | |||||
types.NewMockPV(), | |||||
DialTCPFn(ln.Addr().String(), testConnDeadline, ed25519.GenPrivKey()), | |||||
) | |||||
defer rs.Stop() | |||||
RemoteSignerConnDeadline(time.Millisecond)(rs) | |||||
RemoteSignerConnRetries(retries)(rs) | |||||
assert.Equal(t, rs.Start(), ErrDialRetryMax) | |||||
select { | |||||
case attempts := <-attemptc: | |||||
assert.Equal(t, retries, attempts) | |||||
case <-time.After(100 * time.Millisecond): | |||||
t.Error("expected remote to observe connection attempts") | |||||
} | |||||
} | |||||
func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) { | |||||
// Generate a networking timeout | |||||
dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) | |||||
_, err := dialer() | |||||
assert.Error(t, err) | |||||
assert.True(t, IsConnTimeout(err)) | |||||
} | |||||
func TestIsConnTimeoutForWrappedConnTimeouts(t *testing.T) { | |||||
dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) | |||||
_, err := dialer() | |||||
assert.Error(t, err) | |||||
err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) | |||||
assert.True(t, IsConnTimeout(err)) | |||||
} | |||||
func TestIsConnTimeoutForNonTimeoutErrors(t *testing.T) { | |||||
assert.False(t, IsConnTimeout(cmn.ErrorWrap(ErrDialRetryMax, "max retries exceeded"))) | |||||
assert.False(t, IsConnTimeout(errors.New("completely irrelevant error"))) | |||||
} |
@ -1,168 +0,0 @@ | |||||
package privval | |||||
import ( | |||||
"io" | |||||
"net" | |||||
"time" | |||||
"github.com/pkg/errors" | |||||
"github.com/tendermint/tendermint/crypto/ed25519" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
p2pconn "github.com/tendermint/tendermint/p2p/conn" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
// Socket errors. | |||||
var ( | |||||
ErrDialRetryMax = errors.New("dialed maximum retries") | |||||
) | |||||
// RemoteSignerOption sets an optional parameter on the RemoteSigner. | |||||
type RemoteSignerOption func(*RemoteSigner) | |||||
// RemoteSignerConnDeadline sets the read and write deadline for connections | |||||
// from external signing processes. | |||||
func RemoteSignerConnDeadline(deadline time.Duration) RemoteSignerOption { | |||||
return func(ss *RemoteSigner) { ss.connDeadline = deadline } | |||||
} | |||||
// RemoteSignerConnRetries sets the amount of attempted retries to connect. | |||||
func RemoteSignerConnRetries(retries int) RemoteSignerOption { | |||||
return func(ss *RemoteSigner) { ss.connRetries = retries } | |||||
} | |||||
// RemoteSigner dials using its dialer and responds to any | |||||
// signature requests using its privVal. | |||||
type RemoteSigner struct { | |||||
cmn.BaseService | |||||
chainID string | |||||
connDeadline time.Duration | |||||
connRetries int | |||||
privVal types.PrivValidator | |||||
dialer Dialer | |||||
conn net.Conn | |||||
} | |||||
// Dialer dials a remote address and returns a net.Conn or an error. | |||||
type Dialer func() (net.Conn, error) | |||||
// DialTCPFn dials the given tcp addr, using the given connTimeout and privKey for the | |||||
// authenticated encryption handshake. | |||||
func DialTCPFn(addr string, connTimeout time.Duration, privKey ed25519.PrivKeyEd25519) Dialer { | |||||
return func() (net.Conn, error) { | |||||
conn, err := cmn.Connect(addr) | |||||
if err == nil { | |||||
err = conn.SetDeadline(time.Now().Add(connTimeout)) | |||||
} | |||||
if err == nil { | |||||
conn, err = p2pconn.MakeSecretConnection(conn, privKey) | |||||
} | |||||
return conn, err | |||||
} | |||||
} | |||||
// DialUnixFn dials the given unix socket. | |||||
func DialUnixFn(addr string) Dialer { | |||||
return func() (net.Conn, error) { | |||||
unixAddr := &net.UnixAddr{Name: addr, Net: "unix"} | |||||
return net.DialUnix("unix", nil, unixAddr) | |||||
} | |||||
} | |||||
// NewRemoteSigner return a RemoteSigner that will dial using the given | |||||
// dialer and respond to any signature requests over the connection | |||||
// using the given privVal. | |||||
func NewRemoteSigner( | |||||
logger log.Logger, | |||||
chainID string, | |||||
privVal types.PrivValidator, | |||||
dialer Dialer, | |||||
) *RemoteSigner { | |||||
rs := &RemoteSigner{ | |||||
chainID: chainID, | |||||
connDeadline: time.Second * defaultConnDeadlineSeconds, | |||||
connRetries: defaultDialRetries, | |||||
privVal: privVal, | |||||
dialer: dialer, | |||||
} | |||||
rs.BaseService = *cmn.NewBaseService(logger, "RemoteSigner", rs) | |||||
return rs | |||||
} | |||||
// OnStart implements cmn.Service. | |||||
func (rs *RemoteSigner) OnStart() error { | |||||
conn, err := rs.connect() | |||||
if err != nil { | |||||
rs.Logger.Error("OnStart", "err", err) | |||||
return err | |||||
} | |||||
rs.conn = conn | |||||
go rs.handleConnection(conn) | |||||
return nil | |||||
} | |||||
// OnStop implements cmn.Service. | |||||
func (rs *RemoteSigner) OnStop() { | |||||
if rs.conn == nil { | |||||
return | |||||
} | |||||
if err := rs.conn.Close(); err != nil { | |||||
rs.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed")) | |||||
} | |||||
} | |||||
func (rs *RemoteSigner) connect() (net.Conn, error) { | |||||
for retries := rs.connRetries; retries > 0; retries-- { | |||||
// Don't sleep if it is the first retry. | |||||
if retries != rs.connRetries { | |||||
time.Sleep(rs.connDeadline) | |||||
} | |||||
conn, err := rs.dialer() | |||||
if err != nil { | |||||
rs.Logger.Error("dialing", "err", err) | |||||
continue | |||||
} | |||||
return conn, nil | |||||
} | |||||
return nil, ErrDialRetryMax | |||||
} | |||||
func (rs *RemoteSigner) handleConnection(conn net.Conn) { | |||||
for { | |||||
if !rs.IsRunning() { | |||||
return // Ignore error from listener closing. | |||||
} | |||||
// Reset the connection deadline | |||||
conn.SetDeadline(time.Now().Add(rs.connDeadline)) | |||||
req, err := readMsg(conn) | |||||
if err != nil { | |||||
if err != io.EOF { | |||||
rs.Logger.Error("handleConnection readMsg", "err", err) | |||||
} | |||||
return | |||||
} | |||||
res, err := handleRequest(req, rs.chainID, rs.privVal) | |||||
if err != nil { | |||||
// only log the error; we'll reply with an error in res | |||||
rs.Logger.Error("handleConnection handleRequest", "err", err) | |||||
} | |||||
err = writeMsg(conn, res) | |||||
if err != nil { | |||||
rs.Logger.Error("handleConnection writeMsg", "err", err) | |||||
return | |||||
} | |||||
} | |||||
} |
@ -0,0 +1,68 @@ | |||||
package privval | |||||
import ( | |||||
"net" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
"github.com/tendermint/tendermint/crypto/ed25519" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
// TestSignerRemoteRetryTCPOnly will test connection retry attempts over TCP. We | |||||
// don't need this for Unix sockets because the OS instantly knows the state of | |||||
// both ends of the socket connection. This basically causes the | |||||
// SignerServiceEndpoint.dialer() call inside SignerServiceEndpoint.connect() to return | |||||
// successfully immediately, putting an instant stop to any retry attempts. | |||||
func TestSignerRemoteRetryTCPOnly(t *testing.T) { | |||||
var ( | |||||
attemptCh = make(chan int) | |||||
retries = 2 | |||||
) | |||||
ln, err := net.Listen("tcp", "127.0.0.1:0") | |||||
require.NoError(t, err) | |||||
go func(ln net.Listener, attemptCh chan<- int) { | |||||
attempts := 0 | |||||
for { | |||||
conn, err := ln.Accept() | |||||
require.NoError(t, err) | |||||
err = conn.Close() | |||||
require.NoError(t, err) | |||||
attempts++ | |||||
if attempts == retries { | |||||
attemptCh <- attempts | |||||
break | |||||
} | |||||
} | |||||
}(ln, attemptCh) | |||||
serviceEndpoint := NewSignerServiceEndpoint( | |||||
log.TestingLogger(), | |||||
cmn.RandStr(12), | |||||
types.NewMockPV(), | |||||
DialTCPFn(ln.Addr().String(), testTimeoutReadWrite, ed25519.GenPrivKey()), | |||||
) | |||||
defer serviceEndpoint.Stop() | |||||
SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint) | |||||
SignerServiceEndpointConnRetries(retries)(serviceEndpoint) | |||||
assert.Equal(t, serviceEndpoint.Start(), ErrDialRetryMax) | |||||
select { | |||||
case attempts := <-attemptCh: | |||||
assert.Equal(t, retries, attempts) | |||||
case <-time.After(100 * time.Millisecond): | |||||
t.Error("expected remote to observe connection attempts") | |||||
} | |||||
} |
@ -0,0 +1,139 @@ | |||||
package privval | |||||
import ( | |||||
"io" | |||||
"net" | |||||
"time" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
// SignerServiceEndpointOption sets an optional parameter on the SignerServiceEndpoint. | |||||
type SignerServiceEndpointOption func(*SignerServiceEndpoint) | |||||
// SignerServiceEndpointTimeoutReadWrite sets the read and write timeout for connections | |||||
// from external signing processes. | |||||
func SignerServiceEndpointTimeoutReadWrite(timeout time.Duration) SignerServiceEndpointOption { | |||||
return func(ss *SignerServiceEndpoint) { ss.timeoutReadWrite = timeout } | |||||
} | |||||
// SignerServiceEndpointConnRetries sets the amount of attempted retries to connect. | |||||
func SignerServiceEndpointConnRetries(retries int) SignerServiceEndpointOption { | |||||
return func(ss *SignerServiceEndpoint) { ss.connRetries = retries } | |||||
} | |||||
// SignerServiceEndpoint dials using its dialer and responds to any | |||||
// signature requests using its privVal. | |||||
type SignerServiceEndpoint struct { | |||||
cmn.BaseService | |||||
chainID string | |||||
timeoutReadWrite time.Duration | |||||
connRetries int | |||||
privVal types.PrivValidator | |||||
dialer SocketDialer | |||||
conn net.Conn | |||||
} | |||||
// NewSignerServiceEndpoint returns a SignerServiceEndpoint that will dial using the given | |||||
// dialer and respond to any signature requests over the connection | |||||
// using the given privVal. | |||||
func NewSignerServiceEndpoint( | |||||
logger log.Logger, | |||||
chainID string, | |||||
privVal types.PrivValidator, | |||||
dialer SocketDialer, | |||||
) *SignerServiceEndpoint { | |||||
se := &SignerServiceEndpoint{ | |||||
chainID: chainID, | |||||
timeoutReadWrite: time.Second * defaultTimeoutReadWriteSeconds, | |||||
connRetries: defaultMaxDialRetries, | |||||
privVal: privVal, | |||||
dialer: dialer, | |||||
} | |||||
se.BaseService = *cmn.NewBaseService(logger, "SignerServiceEndpoint", se) | |||||
return se | |||||
} | |||||
// OnStart implements cmn.Service. | |||||
func (se *SignerServiceEndpoint) OnStart() error { | |||||
conn, err := se.connect() | |||||
if err != nil { | |||||
se.Logger.Error("OnStart", "err", err) | |||||
return err | |||||
} | |||||
se.conn = conn | |||||
go se.handleConnection(conn) | |||||
return nil | |||||
} | |||||
// OnStop implements cmn.Service. | |||||
func (se *SignerServiceEndpoint) OnStop() { | |||||
if se.conn == nil { | |||||
return | |||||
} | |||||
if err := se.conn.Close(); err != nil { | |||||
se.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed")) | |||||
} | |||||
} | |||||
func (se *SignerServiceEndpoint) connect() (net.Conn, error) { | |||||
for retries := 0; retries < se.connRetries; retries++ { | |||||
// Don't sleep if it is the first retry. | |||||
if retries > 0 { | |||||
time.Sleep(se.timeoutReadWrite) | |||||
} | |||||
conn, err := se.dialer() | |||||
if err == nil { | |||||
return conn, nil | |||||
} | |||||
se.Logger.Error("dialing", "err", err) | |||||
} | |||||
return nil, ErrDialRetryMax | |||||
} | |||||
func (se *SignerServiceEndpoint) handleConnection(conn net.Conn) { | |||||
for { | |||||
if !se.IsRunning() { | |||||
return // Ignore error from listener closing. | |||||
} | |||||
// Reset the connection deadline | |||||
deadline := time.Now().Add(se.timeoutReadWrite) | |||||
err := conn.SetDeadline(deadline) | |||||
if err != nil { | |||||
return | |||||
} | |||||
req, err := readMsg(conn) | |||||
if err != nil { | |||||
if err != io.EOF { | |||||
se.Logger.Error("handleConnection readMsg", "err", err) | |||||
} | |||||
return | |||||
} | |||||
res, err := handleRequest(req, se.chainID, se.privVal) | |||||
if err != nil { | |||||
// only log the error; we'll reply with an error in res | |||||
se.Logger.Error("handleConnection handleRequest", "err", err) | |||||
} | |||||
err = writeMsg(conn, res) | |||||
if err != nil { | |||||
se.Logger.Error("handleConnection writeMsg", "err", err) | |||||
return | |||||
} | |||||
} | |||||
} |
@ -0,0 +1,230 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
"net" | |||||
"sync" | |||||
"time" | |||||
"github.com/tendermint/tendermint/crypto" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
const ( | |||||
defaultHeartbeatSeconds = 2 | |||||
defaultMaxDialRetries = 10 | |||||
) | |||||
var ( | |||||
heartbeatPeriod = time.Second * defaultHeartbeatSeconds | |||||
) | |||||
// SignerValidatorEndpointOption sets an optional parameter on the SocketVal. | |||||
type SignerValidatorEndpointOption func(*SignerValidatorEndpoint) | |||||
// SignerValidatorEndpointSetHeartbeat sets the period on which to check the liveness of the | |||||
// connected Signer connections. | |||||
func SignerValidatorEndpointSetHeartbeat(period time.Duration) SignerValidatorEndpointOption { | |||||
return func(sc *SignerValidatorEndpoint) { sc.heartbeatPeriod = period } | |||||
} | |||||
// SocketVal implements PrivValidator. | |||||
// It listens for an external process to dial in and uses | |||||
// the socket to request signatures. | |||||
type SignerValidatorEndpoint struct { | |||||
cmn.BaseService | |||||
listener net.Listener | |||||
// ping | |||||
cancelPingCh chan struct{} | |||||
pingTicker *time.Ticker | |||||
heartbeatPeriod time.Duration | |||||
// signer is mutable since it can be reset if the connection fails. | |||||
// failures are detected by a background ping routine. | |||||
// All messages are request/response, so we hold the mutex | |||||
// so only one request/response pair can happen at a time. | |||||
// Methods on the underlying net.Conn itself are already goroutine safe. | |||||
mtx sync.Mutex | |||||
// TODO: Signer should encapsulate and hide the endpoint completely. Invert the relation | |||||
signer *SignerRemote | |||||
} | |||||
// Check that SignerValidatorEndpoint implements PrivValidator. | |||||
var _ types.PrivValidator = (*SignerValidatorEndpoint)(nil) | |||||
// NewSignerValidatorEndpoint returns an instance of SignerValidatorEndpoint. | |||||
func NewSignerValidatorEndpoint(logger log.Logger, listener net.Listener) *SignerValidatorEndpoint { | |||||
sc := &SignerValidatorEndpoint{ | |||||
listener: listener, | |||||
heartbeatPeriod: heartbeatPeriod, | |||||
} | |||||
sc.BaseService = *cmn.NewBaseService(logger, "SignerValidatorEndpoint", sc) | |||||
return sc | |||||
} | |||||
//-------------------------------------------------------- | |||||
// Implement PrivValidator | |||||
// GetPubKey implements PrivValidator. | |||||
func (ve *SignerValidatorEndpoint) GetPubKey() crypto.PubKey { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
return ve.signer.GetPubKey() | |||||
} | |||||
// SignVote implements PrivValidator. | |||||
func (ve *SignerValidatorEndpoint) SignVote(chainID string, vote *types.Vote) error { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
return ve.signer.SignVote(chainID, vote) | |||||
} | |||||
// SignProposal implements PrivValidator. | |||||
func (ve *SignerValidatorEndpoint) SignProposal(chainID string, proposal *types.Proposal) error { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
return ve.signer.SignProposal(chainID, proposal) | |||||
} | |||||
//-------------------------------------------------------- | |||||
// More thread safe methods proxied to the signer | |||||
// Ping is used to check connection health. | |||||
func (ve *SignerValidatorEndpoint) Ping() error { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
return ve.signer.Ping() | |||||
} | |||||
// Close closes the underlying net.Conn. | |||||
func (ve *SignerValidatorEndpoint) Close() { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
if ve.signer != nil { | |||||
if err := ve.signer.Close(); err != nil { | |||||
ve.Logger.Error("OnStop", "err", err) | |||||
} | |||||
} | |||||
if ve.listener != nil { | |||||
if err := ve.listener.Close(); err != nil { | |||||
ve.Logger.Error("OnStop", "err", err) | |||||
} | |||||
} | |||||
} | |||||
//-------------------------------------------------------- | |||||
// Service start and stop | |||||
// OnStart implements cmn.Service. | |||||
func (ve *SignerValidatorEndpoint) OnStart() error { | |||||
if closed, err := ve.reset(); err != nil { | |||||
ve.Logger.Error("OnStart", "err", err) | |||||
return err | |||||
} else if closed { | |||||
return fmt.Errorf("listener is closed") | |||||
} | |||||
// Start a routine to keep the connection alive | |||||
ve.cancelPingCh = make(chan struct{}, 1) | |||||
ve.pingTicker = time.NewTicker(ve.heartbeatPeriod) | |||||
go func() { | |||||
for { | |||||
select { | |||||
case <-ve.pingTicker.C: | |||||
err := ve.Ping() | |||||
if err != nil { | |||||
ve.Logger.Error("Ping", "err", err) | |||||
if err == ErrUnexpectedResponse { | |||||
return | |||||
} | |||||
closed, err := ve.reset() | |||||
if err != nil { | |||||
ve.Logger.Error("Reconnecting to remote signer failed", "err", err) | |||||
continue | |||||
} | |||||
if closed { | |||||
ve.Logger.Info("listener is closing") | |||||
return | |||||
} | |||||
ve.Logger.Info("Re-created connection to remote signer", "impl", ve) | |||||
} | |||||
case <-ve.cancelPingCh: | |||||
ve.pingTicker.Stop() | |||||
return | |||||
} | |||||
} | |||||
}() | |||||
return nil | |||||
} | |||||
// OnStop implements cmn.Service. | |||||
func (ve *SignerValidatorEndpoint) OnStop() { | |||||
if ve.cancelPingCh != nil { | |||||
close(ve.cancelPingCh) | |||||
} | |||||
ve.Close() | |||||
} | |||||
//-------------------------------------------------------- | |||||
// Connection and signer management | |||||
// waits to accept and sets a new connection. | |||||
// connection is closed in OnStop. | |||||
// returns true if the listener is closed | |||||
// (ie. it returns a nil conn). | |||||
func (ve *SignerValidatorEndpoint) reset() (closed bool, err error) { | |||||
ve.mtx.Lock() | |||||
defer ve.mtx.Unlock() | |||||
// first check if the conn already exists and close it. | |||||
if ve.signer != nil { | |||||
if tmpErr := ve.signer.Close(); tmpErr != nil { | |||||
ve.Logger.Error("error closing socket val connection during reset", "err", tmpErr) | |||||
} | |||||
} | |||||
// wait for a new conn | |||||
conn, err := ve.acceptConnection() | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
// listener is closed | |||||
if conn == nil { | |||||
return true, nil | |||||
} | |||||
ve.signer, err = NewSignerRemote(conn) | |||||
if err != nil { | |||||
// failed to fetch the pubkey. close out the connection. | |||||
if tmpErr := conn.Close(); tmpErr != nil { | |||||
ve.Logger.Error("error closing connection", "err", tmpErr) | |||||
} | |||||
return false, err | |||||
} | |||||
return false, nil | |||||
} | |||||
// Attempt to accept a connection. | |||||
// Times out after the listener's timeoutAccept | |||||
func (ve *SignerValidatorEndpoint) acceptConnection() (net.Conn, error) { | |||||
conn, err := ve.listener.Accept() | |||||
if err != nil { | |||||
if !ve.IsRunning() { | |||||
return nil, nil // Ignore error from listener closing. | |||||
} | |||||
return nil, err | |||||
} | |||||
return conn, nil | |||||
} |
@ -0,0 +1,505 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
"net" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
"github.com/tendermint/tendermint/crypto/ed25519" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
"github.com/tendermint/tendermint/libs/log" | |||||
"github.com/tendermint/tendermint/types" | |||||
) | |||||
var ( | |||||
testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second | |||||
testTimeoutReadWrite = 100 * time.Millisecond | |||||
testTimeoutReadWrite2o3 = 66 * time.Millisecond // 2/3 of the other one | |||||
testTimeoutHeartbeat = 10 * time.Millisecond | |||||
testTimeoutHeartbeat3o2 = 6 * time.Millisecond // 3/2 of the other one | |||||
) | |||||
type socketTestCase struct { | |||||
addr string | |||||
dialer SocketDialer | |||||
} | |||||
func socketTestCases(t *testing.T) []socketTestCase { | |||||
tcpAddr := fmt.Sprintf("tcp://%s", testFreeTCPAddr(t)) | |||||
unixFilePath, err := testUnixAddr() | |||||
require.NoError(t, err) | |||||
unixAddr := fmt.Sprintf("unix://%s", unixFilePath) | |||||
return []socketTestCase{ | |||||
{ | |||||
addr: tcpAddr, | |||||
dialer: DialTCPFn(tcpAddr, testTimeoutReadWrite, ed25519.GenPrivKey()), | |||||
}, | |||||
{ | |||||
addr: unixAddr, | |||||
dialer: DialUnixFn(unixFilePath), | |||||
}, | |||||
} | |||||
} | |||||
func TestSocketPVAddress(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
// Execute the test within a closure to ensure the deferred statements | |||||
// are called between each for loop iteration, for isolated test cases. | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
serviceAddr := serviceEndpoint.privVal.GetPubKey().Address() | |||||
validatorAddr := validatorEndpoint.GetPubKey().Address() | |||||
assert.Equal(t, serviceAddr, validatorAddr) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVPubKey(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
clientKey := validatorEndpoint.GetPubKey() | |||||
privvalPubKey := serviceEndpoint.privVal.GetPubKey() | |||||
assert.Equal(t, privvalPubKey, clientKey) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVProposal(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
privProposal = &types.Proposal{Timestamp: ts} | |||||
clientProposal = &types.Proposal{Timestamp: ts} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
require.NoError(t, serviceEndpoint.privVal.SignProposal(chainID, privProposal)) | |||||
require.NoError(t, validatorEndpoint.SignProposal(chainID, clientProposal)) | |||||
assert.Equal(t, privProposal.Signature, clientProposal.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVVote(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
want = &types.Vote{Timestamp: ts, Type: vType} | |||||
have = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVVoteResetDeadline(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
want = &types.Vote{Timestamp: ts, Type: vType} | |||||
have = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
time.Sleep(testTimeoutReadWrite2o3) | |||||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
// This would exceed the deadline if it was not extended by the previous message | |||||
time.Sleep(testTimeoutReadWrite2o3) | |||||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVVoteKeepalive(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
want = &types.Vote{Timestamp: ts, Type: vType} | |||||
have = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
time.Sleep(testTimeoutReadWrite * 2) | |||||
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) | |||||
require.NoError(t, validatorEndpoint.SignVote(chainID, have)) | |||||
assert.Equal(t, want.Signature, have.Signature) | |||||
}() | |||||
} | |||||
} | |||||
func TestSocketPVDeadline(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
listenc = make(chan struct{}) | |||||
thisConnTimeout = 100 * time.Millisecond | |||||
validatorEndpoint = newSignerValidatorEndpoint(log.TestingLogger(), tc.addr, thisConnTimeout) | |||||
) | |||||
go func(sc *SignerValidatorEndpoint) { | |||||
defer close(listenc) | |||||
// Note: the TCP connection times out at the accept() phase, | |||||
// whereas the Unix domain sockets connection times out while | |||||
// attempting to fetch the remote signer's public key. | |||||
assert.True(t, IsConnTimeout(sc.Start())) | |||||
assert.False(t, sc.IsRunning()) | |||||
}(validatorEndpoint) | |||||
for { | |||||
_, err := cmn.Connect(tc.addr) | |||||
if err == nil { | |||||
break | |||||
} | |||||
} | |||||
<-listenc | |||||
}() | |||||
} | |||||
} | |||||
func TestRemoteSignVoteErrors(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewErroringMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
vType = types.PrecommitType | |||||
vote = &types.Vote{Timestamp: ts, Type: vType} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
err := validatorEndpoint.SignVote("", vote) | |||||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||||
err = serviceEndpoint.privVal.SignVote(chainID, vote) | |||||
require.Error(t, err) | |||||
err = validatorEndpoint.SignVote(chainID, vote) | |||||
require.Error(t, err) | |||||
}() | |||||
} | |||||
} | |||||
func TestRemoteSignProposalErrors(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
chainID = cmn.RandStr(12) | |||||
validatorEndpoint, serviceEndpoint = testSetupSocketPair( | |||||
t, | |||||
chainID, | |||||
types.NewErroringMockPV(), | |||||
tc.addr, | |||||
tc.dialer) | |||||
ts = time.Now() | |||||
proposal = &types.Proposal{Timestamp: ts} | |||||
) | |||||
defer validatorEndpoint.Stop() | |||||
defer serviceEndpoint.Stop() | |||||
err := validatorEndpoint.SignProposal("", proposal) | |||||
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) | |||||
err = serviceEndpoint.privVal.SignProposal(chainID, proposal) | |||||
require.Error(t, err) | |||||
err = validatorEndpoint.SignProposal(chainID, proposal) | |||||
require.Error(t, err) | |||||
}() | |||||
} | |||||
} | |||||
func TestErrUnexpectedResponse(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
logger = log.TestingLogger() | |||||
chainID = cmn.RandStr(12) | |||||
readyCh = make(chan struct{}) | |||||
errCh = make(chan error, 1) | |||||
serviceEndpoint = NewSignerServiceEndpoint( | |||||
logger, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.dialer, | |||||
) | |||||
validatorEndpoint = newSignerValidatorEndpoint( | |||||
logger, | |||||
tc.addr, | |||||
testTimeoutReadWrite) | |||||
) | |||||
testStartEndpoint(t, readyCh, validatorEndpoint) | |||||
defer validatorEndpoint.Stop() | |||||
SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint) | |||||
SignerServiceEndpointConnRetries(100)(serviceEndpoint) | |||||
// we do not want to Start() the remote signer here and instead use the connection to | |||||
// reply with intentionally wrong replies below: | |||||
rsConn, err := serviceEndpoint.connect() | |||||
defer rsConn.Close() | |||||
require.NoError(t, err) | |||||
require.NotNil(t, rsConn) | |||||
// send over public key to get the remote signer running: | |||||
go testReadWriteResponse(t, &PubKeyResponse{}, rsConn) | |||||
<-readyCh | |||||
// Proposal: | |||||
go func(errc chan error) { | |||||
errc <- validatorEndpoint.SignProposal(chainID, &types.Proposal{}) | |||||
}(errCh) | |||||
// read request and write wrong response: | |||||
go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn) | |||||
err = <-errCh | |||||
require.Error(t, err) | |||||
require.Equal(t, err, ErrUnexpectedResponse) | |||||
// Vote: | |||||
go func(errc chan error) { | |||||
errc <- validatorEndpoint.SignVote(chainID, &types.Vote{}) | |||||
}(errCh) | |||||
// read request and write wrong response: | |||||
go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn) | |||||
err = <-errCh | |||||
require.Error(t, err) | |||||
require.Equal(t, err, ErrUnexpectedResponse) | |||||
}() | |||||
} | |||||
} | |||||
func TestRetryConnToRemoteSigner(t *testing.T) { | |||||
for _, tc := range socketTestCases(t) { | |||||
func() { | |||||
var ( | |||||
logger = log.TestingLogger() | |||||
chainID = cmn.RandStr(12) | |||||
readyCh = make(chan struct{}) | |||||
serviceEndpoint = NewSignerServiceEndpoint( | |||||
logger, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.dialer, | |||||
) | |||||
thisConnTimeout = testTimeoutReadWrite | |||||
validatorEndpoint = newSignerValidatorEndpoint(logger, tc.addr, thisConnTimeout) | |||||
) | |||||
// Ping every: | |||||
SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) | |||||
SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) | |||||
SignerServiceEndpointConnRetries(10)(serviceEndpoint) | |||||
testStartEndpoint(t, readyCh, validatorEndpoint) | |||||
defer validatorEndpoint.Stop() | |||||
require.NoError(t, serviceEndpoint.Start()) | |||||
assert.True(t, serviceEndpoint.IsRunning()) | |||||
<-readyCh | |||||
time.Sleep(testTimeoutHeartbeat * 2) | |||||
serviceEndpoint.Stop() | |||||
rs2 := NewSignerServiceEndpoint( | |||||
logger, | |||||
chainID, | |||||
types.NewMockPV(), | |||||
tc.dialer, | |||||
) | |||||
// let some pings pass | |||||
time.Sleep(testTimeoutHeartbeat3o2) | |||||
require.NoError(t, rs2.Start()) | |||||
assert.True(t, rs2.IsRunning()) | |||||
defer rs2.Stop() | |||||
// give the client some time to re-establish the conn to the remote signer | |||||
// should see sth like this in the logs: | |||||
// | |||||
// E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" | |||||
// I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal | |||||
time.Sleep(testTimeoutReadWrite * 2) | |||||
}() | |||||
} | |||||
} | |||||
func newSignerValidatorEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerValidatorEndpoint { | |||||
proto, address := cmn.ProtocolAndAddress(addr) | |||||
ln, err := net.Listen(proto, address) | |||||
logger.Info("Listening at", "proto", proto, "address", address) | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
var listener net.Listener | |||||
if proto == "unix" { | |||||
unixLn := NewUnixListener(ln) | |||||
UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn) | |||||
UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn) | |||||
listener = unixLn | |||||
} else { | |||||
tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) | |||||
TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn) | |||||
TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn) | |||||
listener = tcpLn | |||||
} | |||||
return NewSignerValidatorEndpoint(logger, listener) | |||||
} | |||||
func testSetupSocketPair( | |||||
t *testing.T, | |||||
chainID string, | |||||
privValidator types.PrivValidator, | |||||
addr string, | |||||
socketDialer SocketDialer, | |||||
) (*SignerValidatorEndpoint, *SignerServiceEndpoint) { | |||||
var ( | |||||
logger = log.TestingLogger() | |||||
privVal = privValidator | |||||
readyc = make(chan struct{}) | |||||
serviceEndpoint = NewSignerServiceEndpoint( | |||||
logger, | |||||
chainID, | |||||
privVal, | |||||
socketDialer, | |||||
) | |||||
thisConnTimeout = testTimeoutReadWrite | |||||
validatorEndpoint = newSignerValidatorEndpoint(logger, addr, thisConnTimeout) | |||||
) | |||||
SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) | |||||
SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) | |||||
SignerServiceEndpointConnRetries(1e6)(serviceEndpoint) | |||||
testStartEndpoint(t, readyc, validatorEndpoint) | |||||
require.NoError(t, serviceEndpoint.Start()) | |||||
assert.True(t, serviceEndpoint.IsRunning()) | |||||
<-readyc | |||||
return validatorEndpoint, serviceEndpoint | |||||
} | |||||
func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) { | |||||
_, err := readMsg(rsConn) | |||||
require.NoError(t, err) | |||||
err = writeMsg(rsConn, resp) | |||||
require.NoError(t, err) | |||||
} | |||||
func testStartEndpoint(t *testing.T, readyCh chan struct{}, sc *SignerValidatorEndpoint) { | |||||
go func(sc *SignerValidatorEndpoint) { | |||||
require.NoError(t, sc.Start()) | |||||
assert.True(t, sc.IsRunning()) | |||||
readyCh <- struct{}{} | |||||
}(sc) | |||||
} | |||||
// testFreeTCPAddr claims a free port so we don't block on listener being ready. | |||||
func testFreeTCPAddr(t *testing.T) string { | |||||
ln, err := net.Listen("tcp", "127.0.0.1:0") | |||||
require.NoError(t, err) | |||||
defer ln.Close() | |||||
return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port) | |||||
} |
@ -0,0 +1,43 @@ | |||||
package privval | |||||
import ( | |||||
"net" | |||||
"time" | |||||
"github.com/pkg/errors" | |||||
"github.com/tendermint/tendermint/crypto/ed25519" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
p2pconn "github.com/tendermint/tendermint/p2p/conn" | |||||
) | |||||
// Socket errors. | |||||
var ( | |||||
ErrDialRetryMax = errors.New("dialed maximum retries") | |||||
) | |||||
// SocketDialer dials a remote address and returns a net.Conn or an error. | |||||
type SocketDialer func() (net.Conn, error) | |||||
// DialTCPFn dials the given tcp addr, using the given timeoutReadWrite and | |||||
// privKey for the authenticated encryption handshake. | |||||
func DialTCPFn(addr string, timeoutReadWrite time.Duration, privKey ed25519.PrivKeyEd25519) SocketDialer { | |||||
return func() (net.Conn, error) { | |||||
conn, err := cmn.Connect(addr) | |||||
if err == nil { | |||||
deadline := time.Now().Add(timeoutReadWrite) | |||||
err = conn.SetDeadline(deadline) | |||||
} | |||||
if err == nil { | |||||
conn, err = p2pconn.MakeSecretConnection(conn, privKey) | |||||
} | |||||
return conn, err | |||||
} | |||||
} | |||||
// DialUnixFn dials the given unix socket. | |||||
func DialUnixFn(addr string) SocketDialer { | |||||
return func() (net.Conn, error) { | |||||
unixAddr := &net.UnixAddr{Name: addr, Net: "unix"} | |||||
return net.DialUnix("unix", nil, unixAddr) | |||||
} | |||||
} |
@ -0,0 +1,26 @@ | |||||
package privval | |||||
import ( | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/tendermint/tendermint/crypto/ed25519" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
) | |||||
func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) { | |||||
// Generate a networking timeout | |||||
dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) | |||||
_, err := dialer() | |||||
assert.Error(t, err) | |||||
assert.True(t, IsConnTimeout(err)) | |||||
} | |||||
func TestIsConnTimeoutForWrappedConnTimeouts(t *testing.T) { | |||||
dialer := DialTCPFn(testFreeTCPAddr(t), time.Millisecond, ed25519.GenPrivKey()) | |||||
_, err := dialer() | |||||
assert.Error(t, err) | |||||
err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) | |||||
assert.True(t, IsConnTimeout(err)) | |||||
} |
@ -0,0 +1,20 @@ | |||||
package privval | |||||
import ( | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
) | |||||
// IsConnTimeout returns a boolean indicating whether the error is known to | |||||
// report that a connection timeout occurred. This detects both fundamental | |||||
// network timeouts, as well as ErrConnTimeout errors. | |||||
func IsConnTimeout(err error) bool { | |||||
if cmnErr, ok := err.(cmn.Error); ok { | |||||
if cmnErr.Data() == ErrConnTimeout { | |||||
return true | |||||
} | |||||
} | |||||
if _, ok := err.(timeoutError); ok { | |||||
return true | |||||
} | |||||
return false | |||||
} |
@ -0,0 +1,14 @@ | |||||
package privval | |||||
import ( | |||||
"fmt" | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
cmn "github.com/tendermint/tendermint/libs/common" | |||||
) | |||||
func TestIsConnTimeoutForNonTimeoutErrors(t *testing.T) { | |||||
assert.False(t, IsConnTimeout(cmn.ErrorWrap(ErrDialRetryMax, "max retries exceeded"))) | |||||
assert.False(t, IsConnTimeout(fmt.Errorf("completely irrelevant error"))) | |||||
} |