From ca619c80b6e046244e724e64117cdf6127f7a28b Mon Sep 17 00:00:00 2001 From: Alexander Simmerl Date: Tue, 6 Mar 2018 16:11:17 +0100 Subject: [PATCH] Stop privVal socket client on node shutdown --- node/node.go | 7 ++++- types/priv_validator/socket.go | 40 ++++++++++++++--------------- types/priv_validator/socket_test.go | 2 +- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/node/node.go b/node/node.go index 2d48fe013..d40322fad 100644 --- a/node/node.go +++ b/node/node.go @@ -441,8 +441,13 @@ func (n *Node) OnStop() { } n.eventBus.Stop() - n.indexerService.Stop() + + if pvsc, ok := n.privValidator.(*priv_val.SocketClient); ok { + if err := pvsc.Stop(); err != nil { + n.Logger.Error("Error stopping priv validator socket client", "err", err) + } + } } // RunForever waits for an interrupt signal and stops the node. diff --git a/types/priv_validator/socket.go b/types/priv_validator/socket.go index 61e369866..c07a3014a 100644 --- a/types/priv_validator/socket.go +++ b/types/priv_validator/socket.go @@ -32,17 +32,17 @@ var ( ) // SocketClientOption sets an optional parameter on the SocketClient. -type SocketClientOption func(*socketClient) +type SocketClientOption func(*SocketClient) // SocketClientTimeout sets the timeout for connecting to the external socket // address. func SocketClientTimeout(timeout time.Duration) SocketClientOption { - return func(sc *socketClient) { sc.connectTimeout = timeout } + return func(sc *SocketClient) { sc.connectTimeout = timeout } } -// socketClient implements PrivValidator, it uses a socket to request signatures +// SocketClient implements PrivValidator, it uses a socket to request signatures // from an external process. -type socketClient struct { +type SocketClient struct { cmn.BaseService conn net.Conn @@ -52,28 +52,28 @@ type socketClient struct { connectTimeout time.Duration } -// Check that socketClient implements PrivValidator2. -var _ types.PrivValidator2 = (*socketClient)(nil) +// Check that SocketClient implements PrivValidator2. +var _ types.PrivValidator2 = (*SocketClient)(nil) -// NewSocketClient returns an instance of socketClient. +// NewSocketClient returns an instance of SocketClient. func NewSocketClient( logger log.Logger, socketAddr string, privKey *crypto.PrivKeyEd25519, -) *socketClient { - sc := &socketClient{ +) *SocketClient { + sc := &SocketClient{ addr: socketAddr, connectTimeout: time.Second * defaultConnDeadlineSeconds, privKey: privKey, } - sc.BaseService = *cmn.NewBaseService(logger, "privValidatorsocketClient", sc) + sc.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketClient", sc) return sc } // OnStart implements cmn.Service. -func (sc *socketClient) OnStart() error { +func (sc *SocketClient) OnStart() error { if err := sc.BaseService.OnStart(); err != nil { return err } @@ -89,7 +89,7 @@ func (sc *socketClient) OnStart() error { } // OnStop implements cmn.Service. -func (sc *socketClient) OnStop() { +func (sc *SocketClient) OnStop() { sc.BaseService.OnStop() if sc.conn != nil { @@ -99,7 +99,7 @@ func (sc *socketClient) OnStop() { // GetAddress implements PrivValidator. // TODO(xla): Remove when PrivValidator2 replaced PrivValidator. -func (sc *socketClient) GetAddress() types.Address { +func (sc *SocketClient) GetAddress() types.Address { addr, err := sc.Address() if err != nil { panic(err) @@ -109,7 +109,7 @@ func (sc *socketClient) GetAddress() types.Address { } // Address is an alias for PubKey().Address(). -func (sc *socketClient) Address() (cmn.HexBytes, error) { +func (sc *SocketClient) Address() (cmn.HexBytes, error) { p, err := sc.PubKey() if err != nil { return nil, err @@ -120,7 +120,7 @@ func (sc *socketClient) Address() (cmn.HexBytes, error) { // GetPubKey implements PrivValidator. // TODO(xla): Remove when PrivValidator2 replaced PrivValidator. -func (sc *socketClient) GetPubKey() crypto.PubKey { +func (sc *SocketClient) GetPubKey() crypto.PubKey { pubKey, err := sc.PubKey() if err != nil { panic(err) @@ -130,7 +130,7 @@ func (sc *socketClient) GetPubKey() crypto.PubKey { } // PubKey implements PrivValidator2. -func (sc *socketClient) PubKey() (crypto.PubKey, error) { +func (sc *SocketClient) PubKey() (crypto.PubKey, error) { err := writeMsg(sc.conn, &PubKeyMsg{}) if err != nil { return crypto.PubKey{}, err @@ -145,7 +145,7 @@ func (sc *socketClient) PubKey() (crypto.PubKey, error) { } // SignVote implements PrivValidator2. -func (sc *socketClient) SignVote(chainID string, vote *types.Vote) error { +func (sc *SocketClient) SignVote(chainID string, vote *types.Vote) error { err := writeMsg(sc.conn, &SignVoteMsg{Vote: vote}) if err != nil { return err @@ -162,7 +162,7 @@ func (sc *socketClient) SignVote(chainID string, vote *types.Vote) error { } // SignProposal implements PrivValidator2. -func (sc *socketClient) SignProposal(chainID string, proposal *types.Proposal) error { +func (sc *SocketClient) SignProposal(chainID string, proposal *types.Proposal) error { err := writeMsg(sc.conn, &SignProposalMsg{Proposal: proposal}) if err != nil { return err @@ -179,7 +179,7 @@ func (sc *socketClient) SignProposal(chainID string, proposal *types.Proposal) e } // SignHeartbeat implements PrivValidator2. -func (sc *socketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat) error { +func (sc *SocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat) error { err := writeMsg(sc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat}) if err != nil { return err @@ -195,7 +195,7 @@ func (sc *socketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat return nil } -func (sc *socketClient) connect() (net.Conn, error) { +func (sc *SocketClient) connect() (net.Conn, error) { retries := defaultDialRetryMax RETRY_LOOP: diff --git a/types/priv_validator/socket_test.go b/types/priv_validator/socket_test.go index d4928d3e2..d3d81580d 100644 --- a/types/priv_validator/socket_test.go +++ b/types/priv_validator/socket_test.go @@ -128,7 +128,7 @@ func TestSocketClientConnectRetryMax(t *testing.T) { assert.EqualError(sc.Start(), ErrDialRetryMax.Error()) } -func testSetupSocketPair(t *testing.T, chainID string) (*socketClient, *PrivValidatorSocketServer) { +func testSetupSocketPair(t *testing.T, chainID string) (*SocketClient, *PrivValidatorSocketServer) { var ( assert, require = assert.New(t), require.New(t) logger = log.TestingLogger()