diff --git a/node/node.go b/node/node.go index ed0fa1198..9939f1c65 100644 --- a/node/node.go +++ b/node/node.go @@ -215,7 +215,7 @@ func NewNode(config *cfg.Config, // TODO: persist this key so external signer // can actually authenticate us privKey = ed25519.GenPrivKey() - pvsc = privval.NewSocketPV( + pvsc = privval.NewTCPVal( logger.With("module", "privval"), config.PrivValidatorListenAddr, privKey, @@ -579,7 +579,7 @@ func (n *Node) OnStop() { } } - if pvsc, ok := n.privValidator.(*privval.SocketPV); ok { + if pvsc, ok := n.privValidator.(*privval.TCPVal); ok { if err := pvsc.Stop(); err != nil { n.Logger.Error("Error stopping priv validator socket client", "err", err) } diff --git a/privval/ipc.go b/privval/ipc.go new file mode 100644 index 000000000..eda23fe6f --- /dev/null +++ b/privval/ipc.go @@ -0,0 +1,120 @@ +package privval + +import ( + "net" + "time" + + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/types" +) + +// IPCValOption sets an optional parameter on the SocketPV. +type IPCValOption func(*IPCVal) + +// IPCValConnTimeout sets the read and write timeout for connections +// from external signing processes. +func IPCValConnTimeout(timeout time.Duration) IPCValOption { + return func(sc *IPCVal) { sc.connTimeout = timeout } +} + +// IPCValHeartbeat sets the period on which to check the liveness of the +// connected Signer connections. +func IPCValHeartbeat(period time.Duration) IPCValOption { + return func(sc *IPCVal) { sc.connHeartbeat = period } +} + +// IPCVal implements PrivValidator, it uses a unix socket to request signatures +// from an external process. +type IPCVal struct { + cmn.BaseService + *RemoteSignerClient + + addr string + + connTimeout time.Duration + connHeartbeat time.Duration + + conn net.Conn + cancelPing chan struct{} + pingTicker *time.Ticker +} + +// Check that IPCVal implements PrivValidator. +var _ types.PrivValidator = (*IPCVal)(nil) + +// NewIPCVal returns an instance of IPCVal. +func NewIPCVal( + logger log.Logger, + socketAddr string, +) *IPCVal { + sc := &IPCVal{ + addr: socketAddr, + connTimeout: connTimeout, + connHeartbeat: connHeartbeat, + } + + sc.BaseService = *cmn.NewBaseService(logger, "IPCVal", sc) + + return sc +} + +// OnStart implements cmn.Service. +func (sc *IPCVal) OnStart() error { + err := sc.connect() + if err != nil { + sc.Logger.Error("OnStart", "err", err) + return err + } + + sc.RemoteSignerClient = NewRemoteSignerClient(sc.conn) + + // Start a routine to keep the connection alive + sc.cancelPing = make(chan struct{}, 1) + sc.pingTicker = time.NewTicker(sc.connHeartbeat) + go func() { + for { + select { + case <-sc.pingTicker.C: + err := sc.Ping() + if err != nil { + sc.Logger.Error("Ping", "err", err) + } + case <-sc.cancelPing: + sc.pingTicker.Stop() + return + } + } + }() + + return nil +} + +// OnStop implements cmn.Service. +func (sc *IPCVal) OnStop() { + if sc.cancelPing != nil { + close(sc.cancelPing) + } + + if sc.conn != nil { + if err := sc.conn.Close(); err != nil { + sc.Logger.Error("OnStop", "err", err) + } + } +} + +func (sc *IPCVal) connect() error { + la, err := net.ResolveUnixAddr("unix", sc.addr) + if err != nil { + return err + } + + conn, err := net.DialUnix("unix", nil, la) + if err != nil { + return err + } + + sc.conn = newTimeoutConn(conn, sc.connTimeout) + + return nil +} diff --git a/privval/ipc_server.go b/privval/ipc_server.go new file mode 100644 index 000000000..d3907cbdb --- /dev/null +++ b/privval/ipc_server.go @@ -0,0 +1,131 @@ +package privval + +import ( + "io" + "net" + "time" + + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/types" +) + +// IPCRemoteSignerOption sets an optional parameter on the IPCRemoteSigner. +type IPCRemoteSignerOption func(*IPCRemoteSigner) + +// IPCRemoteSignerConnDeadline sets the read and write deadline for connections +// from external signing processes. +func IPCRemoteSignerConnDeadline(deadline time.Duration) IPCRemoteSignerOption { + return func(ss *IPCRemoteSigner) { ss.connDeadline = deadline } +} + +// IPCRemoteSignerConnRetries sets the amount of attempted retries to connect. +func IPCRemoteSignerConnRetries(retries int) IPCRemoteSignerOption { + return func(ss *IPCRemoteSigner) { ss.connRetries = retries } +} + +// IPCRemoteSigner is a RPC implementation of PrivValidator that listens on a unix socket. +type IPCRemoteSigner struct { + cmn.BaseService + + addr string + chainID string + connDeadline time.Duration + connRetries int + privVal types.PrivValidator + + listener *net.UnixListener +} + +// NewIPCRemoteSigner returns an instance of IPCRemoteSigner. +func NewIPCRemoteSigner( + logger log.Logger, + chainID, socketAddr string, + privVal types.PrivValidator, +) *IPCRemoteSigner { + rs := &IPCRemoteSigner{ + addr: socketAddr, + chainID: chainID, + connDeadline: time.Second * defaultConnDeadlineSeconds, + connRetries: defaultDialRetries, + privVal: privVal, + } + + rs.BaseService = *cmn.NewBaseService(logger, "IPCRemoteSigner", rs) + + return rs +} + +// OnStart implements cmn.Service. +func (rs *IPCRemoteSigner) OnStart() error { + err := rs.listen() + if err != nil { + err = cmn.ErrorWrap(err, "listen") + rs.Logger.Error("OnStart", "err", err) + return err + } + + go func() { + for { + conn, err := rs.listener.AcceptUnix() + if err != nil { + return + } + go rs.handleConnection(conn) + } + }() + + return nil +} + +// OnStop implements cmn.Service. +func (rs *IPCRemoteSigner) OnStop() { + if rs.listener != nil { + if err := rs.listener.Close(); err != nil { + rs.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed")) + } + } +} + +func (rs *IPCRemoteSigner) listen() error { + la, err := net.ResolveUnixAddr("unix", rs.addr) + if err != nil { + return err + } + + rs.listener, err = net.ListenUnix("unix", la) + + return err +} + +func (rs *IPCRemoteSigner) handleConnection(conn net.Conn) { + for { + if !rs.IsRunning() { + return // Ignore error from listener closing. + } + + // Reset the connection deadline + conn.SetDeadline(time.Now().Add(rs.connDeadline)) + + req, err := readMsg(conn) + if err != nil { + if err != io.EOF { + rs.Logger.Error("handleConnection", "err", err) + } + return + } + + res, err := handleRequest(req, rs.chainID, rs.privVal) + + if err != nil { + // only log the error; we'll reply with an error in res + rs.Logger.Error("handleConnection", "err", err) + } + + err = writeMsg(conn, res) + if err != nil { + rs.Logger.Error("handleConnection", "err", err) + return + } + } +} diff --git a/privval/ipc_test.go b/privval/ipc_test.go new file mode 100644 index 000000000..c8d6dfc77 --- /dev/null +++ b/privval/ipc_test.go @@ -0,0 +1,147 @@ +package privval + +import ( + "io/ioutil" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/types" +) + +func TestIPCPVVote(t *testing.T) { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV()) + + ts = time.Now() + vType = types.PrecommitType + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer sc.Stop() + defer rs.Stop() + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) +} + +func TestIPCPVVoteResetDeadline(t *testing.T) { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV()) + + ts = time.Now() + vType = types.PrecommitType + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer sc.Stop() + defer rs.Stop() + + time.Sleep(3 * time.Millisecond) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) + + // This would exceed the deadline if it was not extended by the previous message + time.Sleep(3 * time.Millisecond) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) +} + +func TestIPCPVVoteKeepalive(t *testing.T) { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV()) + + ts = time.Now() + vType = types.PrecommitType + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer sc.Stop() + defer rs.Stop() + + time.Sleep(10 * time.Millisecond) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) +} + +func testSetupIPCSocketPair( + t *testing.T, + chainID string, + privValidator types.PrivValidator, +) (*IPCVal, *IPCRemoteSigner) { + addr, err := testUnixAddr() + require.NoError(t, err) + + var ( + logger = log.TestingLogger() + privVal = privValidator + readyc = make(chan struct{}) + rs = NewIPCRemoteSigner( + logger, + chainID, + addr, + privVal, + ) + sc = NewIPCVal( + logger, + addr, + ) + ) + + IPCValConnTimeout(5 * time.Millisecond)(sc) + IPCValHeartbeat(time.Millisecond)(sc) + + IPCRemoteSignerConnDeadline(time.Millisecond * 5)(rs) + + testStartIPCRemoteSigner(t, readyc, rs) + + <-readyc + + require.NoError(t, sc.Start()) + assert.True(t, sc.IsRunning()) + + return sc, rs +} + +func testStartIPCRemoteSigner(t *testing.T, readyc chan struct{}, rs *IPCRemoteSigner) { + go func(rs *IPCRemoteSigner) { + require.NoError(t, rs.Start()) + assert.True(t, rs.IsRunning()) + + readyc <- struct{}{} + }(rs) +} + +func testUnixAddr() (string, error) { + f, err := ioutil.TempFile("/tmp", "nettest") + if err != nil { + return "", err + } + + addr := f.Name() + err = f.Close() + if err != nil { + return "", err + } + err = os.Remove(addr) + if err != nil { + return "", err + } + + return addr, nil +} diff --git a/privval/remote_signer.go b/privval/remote_signer.go new file mode 100644 index 000000000..399ee7905 --- /dev/null +++ b/privval/remote_signer.go @@ -0,0 +1,303 @@ +package privval + +import ( + "fmt" + "io" + "net" + "sync" + + "github.com/tendermint/go-amino" + "github.com/tendermint/tendermint/crypto" + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/types" +) + +// RemoteSignerClient implements PrivValidator, it uses a socket to request signatures +// from an external process. +type RemoteSignerClient struct { + conn net.Conn + lock sync.Mutex +} + +// Check that RemoteSignerClient implements PrivValidator. +var _ types.PrivValidator = (*RemoteSignerClient)(nil) + +// NewRemoteSignerClient returns an instance of RemoteSignerClient. +func NewRemoteSignerClient( + conn net.Conn, +) *RemoteSignerClient { + sc := &RemoteSignerClient{ + conn: conn, + } + return sc +} + +// GetAddress implements PrivValidator. +func (sc *RemoteSignerClient) GetAddress() types.Address { + pubKey, err := sc.getPubKey() + if err != nil { + panic(err) + } + + return pubKey.Address() +} + +// GetPubKey implements PrivValidator. +func (sc *RemoteSignerClient) GetPubKey() crypto.PubKey { + pubKey, err := sc.getPubKey() + if err != nil { + panic(err) + } + + return pubKey +} + +func (sc *RemoteSignerClient) getPubKey() (crypto.PubKey, error) { + sc.lock.Lock() + defer sc.lock.Unlock() + + err := writeMsg(sc.conn, &PubKeyMsg{}) + if err != nil { + return nil, err + } + + res, err := readMsg(sc.conn) + if err != nil { + return nil, err + } + + return res.(*PubKeyMsg).PubKey, nil +} + +// SignVote implements PrivValidator. +func (sc *RemoteSignerClient) SignVote(chainID string, vote *types.Vote) error { + sc.lock.Lock() + defer sc.lock.Unlock() + + err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote}) + if err != nil { + return err + } + + res, err := readMsg(sc.conn) + if err != nil { + return err + } + + resp, ok := res.(*SignedVoteResponse) + if !ok { + return ErrUnexpectedResponse + } + if resp.Error != nil { + return resp.Error + } + *vote = *resp.Vote + + return nil +} + +// SignProposal implements PrivValidator. +func (sc *RemoteSignerClient) SignProposal( + chainID string, + proposal *types.Proposal, +) error { + sc.lock.Lock() + defer sc.lock.Unlock() + + err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal}) + if err != nil { + return err + } + + res, err := readMsg(sc.conn) + if err != nil { + return err + } + resp, ok := res.(*SignedProposalResponse) + if !ok { + return ErrUnexpectedResponse + } + if resp.Error != nil { + return resp.Error + } + *proposal = *resp.Proposal + + return nil +} + +// SignHeartbeat implements PrivValidator. +func (sc *RemoteSignerClient) SignHeartbeat( + chainID string, + heartbeat *types.Heartbeat, +) error { + sc.lock.Lock() + defer sc.lock.Unlock() + + err := writeMsg(sc.conn, &SignHeartbeatRequest{Heartbeat: heartbeat}) + if err != nil { + return err + } + + res, err := readMsg(sc.conn) + if err != nil { + return err + } + resp, ok := res.(*SignedHeartbeatResponse) + if !ok { + return ErrUnexpectedResponse + } + if resp.Error != nil { + return resp.Error + } + *heartbeat = *resp.Heartbeat + + return nil +} + +// Ping is used to check connection health. +func (sc *RemoteSignerClient) Ping() error { + sc.lock.Lock() + defer sc.lock.Unlock() + + err := writeMsg(sc.conn, &PingRequest{}) + if err != nil { + return err + } + + res, err := readMsg(sc.conn) + if err != nil { + return err + } + _, ok := res.(*PingResponse) + if !ok { + return ErrUnexpectedResponse + } + + return nil +} + +// RemoteSignerMsg is sent between RemoteSigner and the RemoteSigner client. +type RemoteSignerMsg interface{} + +func RegisterRemoteSignerMsg(cdc *amino.Codec) { + cdc.RegisterInterface((*RemoteSignerMsg)(nil), nil) + cdc.RegisterConcrete(&PubKeyMsg{}, "tendermint/remotesigner/PubKeyMsg", nil) + cdc.RegisterConcrete(&SignVoteRequest{}, "tendermint/remotesigner/SignVoteRequest", nil) + cdc.RegisterConcrete(&SignedVoteResponse{}, "tendermint/remotesigner/SignedVoteResponse", nil) + cdc.RegisterConcrete(&SignProposalRequest{}, "tendermint/remotesigner/SignProposalRequest", nil) + cdc.RegisterConcrete(&SignedProposalResponse{}, "tendermint/remotesigner/SignedProposalResponse", nil) + cdc.RegisterConcrete(&SignHeartbeatRequest{}, "tendermint/remotesigner/SignHeartbeatRequest", nil) + cdc.RegisterConcrete(&SignedHeartbeatResponse{}, "tendermint/remotesigner/SignedHeartbeatResponse", nil) + cdc.RegisterConcrete(&PingRequest{}, "tendermint/remotesigner/PingRequest", nil) + cdc.RegisterConcrete(&PingResponse{}, "tendermint/remotesigner/PingResponse", nil) +} + +// PubKeyMsg is a PrivValidatorSocket message containing the public key. +type PubKeyMsg struct { + PubKey crypto.PubKey +} + +// SignVoteRequest is a PrivValidatorSocket message containing a vote. +type SignVoteRequest struct { + Vote *types.Vote +} + +// SignedVoteResponse is a PrivValidatorSocket message containing a signed vote along with a potenial error message. +type SignedVoteResponse struct { + Vote *types.Vote + Error *RemoteSignerError +} + +// SignProposalRequest is a PrivValidatorSocket message containing a Proposal. +type SignProposalRequest struct { + Proposal *types.Proposal +} + +type SignedProposalResponse struct { + Proposal *types.Proposal + Error *RemoteSignerError +} + +// SignHeartbeatRequest is a PrivValidatorSocket message containing a Heartbeat. +type SignHeartbeatRequest struct { + Heartbeat *types.Heartbeat +} + +type SignedHeartbeatResponse struct { + Heartbeat *types.Heartbeat + Error *RemoteSignerError +} + +// PingRequest is a PrivValidatorSocket message to keep the connection alive. +type PingRequest struct { +} + +type PingResponse struct { +} + +// RemoteSignerError allows (remote) validators to include meaningful error descriptions in their reply. +type RemoteSignerError struct { + // TODO(ismail): create an enum of known errors + Code int + Description string +} + +func (e *RemoteSignerError) Error() string { + return fmt.Sprintf("RemoteSigner returned error #%d: %s", e.Code, e.Description) +} + +func readMsg(r io.Reader) (msg RemoteSignerMsg, err error) { + const maxRemoteSignerMsgSize = 1024 * 10 + _, err = cdc.UnmarshalBinaryReader(r, &msg, maxRemoteSignerMsgSize) + if _, ok := err.(timeoutError); ok { + err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) + } + return +} + +func writeMsg(w io.Writer, msg interface{}) (err error) { + _, err = cdc.MarshalBinaryWriter(w, msg) + if _, ok := err.(timeoutError); ok { + err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) + } + return +} + +func handleRequest(req RemoteSignerMsg, chainID string, privVal types.PrivValidator) (RemoteSignerMsg, error) { + var res RemoteSignerMsg + var err error + + switch r := req.(type) { + case *PubKeyMsg: + var p crypto.PubKey + p = privVal.GetPubKey() + res = &PubKeyMsg{p} + case *SignVoteRequest: + err = privVal.SignVote(chainID, r.Vote) + if err != nil { + res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}} + } else { + res = &SignedVoteResponse{r.Vote, nil} + } + case *SignProposalRequest: + err = privVal.SignProposal(chainID, r.Proposal) + if err != nil { + res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}} + } else { + res = &SignedProposalResponse{r.Proposal, nil} + } + case *SignHeartbeatRequest: + err = privVal.SignHeartbeat(chainID, r.Heartbeat) + if err != nil { + res = &SignedHeartbeatResponse{nil, &RemoteSignerError{0, err.Error()}} + } else { + res = &SignedHeartbeatResponse{r.Heartbeat, nil} + } + case *PingRequest: + res = &PingResponse{} + default: + err = fmt.Errorf("unknown msg: %v", r) + } + + return res, err +} diff --git a/privval/socket.go b/privval/socket.go deleted file mode 100644 index 64d4c46d0..000000000 --- a/privval/socket.go +++ /dev/null @@ -1,605 +0,0 @@ -package privval - -import ( - "errors" - "fmt" - "io" - "net" - "time" - - "github.com/tendermint/go-amino" - - "github.com/tendermint/tendermint/crypto" - "github.com/tendermint/tendermint/crypto/ed25519" - cmn "github.com/tendermint/tendermint/libs/common" - "github.com/tendermint/tendermint/libs/log" - p2pconn "github.com/tendermint/tendermint/p2p/conn" - "github.com/tendermint/tendermint/types" -) - -const ( - defaultAcceptDeadlineSeconds = 30 // tendermint waits this long for remote val to connect - defaultConnDeadlineSeconds = 3 // must be set before each read - defaultConnHeartBeatSeconds = 30 // tcp keep-alive period - defaultConnWaitSeconds = 60 // XXX: is this redundant with the accept deadline? - defaultDialRetries = 10 // try to connect to tendermint this many times -) - -// Socket errors. -var ( - ErrDialRetryMax = errors.New("dialed maximum retries") - ErrConnWaitTimeout = errors.New("waited for remote signer for too long") - ErrConnTimeout = errors.New("remote signer timed out") - ErrUnexpectedResponse = errors.New("received unexpected response") -) - -// SocketPVOption sets an optional parameter on the SocketPV. -type SocketPVOption func(*SocketPV) - -// SocketPVAcceptDeadline sets the deadline for the SocketPV listener. -// A zero time value disables the deadline. -func SocketPVAcceptDeadline(deadline time.Duration) SocketPVOption { - return func(sc *SocketPV) { sc.acceptDeadline = deadline } -} - -// SocketPVConnDeadline sets the read and write deadline for connections -// from external signing processes. -func SocketPVConnDeadline(deadline time.Duration) SocketPVOption { - return func(sc *SocketPV) { sc.connDeadline = deadline } -} - -// SocketPVHeartbeat sets the period on which to check the liveness of the -// connected Signer connections. -func SocketPVHeartbeat(period time.Duration) SocketPVOption { - return func(sc *SocketPV) { sc.connHeartbeat = period } -} - -// SocketPVConnWait sets the timeout duration before connection of external -// signing processes are considered to be unsuccessful. -func SocketPVConnWait(timeout time.Duration) SocketPVOption { - return func(sc *SocketPV) { sc.connWaitTimeout = timeout } -} - -// SocketPV implements PrivValidator, it uses a socket to request signatures -// from an external process. -type SocketPV struct { - cmn.BaseService - - addr string - acceptDeadline time.Duration - connDeadline time.Duration - connHeartbeat time.Duration - connWaitTimeout time.Duration - privKey ed25519.PrivKeyEd25519 - - conn net.Conn - listener net.Listener -} - -// Check that SocketPV implements PrivValidator. -var _ types.PrivValidator = (*SocketPV)(nil) - -// NewSocketPV returns an instance of SocketPV. -func NewSocketPV( - logger log.Logger, - socketAddr string, - privKey ed25519.PrivKeyEd25519, -) *SocketPV { - sc := &SocketPV{ - addr: socketAddr, - acceptDeadline: time.Second * defaultAcceptDeadlineSeconds, - connDeadline: time.Second * defaultConnDeadlineSeconds, - connHeartbeat: time.Second * defaultConnHeartBeatSeconds, - connWaitTimeout: time.Second * defaultConnWaitSeconds, - privKey: privKey, - } - - sc.BaseService = *cmn.NewBaseService(logger, "SocketPV", sc) - - return sc -} - -// GetAddress implements PrivValidator. -func (sc *SocketPV) GetAddress() types.Address { - addr, err := sc.getAddress() - if err != nil { - panic(err) - } - - return addr -} - -// Address is an alias for PubKey().Address(). -func (sc *SocketPV) getAddress() (cmn.HexBytes, error) { - p, err := sc.getPubKey() - if err != nil { - return nil, err - } - - return p.Address(), nil -} - -// GetPubKey implements PrivValidator. -func (sc *SocketPV) GetPubKey() crypto.PubKey { - pubKey, err := sc.getPubKey() - if err != nil { - panic(err) - } - - return pubKey -} - -func (sc *SocketPV) getPubKey() (crypto.PubKey, error) { - err := writeMsg(sc.conn, &PubKeyMsg{}) - if err != nil { - return nil, err - } - - res, err := readMsg(sc.conn) - if err != nil { - return nil, err - } - - return res.(*PubKeyMsg).PubKey, nil -} - -// SignVote implements PrivValidator. -func (sc *SocketPV) SignVote(chainID string, vote *types.Vote) error { - err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote}) - if err != nil { - return err - } - - res, err := readMsg(sc.conn) - if err != nil { - return err - } - - resp, ok := res.(*SignedVoteResponse) - if !ok { - return ErrUnexpectedResponse - } - if resp.Error != nil { - return fmt.Errorf("remote error occurred: code: %v, description: %s", - resp.Error.Code, - resp.Error.Description) - } - *vote = *resp.Vote - - return nil -} - -// SignProposal implements PrivValidator. -func (sc *SocketPV) SignProposal( - chainID string, - proposal *types.Proposal, -) error { - err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal}) - if err != nil { - return err - } - - res, err := readMsg(sc.conn) - if err != nil { - return err - } - resp, ok := res.(*SignedProposalResponse) - if !ok { - return ErrUnexpectedResponse - } - if resp.Error != nil { - return fmt.Errorf("remote error occurred: code: %v, description: %s", - resp.Error.Code, - resp.Error.Description) - } - *proposal = *resp.Proposal - - return nil -} - -// SignHeartbeat implements PrivValidator. -func (sc *SocketPV) SignHeartbeat( - chainID string, - heartbeat *types.Heartbeat, -) error { - err := writeMsg(sc.conn, &SignHeartbeatRequest{Heartbeat: heartbeat}) - if err != nil { - return err - } - - res, err := readMsg(sc.conn) - if err != nil { - return err - } - resp, ok := res.(*SignedHeartbeatResponse) - if !ok { - return ErrUnexpectedResponse - } - if resp.Error != nil { - return fmt.Errorf("remote error occurred: code: %v, description: %s", - resp.Error.Code, - resp.Error.Description) - } - *heartbeat = *resp.Heartbeat - - return nil -} - -// OnStart implements cmn.Service. -func (sc *SocketPV) OnStart() error { - if err := sc.listen(); err != nil { - err = cmn.ErrorWrap(err, "failed to listen") - sc.Logger.Error( - "OnStart", - "err", err, - ) - return err - } - - conn, err := sc.waitConnection() - if err != nil { - err = cmn.ErrorWrap(err, "failed to accept connection") - sc.Logger.Error( - "OnStart", - "err", err, - ) - - return err - } - - sc.conn = conn - - return nil -} - -// OnStop implements cmn.Service. -func (sc *SocketPV) OnStop() { - if sc.conn != nil { - if err := sc.conn.Close(); err != nil { - err = cmn.ErrorWrap(err, "failed to close connection") - sc.Logger.Error( - "OnStop", - "err", err, - ) - } - } - - if sc.listener != nil { - if err := sc.listener.Close(); err != nil { - err = cmn.ErrorWrap(err, "failed to close listener") - sc.Logger.Error( - "OnStop", - "err", err, - ) - } - } -} - -func (sc *SocketPV) acceptConnection() (net.Conn, error) { - conn, err := sc.listener.Accept() - if err != nil { - if !sc.IsRunning() { - return nil, nil // Ignore error from listener closing. - } - return nil, err - - } - - conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey) - if err != nil { - return nil, err - } - - return conn, nil -} - -func (sc *SocketPV) listen() error { - ln, err := net.Listen(cmn.ProtocolAndAddress(sc.addr)) - if err != nil { - return err - } - - sc.listener = newTCPTimeoutListener( - ln, - sc.acceptDeadline, - sc.connDeadline, - sc.connHeartbeat, - ) - - return nil -} - -// waitConnection uses the configured wait timeout to error if no external -// process connects in the time period. -func (sc *SocketPV) waitConnection() (net.Conn, error) { - var ( - connc = make(chan net.Conn, 1) - errc = make(chan error, 1) - ) - - go func(connc chan<- net.Conn, errc chan<- error) { - conn, err := sc.acceptConnection() - if err != nil { - errc <- err - return - } - - connc <- conn - }(connc, errc) - - select { - case conn := <-connc: - return conn, nil - case err := <-errc: - if _, ok := err.(timeoutError); ok { - return nil, cmn.ErrorWrap(ErrConnWaitTimeout, err.Error()) - } - return nil, err - case <-time.After(sc.connWaitTimeout): - return nil, ErrConnWaitTimeout - } -} - -//--------------------------------------------------------- - -// RemoteSignerOption sets an optional parameter on the RemoteSigner. -type RemoteSignerOption func(*RemoteSigner) - -// RemoteSignerConnDeadline sets the read and write deadline for connections -// from external signing processes. -func RemoteSignerConnDeadline(deadline time.Duration) RemoteSignerOption { - return func(ss *RemoteSigner) { ss.connDeadline = deadline } -} - -// RemoteSignerConnRetries sets the amount of attempted retries to connect. -func RemoteSignerConnRetries(retries int) RemoteSignerOption { - return func(ss *RemoteSigner) { ss.connRetries = retries } -} - -// RemoteSigner implements PrivValidator by dialing to a socket. -type RemoteSigner struct { - cmn.BaseService - - addr string - chainID string - connDeadline time.Duration - connRetries int - privKey ed25519.PrivKeyEd25519 - privVal types.PrivValidator - - conn net.Conn -} - -// NewRemoteSigner returns an instance of RemoteSigner. -func NewRemoteSigner( - logger log.Logger, - chainID, socketAddr string, - privVal types.PrivValidator, - privKey ed25519.PrivKeyEd25519, -) *RemoteSigner { - rs := &RemoteSigner{ - addr: socketAddr, - chainID: chainID, - connDeadline: time.Second * defaultConnDeadlineSeconds, - connRetries: defaultDialRetries, - privKey: privKey, - privVal: privVal, - } - - rs.BaseService = *cmn.NewBaseService(logger, "RemoteSigner", rs) - - return rs -} - -// OnStart implements cmn.Service. -func (rs *RemoteSigner) OnStart() error { - conn, err := rs.connect() - if err != nil { - err = cmn.ErrorWrap(err, "connect") - rs.Logger.Error("OnStart", "err", err) - return err - } - - go rs.handleConnection(conn) - - return nil -} - -// OnStop implements cmn.Service. -func (rs *RemoteSigner) OnStop() { - if rs.conn == nil { - return - } - - if err := rs.conn.Close(); err != nil { - rs.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed")) - } -} - -func (rs *RemoteSigner) connect() (net.Conn, error) { - for retries := rs.connRetries; retries > 0; retries-- { - // Don't sleep if it is the first retry. - if retries != rs.connRetries { - time.Sleep(rs.connDeadline) - } - - conn, err := cmn.Connect(rs.addr) - if err != nil { - err = cmn.ErrorWrap(err, "connection failed") - rs.Logger.Error( - "connect", - "addr", rs.addr, - "err", err, - ) - - continue - } - - if err := conn.SetDeadline(time.Now().Add(time.Second * defaultConnDeadlineSeconds)); err != nil { - err = cmn.ErrorWrap(err, "setting connection timeout failed") - rs.Logger.Error( - "connect", - "err", err, - ) - continue - } - - conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey) - if err != nil { - err = cmn.ErrorWrap(err, "encrypting connection failed") - rs.Logger.Error( - "connect", - "err", err, - ) - - continue - } - - return conn, nil - } - - return nil, ErrDialRetryMax -} - -func (rs *RemoteSigner) handleConnection(conn net.Conn) { - for { - if !rs.IsRunning() { - return // Ignore error from listener closing. - } - - req, err := readMsg(conn) - if err != nil { - if err != io.EOF { - rs.Logger.Error("handleConnection", "err", err) - } - return - } - - var res SocketPVMsg - - switch r := req.(type) { - case *PubKeyMsg: - var p crypto.PubKey - p = rs.privVal.GetPubKey() - res = &PubKeyMsg{p} - case *SignVoteRequest: - err = rs.privVal.SignVote(rs.chainID, r.Vote) - if err != nil { - res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}} - } else { - res = &SignedVoteResponse{r.Vote, nil} - } - case *SignProposalRequest: - err = rs.privVal.SignProposal(rs.chainID, r.Proposal) - if err != nil { - res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}} - } else { - res = &SignedProposalResponse{r.Proposal, nil} - } - case *SignHeartbeatRequest: - err = rs.privVal.SignHeartbeat(rs.chainID, r.Heartbeat) - if err != nil { - res = &SignedHeartbeatResponse{nil, &RemoteSignerError{0, err.Error()}} - } else { - res = &SignedHeartbeatResponse{r.Heartbeat, nil} - } - default: - err = fmt.Errorf("unknown msg: %v", r) - } - - if err != nil { - // only log the error; we'll reply with an error in res - rs.Logger.Error("handleConnection", "err", err) - } - - err = writeMsg(conn, res) - if err != nil { - rs.Logger.Error("handleConnection", "err", err) - return - } - } -} - -//--------------------------------------------------------- - -// SocketPVMsg is sent between RemoteSigner and SocketPV. -type SocketPVMsg interface{} - -func RegisterSocketPVMsg(cdc *amino.Codec) { - cdc.RegisterInterface((*SocketPVMsg)(nil), nil) - cdc.RegisterConcrete(&PubKeyMsg{}, "tendermint/socketpv/PubKeyMsg", nil) - cdc.RegisterConcrete(&SignVoteRequest{}, "tendermint/socketpv/SignVoteRequest", nil) - cdc.RegisterConcrete(&SignedVoteResponse{}, "tendermint/socketpv/SignedVoteResponse", nil) - cdc.RegisterConcrete(&SignProposalRequest{}, "tendermint/socketpv/SignProposalRequest", nil) - cdc.RegisterConcrete(&SignedProposalResponse{}, "tendermint/socketpv/SignedProposalResponse", nil) - cdc.RegisterConcrete(&SignHeartbeatRequest{}, "tendermint/socketpv/SignHeartbeatRequest", nil) - cdc.RegisterConcrete(&SignedHeartbeatResponse{}, "tendermint/socketpv/SignedHeartbeatResponse", nil) -} - -// PubKeyMsg is a PrivValidatorSocket message containing the public key. -type PubKeyMsg struct { - PubKey crypto.PubKey -} - -// SignVoteRequest is a PrivValidatorSocket message containing a vote. -type SignVoteRequest struct { - Vote *types.Vote -} - -// SignedVoteResponse is a PrivValidatorSocket message containing a signed vote along with a potenial error message. -type SignedVoteResponse struct { - Vote *types.Vote - Error *RemoteSignerError -} - -// SignProposalRequest is a PrivValidatorSocket message containing a Proposal. -type SignProposalRequest struct { - Proposal *types.Proposal -} - -type SignedProposalResponse struct { - Proposal *types.Proposal - Error *RemoteSignerError -} - -// SignHeartbeatRequest is a PrivValidatorSocket message containing a Heartbeat. -type SignHeartbeatRequest struct { - Heartbeat *types.Heartbeat -} - -type SignedHeartbeatResponse struct { - Heartbeat *types.Heartbeat - Error *RemoteSignerError -} - -// RemoteSignerError allows (remote) validators to include meaningful error descriptions in their reply. -type RemoteSignerError struct { - // TODO(ismail): create an enum of known errors - Code int - Description string -} - -func readMsg(r io.Reader) (msg SocketPVMsg, err error) { - const maxSocketPVMsgSize = 1024 * 10 - - // set deadline before trying to read - conn := r.(net.Conn) - if err := conn.SetDeadline(time.Now().Add(time.Second * defaultConnDeadlineSeconds)); err != nil { - err = cmn.ErrorWrap(err, "setting connection timeout failed in readMsg") - return msg, err - } - - _, err = cdc.UnmarshalBinaryReader(r, &msg, maxSocketPVMsgSize) - if _, ok := err.(timeoutError); ok { - err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) - } - return -} - -func writeMsg(w io.Writer, msg interface{}) (err error) { - _, err = cdc.MarshalBinaryWriter(w, msg) - if _, ok := err.(timeoutError); ok { - err = cmn.ErrorWrap(ErrConnTimeout, err.Error()) - } - return -} diff --git a/privval/tcp.go b/privval/tcp.go new file mode 100644 index 000000000..11bd833c0 --- /dev/null +++ b/privval/tcp.go @@ -0,0 +1,214 @@ +package privval + +import ( + "errors" + "net" + "time" + + "github.com/tendermint/tendermint/crypto/ed25519" + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + p2pconn "github.com/tendermint/tendermint/p2p/conn" + "github.com/tendermint/tendermint/types" +) + +const ( + defaultAcceptDeadlineSeconds = 3 + defaultConnDeadlineSeconds = 3 + defaultConnHeartBeatSeconds = 2 + defaultDialRetries = 10 +) + +// Socket errors. +var ( + ErrDialRetryMax = errors.New("dialed maximum retries") + ErrConnTimeout = errors.New("remote signer timed out") + ErrUnexpectedResponse = errors.New("received unexpected response") +) + +var ( + acceptDeadline = time.Second * defaultAcceptDeadlineSeconds + connTimeout = time.Second * defaultConnDeadlineSeconds + connHeartbeat = time.Second * defaultConnHeartBeatSeconds +) + +// TCPValOption sets an optional parameter on the SocketPV. +type TCPValOption func(*TCPVal) + +// TCPValAcceptDeadline sets the deadline for the TCPVal listener. +// A zero time value disables the deadline. +func TCPValAcceptDeadline(deadline time.Duration) TCPValOption { + return func(sc *TCPVal) { sc.acceptDeadline = deadline } +} + +// TCPValConnTimeout sets the read and write timeout for connections +// from external signing processes. +func TCPValConnTimeout(timeout time.Duration) TCPValOption { + return func(sc *TCPVal) { sc.connTimeout = timeout } +} + +// TCPValHeartbeat sets the period on which to check the liveness of the +// connected Signer connections. +func TCPValHeartbeat(period time.Duration) TCPValOption { + return func(sc *TCPVal) { sc.connHeartbeat = period } +} + +// TCPVal implements PrivValidator, it uses a socket to request signatures +// from an external process. +type TCPVal struct { + cmn.BaseService + *RemoteSignerClient + + addr string + acceptDeadline time.Duration + connTimeout time.Duration + connHeartbeat time.Duration + privKey ed25519.PrivKeyEd25519 + + conn net.Conn + listener net.Listener + cancelPing chan struct{} + pingTicker *time.Ticker +} + +// Check that TCPVal implements PrivValidator. +var _ types.PrivValidator = (*TCPVal)(nil) + +// NewTCPVal returns an instance of TCPVal. +func NewTCPVal( + logger log.Logger, + socketAddr string, + privKey ed25519.PrivKeyEd25519, +) *TCPVal { + sc := &TCPVal{ + addr: socketAddr, + acceptDeadline: acceptDeadline, + connTimeout: connTimeout, + connHeartbeat: connHeartbeat, + privKey: privKey, + } + + sc.BaseService = *cmn.NewBaseService(logger, "TCPVal", sc) + + return sc +} + +// OnStart implements cmn.Service. +func (sc *TCPVal) OnStart() error { + if err := sc.listen(); err != nil { + sc.Logger.Error("OnStart", "err", err) + return err + } + + conn, err := sc.waitConnection() + if err != nil { + sc.Logger.Error("OnStart", "err", err) + return err + } + + sc.conn = conn + + sc.RemoteSignerClient = NewRemoteSignerClient(sc.conn) + + // Start a routine to keep the connection alive + sc.cancelPing = make(chan struct{}, 1) + sc.pingTicker = time.NewTicker(sc.connHeartbeat) + go func() { + for { + select { + case <-sc.pingTicker.C: + err := sc.Ping() + if err != nil { + sc.Logger.Error( + "Ping", + "err", err, + ) + } + case <-sc.cancelPing: + sc.pingTicker.Stop() + return + } + } + }() + + return nil +} + +// OnStop implements cmn.Service. +func (sc *TCPVal) OnStop() { + if sc.cancelPing != nil { + close(sc.cancelPing) + } + + if sc.conn != nil { + if err := sc.conn.Close(); err != nil { + sc.Logger.Error("OnStop", "err", err) + } + } + + if sc.listener != nil { + if err := sc.listener.Close(); err != nil { + sc.Logger.Error("OnStop", "err", err) + } + } +} + +func (sc *TCPVal) acceptConnection() (net.Conn, error) { + conn, err := sc.listener.Accept() + if err != nil { + if !sc.IsRunning() { + return nil, nil // Ignore error from listener closing. + } + return nil, err + + } + + conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey) + if err != nil { + return nil, err + } + + return conn, nil +} + +func (sc *TCPVal) listen() error { + ln, err := net.Listen(cmn.ProtocolAndAddress(sc.addr)) + if err != nil { + return err + } + + sc.listener = newTCPTimeoutListener( + ln, + sc.acceptDeadline, + sc.connTimeout, + sc.connHeartbeat, + ) + + return nil +} + +// waitConnection uses the configured wait timeout to error if no external +// process connects in the time period. +func (sc *TCPVal) waitConnection() (net.Conn, error) { + var ( + connc = make(chan net.Conn, 1) + errc = make(chan error, 1) + ) + + go func(connc chan<- net.Conn, errc chan<- error) { + conn, err := sc.acceptConnection() + if err != nil { + errc <- err + return + } + + connc <- conn + }(connc, errc) + + select { + case conn := <-connc: + return conn, nil + case err := <-errc: + return nil, err + } +} diff --git a/privval/tcp_server.go b/privval/tcp_server.go new file mode 100644 index 000000000..694023d76 --- /dev/null +++ b/privval/tcp_server.go @@ -0,0 +1,160 @@ +package privval + +import ( + "io" + "net" + "time" + + "github.com/tendermint/tendermint/crypto/ed25519" + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + p2pconn "github.com/tendermint/tendermint/p2p/conn" + "github.com/tendermint/tendermint/types" +) + +// RemoteSignerOption sets an optional parameter on the RemoteSigner. +type RemoteSignerOption func(*RemoteSigner) + +// RemoteSignerConnDeadline sets the read and write deadline for connections +// from external signing processes. +func RemoteSignerConnDeadline(deadline time.Duration) RemoteSignerOption { + return func(ss *RemoteSigner) { ss.connDeadline = deadline } +} + +// RemoteSignerConnRetries sets the amount of attempted retries to connect. +func RemoteSignerConnRetries(retries int) RemoteSignerOption { + return func(ss *RemoteSigner) { ss.connRetries = retries } +} + +// RemoteSigner implements PrivValidator by dialing to a socket. +type RemoteSigner struct { + cmn.BaseService + + addr string + chainID string + connDeadline time.Duration + connRetries int + privKey ed25519.PrivKeyEd25519 + privVal types.PrivValidator + + conn net.Conn +} + +// NewRemoteSigner returns an instance of RemoteSigner. +func NewRemoteSigner( + logger log.Logger, + chainID, socketAddr string, + privVal types.PrivValidator, + privKey ed25519.PrivKeyEd25519, +) *RemoteSigner { + rs := &RemoteSigner{ + addr: socketAddr, + chainID: chainID, + connDeadline: time.Second * defaultConnDeadlineSeconds, + connRetries: defaultDialRetries, + privKey: privKey, + privVal: privVal, + } + + rs.BaseService = *cmn.NewBaseService(logger, "RemoteSigner", rs) + + return rs +} + +// OnStart implements cmn.Service. +func (rs *RemoteSigner) OnStart() error { + conn, err := rs.connect() + if err != nil { + rs.Logger.Error("OnStart", "err", err) + return err + } + + go rs.handleConnection(conn) + + return nil +} + +// OnStop implements cmn.Service. +func (rs *RemoteSigner) OnStop() { + if rs.conn == nil { + return + } + + if err := rs.conn.Close(); err != nil { + rs.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed")) + } +} + +func (rs *RemoteSigner) connect() (net.Conn, error) { + for retries := rs.connRetries; retries > 0; retries-- { + // Don't sleep if it is the first retry. + if retries != rs.connRetries { + time.Sleep(rs.connDeadline) + } + + conn, err := cmn.Connect(rs.addr) + if err != nil { + rs.Logger.Error( + "connect", + "addr", rs.addr, + "err", err, + ) + + continue + } + + if err := conn.SetDeadline(time.Now().Add(connTimeout)); err != nil { + rs.Logger.Error( + "connect", + "err", err, + ) + continue + } + + conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey) + if err != nil { + rs.Logger.Error( + "connect", + "err", err, + ) + + continue + } + + return conn, nil + } + + return nil, ErrDialRetryMax +} + +func (rs *RemoteSigner) handleConnection(conn net.Conn) { + for { + if !rs.IsRunning() { + return // Ignore error from listener closing. + } + + // Reset the connection deadline + conn.SetDeadline(time.Now().Add(rs.connDeadline)) + + req, err := readMsg(conn) + if err != nil { + if err != io.EOF { + rs.Logger.Error("handleConnection", "err", err) + } + return + } + + res, err := handleRequest(req, rs.chainID, rs.privVal) + + if err != nil { + // only log the error; we'll reply with an error in res + rs.Logger.Error("handleConnection", "err", err) + } + + err = writeMsg(conn, res) + if err != nil { + rs.Logger.Error("handleConnection", "err", err) + return + } + } +} diff --git a/privval/socket_tcp.go b/privval/tcp_socket.go similarity index 59% rename from privval/socket_tcp.go rename to privval/tcp_socket.go index b26db00c2..2b17bf26e 100644 --- a/privval/socket_tcp.go +++ b/privval/tcp_socket.go @@ -24,6 +24,13 @@ type tcpTimeoutListener struct { period time.Duration } +// timeoutConn wraps a net.Conn to standardise protocol timeouts / deadline resets. +type timeoutConn struct { + net.Conn + + connDeadline time.Duration +} + // newTCPTimeoutListener returns an instance of tcpTimeoutListener. func newTCPTimeoutListener( ln net.Listener, @@ -38,6 +45,16 @@ func newTCPTimeoutListener( } } +// newTimeoutConn returns an instance of newTCPTimeoutConn. +func newTimeoutConn( + conn net.Conn, + connDeadline time.Duration) *timeoutConn { + return &timeoutConn{ + conn, + connDeadline, + } +} + // Accept implements net.Listener. func (ln tcpTimeoutListener) Accept() (net.Conn, error) { err := ln.SetDeadline(time.Now().Add(ln.acceptDeadline)) @@ -50,17 +67,24 @@ func (ln tcpTimeoutListener) Accept() (net.Conn, error) { return nil, err } - if err := tc.SetDeadline(time.Now().Add(ln.connDeadline)); err != nil { - return nil, err - } + // Wrap the conn in our timeout wrapper + conn := newTimeoutConn(tc, ln.connDeadline) - if err := tc.SetKeepAlive(true); err != nil { - return nil, err - } + return conn, nil +} - if err := tc.SetKeepAlivePeriod(ln.period); err != nil { - return nil, err - } +// Read implements net.Listener. +func (c timeoutConn) Read(b []byte) (n int, err error) { + // Reset deadline + c.Conn.SetReadDeadline(time.Now().Add(c.connDeadline)) + + return c.Conn.Read(b) +} + +// Write implements net.Listener. +func (c timeoutConn) Write(b []byte) (n int, err error) { + // Reset deadline + c.Conn.SetWriteDeadline(time.Now().Add(c.connDeadline)) - return tc, nil + return c.Conn.Write(b) } diff --git a/privval/socket_tcp_test.go b/privval/tcp_socket_test.go similarity index 91% rename from privval/socket_tcp_test.go rename to privval/tcp_socket_test.go index 44a673c0c..285e73ed5 100644 --- a/privval/socket_tcp_test.go +++ b/privval/tcp_socket_test.go @@ -44,13 +44,14 @@ func TestTCPTimeoutListenerConnDeadline(t *testing.T) { time.Sleep(2 * time.Millisecond) - _, err = c.Write([]byte("foo")) + msg := make([]byte, 200) + _, err = c.Read(msg) opErr, ok := err.(*net.OpError) if !ok { t.Fatalf("have %v, want *net.OpError", err) } - if have, want := opErr.Op, "write"; have != want { + if have, want := opErr.Op, "read"; have != want { t.Errorf("have %v, want %v", have, want) } }(ln) diff --git a/privval/socket_test.go b/privval/tcp_test.go similarity index 82% rename from privval/socket_test.go rename to privval/tcp_test.go index aa2e15fa0..6549759d0 100644 --- a/privval/socket_test.go +++ b/privval/tcp_test.go @@ -27,8 +27,7 @@ func TestSocketPVAddress(t *testing.T) { serverAddr := rs.privVal.GetAddress() - clientAddr, err := sc.getAddress() - require.NoError(t, err) + clientAddr := sc.GetAddress() assert.Equal(t, serverAddr, clientAddr) @@ -91,52 +90,83 @@ func TestSocketPVVote(t *testing.T) { assert.Equal(t, want.Signature, have.Signature) } -func TestSocketPVHeartbeat(t *testing.T) { +func TestSocketPVVoteResetDeadline(t *testing.T) { var ( chainID = cmn.RandStr(12) sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV()) - want = &types.Heartbeat{} - have = &types.Heartbeat{} + ts = time.Now() + vType = types.PrecommitType + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} ) defer sc.Stop() defer rs.Stop() - require.NoError(t, rs.privVal.SignHeartbeat(chainID, want)) - require.NoError(t, sc.SignHeartbeat(chainID, have)) + time.Sleep(3 * time.Millisecond) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) + + // This would exceed the deadline if it was not extended by the previous message + time.Sleep(3 * time.Millisecond) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) assert.Equal(t, want.Signature, have.Signature) } -func TestSocketPVAcceptDeadline(t *testing.T) { +func TestSocketPVVoteKeepalive(t *testing.T) { var ( - sc = NewSocketPV( - log.TestingLogger(), - "127.0.0.1:0", - ed25519.GenPrivKey(), - ) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV()) + + ts = time.Now() + vType = types.PrecommitType + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} ) defer sc.Stop() + defer rs.Stop() + + time.Sleep(10 * time.Millisecond) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) +} + +func TestSocketPVHeartbeat(t *testing.T) { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV()) - SocketPVAcceptDeadline(time.Millisecond)(sc) + want = &types.Heartbeat{} + have = &types.Heartbeat{} + ) + defer sc.Stop() + defer rs.Stop() - assert.Equal(t, sc.Start().(cmn.Error).Data(), ErrConnWaitTimeout) + require.NoError(t, rs.privVal.SignHeartbeat(chainID, want)) + require.NoError(t, sc.SignHeartbeat(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) } func TestSocketPVDeadline(t *testing.T) { var ( addr = testFreeAddr(t) listenc = make(chan struct{}) - sc = NewSocketPV( + sc = NewTCPVal( log.TestingLogger(), addr, ed25519.GenPrivKey(), ) ) - SocketPVConnDeadline(100 * time.Millisecond)(sc) - SocketPVConnWait(500 * time.Millisecond)(sc) + TCPValConnTimeout(100 * time.Millisecond)(sc) - go func(sc *SocketPV) { + go func(sc *TCPVal) { defer close(listenc) require.NoError(t, sc.Start()) @@ -161,26 +191,10 @@ func TestSocketPVDeadline(t *testing.T) { <-listenc - // Sleep to guarantee deadline has been hit. - time.Sleep(20 * time.Microsecond) - _, err := sc.getPubKey() assert.Equal(t, err.(cmn.Error).Data(), ErrConnTimeout) } -func TestSocketPVWait(t *testing.T) { - sc := NewSocketPV( - log.TestingLogger(), - "127.0.0.1:0", - ed25519.GenPrivKey(), - ) - defer sc.Stop() - - SocketPVConnWait(time.Millisecond)(sc) - - assert.Equal(t, sc.Start().(cmn.Error).Data(), ErrConnWaitTimeout) -} - func TestRemoteSignerRetry(t *testing.T) { var ( attemptc = make(chan int) @@ -221,7 +235,7 @@ func TestRemoteSignerRetry(t *testing.T) { RemoteSignerConnDeadline(time.Millisecond)(rs) RemoteSignerConnRetries(retries)(rs) - assert.Equal(t, rs.Start().(cmn.Error).Data(), ErrDialRetryMax) + assert.Equal(t, rs.Start(), ErrDialRetryMax) select { case attempts := <-attemptc: @@ -328,7 +342,7 @@ func TestErrUnexpectedResponse(t *testing.T) { types.NewMockPV(), ed25519.GenPrivKey(), ) - sc = NewSocketPV( + sc = NewTCPVal( logger, addr, ed25519.GenPrivKey(), @@ -383,7 +397,7 @@ func testSetupSocketPair( t *testing.T, chainID string, privValidator types.PrivValidator, -) (*SocketPV, *RemoteSigner) { +) (*TCPVal, *RemoteSigner) { var ( addr = testFreeAddr(t) logger = log.TestingLogger() @@ -396,18 +410,20 @@ func testSetupSocketPair( privVal, ed25519.GenPrivKey(), ) - sc = NewSocketPV( + sc = NewTCPVal( logger, addr, ed25519.GenPrivKey(), ) ) - testStartSocketPV(t, readyc, sc) - - RemoteSignerConnDeadline(time.Millisecond)(rs) + TCPValConnTimeout(5 * time.Millisecond)(sc) + TCPValHeartbeat(2 * time.Millisecond)(sc) + RemoteSignerConnDeadline(5 * time.Millisecond)(rs) RemoteSignerConnRetries(1e6)(rs) + testStartSocketPV(t, readyc, sc) + require.NoError(t, rs.Start()) assert.True(t, rs.IsRunning()) @@ -416,7 +432,7 @@ func testSetupSocketPair( return sc, rs } -func testReadWriteResponse(t *testing.T, resp SocketPVMsg, rsConn net.Conn) { +func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) { _, err := readMsg(rsConn) require.NoError(t, err) @@ -424,8 +440,8 @@ func testReadWriteResponse(t *testing.T, resp SocketPVMsg, rsConn net.Conn) { require.NoError(t, err) } -func testStartSocketPV(t *testing.T, readyc chan struct{}, sc *SocketPV) { - go func(sc *SocketPV) { +func testStartSocketPV(t *testing.T, readyc chan struct{}, sc *TCPVal) { + go func(sc *TCPVal) { require.NoError(t, sc.Start()) assert.True(t, sc.IsRunning()) diff --git a/privval/wire.go b/privval/wire.go index 50660ff34..2e11677e4 100644 --- a/privval/wire.go +++ b/privval/wire.go @@ -9,5 +9,5 @@ var cdc = amino.NewCodec() func init() { cryptoAmino.RegisterAmino(cdc) - RegisterSocketPVMsg(cdc) + RegisterRemoteSignerMsg(cdc) }