package privval
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/tendermint/tendermint/crypto"
|
|
cmn "github.com/tendermint/tendermint/libs/common"
|
|
"github.com/tendermint/tendermint/libs/log"
|
|
"github.com/tendermint/tendermint/types"
|
|
)
|
|
|
|
const (
|
|
defaultMaxDialRetries = 10
|
|
)
|
|
|
|
// SignerServiceEndpointOption sets an optional parameter on the SignerDialerEndpoint.
|
|
type SignerServiceEndpointOption func(*SignerDialerEndpoint)
|
|
|
|
// SignerServiceEndpointTimeoutReadWrite sets the read and write timeout for connections
|
|
// from external signing processes.
|
|
func SignerServiceEndpointTimeoutReadWrite(timeout time.Duration) SignerServiceEndpointOption {
|
|
return func(ss *SignerDialerEndpoint) { ss.timeoutReadWrite = timeout }
|
|
}
|
|
|
|
// SignerServiceEndpointConnRetries sets the amount of attempted retries to AcceptNewConnection.
|
|
func SignerServiceEndpointConnRetries(retries int) SignerServiceEndpointOption {
|
|
return func(ss *SignerDialerEndpoint) { ss.maxConnRetries = retries }
|
|
}
|
|
|
|
// TODO(jleni): Create a common type for a signerEndpoint (common for both listener/dialer)
|
|
// getConnection
|
|
// AcceptNewConnection
|
|
// read
|
|
// write
|
|
// close
|
|
|
|
// SignerDialerEndpoint dials using its dialer and responds to any
|
|
// signature requests using its privVal.
|
|
type SignerDialerEndpoint struct {
|
|
cmn.BaseService
|
|
|
|
mtx sync.Mutex
|
|
dialer SocketDialer
|
|
conn net.Conn
|
|
|
|
timeoutReadWrite time.Duration
|
|
maxConnRetries int
|
|
|
|
chainID string
|
|
privVal types.PrivValidator
|
|
|
|
stopCh chan struct{}
|
|
stoppedCh chan struct{}
|
|
}
|
|
|
|
// NewSignerDialerEndpoint returns a SignerDialerEndpoint that will dial using the given
|
|
// dialer and respond to any signature requests over the connection
|
|
// using the given privVal.
|
|
func NewSignerDialerEndpoint(
|
|
logger log.Logger,
|
|
chainID string,
|
|
privVal types.PrivValidator,
|
|
dialer SocketDialer,
|
|
) *SignerDialerEndpoint {
|
|
|
|
se := &SignerDialerEndpoint{
|
|
dialer: dialer,
|
|
timeoutReadWrite: defaultTimeoutReadWriteSeconds * time.Second,
|
|
maxConnRetries: defaultMaxDialRetries,
|
|
|
|
chainID: chainID,
|
|
privVal: privVal,
|
|
}
|
|
|
|
se.BaseService = *cmn.NewBaseService(logger, "SignerDialerEndpoint", se)
|
|
return se
|
|
}
|
|
|
|
// OnStart implements cmn.Service.
|
|
func (ss *SignerDialerEndpoint) OnStart() error {
|
|
ss.Logger.Debug("SignerDialerEndpoint: OnStart")
|
|
|
|
ss.stopCh = make(chan struct{})
|
|
ss.stoppedCh = make(chan struct{})
|
|
|
|
go ss.serviceLoop()
|
|
|
|
ss.Logger.Debug("OnStart - done")
|
|
return nil
|
|
}
|
|
|
|
// OnStop implements cmn.Service.
|
|
func (ss *SignerDialerEndpoint) OnStop() {
|
|
// Trigger a stop and wait
|
|
close(ss.stopCh)
|
|
<-ss.stoppedCh
|
|
|
|
if ss.conn != nil {
|
|
if err := ss.conn.Close(); err != nil {
|
|
ss.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed"))
|
|
ss.Logger.Debug("Reset conn")
|
|
ss.conn = nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ss *SignerDialerEndpoint) serviceLoop() {
|
|
defer close(ss.stoppedCh)
|
|
|
|
retries := 0
|
|
var err error
|
|
|
|
for {
|
|
select {
|
|
default:
|
|
{
|
|
ss.Logger.Debug("Try connect", "retries", retries, "max", ss.maxConnRetries)
|
|
|
|
if retries > ss.maxConnRetries {
|
|
ss.Logger.Error("Maximum retries reached", "retries", retries)
|
|
return
|
|
}
|
|
|
|
if ss.conn == nil {
|
|
ss.conn, err = ss.dialer()
|
|
|
|
if err != nil {
|
|
ss.conn = nil // Explicitly set to nil because dialer returns an interface (https://golang.org/doc/faq#nil_error)
|
|
retries++
|
|
continue
|
|
}
|
|
}
|
|
|
|
retries = 0
|
|
ss.handleRequest()
|
|
}
|
|
|
|
case <-ss.stopCh:
|
|
{
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ss *SignerDialerEndpoint) readMessage() (msg RemoteSignerMsg, err error) {
|
|
// TODO(jleni): Avoid duplication. Unify endpoints
|
|
if ss.conn == nil {
|
|
return nil, fmt.Errorf("not connected")
|
|
}
|
|
|
|
// Reset read deadline
|
|
deadline := time.Now().Add(ss.timeoutReadWrite)
|
|
ss.Logger.Debug("SignerDialerEndpoint: readMessage", "deadline", deadline)
|
|
|
|
err = ss.conn.SetReadDeadline(deadline)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
const maxRemoteSignerMsgSize = 1024 * 10
|
|
_, err = cdc.UnmarshalBinaryLengthPrefixedReader(ss.conn, &msg, maxRemoteSignerMsgSize)
|
|
if _, ok := err.(timeoutError); ok {
|
|
err = cmn.ErrorWrap(ErrDialerTimeout, err.Error())
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (ss *SignerDialerEndpoint) writeMessage(msg RemoteSignerMsg) (err error) {
|
|
// TODO(jleni): Avoid duplication. Unify endpoints
|
|
|
|
if ss.conn == nil {
|
|
return fmt.Errorf("not connected")
|
|
}
|
|
|
|
// Reset read deadline
|
|
deadline := time.Now().Add(ss.timeoutReadWrite)
|
|
ss.Logger.Debug("SignerDialerEndpoint: readMessage", "deadline", deadline)
|
|
err = ss.conn.SetWriteDeadline(deadline)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
_, err = cdc.MarshalBinaryLengthPrefixedWriter(ss.conn, msg)
|
|
if _, ok := err.(timeoutError); ok {
|
|
err = cmn.ErrorWrap(ErrDialerTimeout, err.Error())
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (ss *SignerDialerEndpoint) handleRequest() {
|
|
if !ss.IsRunning() {
|
|
return // Ignore error from listener closing.
|
|
}
|
|
|
|
ss.Logger.Info("SignerDialerEndpoint: connected", "timeout", ss.timeoutReadWrite)
|
|
|
|
req, err := ss.readMessage()
|
|
if err != nil {
|
|
if err != io.EOF {
|
|
ss.Logger.Error("SignerDialerEndpoint handleMessage", "err", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
res, err := handleMessage(req, ss.chainID, ss.privVal)
|
|
|
|
if err != nil {
|
|
// only log the error; we'll reply with an error in res
|
|
ss.Logger.Error("handleMessage handleMessage", "err", err)
|
|
}
|
|
|
|
err = ss.writeMessage(res)
|
|
if err != nil {
|
|
ss.Logger.Error("handleMessage writeMessage", "err", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
func handleMessage(req RemoteSignerMsg, chainID string, privVal types.PrivValidator) (RemoteSignerMsg, error) {
|
|
var res RemoteSignerMsg
|
|
var err error
|
|
|
|
switch r := req.(type) {
|
|
case *PubKeyRequest:
|
|
var p crypto.PubKey
|
|
p = privVal.GetPubKey()
|
|
res = &PubKeyResponse{p, nil}
|
|
|
|
case *SignVoteRequest:
|
|
err = privVal.SignVote(chainID, r.Vote)
|
|
if err != nil {
|
|
res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}}
|
|
} else {
|
|
res = &SignedVoteResponse{r.Vote, nil}
|
|
}
|
|
|
|
case *SignProposalRequest:
|
|
err = privVal.SignProposal(chainID, r.Proposal)
|
|
if err != nil {
|
|
res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}}
|
|
} else {
|
|
res = &SignedProposalResponse{r.Proposal, nil}
|
|
}
|
|
|
|
default:
|
|
err = fmt.Errorf("unknown msg: %v", r)
|
|
}
|
|
|
|
return res, err
|
|
}
|