From 841629f5b7223825601263002f8ec92214db4489 Mon Sep 17 00:00:00 2001 From: Sam Kleinman Date: Tue, 11 Jan 2022 15:09:19 -0500 Subject: [PATCH] privval: improve client shutdown to prevent resource leak (#7544) --- privval/retry_signer_client.go | 8 ++++---- privval/signer_client.go | 21 +++++++++++++-------- privval/signer_client_test.go | 7 ++++++- privval/signer_dialer_endpoint.go | 15 +++++++++++++-- privval/signer_endpoint.go | 5 ++++- privval/signer_listener_endpoint.go | 14 +++++++------- privval/signer_listener_endpoint_test.go | 8 ++++++++ privval/signer_server.go | 2 +- 8 files changed, 56 insertions(+), 24 deletions(-) diff --git a/privval/retry_signer_client.go b/privval/retry_signer_client.go index ccd9834e4..6dacc9a28 100644 --- a/privval/retry_signer_client.go +++ b/privval/retry_signer_client.go @@ -34,15 +34,15 @@ func (sc *RetrySignerClient) IsConnected() bool { return sc.next.IsConnected() } -func (sc *RetrySignerClient) WaitForConnection(maxWait time.Duration) error { - return sc.next.WaitForConnection(maxWait) +func (sc *RetrySignerClient) WaitForConnection(ctx context.Context, maxWait time.Duration) error { + return sc.next.WaitForConnection(ctx, maxWait) } //-------------------------------------------------------- // Implement PrivValidator -func (sc *RetrySignerClient) Ping() error { - return sc.next.Ping() +func (sc *RetrySignerClient) Ping(ctx context.Context) error { + return sc.next.Ping(ctx) } func (sc *RetrySignerClient) GetPubKey(ctx context.Context) (crypto.PubKey, error) { diff --git a/privval/signer_client.go b/privval/signer_client.go index 981b4e175..3247a74b7 100644 --- a/privval/signer_client.go +++ b/privval/signer_client.go @@ -41,7 +41,12 @@ func NewSignerClient(ctx context.Context, endpoint *SignerListenerEndpoint, chai // Close closes the underlying connection func (sc *SignerClient) Close() error { - return sc.endpoint.Close() + err := sc.endpoint.Stop() + cerr := sc.endpoint.Close() + if err != nil { + return err + } + return cerr } // IsConnected indicates with the signer is connected to a remote signing service @@ -50,16 +55,16 @@ func (sc *SignerClient) IsConnected() bool { } // WaitForConnection waits maxWait for a connection or returns a timeout error -func (sc *SignerClient) WaitForConnection(maxWait time.Duration) error { - return sc.endpoint.WaitForConnection(maxWait) +func (sc *SignerClient) WaitForConnection(ctx context.Context, maxWait time.Duration) error { + return sc.endpoint.WaitForConnection(ctx, maxWait) } //-------------------------------------------------------- // Implement PrivValidator // Ping sends a ping request to the remote signer -func (sc *SignerClient) Ping() error { - response, err := sc.endpoint.SendRequest(mustWrapMsg(&privvalproto.PingRequest{})) +func (sc *SignerClient) Ping(ctx context.Context) error { + response, err := sc.endpoint.SendRequest(ctx, mustWrapMsg(&privvalproto.PingRequest{})) if err != nil { sc.logger.Error("SignerClient::Ping", "err", err) return nil @@ -76,7 +81,7 @@ func (sc *SignerClient) Ping() error { // GetPubKey retrieves a public key from a remote signer // returns an error if client is not able to provide the key func (sc *SignerClient) GetPubKey(ctx context.Context) (crypto.PubKey, error) { - response, err := sc.endpoint.SendRequest(mustWrapMsg(&privvalproto.PubKeyRequest{ChainId: sc.chainID})) + response, err := sc.endpoint.SendRequest(ctx, mustWrapMsg(&privvalproto.PubKeyRequest{ChainId: sc.chainID})) if err != nil { return nil, fmt.Errorf("send: %w", err) } @@ -99,7 +104,7 @@ func (sc *SignerClient) GetPubKey(ctx context.Context) (crypto.PubKey, error) { // SignVote requests a remote signer to sign a vote func (sc *SignerClient) SignVote(ctx context.Context, chainID string, vote *tmproto.Vote) error { - response, err := sc.endpoint.SendRequest(mustWrapMsg(&privvalproto.SignVoteRequest{Vote: vote, ChainId: chainID})) + response, err := sc.endpoint.SendRequest(ctx, mustWrapMsg(&privvalproto.SignVoteRequest{Vote: vote, ChainId: chainID})) if err != nil { return err } @@ -119,7 +124,7 @@ func (sc *SignerClient) SignVote(ctx context.Context, chainID string, vote *tmpr // SignProposal requests a remote signer to sign a proposal func (sc *SignerClient) SignProposal(ctx context.Context, chainID string, proposal *tmproto.Proposal) error { - response, err := sc.endpoint.SendRequest(mustWrapMsg( + response, err := sc.endpoint.SendRequest(ctx, mustWrapMsg( &privvalproto.SignProposalRequest{Proposal: proposal, ChainId: chainID}, )) if err != nil { diff --git a/privval/signer_client_test.go b/privval/signer_client_test.go index a7cddbd6e..6f90095c1 100644 --- a/privval/signer_client_test.go +++ b/privval/signer_client_test.go @@ -57,6 +57,7 @@ func getSignerTestCases(ctx context.Context, t *testing.T, logger log.Logger) [] signerServer: ss, }) t.Cleanup(ss.Wait) + t.Cleanup(sc.endpoint.Wait) } return testCases @@ -72,10 +73,14 @@ func TestSignerClose(t *testing.T) { for _, tc := range getSignerTestCases(bctx, t, logger) { t.Run(tc.name, func(t *testing.T) { + t.Cleanup(leaktest.Check(t)) + defer tc.closer() assert.NoError(t, tc.signerClient.Close()) assert.NoError(t, tc.signerServer.Stop()) + t.Cleanup(tc.signerClient.endpoint.Wait) + t.Cleanup(tc.signerServer.Wait) }) } } @@ -89,7 +94,7 @@ func TestSignerPing(t *testing.T) { logger := log.NewTestingLogger(t) for _, tc := range getSignerTestCases(ctx, t, logger) { - err := tc.signerClient.Ping() + err := tc.signerClient.Ping(ctx) assert.NoError(t, err) } } diff --git a/privval/signer_dialer_endpoint.go b/privval/signer_dialer_endpoint.go index 76b3bd501..b291a7ef5 100644 --- a/privval/signer_dialer_endpoint.go +++ b/privval/signer_dialer_endpoint.go @@ -74,20 +74,31 @@ func NewSignerDialerEndpoint( func (sd *SignerDialerEndpoint) OnStart(context.Context) error { return nil } func (sd *SignerDialerEndpoint) OnStop() {} -func (sd *SignerDialerEndpoint) ensureConnection() error { +func (sd *SignerDialerEndpoint) ensureConnection(ctx context.Context) error { if sd.IsConnected() { return nil } + timer := time.NewTimer(0) + defer timer.Stop() retries := 0 for retries < sd.maxConnRetries { + if err := ctx.Err(); err != nil { + return err + } conn, err := sd.dialer() if err != nil { retries++ sd.logger.Debug("SignerDialer: Reconnection failed", "retries", retries, "max", sd.maxConnRetries, "err", err) + // Wait between retries - time.Sleep(sd.retryWait) + timer.Reset(sd.retryWait) + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + } } else { sd.SetConnection(conn) sd.logger.Debug("SignerDialer: Connection Ready") diff --git a/privval/signer_endpoint.go b/privval/signer_endpoint.go index 5cf4f7be7..8810bdf85 100644 --- a/privval/signer_endpoint.go +++ b/privval/signer_endpoint.go @@ -1,6 +1,7 @@ package privval import ( + "context" "fmt" "net" "sync" @@ -54,11 +55,13 @@ func (se *signerEndpoint) GetAvailableConnection(connectionAvailableCh chan net. } // TryGetConnection retrieves a connection if it is already available -func (se *signerEndpoint) WaitConnection(connectionAvailableCh chan net.Conn, maxWait time.Duration) error { +func (se *signerEndpoint) WaitConnection(ctx context.Context, connectionAvailableCh chan net.Conn, maxWait time.Duration) error { se.connMtx.Lock() defer se.connMtx.Unlock() select { + case <-ctx.Done(): + return ctx.Err() case se.conn = <-connectionAvailableCh: case <-time.After(maxWait): return ErrConnectionTimeout diff --git a/privval/signer_listener_endpoint.go b/privval/signer_listener_endpoint.go index ff2c0b7c2..12c915973 100644 --- a/privval/signer_listener_endpoint.go +++ b/privval/signer_listener_endpoint.go @@ -99,18 +99,18 @@ func (sl *SignerListenerEndpoint) OnStop() { } // WaitForConnection waits maxWait for a connection or returns a timeout error -func (sl *SignerListenerEndpoint) WaitForConnection(maxWait time.Duration) error { +func (sl *SignerListenerEndpoint) WaitForConnection(ctx context.Context, maxWait time.Duration) error { sl.instanceMtx.Lock() defer sl.instanceMtx.Unlock() - return sl.ensureConnection(maxWait) + return sl.ensureConnection(ctx, maxWait) } // SendRequest ensures there is a connection, sends a request and waits for a response -func (sl *SignerListenerEndpoint) SendRequest(request privvalproto.Message) (*privvalproto.Message, error) { +func (sl *SignerListenerEndpoint) SendRequest(ctx context.Context, request privvalproto.Message) (*privvalproto.Message, error) { sl.instanceMtx.Lock() defer sl.instanceMtx.Unlock() - err := sl.ensureConnection(sl.timeoutAccept) + err := sl.ensureConnection(ctx, sl.timeoutAccept) if err != nil { return nil, err } @@ -131,7 +131,7 @@ func (sl *SignerListenerEndpoint) SendRequest(request privvalproto.Message) (*pr return &res, nil } -func (sl *SignerListenerEndpoint) ensureConnection(maxWait time.Duration) error { +func (sl *SignerListenerEndpoint) ensureConnection(ctx context.Context, maxWait time.Duration) error { if sl.IsConnected() { return nil } @@ -144,7 +144,7 @@ func (sl *SignerListenerEndpoint) ensureConnection(maxWait time.Duration) error // block until connected or timeout sl.logger.Info("SignerListener: Blocking for connection") sl.triggerConnect() - return sl.WaitConnection(sl.connectionAvailableCh, maxWait) + return sl.WaitConnection(ctx, sl.connectionAvailableCh, maxWait) } func (sl *SignerListenerEndpoint) acceptNewConnection() (net.Conn, error) { @@ -207,7 +207,7 @@ func (sl *SignerListenerEndpoint) pingLoop(ctx context.Context) { select { case <-sl.pingTimer.C: { - _, err := sl.SendRequest(mustWrapMsg(&privvalproto.PingRequest{})) + _, err := sl.SendRequest(ctx, mustWrapMsg(&privvalproto.PingRequest{})) if err != nil { sl.logger.Error("SignerListener: Ping timeout") sl.triggerReconnect() diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index 47fe812c9..148c6acfb 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/fortytw2/leaktest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -89,6 +90,8 @@ func TestSignerRemoteRetryTCPOnly(t *testing.T) { } func TestRetryConnToRemoteSigner(t *testing.T) { + t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -102,6 +105,7 @@ func TestRetryConnToRemoteSigner(t *testing.T) { thisConnTimeout = testTimeoutReadWrite listenerEndpoint = newSignerListenerEndpoint(t, logger, tc.addr, thisConnTimeout) ) + t.Cleanup(listenerEndpoint.Wait) dialerEndpoint := NewSignerDialerEndpoint( logger, @@ -116,6 +120,8 @@ func TestRetryConnToRemoteSigner(t *testing.T) { require.NoError(t, signerServer.Start(ctx)) assert.True(t, signerServer.IsRunning()) + t.Cleanup(signerServer.Wait) + <-endpointIsOpenCh if err := signerServer.Stop(); err != nil { t.Error(err) @@ -130,6 +136,8 @@ func TestRetryConnToRemoteSigner(t *testing.T) { // let some pings pass require.NoError(t, signerServer2.Start(ctx)) assert.True(t, signerServer2.IsRunning()) + t.Cleanup(signerServer2.Wait) + t.Cleanup(func() { _ = signerServer2.Stop() }) // give the client some time to re-establish the conn to the remote signer // should see sth like this in the logs: diff --git a/privval/signer_server.go b/privval/signer_server.go index e98d78b75..4945b8150 100644 --- a/privval/signer_server.go +++ b/privval/signer_server.go @@ -97,7 +97,7 @@ func (ss *SignerServer) serviceLoop(ctx context.Context) { case <-ctx.Done(): return default: - if err := ss.endpoint.ensureConnection(); err != nil { + if err := ss.endpoint.ensureConnection(ctx); err != nil { return } ss.servicePendingRequest(ctx)