You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

506 lines
14 KiB

package privval
import (
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/crypto/ed25519"
cmn "github.com/tendermint/tendermint/libs/common"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/types"
)
var (
testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second
testTimeoutReadWrite = 100 * time.Millisecond
testTimeoutReadWrite2o3 = 66 * time.Millisecond // 2/3 of the other one
testTimeoutHeartbeat = 10 * time.Millisecond
testTimeoutHeartbeat3o2 = 6 * time.Millisecond // 3/2 of the other one
)
type socketTestCase struct {
addr string
dialer SocketDialer
}
func socketTestCases(t *testing.T) []socketTestCase {
tcpAddr := fmt.Sprintf("tcp://%s", testFreeTCPAddr(t))
unixFilePath, err := testUnixAddr()
require.NoError(t, err)
unixAddr := fmt.Sprintf("unix://%s", unixFilePath)
return []socketTestCase{
{
addr: tcpAddr,
dialer: DialTCPFn(tcpAddr, testTimeoutReadWrite, ed25519.GenPrivKey()),
},
{
addr: unixAddr,
dialer: DialUnixFn(unixFilePath),
},
}
}
func TestSocketPVAddress(t *testing.T) {
for _, tc := range socketTestCases(t) {
// Execute the test within a closure to ensure the deferred statements
// are called between each for loop iteration, for isolated test cases.
func() {
var (
chainID = cmn.RandStr(12)
validatorEndpoint, serviceEndpoint = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer)
)
defer validatorEndpoint.Stop()
defer serviceEndpoint.Stop()
serviceAddr := serviceEndpoint.privVal.GetPubKey().Address()
validatorAddr := validatorEndpoint.GetPubKey().Address()
assert.Equal(t, serviceAddr, validatorAddr)
}()
}
}
func TestSocketPVPubKey(t *testing.T) {
for _, tc := range socketTestCases(t) {
func() {
var (
chainID = cmn.RandStr(12)
validatorEndpoint, serviceEndpoint = testSetupSocketPair(
t,
chainID,
types.NewMockPV(),
tc.addr,
tc.dialer)
)
defer validatorEndpoint.Stop()
defer serviceEndpoint.Stop()
clientKey := validatorEndpoint.GetPubKey()
privvalPubKey := serviceEndpoint.privVal.GetPubKey()
assert.Equal(t, privvalPubKey, clientKey)
}()
}
}
func TestSocketPVProposal(t *testing.T) {
for _, tc := range socketTestCases(t) {
func() {
var (
chainID = cmn.RandStr(12)
validatorEndpoint, serviceEndpoint = testSetupSocketPair(
t,
chainID,
types.NewMockPV(),
tc.addr,
tc.dialer)
ts = time.Now()
privProposal = &types.Proposal{Timestamp: ts}
clientProposal = &types.Proposal{Timestamp: ts}
)
defer validatorEndpoint.Stop()
defer serviceEndpoint.Stop()
require.NoError(t, serviceEndpoint.privVal.SignProposal(chainID, privProposal))
require.NoError(t, validatorEndpoint.SignProposal(chainID, clientProposal))
assert.Equal(t, privProposal.Signature, clientProposal.Signature)
}()
}
}
func TestSocketPVVote(t *testing.T) {
for _, tc := range socketTestCases(t) {
func() {
var (
chainID = cmn.RandStr(12)
validatorEndpoint, serviceEndpoint = testSetupSocketPair(
t,
chainID,
types.NewMockPV(),
tc.addr,
tc.dialer)
ts = time.Now()
vType = types.PrecommitType
want = &types.Vote{Timestamp: ts, Type: vType}
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer validatorEndpoint.Stop()
defer serviceEndpoint.Stop()
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want))
require.NoError(t, validatorEndpoint.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}()
}
}
func TestSocketPVVoteResetDeadline(t *testing.T) {
for _, tc := range socketTestCases(t) {
func() {
var (
chainID = cmn.RandStr(12)
validatorEndpoint, serviceEndpoint = testSetupSocketPair(
t,
chainID,
types.NewMockPV(),
tc.addr,
tc.dialer)
ts = time.Now()
vType = types.PrecommitType
want = &types.Vote{Timestamp: ts, Type: vType}
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer validatorEndpoint.Stop()
defer serviceEndpoint.Stop()
time.Sleep(testTimeoutReadWrite2o3)
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want))
require.NoError(t, validatorEndpoint.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
// This would exceed the deadline if it was not extended by the previous message
time.Sleep(testTimeoutReadWrite2o3)
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want))
require.NoError(t, validatorEndpoint.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}()
}
}
func TestSocketPVVoteKeepalive(t *testing.T) {
for _, tc := range socketTestCases(t) {
func() {
var (
chainID = cmn.RandStr(12)
validatorEndpoint, serviceEndpoint = testSetupSocketPair(
t,
chainID,
types.NewMockPV(),
tc.addr,
tc.dialer)
ts = time.Now()
vType = types.PrecommitType
want = &types.Vote{Timestamp: ts, Type: vType}
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer validatorEndpoint.Stop()
defer serviceEndpoint.Stop()
time.Sleep(testTimeoutReadWrite * 2)
require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want))
require.NoError(t, validatorEndpoint.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}()
}
}
func TestSocketPVDeadline(t *testing.T) {
for _, tc := range socketTestCases(t) {
func() {
var (
listenc = make(chan struct{})
thisConnTimeout = 100 * time.Millisecond
validatorEndpoint = newSignerValidatorEndpoint(log.TestingLogger(), tc.addr, thisConnTimeout)
)
go func(sc *SignerValidatorEndpoint) {
defer close(listenc)
// Note: the TCP connection times out at the accept() phase,
// whereas the Unix domain sockets connection times out while
// attempting to fetch the remote signer's public key.
assert.True(t, IsConnTimeout(sc.Start()))
assert.False(t, sc.IsRunning())
}(validatorEndpoint)
for {
_, err := cmn.Connect(tc.addr)
if err == nil {
break
}
}
<-listenc
}()
}
}
func TestRemoteSignVoteErrors(t *testing.T) {
for _, tc := range socketTestCases(t) {
func() {
var (
chainID = cmn.RandStr(12)
validatorEndpoint, serviceEndpoint = testSetupSocketPair(
t,
chainID,
types.NewErroringMockPV(),
tc.addr,
tc.dialer)
ts = time.Now()
vType = types.PrecommitType
vote = &types.Vote{Timestamp: ts, Type: vType}
)
defer validatorEndpoint.Stop()
defer serviceEndpoint.Stop()
err := validatorEndpoint.SignVote("", vote)
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
err = serviceEndpoint.privVal.SignVote(chainID, vote)
require.Error(t, err)
err = validatorEndpoint.SignVote(chainID, vote)
require.Error(t, err)
}()
}
}
func TestRemoteSignProposalErrors(t *testing.T) {
for _, tc := range socketTestCases(t) {
func() {
var (
chainID = cmn.RandStr(12)
validatorEndpoint, serviceEndpoint = testSetupSocketPair(
t,
chainID,
types.NewErroringMockPV(),
tc.addr,
tc.dialer)
ts = time.Now()
proposal = &types.Proposal{Timestamp: ts}
)
defer validatorEndpoint.Stop()
defer serviceEndpoint.Stop()
err := validatorEndpoint.SignProposal("", proposal)
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
err = serviceEndpoint.privVal.SignProposal(chainID, proposal)
require.Error(t, err)
err = validatorEndpoint.SignProposal(chainID, proposal)
require.Error(t, err)
}()
}
}
func TestErrUnexpectedResponse(t *testing.T) {
for _, tc := range socketTestCases(t) {
func() {
var (
logger = log.TestingLogger()
chainID = cmn.RandStr(12)
readyCh = make(chan struct{})
errCh = make(chan error, 1)
serviceEndpoint = NewSignerServiceEndpoint(
logger,
chainID,
types.NewMockPV(),
tc.dialer,
)
validatorEndpoint = newSignerValidatorEndpoint(
logger,
tc.addr,
testTimeoutReadWrite)
)
testStartEndpoint(t, readyCh, validatorEndpoint)
defer validatorEndpoint.Stop()
SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint)
SignerServiceEndpointConnRetries(100)(serviceEndpoint)
// we do not want to Start() the remote signer here and instead use the connection to
// reply with intentionally wrong replies below:
rsConn, err := serviceEndpoint.connect()
require.NoError(t, err)
require.NotNil(t, rsConn)
defer rsConn.Close()
// send over public key to get the remote signer running:
go testReadWriteResponse(t, &PubKeyResponse{}, rsConn)
<-readyCh
// Proposal:
go func(errc chan error) {
errc <- validatorEndpoint.SignProposal(chainID, &types.Proposal{})
}(errCh)
// read request and write wrong response:
go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn)
err = <-errCh
require.Error(t, err)
require.Equal(t, err, ErrUnexpectedResponse)
// Vote:
go func(errc chan error) {
errc <- validatorEndpoint.SignVote(chainID, &types.Vote{})
}(errCh)
// read request and write wrong response:
go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn)
err = <-errCh
require.Error(t, err)
require.Equal(t, err, ErrUnexpectedResponse)
}()
}
}
func TestRetryConnToRemoteSigner(t *testing.T) {
for _, tc := range socketTestCases(t) {
func() {
var (
logger = log.TestingLogger()
chainID = cmn.RandStr(12)
readyCh = make(chan struct{})
serviceEndpoint = NewSignerServiceEndpoint(
logger,
chainID,
types.NewMockPV(),
tc.dialer,
)
thisConnTimeout = testTimeoutReadWrite
validatorEndpoint = newSignerValidatorEndpoint(logger, tc.addr, thisConnTimeout)
)
// Ping every:
SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint)
SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint)
SignerServiceEndpointConnRetries(10)(serviceEndpoint)
testStartEndpoint(t, readyCh, validatorEndpoint)
defer validatorEndpoint.Stop()
require.NoError(t, serviceEndpoint.Start())
assert.True(t, serviceEndpoint.IsRunning())
<-readyCh
time.Sleep(testTimeoutHeartbeat * 2)
serviceEndpoint.Stop()
rs2 := NewSignerServiceEndpoint(
logger,
chainID,
types.NewMockPV(),
tc.dialer,
)
// let some pings pass
time.Sleep(testTimeoutHeartbeat3o2)
require.NoError(t, rs2.Start())
assert.True(t, rs2.IsRunning())
defer rs2.Stop()
// give the client some time to re-establish the conn to the remote signer
// should see sth like this in the logs:
//
// E[10016-01-10|17:12:46.128] Ping err="remote signer timed out"
// I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal
time.Sleep(testTimeoutReadWrite * 2)
}()
}
}
func newSignerValidatorEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerValidatorEndpoint {
proto, address := cmn.ProtocolAndAddress(addr)
ln, err := net.Listen(proto, address)
logger.Info("Listening at", "proto", proto, "address", address)
if err != nil {
panic(err)
}
var listener net.Listener
if proto == "unix" {
unixLn := NewUnixListener(ln)
UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn)
UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn)
listener = unixLn
} else {
tcpLn := NewTCPListener(ln, ed25519.GenPrivKey())
TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn)
TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn)
listener = tcpLn
}
return NewSignerValidatorEndpoint(logger, listener)
}
func testSetupSocketPair(
t *testing.T,
chainID string,
privValidator types.PrivValidator,
addr string,
socketDialer SocketDialer,
) (*SignerValidatorEndpoint, *SignerServiceEndpoint) {
var (
logger = log.TestingLogger()
privVal = privValidator
readyc = make(chan struct{})
serviceEndpoint = NewSignerServiceEndpoint(
logger,
chainID,
privVal,
socketDialer,
)
thisConnTimeout = testTimeoutReadWrite
validatorEndpoint = newSignerValidatorEndpoint(logger, addr, thisConnTimeout)
)
SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint)
SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint)
SignerServiceEndpointConnRetries(1e6)(serviceEndpoint)
testStartEndpoint(t, readyc, validatorEndpoint)
require.NoError(t, serviceEndpoint.Start())
assert.True(t, serviceEndpoint.IsRunning())
<-readyc
return validatorEndpoint, serviceEndpoint
}
func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) {
_, err := readMsg(rsConn)
require.NoError(t, err)
err = writeMsg(rsConn, resp)
require.NoError(t, err)
}
func testStartEndpoint(t *testing.T, readyCh chan struct{}, sc *SignerValidatorEndpoint) {
go func(sc *SignerValidatorEndpoint) {
require.NoError(t, sc.Start())
assert.True(t, sc.IsRunning())
readyCh <- struct{}{}
}(sc)
}
// testFreeTCPAddr claims a free port so we don't block on listener being ready.
func testFreeTCPAddr(t *testing.T) string {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port)
}