package privval import ( "fmt" "net" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/crypto/ed25519" cmn "github.com/tendermint/tendermint/libs/common" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/types" ) var ( testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second testTimeoutReadWrite = 100 * time.Millisecond testTimeoutReadWrite2o3 = 66 * time.Millisecond // 2/3 of the other one testTimeoutHeartbeat = 10 * time.Millisecond testTimeoutHeartbeat3o2 = 6 * time.Millisecond // 3/2 of the other one ) type socketTestCase struct { addr string dialer SocketDialer } func socketTestCases(t *testing.T) []socketTestCase { tcpAddr := fmt.Sprintf("tcp://%s", testFreeTCPAddr(t)) unixFilePath, err := testUnixAddr() require.NoError(t, err) unixAddr := fmt.Sprintf("unix://%s", unixFilePath) return []socketTestCase{ { addr: tcpAddr, dialer: DialTCPFn(tcpAddr, testTimeoutReadWrite, ed25519.GenPrivKey()), }, { addr: unixAddr, dialer: DialUnixFn(unixFilePath), }, } } func TestSocketPVAddress(t *testing.T) { for _, tc := range socketTestCases(t) { // Execute the test within a closure to ensure the deferred statements // are called between each for loop iteration, for isolated test cases. func() { var ( chainID = cmn.RandStr(12) validatorEndpoint, serviceEndpoint = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) ) defer validatorEndpoint.Stop() defer serviceEndpoint.Stop() serviceAddr := serviceEndpoint.privVal.GetPubKey().Address() validatorAddr := validatorEndpoint.GetPubKey().Address() assert.Equal(t, serviceAddr, validatorAddr) }() } } func TestSocketPVPubKey(t *testing.T) { for _, tc := range socketTestCases(t) { func() { var ( chainID = cmn.RandStr(12) validatorEndpoint, serviceEndpoint = testSetupSocketPair( t, chainID, types.NewMockPV(), tc.addr, tc.dialer) ) defer validatorEndpoint.Stop() defer serviceEndpoint.Stop() clientKey := validatorEndpoint.GetPubKey() privvalPubKey := serviceEndpoint.privVal.GetPubKey() assert.Equal(t, privvalPubKey, clientKey) }() } } func TestSocketPVProposal(t *testing.T) { for _, tc := range socketTestCases(t) { func() { var ( chainID = cmn.RandStr(12) validatorEndpoint, serviceEndpoint = testSetupSocketPair( t, chainID, types.NewMockPV(), tc.addr, tc.dialer) ts = time.Now() privProposal = &types.Proposal{Timestamp: ts} clientProposal = &types.Proposal{Timestamp: ts} ) defer validatorEndpoint.Stop() defer serviceEndpoint.Stop() require.NoError(t, serviceEndpoint.privVal.SignProposal(chainID, privProposal)) require.NoError(t, validatorEndpoint.SignProposal(chainID, clientProposal)) assert.Equal(t, privProposal.Signature, clientProposal.Signature) }() } } func TestSocketPVVote(t *testing.T) { for _, tc := range socketTestCases(t) { func() { var ( chainID = cmn.RandStr(12) validatorEndpoint, serviceEndpoint = testSetupSocketPair( t, chainID, types.NewMockPV(), tc.addr, tc.dialer) ts = time.Now() vType = types.PrecommitType want = &types.Vote{Timestamp: ts, Type: vType} have = &types.Vote{Timestamp: ts, Type: vType} ) defer validatorEndpoint.Stop() defer serviceEndpoint.Stop() require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) require.NoError(t, validatorEndpoint.SignVote(chainID, have)) assert.Equal(t, want.Signature, have.Signature) }() } } func TestSocketPVVoteResetDeadline(t *testing.T) { for _, tc := range socketTestCases(t) { func() { var ( chainID = cmn.RandStr(12) validatorEndpoint, serviceEndpoint = testSetupSocketPair( t, chainID, types.NewMockPV(), tc.addr, tc.dialer) ts = time.Now() vType = types.PrecommitType want = &types.Vote{Timestamp: ts, Type: vType} have = &types.Vote{Timestamp: ts, Type: vType} ) defer validatorEndpoint.Stop() defer serviceEndpoint.Stop() time.Sleep(testTimeoutReadWrite2o3) require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) require.NoError(t, validatorEndpoint.SignVote(chainID, have)) assert.Equal(t, want.Signature, have.Signature) // This would exceed the deadline if it was not extended by the previous message time.Sleep(testTimeoutReadWrite2o3) require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) require.NoError(t, validatorEndpoint.SignVote(chainID, have)) assert.Equal(t, want.Signature, have.Signature) }() } } func TestSocketPVVoteKeepalive(t *testing.T) { for _, tc := range socketTestCases(t) { func() { var ( chainID = cmn.RandStr(12) validatorEndpoint, serviceEndpoint = testSetupSocketPair( t, chainID, types.NewMockPV(), tc.addr, tc.dialer) ts = time.Now() vType = types.PrecommitType want = &types.Vote{Timestamp: ts, Type: vType} have = &types.Vote{Timestamp: ts, Type: vType} ) defer validatorEndpoint.Stop() defer serviceEndpoint.Stop() time.Sleep(testTimeoutReadWrite * 2) require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) require.NoError(t, validatorEndpoint.SignVote(chainID, have)) assert.Equal(t, want.Signature, have.Signature) }() } } func TestSocketPVDeadline(t *testing.T) { for _, tc := range socketTestCases(t) { func() { var ( listenc = make(chan struct{}) thisConnTimeout = 100 * time.Millisecond validatorEndpoint = newSignerValidatorEndpoint(log.TestingLogger(), tc.addr, thisConnTimeout) ) go func(sc *SignerValidatorEndpoint) { defer close(listenc) // Note: the TCP connection times out at the accept() phase, // whereas the Unix domain sockets connection times out while // attempting to fetch the remote signer's public key. assert.True(t, IsConnTimeout(sc.Start())) assert.False(t, sc.IsRunning()) }(validatorEndpoint) for { _, err := cmn.Connect(tc.addr) if err == nil { break } } <-listenc }() } } func TestRemoteSignVoteErrors(t *testing.T) { for _, tc := range socketTestCases(t) { func() { var ( chainID = cmn.RandStr(12) validatorEndpoint, serviceEndpoint = testSetupSocketPair( t, chainID, types.NewErroringMockPV(), tc.addr, tc.dialer) ts = time.Now() vType = types.PrecommitType vote = &types.Vote{Timestamp: ts, Type: vType} ) defer validatorEndpoint.Stop() defer serviceEndpoint.Stop() err := validatorEndpoint.SignVote("", vote) require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) err = serviceEndpoint.privVal.SignVote(chainID, vote) require.Error(t, err) err = validatorEndpoint.SignVote(chainID, vote) require.Error(t, err) }() } } func TestRemoteSignProposalErrors(t *testing.T) { for _, tc := range socketTestCases(t) { func() { var ( chainID = cmn.RandStr(12) validatorEndpoint, serviceEndpoint = testSetupSocketPair( t, chainID, types.NewErroringMockPV(), tc.addr, tc.dialer) ts = time.Now() proposal = &types.Proposal{Timestamp: ts} ) defer validatorEndpoint.Stop() defer serviceEndpoint.Stop() err := validatorEndpoint.SignProposal("", proposal) require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) err = serviceEndpoint.privVal.SignProposal(chainID, proposal) require.Error(t, err) err = validatorEndpoint.SignProposal(chainID, proposal) require.Error(t, err) }() } } func TestErrUnexpectedResponse(t *testing.T) { for _, tc := range socketTestCases(t) { func() { var ( logger = log.TestingLogger() chainID = cmn.RandStr(12) readyCh = make(chan struct{}) errCh = make(chan error, 1) serviceEndpoint = NewSignerServiceEndpoint( logger, chainID, types.NewMockPV(), tc.dialer, ) validatorEndpoint = newSignerValidatorEndpoint( logger, tc.addr, testTimeoutReadWrite) ) testStartEndpoint(t, readyCh, validatorEndpoint) defer validatorEndpoint.Stop() SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint) SignerServiceEndpointConnRetries(100)(serviceEndpoint) // we do not want to Start() the remote signer here and instead use the connection to // reply with intentionally wrong replies below: rsConn, err := serviceEndpoint.connect() require.NoError(t, err) require.NotNil(t, rsConn) defer rsConn.Close() // send over public key to get the remote signer running: go testReadWriteResponse(t, &PubKeyResponse{}, rsConn) <-readyCh // Proposal: go func(errc chan error) { errc <- validatorEndpoint.SignProposal(chainID, &types.Proposal{}) }(errCh) // read request and write wrong response: go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn) err = <-errCh require.Error(t, err) require.Equal(t, err, ErrUnexpectedResponse) // Vote: go func(errc chan error) { errc <- validatorEndpoint.SignVote(chainID, &types.Vote{}) }(errCh) // read request and write wrong response: go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn) err = <-errCh require.Error(t, err) require.Equal(t, err, ErrUnexpectedResponse) }() } } func TestRetryConnToRemoteSigner(t *testing.T) { for _, tc := range socketTestCases(t) { func() { var ( logger = log.TestingLogger() chainID = cmn.RandStr(12) readyCh = make(chan struct{}) serviceEndpoint = NewSignerServiceEndpoint( logger, chainID, types.NewMockPV(), tc.dialer, ) thisConnTimeout = testTimeoutReadWrite validatorEndpoint = newSignerValidatorEndpoint(logger, tc.addr, thisConnTimeout) ) // Ping every: SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) SignerServiceEndpointConnRetries(10)(serviceEndpoint) testStartEndpoint(t, readyCh, validatorEndpoint) defer validatorEndpoint.Stop() require.NoError(t, serviceEndpoint.Start()) assert.True(t, serviceEndpoint.IsRunning()) <-readyCh time.Sleep(testTimeoutHeartbeat * 2) serviceEndpoint.Stop() rs2 := NewSignerServiceEndpoint( logger, chainID, types.NewMockPV(), tc.dialer, ) // let some pings pass time.Sleep(testTimeoutHeartbeat3o2) require.NoError(t, rs2.Start()) assert.True(t, rs2.IsRunning()) defer rs2.Stop() // give the client some time to re-establish the conn to the remote signer // should see sth like this in the logs: // // E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" // I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal time.Sleep(testTimeoutReadWrite * 2) }() } } func newSignerValidatorEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerValidatorEndpoint { proto, address := cmn.ProtocolAndAddress(addr) ln, err := net.Listen(proto, address) logger.Info("Listening at", "proto", proto, "address", address) if err != nil { panic(err) } var listener net.Listener if proto == "unix" { unixLn := NewUnixListener(ln) UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn) UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn) listener = unixLn } else { tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn) TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn) listener = tcpLn } return NewSignerValidatorEndpoint(logger, listener) } func testSetupSocketPair( t *testing.T, chainID string, privValidator types.PrivValidator, addr string, socketDialer SocketDialer, ) (*SignerValidatorEndpoint, *SignerServiceEndpoint) { var ( logger = log.TestingLogger() privVal = privValidator readyc = make(chan struct{}) serviceEndpoint = NewSignerServiceEndpoint( logger, chainID, privVal, socketDialer, ) thisConnTimeout = testTimeoutReadWrite validatorEndpoint = newSignerValidatorEndpoint(logger, addr, thisConnTimeout) ) SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) SignerServiceEndpointConnRetries(1e6)(serviceEndpoint) testStartEndpoint(t, readyc, validatorEndpoint) require.NoError(t, serviceEndpoint.Start()) assert.True(t, serviceEndpoint.IsRunning()) <-readyc return validatorEndpoint, serviceEndpoint } func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) { _, err := readMsg(rsConn) require.NoError(t, err) err = writeMsg(rsConn, resp) require.NoError(t, err) } func testStartEndpoint(t *testing.T, readyCh chan struct{}, sc *SignerValidatorEndpoint) { go func(sc *SignerValidatorEndpoint) { require.NoError(t, sc.Start()) assert.True(t, sc.IsRunning()) readyCh <- struct{}{} }(sc) } // testFreeTCPAddr claims a free port so we don't block on listener being ready. func testFreeTCPAddr(t *testing.T) string { ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer ln.Close() return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port) }