diff --git a/privval/signer_dialer_endpoint.go b/privval/signer_dialer_endpoint.go index e9bf80234..5dbf6fe7e 100644 --- a/privval/signer_dialer_endpoint.go +++ b/privval/signer_dialer_endpoint.go @@ -104,13 +104,7 @@ func (sd *SignerDialerEndpoint) Close() error { defer sd.mtx.Unlock() sd.Logger.Debug("SignerDialerEndpoint: Close") - if sd.conn != nil { - if err := sd.conn.Close(); err != nil { - sd.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed")) - sd.conn = nil - } - } - + sd.dropConnection() return nil } @@ -121,7 +115,35 @@ func (sd *SignerDialerEndpoint) IsConnected() bool { return sd.isConnected() } -// IsConnected indicates if there is an active connection +func (sd *SignerDialerEndpoint) handleRequest() { + if !sd.IsRunning() { + return // Ignore error from listener closing. + } + + sd.Logger.Info("SignerDialerEndpoint: connected", "timeout", sd.timeoutReadWrite) + + req, err := sd.readMessage() + if err != nil { + if err != io.EOF { + sd.Logger.Error("SignerDialerEndpoint handleMessage", "err", err) + } + return + } + + res, err := HandleValidatorRequest(req, sd.chainID, sd.privVal) + + if err != nil { + // only log the error; we'll reply with an error in res + sd.Logger.Error("handleMessage handleMessage", "err", err) + } + + err = sd.writeMessage(res) + if err != nil { + sd.Logger.Error("handleMessage writeMessage", "err", err) + return + } +} + func (sd *SignerDialerEndpoint) isConnected() bool { return sd.IsRunning() && sd.conn != nil } @@ -176,32 +198,12 @@ func (sd *SignerDialerEndpoint) writeMessage(msg RemoteSignerMsg) (err error) { return } -func (sd *SignerDialerEndpoint) handleRequest() { - if !sd.IsRunning() { - return // Ignore error from listener closing. - } - - sd.Logger.Info("SignerDialerEndpoint: connected", "timeout", sd.timeoutReadWrite) - - req, err := sd.readMessage() - if err != nil { - if err != io.EOF { - sd.Logger.Error("SignerDialerEndpoint handleMessage", "err", err) +func (sd *SignerDialerEndpoint) dropConnection() { + if sd.conn != nil { + if err := sd.conn.Close(); err != nil { + sd.Logger.Error("SignerDialerEndpoint::dropConnection", "err", err) } - return - } - - res, err := HandleValidatorRequest(req, sd.chainID, sd.privVal) - - if err != nil { - // only log the error; we'll reply with an error in res - sd.Logger.Error("handleMessage handleMessage", "err", err) - } - - err = sd.writeMessage(res) - if err != nil { - sd.Logger.Error("handleMessage writeMessage", "err", err) - return + sd.conn = nil } } diff --git a/privval/signer_listener_endpoint.go b/privval/signer_listener_endpoint.go index 8ce214689..0d411c185 100644 --- a/privval/signer_listener_endpoint.go +++ b/privval/signer_listener_endpoint.go @@ -229,6 +229,24 @@ func (sl *SignerListenerEndpoint) ensureConnection(maxWait time.Duration) error return nil } +func (sl *SignerListenerEndpoint) acceptNewConnection() (net.Conn, error) { + sl.Logger.Debug("SignerListenerEndpoint: AcceptNewConnection") + + if !sl.IsRunning() || sl.listener == nil { + return nil, fmt.Errorf("endpoint is closing") + } + + // wait for a new conn + conn, err := sl.listener.Accept() + if err != nil { + sl.Logger.Debug("listener accept failed", "err", err) + return nil, err + } + + sl.Logger.Info("SignerListenerEndpoint: New connection") + return conn, nil +} + func (sl *SignerListenerEndpoint) dropConnection() { if sl.conn != nil { if err := sl.conn.Close(); err != nil { @@ -275,21 +293,3 @@ func (sl *SignerListenerEndpoint) serviceLoop() { } } } - -func (sl *SignerListenerEndpoint) acceptNewConnection() (net.Conn, error) { - sl.Logger.Debug("SignerListenerEndpoint: AcceptNewConnection") - - if !sl.IsRunning() || sl.listener == nil { - return nil, fmt.Errorf("endpoint is closing") - } - - // wait for a new conn - conn, err := sl.listener.Accept() - if err != nil { - sl.Logger.Debug("listener accept failed", "err", err) - return nil, err - } - - sl.Logger.Info("SignerListenerEndpoint: New connection") - return conn, nil -}