Browse Source

privval: improve client shutdown to prevent resource leak (#7544)

wb/rollback-test-fix
Sam Kleinman 2 years ago
committed by GitHub
parent
commit
841629f5b7
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 56 additions and 24 deletions
  1. +4
    -4
      privval/retry_signer_client.go
  2. +13
    -8
      privval/signer_client.go
  3. +6
    -1
      privval/signer_client_test.go
  4. +13
    -2
      privval/signer_dialer_endpoint.go
  5. +4
    -1
      privval/signer_endpoint.go
  6. +7
    -7
      privval/signer_listener_endpoint.go
  7. +8
    -0
      privval/signer_listener_endpoint_test.go
  8. +1
    -1
      privval/signer_server.go

+ 4
- 4
privval/retry_signer_client.go View File

@ -34,15 +34,15 @@ func (sc *RetrySignerClient) IsConnected() bool {
return sc.next.IsConnected()
}
func (sc *RetrySignerClient) WaitForConnection(maxWait time.Duration) error {
return sc.next.WaitForConnection(maxWait)
func (sc *RetrySignerClient) WaitForConnection(ctx context.Context, maxWait time.Duration) error {
return sc.next.WaitForConnection(ctx, maxWait)
}
//--------------------------------------------------------
// Implement PrivValidator
func (sc *RetrySignerClient) Ping() error {
return sc.next.Ping()
func (sc *RetrySignerClient) Ping(ctx context.Context) error {
return sc.next.Ping(ctx)
}
func (sc *RetrySignerClient) GetPubKey(ctx context.Context) (crypto.PubKey, error) {


+ 13
- 8
privval/signer_client.go View File

@ -41,7 +41,12 @@ func NewSignerClient(ctx context.Context, endpoint *SignerListenerEndpoint, chai
// Close closes the underlying connection
func (sc *SignerClient) Close() error {
return sc.endpoint.Close()
err := sc.endpoint.Stop()
cerr := sc.endpoint.Close()
if err != nil {
return err
}
return cerr
}
// IsConnected indicates with the signer is connected to a remote signing service
@ -50,16 +55,16 @@ func (sc *SignerClient) IsConnected() bool {
}
// WaitForConnection waits maxWait for a connection or returns a timeout error
func (sc *SignerClient) WaitForConnection(maxWait time.Duration) error {
return sc.endpoint.WaitForConnection(maxWait)
func (sc *SignerClient) WaitForConnection(ctx context.Context, maxWait time.Duration) error {
return sc.endpoint.WaitForConnection(ctx, maxWait)
}
//--------------------------------------------------------
// Implement PrivValidator
// Ping sends a ping request to the remote signer
func (sc *SignerClient) Ping() error {
response, err := sc.endpoint.SendRequest(mustWrapMsg(&privvalproto.PingRequest{}))
func (sc *SignerClient) Ping(ctx context.Context) error {
response, err := sc.endpoint.SendRequest(ctx, mustWrapMsg(&privvalproto.PingRequest{}))
if err != nil {
sc.logger.Error("SignerClient::Ping", "err", err)
return nil
@ -76,7 +81,7 @@ func (sc *SignerClient) Ping() error {
// GetPubKey retrieves a public key from a remote signer
// returns an error if client is not able to provide the key
func (sc *SignerClient) GetPubKey(ctx context.Context) (crypto.PubKey, error) {
response, err := sc.endpoint.SendRequest(mustWrapMsg(&privvalproto.PubKeyRequest{ChainId: sc.chainID}))
response, err := sc.endpoint.SendRequest(ctx, mustWrapMsg(&privvalproto.PubKeyRequest{ChainId: sc.chainID}))
if err != nil {
return nil, fmt.Errorf("send: %w", err)
}
@ -99,7 +104,7 @@ func (sc *SignerClient) GetPubKey(ctx context.Context) (crypto.PubKey, error) {
// SignVote requests a remote signer to sign a vote
func (sc *SignerClient) SignVote(ctx context.Context, chainID string, vote *tmproto.Vote) error {
response, err := sc.endpoint.SendRequest(mustWrapMsg(&privvalproto.SignVoteRequest{Vote: vote, ChainId: chainID}))
response, err := sc.endpoint.SendRequest(ctx, mustWrapMsg(&privvalproto.SignVoteRequest{Vote: vote, ChainId: chainID}))
if err != nil {
return err
}
@ -119,7 +124,7 @@ func (sc *SignerClient) SignVote(ctx context.Context, chainID string, vote *tmpr
// SignProposal requests a remote signer to sign a proposal
func (sc *SignerClient) SignProposal(ctx context.Context, chainID string, proposal *tmproto.Proposal) error {
response, err := sc.endpoint.SendRequest(mustWrapMsg(
response, err := sc.endpoint.SendRequest(ctx, mustWrapMsg(
&privvalproto.SignProposalRequest{Proposal: proposal, ChainId: chainID},
))
if err != nil {


+ 6
- 1
privval/signer_client_test.go View File

@ -57,6 +57,7 @@ func getSignerTestCases(ctx context.Context, t *testing.T, logger log.Logger) []
signerServer: ss,
})
t.Cleanup(ss.Wait)
t.Cleanup(sc.endpoint.Wait)
}
return testCases
@ -72,10 +73,14 @@ func TestSignerClose(t *testing.T) {
for _, tc := range getSignerTestCases(bctx, t, logger) {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(leaktest.Check(t))
defer tc.closer()
assert.NoError(t, tc.signerClient.Close())
assert.NoError(t, tc.signerServer.Stop())
t.Cleanup(tc.signerClient.endpoint.Wait)
t.Cleanup(tc.signerServer.Wait)
})
}
}
@ -89,7 +94,7 @@ func TestSignerPing(t *testing.T) {
logger := log.NewTestingLogger(t)
for _, tc := range getSignerTestCases(ctx, t, logger) {
err := tc.signerClient.Ping()
err := tc.signerClient.Ping(ctx)
assert.NoError(t, err)
}
}


+ 13
- 2
privval/signer_dialer_endpoint.go View File

@ -74,20 +74,31 @@ func NewSignerDialerEndpoint(
func (sd *SignerDialerEndpoint) OnStart(context.Context) error { return nil }
func (sd *SignerDialerEndpoint) OnStop() {}
func (sd *SignerDialerEndpoint) ensureConnection() error {
func (sd *SignerDialerEndpoint) ensureConnection(ctx context.Context) error {
if sd.IsConnected() {
return nil
}
timer := time.NewTimer(0)
defer timer.Stop()
retries := 0
for retries < sd.maxConnRetries {
if err := ctx.Err(); err != nil {
return err
}
conn, err := sd.dialer()
if err != nil {
retries++
sd.logger.Debug("SignerDialer: Reconnection failed", "retries", retries, "max", sd.maxConnRetries, "err", err)
// Wait between retries
time.Sleep(sd.retryWait)
timer.Reset(sd.retryWait)
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
}
} else {
sd.SetConnection(conn)
sd.logger.Debug("SignerDialer: Connection Ready")


+ 4
- 1
privval/signer_endpoint.go View File

@ -1,6 +1,7 @@
package privval
import (
"context"
"fmt"
"net"
"sync"
@ -54,11 +55,13 @@ func (se *signerEndpoint) GetAvailableConnection(connectionAvailableCh chan net.
}
// TryGetConnection retrieves a connection if it is already available
func (se *signerEndpoint) WaitConnection(connectionAvailableCh chan net.Conn, maxWait time.Duration) error {
func (se *signerEndpoint) WaitConnection(ctx context.Context, connectionAvailableCh chan net.Conn, maxWait time.Duration) error {
se.connMtx.Lock()
defer se.connMtx.Unlock()
select {
case <-ctx.Done():
return ctx.Err()
case se.conn = <-connectionAvailableCh:
case <-time.After(maxWait):
return ErrConnectionTimeout


+ 7
- 7
privval/signer_listener_endpoint.go View File

@ -99,18 +99,18 @@ func (sl *SignerListenerEndpoint) OnStop() {
}
// WaitForConnection waits maxWait for a connection or returns a timeout error
func (sl *SignerListenerEndpoint) WaitForConnection(maxWait time.Duration) error {
func (sl *SignerListenerEndpoint) WaitForConnection(ctx context.Context, maxWait time.Duration) error {
sl.instanceMtx.Lock()
defer sl.instanceMtx.Unlock()
return sl.ensureConnection(maxWait)
return sl.ensureConnection(ctx, maxWait)
}
// SendRequest ensures there is a connection, sends a request and waits for a response
func (sl *SignerListenerEndpoint) SendRequest(request privvalproto.Message) (*privvalproto.Message, error) {
func (sl *SignerListenerEndpoint) SendRequest(ctx context.Context, request privvalproto.Message) (*privvalproto.Message, error) {
sl.instanceMtx.Lock()
defer sl.instanceMtx.Unlock()
err := sl.ensureConnection(sl.timeoutAccept)
err := sl.ensureConnection(ctx, sl.timeoutAccept)
if err != nil {
return nil, err
}
@ -131,7 +131,7 @@ func (sl *SignerListenerEndpoint) SendRequest(request privvalproto.Message) (*pr
return &res, nil
}
func (sl *SignerListenerEndpoint) ensureConnection(maxWait time.Duration) error {
func (sl *SignerListenerEndpoint) ensureConnection(ctx context.Context, maxWait time.Duration) error {
if sl.IsConnected() {
return nil
}
@ -144,7 +144,7 @@ func (sl *SignerListenerEndpoint) ensureConnection(maxWait time.Duration) error
// block until connected or timeout
sl.logger.Info("SignerListener: Blocking for connection")
sl.triggerConnect()
return sl.WaitConnection(sl.connectionAvailableCh, maxWait)
return sl.WaitConnection(ctx, sl.connectionAvailableCh, maxWait)
}
func (sl *SignerListenerEndpoint) acceptNewConnection() (net.Conn, error) {
@ -207,7 +207,7 @@ func (sl *SignerListenerEndpoint) pingLoop(ctx context.Context) {
select {
case <-sl.pingTimer.C:
{
_, err := sl.SendRequest(mustWrapMsg(&privvalproto.PingRequest{}))
_, err := sl.SendRequest(ctx, mustWrapMsg(&privvalproto.PingRequest{}))
if err != nil {
sl.logger.Error("SignerListener: Ping timeout")
sl.triggerReconnect()


+ 8
- 0
privval/signer_listener_endpoint_test.go View File

@ -6,6 +6,7 @@ import (
"testing"
"time"
"github.com/fortytw2/leaktest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -89,6 +90,8 @@ func TestSignerRemoteRetryTCPOnly(t *testing.T) {
}
func TestRetryConnToRemoteSigner(t *testing.T) {
t.Cleanup(leaktest.Check(t))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -102,6 +105,7 @@ func TestRetryConnToRemoteSigner(t *testing.T) {
thisConnTimeout = testTimeoutReadWrite
listenerEndpoint = newSignerListenerEndpoint(t, logger, tc.addr, thisConnTimeout)
)
t.Cleanup(listenerEndpoint.Wait)
dialerEndpoint := NewSignerDialerEndpoint(
logger,
@ -116,6 +120,8 @@ func TestRetryConnToRemoteSigner(t *testing.T) {
require.NoError(t, signerServer.Start(ctx))
assert.True(t, signerServer.IsRunning())
t.Cleanup(signerServer.Wait)
<-endpointIsOpenCh
if err := signerServer.Stop(); err != nil {
t.Error(err)
@ -130,6 +136,8 @@ func TestRetryConnToRemoteSigner(t *testing.T) {
// let some pings pass
require.NoError(t, signerServer2.Start(ctx))
assert.True(t, signerServer2.IsRunning())
t.Cleanup(signerServer2.Wait)
t.Cleanup(func() { _ = signerServer2.Stop() })
// give the client some time to re-establish the conn to the remote signer
// should see sth like this in the logs:


+ 1
- 1
privval/signer_server.go View File

@ -97,7 +97,7 @@ func (ss *SignerServer) serviceLoop(ctx context.Context) {
case <-ctx.Done():
return
default:
if err := ss.endpoint.ensureConnection(); err != nil {
if err := ss.endpoint.ensureConnection(ctx); err != nil {
return
}
ss.servicePendingRequest(ctx)


Loading…
Cancel
Save