You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

220 lines
5.5 KiB

package privval
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/libs/service"
privvalproto "github.com/tendermint/tendermint/proto/tendermint/privval"
)
// SignerListenerEndpointOption sets an optional parameter on the SignerListenerEndpoint.
type SignerListenerEndpointOption func(*SignerListenerEndpoint)
// SignerListenerEndpointTimeoutReadWrite sets the read and write timeout for
// connections from external signing processes.
//
// Default: 5s
func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListenerEndpointOption {
return func(sl *SignerListenerEndpoint) { sl.signerEndpoint.timeoutReadWrite = timeout }
}
// SignerListenerEndpoint listens for an external process to dial in and keeps
// the connection alive by dropping and reconnecting.
//
// The process will send pings every ~3s (read/write timeout * 2/3) to keep the
// connection alive.
type SignerListenerEndpoint struct {
signerEndpoint
listener net.Listener
connectRequestCh chan struct{}
connectionAvailableCh chan net.Conn
timeoutAccept time.Duration
pingTimer *time.Ticker
pingInterval time.Duration
instanceMtx sync.Mutex // Ensures instance public methods access, i.e. SendRequest
}
// NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint.
func NewSignerListenerEndpoint(
logger log.Logger,
listener net.Listener,
options ...SignerListenerEndpointOption,
) *SignerListenerEndpoint {
sl := &SignerListenerEndpoint{
listener: listener,
timeoutAccept: defaultTimeoutAcceptSeconds * time.Second,
}
sl.signerEndpoint.logger = logger
sl.BaseService = *service.NewBaseService(logger, "SignerListenerEndpoint", sl)
sl.signerEndpoint.timeoutReadWrite = defaultTimeoutReadWriteSeconds * time.Second
for _, optionFunc := range options {
optionFunc(sl)
}
return sl
}
// OnStart implements service.Service.
func (sl *SignerListenerEndpoint) OnStart(ctx context.Context) error {
sl.connectRequestCh = make(chan struct{})
sl.connectionAvailableCh = make(chan net.Conn)
// NOTE: ping timeout must be less than read/write timeout
sl.pingInterval = time.Duration(sl.signerEndpoint.timeoutReadWrite.Milliseconds()*2/3) * time.Millisecond
sl.pingTimer = time.NewTicker(sl.pingInterval)
go sl.serviceLoop(ctx)
go sl.pingLoop(ctx)
sl.connectRequestCh <- struct{}{}
return nil
}
// OnStop implements service.Service
func (sl *SignerListenerEndpoint) OnStop() {
sl.instanceMtx.Lock()
defer sl.instanceMtx.Unlock()
_ = sl.Close()
// Stop listening
if sl.listener != nil {
if err := sl.listener.Close(); err != nil {
sl.logger.Error("Closing Listener", "err", err)
sl.listener = nil
}
}
sl.pingTimer.Stop()
}
// WaitForConnection waits maxWait for a connection or returns a timeout error
func (sl *SignerListenerEndpoint) WaitForConnection(ctx context.Context, maxWait time.Duration) error {
sl.instanceMtx.Lock()
defer sl.instanceMtx.Unlock()
return sl.ensureConnection(ctx, maxWait)
}
// SendRequest ensures there is a connection, sends a request and waits for a response
func (sl *SignerListenerEndpoint) SendRequest(ctx context.Context, request privvalproto.Message) (*privvalproto.Message, error) {
sl.instanceMtx.Lock()
defer sl.instanceMtx.Unlock()
err := sl.ensureConnection(ctx, sl.timeoutAccept)
if err != nil {
return nil, err
}
err = sl.WriteMessage(request)
if err != nil {
return nil, err
}
res, err := sl.ReadMessage()
if err != nil {
return nil, err
}
// Reset pingTimer to avoid sending unnecessary pings.
sl.pingTimer.Reset(sl.pingInterval)
return &res, nil
}
func (sl *SignerListenerEndpoint) ensureConnection(ctx context.Context, maxWait time.Duration) error {
if sl.IsConnected() {
return nil
}
// Is there a connection ready? then use it
if sl.GetAvailableConnection(sl.connectionAvailableCh) {
return nil
}
// block until connected or timeout
sl.logger.Info("SignerListener: Blocking for connection")
sl.triggerConnect()
return sl.WaitConnection(ctx, sl.connectionAvailableCh, maxWait)
}
func (sl *SignerListenerEndpoint) acceptNewConnection() (net.Conn, error) {
if !sl.IsRunning() || sl.listener == nil {
return nil, fmt.Errorf("endpoint is closing")
}
// wait for a new conn
sl.logger.Info("SignerListener: Listening for new connection")
conn, err := sl.listener.Accept()
if err != nil {
return nil, err
}
return conn, nil
}
func (sl *SignerListenerEndpoint) triggerConnect() {
select {
case sl.connectRequestCh <- struct{}{}:
default:
}
}
func (sl *SignerListenerEndpoint) triggerReconnect() {
sl.DropConnection()
sl.triggerConnect()
}
func (sl *SignerListenerEndpoint) serviceLoop(ctx context.Context) {
for {
select {
case <-sl.connectRequestCh:
{
conn, err := sl.acceptNewConnection()
if err == nil {
sl.logger.Info("SignerListener: Connected")
// We have a good connection, wait for someone that needs one otherwise cancellation
select {
case sl.connectionAvailableCh <- conn:
case <-ctx.Done():
return
}
}
select {
case sl.connectRequestCh <- struct{}{}:
default:
}
}
case <-ctx.Done():
return
}
}
}
func (sl *SignerListenerEndpoint) pingLoop(ctx context.Context) {
for {
select {
case <-sl.pingTimer.C:
{
_, err := sl.SendRequest(ctx, mustWrapMsg(&privvalproto.PingRequest{}))
if err != nil {
sl.logger.Error("SignerListener: Ping timeout")
sl.triggerReconnect()
}
}
case <-ctx.Done():
return
}
}
}