diff --git a/privval/client_test.go b/privval/client_test.go index 7fae6bf8b..3c3270649 100644 --- a/privval/client_test.go +++ b/privval/client_test.go @@ -27,120 +27,170 @@ var ( testHeartbeatTimeout3o2 = 6 * time.Millisecond // 3/2 of the other one ) -func TestSocketPVAddress(t *testing.T) { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV()) - ) - defer sc.Stop() - defer rs.Stop() +type socketTestCase struct { + addr string + dialer Dialer +} - serverAddr := rs.privVal.GetPubKey().Address() - clientAddr := sc.GetPubKey().Address() +func socketTestCases(t *testing.T) []socketTestCase { + tcpAddr := fmt.Sprintf("tcp://%s", testFreeTCPAddr(t)) + unixFilePath, err := testUnixAddr() + require.NoError(t, err) + unixAddr := fmt.Sprintf("unix://%s", unixFilePath) + return []socketTestCase{ + socketTestCase{ + addr: tcpAddr, + dialer: DialTCPFn(tcpAddr, testConnDeadline, ed25519.GenPrivKey()), + }, + socketTestCase{ + addr: unixAddr, + dialer: DialUnixFn(unixFilePath), + }, + } +} - assert.Equal(t, serverAddr, clientAddr) +func TestSocketPVAddress(t *testing.T) { + for _, tc := range socketTestCases(t) { + // Execute the test within a closure to ensure the deferred statements + // are called between each for loop iteration, for isolated test cases. + func() { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) + ) + defer sc.Stop() + defer rs.Stop() + + serverAddr := rs.privVal.GetPubKey().Address() + clientAddr := sc.GetPubKey().Address() + + assert.Equal(t, serverAddr, clientAddr) + }() + } } func TestSocketPVPubKey(t *testing.T) { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV()) - ) - defer sc.Stop() - defer rs.Stop() + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) + ) + defer sc.Stop() + defer rs.Stop() - clientKey := sc.GetPubKey() + clientKey := sc.GetPubKey() - privvalPubKey := rs.privVal.GetPubKey() + privvalPubKey := rs.privVal.GetPubKey() - assert.Equal(t, privvalPubKey, clientKey) + assert.Equal(t, privvalPubKey, clientKey) + }() + } } func TestSocketPVProposal(t *testing.T) { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV()) - - ts = time.Now() - privProposal = &types.Proposal{Timestamp: ts} - clientProposal = &types.Proposal{Timestamp: ts} - ) - defer sc.Stop() - defer rs.Stop() - - require.NoError(t, rs.privVal.SignProposal(chainID, privProposal)) - require.NoError(t, sc.SignProposal(chainID, clientProposal)) - assert.Equal(t, privProposal.Signature, clientProposal.Signature) + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) + + ts = time.Now() + privProposal = &types.Proposal{Timestamp: ts} + clientProposal = &types.Proposal{Timestamp: ts} + ) + defer sc.Stop() + defer rs.Stop() + + require.NoError(t, rs.privVal.SignProposal(chainID, privProposal)) + require.NoError(t, sc.SignProposal(chainID, clientProposal)) + assert.Equal(t, privProposal.Signature, clientProposal.Signature) + }() + } } func TestSocketPVVote(t *testing.T) { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV()) - - ts = time.Now() - vType = types.PrecommitType - want = &types.Vote{Timestamp: ts, Type: vType} - have = &types.Vote{Timestamp: ts, Type: vType} - ) - defer sc.Stop() - defer rs.Stop() - - require.NoError(t, rs.privVal.SignVote(chainID, want)) - require.NoError(t, sc.SignVote(chainID, have)) - assert.Equal(t, want.Signature, have.Signature) + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) + + ts = time.Now() + vType = types.PrecommitType + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer sc.Stop() + defer rs.Stop() + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) + }() + } } func TestSocketPVVoteResetDeadline(t *testing.T) { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV()) - - ts = time.Now() - vType = types.PrecommitType - want = &types.Vote{Timestamp: ts, Type: vType} - have = &types.Vote{Timestamp: ts, Type: vType} - ) - defer sc.Stop() - defer rs.Stop() - - time.Sleep(testConnDeadline2o3) - - require.NoError(t, rs.privVal.SignVote(chainID, want)) - require.NoError(t, sc.SignVote(chainID, have)) - assert.Equal(t, want.Signature, have.Signature) - - // This would exceed the deadline if it was not extended by the previous message - time.Sleep(testConnDeadline2o3) - - require.NoError(t, rs.privVal.SignVote(chainID, want)) - require.NoError(t, sc.SignVote(chainID, have)) - assert.Equal(t, want.Signature, have.Signature) + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) + + ts = time.Now() + vType = types.PrecommitType + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer sc.Stop() + defer rs.Stop() + + time.Sleep(testConnDeadline2o3) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) + + // This would exceed the deadline if it was not extended by the previous message + time.Sleep(testConnDeadline2o3) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) + }() + } } func TestSocketPVVoteKeepalive(t *testing.T) { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV()) - - ts = time.Now() - vType = types.PrecommitType - want = &types.Vote{Timestamp: ts, Type: vType} - have = &types.Vote{Timestamp: ts, Type: vType} - ) - defer sc.Stop() - defer rs.Stop() - - time.Sleep(testConnDeadline * 2) - - require.NoError(t, rs.privVal.SignVote(chainID, want)) - require.NoError(t, sc.SignVote(chainID, have)) - assert.Equal(t, want.Signature, have.Signature) + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) + + ts = time.Now() + vType = types.PrecommitType + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer sc.Stop() + defer rs.Stop() + + time.Sleep(testConnDeadline * 2) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) + }() + } } -func TestSocketPVDeadline(t *testing.T) { +// TestSocketPVDeadlineTCPOnly is not relevant to Unix domain sockets, since the +// OS knows instantaneously the state of both sides of the connection. +func TestSocketPVDeadlineTCPOnly(t *testing.T) { var ( - addr = testFreeAddr(t) + addr = testFreeTCPAddr(t) listenc = make(chan struct{}) thisConnTimeout = 100 * time.Millisecond sc = newSocketVal(log.TestingLogger(), addr, thisConnTimeout) @@ -172,218 +222,195 @@ func TestSocketPVDeadline(t *testing.T) { <-listenc } -func TestRemoteSignerRetry(t *testing.T) { - var ( - attemptc = make(chan int) - retries = 2 - ) - - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - go func(ln net.Listener, attemptc chan<- int) { - attempts := 0 - - for { - conn, err := ln.Accept() - require.NoError(t, err) - - err = conn.Close() - require.NoError(t, err) - - attempts++ - - if attempts == retries { - attemptc <- attempts - break - } - } - }(ln, attemptc) - - rs := NewRemoteSigner( - log.TestingLogger(), - cmn.RandStr(12), - types.NewMockPV(), - DialTCPFn(ln.Addr().String(), testConnDeadline, ed25519.GenPrivKey()), - ) - defer rs.Stop() - - RemoteSignerConnDeadline(time.Millisecond)(rs) - RemoteSignerConnRetries(retries)(rs) - - assert.Equal(t, rs.Start(), ErrDialRetryMax) - - select { - case attempts := <-attemptc: - assert.Equal(t, retries, attempts) - case <-time.After(100 * time.Millisecond): - t.Error("expected remote to observe connection attempts") - } -} - func TestRemoteSignVoteErrors(t *testing.T) { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewErroringMockPV()) - - ts = time.Now() - vType = types.PrecommitType - vote = &types.Vote{Timestamp: ts, Type: vType} - ) - defer sc.Stop() - defer rs.Stop() - - err := sc.SignVote("", vote) - require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) - - err = rs.privVal.SignVote(chainID, vote) - require.Error(t, err) - err = sc.SignVote(chainID, vote) - require.Error(t, err) + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID, types.NewErroringMockPV(), tc.addr, tc.dialer) + + ts = time.Now() + vType = types.PrecommitType + vote = &types.Vote{Timestamp: ts, Type: vType} + ) + defer sc.Stop() + defer rs.Stop() + + err := sc.SignVote("", vote) + require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) + + err = rs.privVal.SignVote(chainID, vote) + require.Error(t, err) + err = sc.SignVote(chainID, vote) + require.Error(t, err) + }() + } } func TestRemoteSignProposalErrors(t *testing.T) { - var ( - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID, types.NewErroringMockPV()) - - ts = time.Now() - proposal = &types.Proposal{Timestamp: ts} - ) - defer sc.Stop() - defer rs.Stop() - - err := sc.SignProposal("", proposal) - require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) - - err = rs.privVal.SignProposal(chainID, proposal) - require.Error(t, err) - - err = sc.SignProposal(chainID, proposal) - require.Error(t, err) + for _, tc := range socketTestCases(t) { + func() { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID, types.NewErroringMockPV(), tc.addr, tc.dialer) + + ts = time.Now() + proposal = &types.Proposal{Timestamp: ts} + ) + defer sc.Stop() + defer rs.Stop() + + err := sc.SignProposal("", proposal) + require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) + + err = rs.privVal.SignProposal(chainID, proposal) + require.Error(t, err) + + err = sc.SignProposal(chainID, proposal) + require.Error(t, err) + }() + } } func TestErrUnexpectedResponse(t *testing.T) { - var ( - addr = testFreeAddr(t) - logger = log.TestingLogger() - chainID = cmn.RandStr(12) - readyc = make(chan struct{}) - errc = make(chan error, 1) - - rs = NewRemoteSigner( - logger, - chainID, - types.NewMockPV(), - DialTCPFn(addr, testConnDeadline, ed25519.GenPrivKey()), - ) - sc = newSocketVal(logger, addr, testConnDeadline) - ) - - testStartSocketPV(t, readyc, sc) - defer sc.Stop() - RemoteSignerConnDeadline(time.Millisecond)(rs) - RemoteSignerConnRetries(100)(rs) - // we do not want to Start() the remote signer here and instead use the connection to - // reply with intentionally wrong replies below: - rsConn, err := rs.connect() - defer rsConn.Close() - require.NoError(t, err) - require.NotNil(t, rsConn) - // send over public key to get the remote signer running: - go testReadWriteResponse(t, &PubKeyResponse{}, rsConn) - <-readyc - - // Proposal: - go func(errc chan error) { - errc <- sc.SignProposal(chainID, &types.Proposal{}) - }(errc) - // read request and write wrong response: - go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn) - err = <-errc - require.Error(t, err) - require.Equal(t, err, ErrUnexpectedResponse) - - // Vote: - go func(errc chan error) { - errc <- sc.SignVote(chainID, &types.Vote{}) - }(errc) - // read request and write wrong response: - go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn) - err = <-errc - require.Error(t, err) - require.Equal(t, err, ErrUnexpectedResponse) + for _, tc := range socketTestCases(t) { + func() { + var ( + logger = log.TestingLogger() + chainID = cmn.RandStr(12) + readyc = make(chan struct{}) + errc = make(chan error, 1) + + rs = NewRemoteSigner( + logger, + chainID, + types.NewMockPV(), + tc.dialer, + ) + sc = newSocketVal(logger, tc.addr, testConnDeadline) + ) + + testStartSocketPV(t, readyc, sc) + defer sc.Stop() + RemoteSignerConnDeadline(time.Millisecond)(rs) + RemoteSignerConnRetries(100)(rs) + // we do not want to Start() the remote signer here and instead use the connection to + // reply with intentionally wrong replies below: + rsConn, err := rs.connect() + defer rsConn.Close() + require.NoError(t, err) + require.NotNil(t, rsConn) + // send over public key to get the remote signer running: + go testReadWriteResponse(t, &PubKeyResponse{}, rsConn) + <-readyc + + // Proposal: + go func(errc chan error) { + errc <- sc.SignProposal(chainID, &types.Proposal{}) + }(errc) + // read request and write wrong response: + go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn) + err = <-errc + require.Error(t, err) + require.Equal(t, err, ErrUnexpectedResponse) + + // Vote: + go func(errc chan error) { + errc <- sc.SignVote(chainID, &types.Vote{}) + }(errc) + // read request and write wrong response: + go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn) + err = <-errc + require.Error(t, err) + require.Equal(t, err, ErrUnexpectedResponse) + }() + } } -func TestRetryTCPConnToRemoteSigner(t *testing.T) { - var ( - addr = testFreeAddr(t) - logger = log.TestingLogger() - chainID = cmn.RandStr(12) - readyc = make(chan struct{}) - - rs = NewRemoteSigner( - logger, - chainID, - types.NewMockPV(), - DialTCPFn(addr, testConnDeadline, ed25519.GenPrivKey()), - ) - thisConnTimeout = testConnDeadline - sc = newSocketVal(logger, addr, thisConnTimeout) - ) - // Ping every: - SocketValHeartbeat(testHeartbeatTimeout)(sc) - - RemoteSignerConnDeadline(testConnDeadline)(rs) - RemoteSignerConnRetries(10)(rs) - - testStartSocketPV(t, readyc, sc) - defer sc.Stop() - require.NoError(t, rs.Start()) - assert.True(t, rs.IsRunning()) - - <-readyc - time.Sleep(testHeartbeatTimeout * 2) - - rs.Stop() - rs2 := NewRemoteSigner( - logger, - chainID, - types.NewMockPV(), - DialTCPFn(addr, testConnDeadline, ed25519.GenPrivKey()), - ) - // let some pings pass - time.Sleep(testHeartbeatTimeout3o2) - require.NoError(t, rs2.Start()) - assert.True(t, rs2.IsRunning()) - defer rs2.Stop() - - // give the client some time to re-establish the conn to the remote signer - // should see sth like this in the logs: - // - // E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" - // I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal - time.Sleep(testConnDeadline * 2) +func TestRetryConnToRemoteSigner(t *testing.T) { + for _, tc := range socketTestCases(t) { + func() { + var ( + logger = log.TestingLogger() + chainID = cmn.RandStr(12) + readyc = make(chan struct{}) + + rs = NewRemoteSigner( + logger, + chainID, + types.NewMockPV(), + tc.dialer, + ) + thisConnTimeout = testConnDeadline + sc = newSocketVal(logger, tc.addr, thisConnTimeout) + ) + // Ping every: + SocketValHeartbeat(testHeartbeatTimeout)(sc) + + RemoteSignerConnDeadline(testConnDeadline)(rs) + RemoteSignerConnRetries(10)(rs) + + testStartSocketPV(t, readyc, sc) + defer sc.Stop() + require.NoError(t, rs.Start()) + assert.True(t, rs.IsRunning()) + + <-readyc + time.Sleep(testHeartbeatTimeout * 2) + + rs.Stop() + rs2 := NewRemoteSigner( + logger, + chainID, + types.NewMockPV(), + tc.dialer, + ) + // let some pings pass + time.Sleep(testHeartbeatTimeout3o2) + require.NoError(t, rs2.Start()) + assert.True(t, rs2.IsRunning()) + defer rs2.Stop() + + // give the client some time to re-establish the conn to the remote signer + // should see sth like this in the logs: + // + // E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" + // I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal + time.Sleep(testConnDeadline * 2) + }() + } } func newSocketVal(logger log.Logger, addr string, connDeadline time.Duration) *SocketVal { - ln, err := net.Listen(cmn.ProtocolAndAddress(addr)) + proto, address := cmn.ProtocolAndAddress(addr) + ln, err := net.Listen(proto, address) + logger.Info("Listening at", "proto", proto, "address", address) if err != nil { panic(err) } - tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) - TCPListenerAcceptDeadline(testAcceptDeadline)(tcpLn) - TCPListenerConnDeadline(testConnDeadline)(tcpLn) - return NewSocketVal(logger, tcpLn) + var svln net.Listener + if proto == "unix" { + unixLn := NewUnixListener(ln) + UnixListenerAcceptDeadline(testAcceptDeadline)(unixLn) + UnixListenerConnDeadline(connDeadline)(unixLn) + svln = unixLn + } else { + tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) + TCPListenerAcceptDeadline(testAcceptDeadline)(tcpLn) + TCPListenerConnDeadline(connDeadline)(tcpLn) + svln = tcpLn + } + return NewSocketVal(logger, svln) } func testSetupSocketPair( t *testing.T, chainID string, privValidator types.PrivValidator, + addr string, + dialer Dialer, ) (*SocketVal, *RemoteSigner) { var ( - addr = testFreeAddr(t) logger = log.TestingLogger() privVal = privValidator readyc = make(chan struct{}) @@ -391,7 +418,7 @@ func testSetupSocketPair( logger, chainID, privVal, - DialTCPFn(addr, testConnDeadline, ed25519.GenPrivKey()), + dialer, ) thisConnTimeout = testConnDeadline @@ -429,8 +456,8 @@ func testStartSocketPV(t *testing.T, readyc chan struct{}, sc *SocketVal) { }(sc) } -// testFreeAddr claims a free port so we don't block on listener being ready. -func testFreeAddr(t *testing.T) string { +// testFreeTCPAddr claims a free port so we don't block on listener being ready. +func testFreeTCPAddr(t *testing.T) string { ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer ln.Close() diff --git a/privval/remote_signer_test.go b/privval/remote_signer_test.go new file mode 100644 index 000000000..8927e2242 --- /dev/null +++ b/privval/remote_signer_test.go @@ -0,0 +1,68 @@ +package privval + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/crypto/ed25519" + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/types" +) + +// TestRemoteSignerRetryTCPOnly will test connection retry attempts over TCP. We +// don't need this for Unix sockets because the OS instantly knows the state of +// both ends of the socket connection. This basically causes the +// RemoteSigner.dialer() call inside RemoteSigner.connect() to return +// successfully immediately, putting an instant stop to any retry attempts. +func TestRemoteSignerRetryTCPOnly(t *testing.T) { + var ( + attemptc = make(chan int) + retries = 2 + ) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + go func(ln net.Listener, attemptc chan<- int) { + attempts := 0 + + for { + conn, err := ln.Accept() + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + + attempts++ + + if attempts == retries { + attemptc <- attempts + break + } + } + }(ln, attemptc) + + rs := NewRemoteSigner( + log.TestingLogger(), + cmn.RandStr(12), + types.NewMockPV(), + DialTCPFn(ln.Addr().String(), testConnDeadline, ed25519.GenPrivKey()), + ) + defer rs.Stop() + + RemoteSignerConnDeadline(time.Millisecond)(rs) + RemoteSignerConnRetries(retries)(rs) + + assert.Equal(t, rs.Start(), ErrDialRetryMax) + + select { + case attempts := <-attemptc: + assert.Equal(t, retries, attempts) + case <-time.After(100 * time.Millisecond): + t.Error("expected remote to observe connection attempts") + } +} diff --git a/privval/socket.go b/privval/socket.go index 96fa6c8e8..bd9cd9209 100644 --- a/privval/socket.go +++ b/privval/socket.go @@ -157,7 +157,7 @@ type timeoutConn struct { connDeadline time.Duration } -// newTimeoutConn returns an instance of newTCPTimeoutConn. +// newTimeoutConn returns an instance of timeoutConn. func newTimeoutConn( conn net.Conn, connDeadline time.Duration) *timeoutConn { diff --git a/privval/socket_test.go b/privval/socket_test.go index dfce45b57..b411b7f3a 100644 --- a/privval/socket_test.go +++ b/privval/socket_test.go @@ -29,7 +29,7 @@ type listenerTestCase struct { // testUnixAddr will attempt to obtain a platform-independent temporary file // name for a Unix socket func testUnixAddr() (string, error) { - f, err := ioutil.TempFile("", "tendermint-privval-test") + f, err := ioutil.TempFile("", "tendermint-privval-test-*") if err != nil { return "", err }