You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

257 lines
6.0 KiB

package privval
import (
"fmt"
"io"
"net"
"sync"
"time"
"github.com/tendermint/tendermint/crypto"
cmn "github.com/tendermint/tendermint/libs/common"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/types"
)
const (
defaultMaxDialRetries = 10
)
// SignerServiceEndpointOption sets an optional parameter on the SignerDialerEndpoint.
type SignerServiceEndpointOption func(*SignerDialerEndpoint)
// SignerServiceEndpointTimeoutReadWrite sets the read and write timeout for connections
// from external signing processes.
func SignerServiceEndpointTimeoutReadWrite(timeout time.Duration) SignerServiceEndpointOption {
return func(ss *SignerDialerEndpoint) { ss.timeoutReadWrite = timeout }
}
// SignerServiceEndpointConnRetries sets the amount of attempted retries to AcceptNewConnection.
func SignerServiceEndpointConnRetries(retries int) SignerServiceEndpointOption {
return func(ss *SignerDialerEndpoint) { ss.maxConnRetries = retries }
}
// TODO(jleni): Create a common type for a signerEndpoint (common for both listener/dialer)
// getConnection
// AcceptNewConnection
// read
// write
// close
// SignerDialerEndpoint dials using its dialer and responds to any
// signature requests using its privVal.
type SignerDialerEndpoint struct {
cmn.BaseService
mtx sync.Mutex
dialer SocketDialer
conn net.Conn
timeoutReadWrite time.Duration
maxConnRetries int
chainID string
privVal types.PrivValidator
stopCh chan struct{}
stoppedCh chan struct{}
}
// NewSignerDialerEndpoint returns a SignerDialerEndpoint that will dial using the given
// dialer and respond to any signature requests over the connection
// using the given privVal.
func NewSignerDialerEndpoint(
logger log.Logger,
chainID string,
privVal types.PrivValidator,
dialer SocketDialer,
) *SignerDialerEndpoint {
se := &SignerDialerEndpoint{
dialer: dialer,
timeoutReadWrite: defaultTimeoutReadWriteSeconds * time.Second,
maxConnRetries: defaultMaxDialRetries,
chainID: chainID,
privVal: privVal,
}
se.BaseService = *cmn.NewBaseService(logger, "SignerDialerEndpoint", se)
return se
}
// OnStart implements cmn.Service.
func (ss *SignerDialerEndpoint) OnStart() error {
ss.Logger.Debug("SignerDialerEndpoint: OnStart")
ss.stopCh = make(chan struct{})
ss.stoppedCh = make(chan struct{})
go ss.serviceLoop()
ss.Logger.Debug("OnStart - done")
return nil
}
// OnStop implements cmn.Service.
func (ss *SignerDialerEndpoint) OnStop() {
// Trigger a stop and wait
close(ss.stopCh)
<-ss.stoppedCh
if ss.conn != nil {
if err := ss.conn.Close(); err != nil {
ss.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed"))
ss.Logger.Debug("Reset conn")
ss.conn = nil
}
}
}
func (ss *SignerDialerEndpoint) serviceLoop() {
defer close(ss.stoppedCh)
retries := 0
var err error
for {
select {
default:
{
ss.Logger.Debug("Try connect", "retries", retries, "max", ss.maxConnRetries)
if retries > ss.maxConnRetries {
ss.Logger.Error("Maximum retries reached", "retries", retries)
return
}
if ss.conn == nil {
ss.conn, err = ss.dialer()
if err != nil {
ss.conn = nil // Explicitly set to nil because dialer returns an interface (https://golang.org/doc/faq#nil_error)
retries++
continue
}
}
retries = 0
ss.handleRequest()
}
case <-ss.stopCh:
{
return
}
}
}
}
func (ss *SignerDialerEndpoint) readMessage() (msg RemoteSignerMsg, err error) {
// TODO(jleni): Avoid duplication. Unify endpoints
if ss.conn == nil {
return nil, fmt.Errorf("not connected")
}
// Reset read deadline
deadline := time.Now().Add(ss.timeoutReadWrite)
ss.Logger.Debug("SignerDialerEndpoint: readMessage", "deadline", deadline)
err = ss.conn.SetReadDeadline(deadline)
if err != nil {
return
}
const maxRemoteSignerMsgSize = 1024 * 10
_, err = cdc.UnmarshalBinaryLengthPrefixedReader(ss.conn, &msg, maxRemoteSignerMsgSize)
if _, ok := err.(timeoutError); ok {
err = cmn.ErrorWrap(ErrDialerTimeout, err.Error())
}
return
}
func (ss *SignerDialerEndpoint) writeMessage(msg RemoteSignerMsg) (err error) {
// TODO(jleni): Avoid duplication. Unify endpoints
if ss.conn == nil {
return fmt.Errorf("not connected")
}
// Reset read deadline
deadline := time.Now().Add(ss.timeoutReadWrite)
ss.Logger.Debug("SignerDialerEndpoint: readMessage", "deadline", deadline)
err = ss.conn.SetWriteDeadline(deadline)
if err != nil {
return
}
_, err = cdc.MarshalBinaryLengthPrefixedWriter(ss.conn, msg)
if _, ok := err.(timeoutError); ok {
err = cmn.ErrorWrap(ErrDialerTimeout, err.Error())
}
return
}
func (ss *SignerDialerEndpoint) handleRequest() {
if !ss.IsRunning() {
return // Ignore error from listener closing.
}
ss.Logger.Info("SignerDialerEndpoint: connected", "timeout", ss.timeoutReadWrite)
req, err := ss.readMessage()
if err != nil {
if err != io.EOF {
ss.Logger.Error("SignerDialerEndpoint handleMessage", "err", err)
}
return
}
res, err := handleMessage(req, ss.chainID, ss.privVal)
if err != nil {
// only log the error; we'll reply with an error in res
ss.Logger.Error("handleMessage handleMessage", "err", err)
}
err = ss.writeMessage(res)
if err != nil {
ss.Logger.Error("handleMessage writeMessage", "err", err)
return
}
}
func handleMessage(req RemoteSignerMsg, chainID string, privVal types.PrivValidator) (RemoteSignerMsg, error) {
var res RemoteSignerMsg
var err error
switch r := req.(type) {
case *PubKeyRequest:
var p crypto.PubKey
p = privVal.GetPubKey()
res = &PubKeyResponse{p, nil}
case *SignVoteRequest:
err = privVal.SignVote(chainID, r.Vote)
if err != nil {
res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedVoteResponse{r.Vote, nil}
}
case *SignProposalRequest:
err = privVal.SignProposal(chainID, r.Proposal)
if err != nil {
res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedProposalResponse{r.Proposal, nil}
}
default:
err = fmt.Errorf("unknown msg: %v", r)
}
return res, err
}