diff --git a/Gopkg.lock b/Gopkg.lock index 10739e8eb..91e0b41e2 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -305,7 +305,6 @@ "idna", "internal/timeseries", "lex/httplex", - "netutil", "trace" ] revision = "cbe0f9307d0156177f9dd5dc85da1a31abc5f2fb" @@ -373,6 +372,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "fe167dd9055ba9a4016e7bdad88da263372bca7ebdcebf5c81c609f396e605a3" + inputs-digest = "ed9db0be72a900f4812675f683db20eff9d64ef4511dc00ad29a810da65909c2" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index b963fe13c..61406ad66 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -90,10 +90,6 @@ name = "google.golang.org/grpc" version = "1.7.3" -[[constraint]] - branch = "master" - name = "golang.org/x/net" - [prune] go-tests = true unused-packages = true diff --git a/cmd/priv_val_server/main.go b/cmd/priv_val_server/main.go index 0d18f8ed2..9f3ec73ca 100644 --- a/cmd/priv_val_server/main.go +++ b/cmd/priv_val_server/main.go @@ -4,6 +4,7 @@ import ( "flag" "os" + crypto "github.com/tendermint/go-crypto" cmn "github.com/tendermint/tmlibs/common" "github.com/tendermint/tmlibs/log" @@ -36,7 +37,7 @@ func main() { *chainID, *addr, privVal, - nil, + crypto.GenPrivKeyEd25519(), ) err := rs.Start() if err != nil { diff --git a/node/node.go b/node/node.go index 83ac50ec6..dffdb83e8 100644 --- a/node/node.go +++ b/node/node.go @@ -183,7 +183,7 @@ func NewNode(config *cfg.Config, pvsc = priv_val.NewSocketClient( logger.With("module", "priv_val"), config.PrivValidatorListenAddr, - &privKey, + privKey, ) ) diff --git a/types/priv_validator/socket.go b/types/priv_validator/socket.go index 05bc77710..26cab72b9 100644 --- a/types/priv_validator/socket.go +++ b/types/priv_validator/socket.go @@ -11,39 +11,53 @@ import ( wire "github.com/tendermint/go-wire" cmn "github.com/tendermint/tmlibs/common" "github.com/tendermint/tmlibs/log" - "golang.org/x/net/netutil" p2pconn "github.com/tendermint/tendermint/p2p/conn" "github.com/tendermint/tendermint/types" ) const ( - defaultConnDeadlineSeconds = 3 - defaultConnWaitSeconds = 60 - defaultDialRetries = 10 - defaultSignersMax = 1 + defaultAcceptDeadlineSeconds = 3 + defaultConnDeadlineSeconds = 3 + defaultConnHeartBeatSeconds = 30 + defaultConnWaitSeconds = 60 + defaultDialRetries = 10 ) // Socket errors. var ( - ErrDialRetryMax = errors.New("Error max client retries") - ErrConnWaitTimeout = errors.New("Error waiting for external connection") - ErrConnTimeout = errors.New("Error connection timed out") + ErrDialRetryMax = errors.New("dialed maximum retries") + ErrConnWaitTimeout = errors.New("waited for remote signer for too long") + ErrConnTimeout = errors.New("remote signer timed out") ) var ( - connDeadline = time.Second * defaultConnDeadlineSeconds + acceptDeadline = time.Second + defaultAcceptDeadlineSeconds + connDeadline = time.Second * defaultConnDeadlineSeconds + connHeartbeat = time.Second * defaultConnHeartBeatSeconds ) // SocketClientOption sets an optional parameter on the SocketClient. type SocketClientOption func(*SocketClient) +// SocketClientAcceptDeadline sets the deadline for the SocketClient listener. +// A zero time value disables the deadline. +func SocketClientAcceptDeadline(deadline time.Duration) SocketClientOption { + return func(sc *SocketClient) { sc.acceptDeadline = deadline } +} + // SocketClientConnDeadline sets the read and write deadline for connections // from external signing processes. func SocketClientConnDeadline(deadline time.Duration) SocketClientOption { return func(sc *SocketClient) { sc.connDeadline = deadline } } +// SocketClientHeartbeat sets the period on which to check the liveness of the +// connected Signer connections. +func SocketClientHeartbeat(period time.Duration) SocketClientOption { + return func(sc *SocketClient) { sc.connHeartbeat = period } +} + // SocketClientConnWait sets the timeout duration before connection of external // signing processes are considered to be unsuccessful. func SocketClientConnWait(timeout time.Duration) SocketClientOption { @@ -56,9 +70,11 @@ type SocketClient struct { cmn.BaseService addr string + acceptDeadline time.Duration connDeadline time.Duration + connHeartbeat time.Duration connWaitTimeout time.Duration - privKey *crypto.PrivKeyEd25519 + privKey crypto.PrivKeyEd25519 conn net.Conn listener net.Listener @@ -71,11 +87,13 @@ var _ types.PrivValidator2 = (*SocketClient)(nil) func NewSocketClient( logger log.Logger, socketAddr string, - privKey *crypto.PrivKeyEd25519, + privKey crypto.PrivKeyEd25519, ) *SocketClient { sc := &SocketClient{ addr: socketAddr, - connDeadline: time.Second * defaultConnDeadlineSeconds, + acceptDeadline: acceptDeadline, + connDeadline: connDeadline, + connHeartbeat: connHeartbeat, connWaitTimeout: time.Second * defaultConnWaitSeconds, privKey: privKey, } @@ -85,57 +103,6 @@ func NewSocketClient( return sc } -// OnStart implements cmn.Service. -func (sc *SocketClient) OnStart() error { - if sc.listener == nil { - if err := sc.listen(); err != nil { - sc.Logger.Error( - "OnStart", - "err", errors.Wrap(err, "failed to listen"), - ) - - return err - } - } - - conn, err := sc.waitConnection() - if err != nil { - sc.Logger.Error( - "OnStart", - "err", errors.Wrap(err, "failed to accept connection"), - ) - - return err - } - - sc.conn = conn - - return nil -} - -// OnStop implements cmn.Service. -func (sc *SocketClient) OnStop() { - sc.BaseService.OnStop() - - if sc.conn != nil { - if err := sc.conn.Close(); err != nil { - sc.Logger.Error( - "OnStop", - "err", errors.Wrap(err, "failed to close connection"), - ) - } - } - - if sc.listener != nil { - if err := sc.listener.Close(); err != nil { - sc.Logger.Error( - "OnStop", - "err", errors.Wrap(err, "failed to close listener"), - ) - } - } -} - // GetAddress implements PrivValidator. // TODO(xla): Remove when PrivValidator2 replaced PrivValidator. func (sc *SocketClient) GetAddress() types.Address { @@ -240,6 +207,53 @@ func (sc *SocketClient) SignHeartbeat( return nil } +// OnStart implements cmn.Service. +func (sc *SocketClient) OnStart() error { + if err := sc.listen(); err != nil { + sc.Logger.Error( + "OnStart", + "err", errors.Wrap(err, "failed to listen"), + ) + + return err + } + + conn, err := sc.waitConnection() + if err != nil { + sc.Logger.Error( + "OnStart", + "err", errors.Wrap(err, "failed to accept connection"), + ) + + return err + } + + sc.conn = conn + + return nil +} + +// OnStop implements cmn.Service. +func (sc *SocketClient) OnStop() { + if sc.conn != nil { + if err := sc.conn.Close(); err != nil { + sc.Logger.Error( + "OnStop", + "err", errors.Wrap(err, "failed to close connection"), + ) + } + } + + if sc.listener != nil { + if err := sc.listener.Close(); err != nil { + sc.Logger.Error( + "OnStop", + "err", errors.Wrap(err, "failed to close listener"), + ) + } + } +} + func (sc *SocketClient) acceptConnection() (net.Conn, error) { conn, err := sc.listener.Accept() if err != nil { @@ -250,17 +264,11 @@ func (sc *SocketClient) acceptConnection() (net.Conn, error) { } - if err := conn.SetDeadline(time.Now().Add(sc.connDeadline)); err != nil { + conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap()) + if err != nil { return nil, err } - if sc.privKey != nil { - conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap()) - if err != nil { - return nil, err - } - } - return conn, nil } @@ -270,7 +278,12 @@ func (sc *SocketClient) listen() error { return err } - sc.listener = netutil.LimitListener(ln, defaultSignersMax) + sc.listener = newTCPTimeoutListener( + ln, + sc.acceptDeadline, + sc.connDeadline, + sc.connHeartbeat, + ) return nil } @@ -297,6 +310,9 @@ func (sc *SocketClient) waitConnection() (net.Conn, error) { case conn := <-connc: return conn, nil case err := <-errc: + if _, ok := err.(timeoutError); ok { + return nil, errors.Wrap(ErrConnWaitTimeout, err.Error()) + } return nil, err case <-time.After(sc.connWaitTimeout): return nil, ErrConnWaitTimeout @@ -319,8 +335,7 @@ func RemoteSignerConnRetries(retries int) RemoteSignerOption { return func(ss *RemoteSigner) { ss.connRetries = retries } } -// RemoteSigner implements PrivValidator. -// It responds to requests over a socket +// RemoteSigner implements PrivValidator by dialing to a socket. type RemoteSigner struct { cmn.BaseService @@ -328,19 +343,18 @@ type RemoteSigner struct { chainID string connDeadline time.Duration connRetries int - privKey *crypto.PrivKeyEd25519 + privKey crypto.PrivKeyEd25519 privVal PrivValidator conn net.Conn } -// NewRemoteSigner returns an instance of -// RemoteSigner. +// NewRemoteSigner returns an instance of RemoteSigner. func NewRemoteSigner( logger log.Logger, chainID, socketAddr string, privVal PrivValidator, - privKey *crypto.PrivKeyEd25519, + privKey crypto.PrivKeyEd25519, ) *RemoteSigner { rs := &RemoteSigner{ addr: socketAddr, @@ -382,17 +396,12 @@ func (rs *RemoteSigner) OnStop() { } func (rs *RemoteSigner) connect() (net.Conn, error) { - retries := defaultDialRetries - -RETRY_LOOP: - for retries > 0 { + for retries := rs.connRetries; retries > 0; retries-- { // Don't sleep if it is the first retry. - if retries != defaultDialRetries { + if retries != rs.connRetries { time.Sleep(rs.connDeadline) } - retries-- - conn, err := cmn.Connect(rs.addr) if err != nil { rs.Logger.Error( @@ -401,7 +410,7 @@ RETRY_LOOP: "err", errors.Wrap(err, "connection failed"), ) - continue RETRY_LOOP + continue } if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil { @@ -412,16 +421,14 @@ RETRY_LOOP: continue } - if rs.privKey != nil { - conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey.Wrap()) - if err != nil { - rs.Logger.Error( - "sc connect", - "err", errors.Wrap(err, "encrypting connection failed"), - ) + conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey.Wrap()) + if err != nil { + rs.Logger.Error( + "connect", + "err", errors.Wrap(err, "encrypting connection failed"), + ) - continue RETRY_LOOP - } + continue } return conn, nil @@ -444,7 +451,7 @@ func (rs *RemoteSigner) handleConnection(conn net.Conn) { return } - var res PrivValidatorSocketMsg + var res PrivValMsg switch r := req.(type) { case *PubKeyMsg: @@ -487,12 +494,11 @@ const ( msgTypeSignHeartbeat = byte(0x12) ) -// PrivValidatorSocketMsg is a message sent between PrivValidatorSocket client -// and server. -type PrivValidatorSocketMsg interface{} +// PrivValMsg is sent between RemoteSigner and SocketClient. +type PrivValMsg interface{} var _ = wire.RegisterInterface( - struct{ PrivValidatorSocketMsg }{}, + struct{ PrivValMsg }{}, wire.ConcreteType{&PubKeyMsg{}, msgTypePubKey}, wire.ConcreteType{&SignVoteMsg{}, msgTypeSignVote}, wire.ConcreteType{&SignProposalMsg{}, msgTypeSignProposal}, @@ -519,27 +525,27 @@ type SignHeartbeatMsg struct { Heartbeat *types.Heartbeat } -func readMsg(r io.Reader) (PrivValidatorSocketMsg, error) { +func readMsg(r io.Reader) (PrivValMsg, error) { var ( n int err error ) - read := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, &n, &err) + read := wire.ReadBinary(struct{ PrivValMsg }{}, r, 0, &n, &err) if err != nil { - if opErr, ok := err.(*net.OpError); ok { - return nil, errors.Wrapf(ErrConnTimeout, opErr.Addr.String()) + if _, ok := err.(timeoutError); ok { + return nil, errors.Wrap(ErrConnTimeout, err.Error()) } return nil, err } - w, ok := read.(struct{ PrivValidatorSocketMsg }) + w, ok := read.(struct{ PrivValMsg }) if !ok { return nil, errors.New("unknown type") } - return w.PrivValidatorSocketMsg, nil + return w.PrivValMsg, nil } func writeMsg(w io.Writer, msg interface{}) error { @@ -549,9 +555,9 @@ func writeMsg(w io.Writer, msg interface{}) error { ) // TODO(xla): This extra wrap should be gone with the sdk-2 update. - wire.WriteBinary(struct{ PrivValidatorSocketMsg }{msg}, w, &n, &err) - if opErr, ok := err.(*net.OpError); ok { - return errors.Wrapf(ErrConnTimeout, opErr.Addr.String()) + wire.WriteBinary(struct{ PrivValMsg }{msg}, w, &n, &err) + if _, ok := err.(timeoutError); ok { + return errors.Wrap(ErrConnTimeout, err.Error()) } return err diff --git a/types/priv_validator/socket_tcp.go b/types/priv_validator/socket_tcp.go new file mode 100644 index 000000000..2421eb9f4 --- /dev/null +++ b/types/priv_validator/socket_tcp.go @@ -0,0 +1,66 @@ +package types + +import ( + "net" + "time" +) + +// timeoutError can be used to check if an error returned from the netp package +// was due to a timeout. +type timeoutError interface { + Timeout() bool +} + +// tcpTimeoutListener implements net.Listener. +var _ net.Listener = (*tcpTimeoutListener)(nil) + +// tcpTimeoutListener wraps a *net.TCPListener to standardise protocol timeouts +// and potentially other tuning parameters. +type tcpTimeoutListener struct { + *net.TCPListener + + acceptDeadline time.Duration + connDeadline time.Duration + period time.Duration +} + +// newTCPTimeoutListener returns an instance of tcpTimeoutListener. +func newTCPTimeoutListener( + ln net.Listener, + acceptDeadline, connDeadline time.Duration, + period time.Duration, +) tcpTimeoutListener { + return tcpTimeoutListener{ + TCPListener: ln.(*net.TCPListener), + acceptDeadline: acceptDeadline, + connDeadline: connDeadline, + period: period, + } +} + +// Accept implements net.Listener. +func (ln tcpTimeoutListener) Accept() (net.Conn, error) { + err := ln.SetDeadline(time.Now().Add(ln.acceptDeadline)) + if err != nil { + return nil, err + } + + tc, err := ln.AcceptTCP() + if err != nil { + return nil, err + } + + if err := tc.SetDeadline(time.Now().Add(ln.connDeadline)); err != nil { + return nil, err + } + + if err := tc.SetKeepAlive(true); err != nil { + return nil, err + } + + if err := tc.SetKeepAlivePeriod(ln.period); err != nil { + return nil, err + } + + return tc, nil +} diff --git a/types/priv_validator/socket_tcp_test.go b/types/priv_validator/socket_tcp_test.go new file mode 100644 index 000000000..cd95ab0b9 --- /dev/null +++ b/types/priv_validator/socket_tcp_test.go @@ -0,0 +1,64 @@ +package types + +import ( + "net" + "testing" + "time" +) + +func TestTCPTimeoutListenerAcceptDeadline(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + ln = newTCPTimeoutListener(ln, time.Millisecond, time.Second, time.Second) + + _, err = ln.Accept() + opErr, ok := err.(*net.OpError) + if !ok { + t.Fatalf("have %v, want *net.OpError", err) + } + + if have, want := opErr.Op, "accept"; have != want { + t.Errorf("have %v, want %v", have, want) + } +} + +func TestTCPTimeoutListenerConnDeadline(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + ln = newTCPTimeoutListener(ln, time.Second, time.Millisecond, time.Second) + + donec := make(chan struct{}) + go func(ln net.Listener) { + defer close(donec) + + c, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + + time.Sleep(2 * time.Millisecond) + + _, err = c.Write([]byte("foo")) + opErr, ok := err.(*net.OpError) + if !ok { + t.Fatalf("have %v, want *net.OpError", err) + } + + if have, want := opErr.Op, "write"; have != want { + t.Errorf("have %v, want %v", have, want) + } + }(ln) + + _, err = net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + <-donec +} diff --git a/types/priv_validator/socket_test.go b/types/priv_validator/socket_test.go index 36f09f40c..2859c9452 100644 --- a/types/priv_validator/socket_test.go +++ b/types/priv_validator/socket_test.go @@ -1,6 +1,8 @@ package types import ( + "fmt" + "net" "testing" "time" @@ -12,57 +14,55 @@ import ( cmn "github.com/tendermint/tmlibs/common" "github.com/tendermint/tmlibs/log" + p2pconn "github.com/tendermint/tendermint/p2p/conn" "github.com/tendermint/tendermint/types" ) func TestSocketClientAddress(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ) defer sc.Stop() defer rs.Stop() serverAddr, err := rs.privVal.Address() - require.NoError(err) + require.NoError(t, err) clientAddr, err := sc.Address() - require.NoError(err) + require.NoError(t, err) - assert.Equal(serverAddr, clientAddr) + assert.Equal(t, serverAddr, clientAddr) // TODO(xla): Remove when PrivValidator2 replaced PrivValidator. - assert.Equal(serverAddr, sc.GetAddress()) + assert.Equal(t, serverAddr, sc.GetAddress()) } func TestSocketClientPubKey(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ) defer sc.Stop() defer rs.Stop() clientKey, err := sc.PubKey() - require.NoError(err) + require.NoError(t, err) privKey, err := rs.privVal.PubKey() - require.NoError(err) + require.NoError(t, err) - assert.Equal(privKey, clientKey) + assert.Equal(t, privKey, clientKey) // TODO(xla): Remove when PrivValidator2 replaced PrivValidator. - assert.Equal(privKey, sc.GetPubKey()) + assert.Equal(t, privKey, sc.GetPubKey()) } func TestSocketClientProposal(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ts = time.Now() privProposal = &types.Proposal{Timestamp: ts} @@ -71,16 +71,15 @@ func TestSocketClientProposal(t *testing.T) { defer sc.Stop() defer rs.Stop() - require.NoError(rs.privVal.SignProposal(chainID, privProposal)) - require.NoError(sc.SignProposal(chainID, clientProposal)) - assert.Equal(privProposal.Signature, clientProposal.Signature) + require.NoError(t, rs.privVal.SignProposal(chainID, privProposal)) + require.NoError(t, sc.SignProposal(chainID, clientProposal)) + assert.Equal(t, privProposal.Signature, clientProposal.Signature) } func TestSocketClientVote(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ts = time.Now() vType = types.VoteTypePrecommit @@ -90,16 +89,15 @@ func TestSocketClientVote(t *testing.T) { defer sc.Stop() defer rs.Stop() - require.NoError(rs.privVal.SignVote(chainID, want)) - require.NoError(sc.SignVote(chainID, have)) - assert.Equal(want.Signature, have.Signature) + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) } func TestSocketClientHeartbeat(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) want = &types.Heartbeat{} have = &types.Heartbeat{} @@ -107,79 +105,133 @@ func TestSocketClientHeartbeat(t *testing.T) { defer sc.Stop() defer rs.Stop() - require.NoError(rs.privVal.SignHeartbeat(chainID, want)) - require.NoError(sc.SignHeartbeat(chainID, have)) - assert.Equal(want.Signature, have.Signature) + require.NoError(t, rs.privVal.SignHeartbeat(chainID, want)) + require.NoError(t, sc.SignHeartbeat(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) } -func TestSocketClientDeadline(t *testing.T) { +func TestSocketClientAcceptDeadline(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - readyc = make(chan struct{}) - sc = NewSocketClient( + sc = NewSocketClient( log.TestingLogger(), "127.0.0.1:0", - nil, + crypto.GenPrivKeyEd25519(), ) ) defer sc.Stop() - SocketClientConnDeadline(time.Millisecond)(sc) + SocketClientAcceptDeadline(time.Millisecond)(sc) - require.NoError(sc.listen()) + assert.Equal(t, errors.Cause(sc.Start()), ErrConnWaitTimeout) +} + +func TestSocketClientDeadline(t *testing.T) { + var ( + addr = testFreeAddr(t) + listenc = make(chan struct{}) + sc = NewSocketClient( + log.TestingLogger(), + addr, + crypto.GenPrivKeyEd25519(), + ) + ) + + SocketClientConnDeadline(10 * time.Millisecond)(sc) + SocketClientConnWait(500 * time.Millisecond)(sc) go func(sc *SocketClient) { - require.NoError(sc.Start()) - assert.True(sc.IsRunning()) + defer close(listenc) - readyc <- struct{}{} + require.NoError(t, sc.Start()) + + assert.True(t, sc.IsRunning()) }(sc) - _, err := cmn.Connect(sc.listener.Addr().String()) - require.NoError(err) + for { + conn, err := cmn.Connect(addr) + if err != nil { + continue + } - <-readyc + _, err = p2pconn.MakeSecretConnection( + conn, + crypto.GenPrivKeyEd25519().Wrap(), + ) + if err == nil { + break + } + } + + <-listenc - _, err = sc.PubKey() - assert.Equal(errors.Cause(err), ErrConnTimeout) + // Sleep to guarantee deadline has been hit. + time.Sleep(20 * time.Microsecond) + + _, err := sc.PubKey() + assert.Equal(t, errors.Cause(err), ErrConnTimeout) } func TestSocketClientWait(t *testing.T) { - var ( - assert, _ = assert.New(t), require.New(t) - logger = log.TestingLogger() - privKey = crypto.GenPrivKeyEd25519() - sc = NewSocketClient( - logger, - "127.0.0.1:0", - &privKey, - ) + sc := NewSocketClient( + log.TestingLogger(), + "127.0.0.1:0", + crypto.GenPrivKeyEd25519(), ) defer sc.Stop() SocketClientConnWait(time.Millisecond)(sc) - assert.EqualError(sc.Start(), ErrConnWaitTimeout.Error()) + assert.Equal(t, errors.Cause(sc.Start()), ErrConnWaitTimeout) } func TestRemoteSignerRetry(t *testing.T) { var ( - assert, _ = assert.New(t), require.New(t) - privKey = crypto.GenPrivKeyEd25519() - rs = NewRemoteSigner( - log.TestingLogger(), - cmn.RandStr(12), - "127.0.0.1:0", - NewTestPrivValidator(types.GenSigner()), - &privKey, - ) + attemptc = make(chan int) + retries = 2 + ) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + go func(ln net.Listener, attemptc chan<- int) { + attempts := 0 + + for { + conn, err := ln.Accept() + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + + attempts++ + + if attempts == retries { + attemptc <- attempts + break + } + } + }(ln, attemptc) + + rs := NewRemoteSigner( + log.TestingLogger(), + cmn.RandStr(12), + ln.Addr().String(), + NewTestPrivValidator(types.GenSigner()), + crypto.GenPrivKeyEd25519(), ) defer rs.Stop() RemoteSignerConnDeadline(time.Millisecond)(rs) - RemoteSignerConnRetries(2)(rs) + RemoteSignerConnRetries(retries)(rs) - assert.EqualError(rs.Start(), ErrDialRetryMax.Error()) + assert.Equal(t, errors.Cause(rs.Start()), ErrDialRetryMax) + + select { + case attempts := <-attemptc: + assert.Equal(t, retries, attempts) + case <-time.After(100 * time.Millisecond): + t.Error("expected remote to observe connection attempts") + } } func testSetupSocketPair( @@ -187,40 +239,48 @@ func testSetupSocketPair( chainID string, ) (*SocketClient, *RemoteSigner) { var ( - assert, require = assert.New(t), require.New(t) - logger = log.TestingLogger() - signer = types.GenSigner() - clientPrivKey = crypto.GenPrivKeyEd25519() - remotePrivKey = crypto.GenPrivKeyEd25519() - privVal = NewTestPrivValidator(signer) - readyc = make(chan struct{}) - sc = NewSocketClient( + addr = testFreeAddr(t) + logger = log.TestingLogger() + signer = types.GenSigner() + privVal = NewTestPrivValidator(signer) + readyc = make(chan struct{}) + rs = NewRemoteSigner( logger, - "127.0.0.1:0", - &clientPrivKey, + chainID, + addr, + privVal, + crypto.GenPrivKeyEd25519(), + ) + sc = NewSocketClient( + logger, + addr, + crypto.GenPrivKeyEd25519(), ) ) - require.NoError(sc.listen()) - go func(sc *SocketClient) { - require.NoError(sc.Start()) - assert.True(sc.IsRunning()) + require.NoError(t, sc.Start()) + assert.True(t, sc.IsRunning()) readyc <- struct{}{} }(sc) - rs := NewRemoteSigner( - logger, - chainID, - sc.listener.Addr().String(), - privVal, - &remotePrivKey, - ) - require.NoError(rs.Start()) - assert.True(rs.IsRunning()) + RemoteSignerConnDeadline(time.Millisecond)(rs) + RemoteSignerConnRetries(1e6)(rs) + + require.NoError(t, rs.Start()) + assert.True(t, rs.IsRunning()) <-readyc return sc, rs } + +// testFreeAddr claims a free port so we don't block on listener being ready. +func testFreeAddr(t *testing.T) string { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port) +}