Browse Source

Invert privVal socket communication

Follow-up to #1255 aligning with the expectation that the external
signing process connects to the node. The SocketClient will block on
start until one connection has been established, support for multiple
signers connected simultaneously is a planned future extension.

* SocketClient accepts connection
* PrivValSocketServer renamed to RemoteSigner
* extend tests
pull/1286/head
Alexander Simmerl 6 years ago
parent
commit
589781721a
No known key found for this signature in database GPG Key ID: 4694E95C9CC61BDA
3 changed files with 330 additions and 171 deletions
  1. +15
    -10
      cmd/priv_val_server/main.go
  2. +208
    -116
      types/priv_validator/socket.go
  3. +107
    -45
      types/priv_validator/socket_test.go

+ 15
- 10
cmd/priv_val_server/main.go View File

@ -12,36 +12,41 @@ import (
func main() {
var (
addr = flag.String("addr", ":46659", "Address of client to connect to")
chainID = flag.String("chain-id", "mychain", "chain id")
listenAddr = flag.String("laddr", ":46659", "Validator listen address (0.0.0.0:0 means any interface, any port")
maxConn = flag.Int("clients", 3, "maximum of concurrent connections")
privValPath = flag.String("priv", "", "priv val file path")
logger = log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("module", "priv_val")
logger = log.NewTMLogger(
log.NewSyncWriter(os.Stdout),
).With("module", "priv_val")
)
flag.Parse()
logger.Info(
"Starting private validator",
"addr", *addr,
"chainID", *chainID,
"listenAddr", *listenAddr,
"maxConn", *maxConn,
"privPath", *privValPath,
)
privVal := priv_val.LoadPrivValidatorJSON(*privValPath)
pvss := priv_val.NewPrivValidatorSocketServer(
rs := priv_val.NewRemoteSigner(
logger,
*chainID,
*listenAddr,
*maxConn,
*addr,
privVal,
nil,
)
pvss.Start()
err := rs.Start()
if err != nil {
panic(err)
}
cmn.TrapSignal(func() {
pvss.Stop()
err := rs.Stop()
if err != nil {
panic(err)
}
})
}

+ 208
- 116
types/priv_validator/socket.go View File

@ -19,12 +19,16 @@ import (
const (
defaultConnDeadlineSeconds = 3
defaultDialRetryMax = 10
defaultConnWaitSeconds = 60
defaultDialRetries = 10
defaultSignersMax = 1
)
// Socket errors.
var (
ErrDialRetryMax = errors.New("Error max client retries")
ErrDialRetryMax = errors.New("Error max client retries")
ErrConnWaitTimeout = errors.New("Error waiting for external connection")
ErrConnTimeout = errors.New("Error connection timed out")
)
var (
@ -34,10 +38,16 @@ var (
// SocketClientOption sets an optional parameter on the SocketClient.
type SocketClientOption func(*SocketClient)
// SocketClientTimeout sets the timeout for connecting to the external socket
// address.
func SocketClientTimeout(timeout time.Duration) SocketClientOption {
return func(sc *SocketClient) { sc.connectTimeout = timeout }
// SocketClientConnDeadline sets the read and write deadline for connections
// from external signing processes.
func SocketClientConnDeadline(deadline time.Duration) SocketClientOption {
return func(sc *SocketClient) { sc.connDeadline = deadline }
}
// SocketClientConnWait sets the timeout duration before connection of external
// signing processes are considered to be unsuccessful.
func SocketClientConnWait(timeout time.Duration) SocketClientOption {
return func(sc *SocketClient) { sc.connWaitTimeout = timeout }
}
// SocketClient implements PrivValidator, it uses a socket to request signatures
@ -45,11 +55,13 @@ func SocketClientTimeout(timeout time.Duration) SocketClientOption {
type SocketClient struct {
cmn.BaseService
conn net.Conn
privKey *crypto.PrivKeyEd25519
addr string
connDeadline time.Duration
connWaitTimeout time.Duration
privKey *crypto.PrivKeyEd25519
addr string
connectTimeout time.Duration
conn net.Conn
listener net.Listener
}
// Check that SocketClient implements PrivValidator2.
@ -62,24 +74,37 @@ func NewSocketClient(
privKey *crypto.PrivKeyEd25519,
) *SocketClient {
sc := &SocketClient{
addr: socketAddr,
connectTimeout: time.Second * defaultConnDeadlineSeconds,
privKey: privKey,
addr: socketAddr,
connDeadline: time.Second * defaultConnDeadlineSeconds,
connWaitTimeout: time.Second * defaultConnWaitSeconds,
privKey: privKey,
}
sc.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketClient", sc)
sc.BaseService = *cmn.NewBaseService(logger, "SocketClient", sc)
return sc
}
// OnStart implements cmn.Service.
func (sc *SocketClient) OnStart() error {
if err := sc.BaseService.OnStart(); err != nil {
return err
if sc.listener == nil {
if err := sc.listen(); err != nil {
sc.Logger.Error(
"OnStart",
"err", errors.Wrap(err, "failed to listen"),
)
return err
}
}
conn, err := sc.connect()
conn, err := sc.waitConnection()
if err != nil {
sc.Logger.Error(
"OnStart",
"err", errors.Wrap(err, "failed to accept connection"),
)
return err
}
@ -93,7 +118,21 @@ func (sc *SocketClient) OnStop() {
sc.BaseService.OnStop()
if sc.conn != nil {
sc.conn.Close()
if err := sc.conn.Close(); err != nil {
sc.Logger.Error(
"OnStop",
"err", errors.Wrap(err, "failed to close connection"),
)
}
}
if sc.listener != nil {
if err := sc.listener.Close(); err != nil {
sc.Logger.Error(
"OnStop",
"err", errors.Wrap(err, "failed to close listener"),
)
}
}
}
@ -162,7 +201,10 @@ func (sc *SocketClient) SignVote(chainID string, vote *types.Vote) error {
}
// SignProposal implements PrivValidator2.
func (sc *SocketClient) SignProposal(chainID string, proposal *types.Proposal) error {
func (sc *SocketClient) SignProposal(
chainID string,
proposal *types.Proposal,
) error {
err := writeMsg(sc.conn, &SignProposalMsg{Proposal: proposal})
if err != nil {
return err
@ -179,7 +221,10 @@ func (sc *SocketClient) SignProposal(chainID string, proposal *types.Proposal) e
}
// SignHeartbeat implements PrivValidator2.
func (sc *SocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat) error {
func (sc *SocketClient) SignHeartbeat(
chainID string,
heartbeat *types.Heartbeat,
) error {
err := writeMsg(sc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat})
if err != nil {
return err
@ -195,166 +240,206 @@ func (sc *SocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat
return nil
}
func (sc *SocketClient) connect() (net.Conn, error) {
retries := defaultDialRetryMax
RETRY_LOOP:
for retries > 0 {
if retries != defaultDialRetryMax {
time.Sleep(sc.connectTimeout)
func (sc *SocketClient) acceptConnection() (net.Conn, error) {
conn, err := sc.listener.Accept()
if err != nil {
if !sc.IsRunning() {
return nil, nil // Ignore error from listener closing.
}
return nil, err
retries--
}
conn, err := cmn.Connect(sc.addr)
if err != nil {
sc.Logger.Error(
"sc connect",
"addr", sc.addr,
"err", errors.Wrap(err, "connection failed"),
)
if err := conn.SetDeadline(time.Now().Add(sc.connDeadline)); err != nil {
return nil, err
}
continue RETRY_LOOP
if sc.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap())
if err != nil {
return nil, err
}
}
if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil {
sc.Logger.Error(
"sc connect",
"err", errors.Wrap(err, "setting connection timeout failed"),
)
continue
}
return conn, nil
}
if sc.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap())
if err != nil {
sc.Logger.Error(
"sc connect",
"err", errors.Wrap(err, "encrypting connection failed"),
)
func (sc *SocketClient) listen() error {
ln, err := net.Listen(cmn.ProtocolAndAddress(sc.addr))
if err != nil {
return err
}
continue RETRY_LOOP
}
sc.listener = netutil.LimitListener(ln, defaultSignersMax)
return nil
}
// waitConnection uses the configured wait timeout to error if no external
// process connects in the time period.
func (sc *SocketClient) waitConnection() (net.Conn, error) {
var (
connc = make(chan net.Conn, 1)
errc = make(chan error, 1)
)
go func(connc chan<- net.Conn, errc chan<- error) {
conn, err := sc.acceptConnection()
if err != nil {
errc <- err
return
}
connc <- conn
}(connc, errc)
select {
case conn := <-connc:
return conn, nil
case err := <-errc:
return nil, err
case <-time.After(sc.connWaitTimeout):
return nil, ErrConnWaitTimeout
}
return nil, ErrDialRetryMax
}
//---------------------------------------------------------
// PrivValidatorSocketServer implements PrivValidator.
// RemoteSignerOption sets an optional parameter on the RemoteSigner.
type RemoteSignerOption func(*RemoteSigner)
// RemoteSignerConnDeadline sets the read and write deadline for connections
// from external signing processes.
func RemoteSignerConnDeadline(deadline time.Duration) RemoteSignerOption {
return func(ss *RemoteSigner) { ss.connDeadline = deadline }
}
// RemoteSignerConnRetries sets the amount of attempted retries to connect.
func RemoteSignerConnRetries(retries int) RemoteSignerOption {
return func(ss *RemoteSigner) { ss.connRetries = retries }
}
// RemoteSigner implements PrivValidator.
// It responds to requests over a socket
type PrivValidatorSocketServer struct {
type RemoteSigner struct {
cmn.BaseService
proto, addr string
listener net.Listener
maxConnections int
privKey *crypto.PrivKeyEd25519
addr string
chainID string
connDeadline time.Duration
connRetries int
privKey *crypto.PrivKeyEd25519
privVal PrivValidator
privVal PrivValidator
chainID string
conn net.Conn
}
// NewPrivValidatorSocketServer returns an instance of
// PrivValidatorSocketServer.
func NewPrivValidatorSocketServer(
// NewRemoteSigner returns an instance of
// RemoteSigner.
func NewRemoteSigner(
logger log.Logger,
chainID, socketAddr string,
maxConnections int,
privVal PrivValidator,
privKey *crypto.PrivKeyEd25519,
) *PrivValidatorSocketServer {
proto, addr := cmn.ProtocolAndAddress(socketAddr)
pvss := &PrivValidatorSocketServer{
proto: proto,
addr: addr,
maxConnections: maxConnections,
privKey: privKey,
privVal: privVal,
chainID: chainID,
) *RemoteSigner {
rs := &RemoteSigner{
addr: socketAddr,
chainID: chainID,
connDeadline: time.Second * defaultConnDeadlineSeconds,
connRetries: defaultDialRetries,
privKey: privKey,
privVal: privVal,
}
pvss.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketServer", pvss)
return pvss
rs.BaseService = *cmn.NewBaseService(logger, "RemoteSigner", rs)
return rs
}
// OnStart implements cmn.Service.
func (pvss *PrivValidatorSocketServer) OnStart() error {
ln, err := net.Listen(pvss.proto, pvss.addr)
func (rs *RemoteSigner) OnStart() error {
conn, err := rs.connect()
if err != nil {
rs.Logger.Error("OnStart", "err", errors.Wrap(err, "connect"))
return err
}
pvss.listener = netutil.LimitListener(ln, pvss.maxConnections)
go pvss.acceptConnections()
go rs.handleConnection(conn)
return nil
}
// OnStop implements cmn.Service.
func (pvss *PrivValidatorSocketServer) OnStop() {
if pvss.listener == nil {
func (rs *RemoteSigner) OnStop() {
if rs.conn == nil {
return
}
if err := pvss.listener.Close(); err != nil {
pvss.Logger.Error("OnStop", "err", errors.Wrap(err, "closing listener failed"))
if err := rs.conn.Close(); err != nil {
rs.Logger.Error("OnStop", "err", errors.Wrap(err, "closing listener failed"))
}
}
func (pvss *PrivValidatorSocketServer) acceptConnections() {
for {
conn, err := pvss.listener.Accept()
func (rs *RemoteSigner) connect() (net.Conn, error) {
retries := defaultDialRetries
RETRY_LOOP:
for retries > 0 {
// Don't sleep if it is the first retry.
if retries != defaultDialRetries {
time.Sleep(rs.connDeadline)
}
retries--
conn, err := cmn.Connect(rs.addr)
if err != nil {
if !pvss.IsRunning() {
return // Ignore error from listener closing.
}
pvss.Logger.Error(
"acceptConnections",
"err", errors.Wrap(err, "failed to accept connection"),
rs.Logger.Error(
"connect",
"addr", rs.addr,
"err", errors.Wrap(err, "connection failed"),
)
continue
continue RETRY_LOOP
}
if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil {
pvss.Logger.Error(
"acceptConnetions",
rs.Logger.Error(
"connect",
"err", errors.Wrap(err, "setting connection timeout failed"),
)
continue
}
if pvss.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, pvss.privKey.Wrap())
if rs.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey.Wrap())
if err != nil {
pvss.Logger.Error(
"acceptConnections",
"err", errors.Wrap(err, "secret connection failed"),
rs.Logger.Error(
"sc connect",
"err", errors.Wrap(err, "encrypting connection failed"),
)
continue
continue RETRY_LOOP
}
}
go pvss.handleConnection(conn)
return conn, nil
}
}
func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) {
defer conn.Close()
return nil, ErrDialRetryMax
}
func (rs *RemoteSigner) handleConnection(conn net.Conn) {
for {
if !pvss.IsRunning() {
if !rs.IsRunning() {
return // Ignore error from listener closing.
}
req, err := readMsg(conn)
if err != nil {
if err != io.EOF {
pvss.Logger.Error("handleConnection", "err", err)
rs.Logger.Error("handleConnection", "err", err)
}
return
}
@ -365,29 +450,29 @@ func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) {
case *PubKeyMsg:
var p crypto.PubKey
p, err = pvss.privVal.PubKey()
p, err = rs.privVal.PubKey()
res = &PubKeyMsg{p}
case *SignVoteMsg:
err = pvss.privVal.SignVote(pvss.chainID, r.Vote)
err = rs.privVal.SignVote(rs.chainID, r.Vote)
res = &SignVoteMsg{r.Vote}
case *SignProposalMsg:
err = pvss.privVal.SignProposal(pvss.chainID, r.Proposal)
err = rs.privVal.SignProposal(rs.chainID, r.Proposal)
res = &SignProposalMsg{r.Proposal}
case *SignHeartbeatMsg:
err = pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat)
err = rs.privVal.SignHeartbeat(rs.chainID, r.Heartbeat)
res = &SignHeartbeatMsg{r.Heartbeat}
default:
err = fmt.Errorf("unknown msg: %v", r)
}
if err != nil {
pvss.Logger.Error("handleConnection", "err", err)
rs.Logger.Error("handleConnection", "err", err)
return
}
err = writeMsg(conn, res)
if err != nil {
pvss.Logger.Error("handleConnection", "err", err)
rs.Logger.Error("handleConnection", "err", err)
return
}
}
@ -442,6 +527,10 @@ func readMsg(r io.Reader) (PrivValidatorSocketMsg, error) {
read := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, &n, &err)
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
return nil, errors.Wrapf(ErrConnTimeout, opErr.Addr.String())
}
return nil, err
}
@ -461,6 +550,9 @@ func writeMsg(w io.Writer, msg interface{}) error {
// TODO(xla): This extra wrap should be gone with the sdk-2 update.
wire.WriteBinary(struct{ PrivValidatorSocketMsg }{msg}, w, &n, &err)
if opErr, ok := err.(*net.OpError); ok {
return errors.Wrapf(ErrConnTimeout, opErr.Addr.String())
}
return err
}

+ 107
- 45
types/priv_validator/socket_test.go View File

@ -4,10 +4,12 @@ import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
crypto "github.com/tendermint/go-crypto"
cmn "github.com/tendermint/tmlibs/common"
"github.com/tendermint/tmlibs/log"
"github.com/tendermint/tendermint/types"
@ -16,13 +18,13 @@ import (
func TestSocketClientAddress(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
serverAddr, err := pvss.privVal.Address()
serverAddr, err := rs.privVal.Address()
require.NoError(err)
clientAddr, err := sc.Address()
@ -38,16 +40,16 @@ func TestSocketClientAddress(t *testing.T) {
func TestSocketClientPubKey(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
clientKey, err := sc.PubKey()
require.NoError(err)
privKey, err := pvss.privVal.PubKey()
privKey, err := rs.privVal.PubKey()
require.NoError(err)
assert.Equal(privKey, clientKey)
@ -59,17 +61,17 @@ func TestSocketClientPubKey(t *testing.T) {
func TestSocketClientProposal(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
ts = time.Now()
privProposal = &types.Proposal{Timestamp: ts}
clientProposal = &types.Proposal{Timestamp: ts}
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
require.NoError(pvss.privVal.SignProposal(chainID, privProposal))
require.NoError(rs.privVal.SignProposal(chainID, privProposal))
require.NoError(sc.SignProposal(chainID, clientProposal))
assert.Equal(privProposal.Signature, clientProposal.Signature)
}
@ -77,8 +79,8 @@ func TestSocketClientProposal(t *testing.T) {
func TestSocketClientVote(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
ts = time.Now()
vType = types.VoteTypePrecommit
@ -86,9 +88,9 @@ func TestSocketClientVote(t *testing.T) {
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
require.NoError(pvss.privVal.SignVote(chainID, want))
require.NoError(rs.privVal.SignVote(chainID, want))
require.NoError(sc.SignVote(chainID, have))
assert.Equal(want.Signature, have.Signature)
}
@ -96,69 +98,129 @@ func TestSocketClientVote(t *testing.T) {
func TestSocketClientHeartbeat(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
want = &types.Heartbeat{}
have = &types.Heartbeat{}
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
require.NoError(pvss.privVal.SignHeartbeat(chainID, want))
require.NoError(rs.privVal.SignHeartbeat(chainID, want))
require.NoError(sc.SignHeartbeat(chainID, have))
assert.Equal(want.Signature, have.Signature)
}
func TestSocketClientConnectRetryMax(t *testing.T) {
func TestSocketClientDeadline(t *testing.T) {
var (
assert, _ = assert.New(t), require.New(t)
logger = log.TestingLogger()
clientPrivKey = crypto.GenPrivKeyEd25519()
sc = NewSocketClient(
assert, require = assert.New(t), require.New(t)
readyc = make(chan struct{})
sc = NewSocketClient(
log.TestingLogger(),
"127.0.0.1:0",
nil,
)
)
defer sc.Stop()
SocketClientConnDeadline(time.Millisecond)(sc)
require.NoError(sc.listen())
go func(sc *SocketClient) {
require.NoError(sc.Start())
assert.True(sc.IsRunning())
readyc <- struct{}{}
}(sc)
_, err := cmn.Connect(sc.listener.Addr().String())
require.NoError(err)
<-readyc
_, err = sc.PubKey()
assert.Equal(errors.Cause(err), ErrConnTimeout)
}
func TestSocketClientWait(t *testing.T) {
var (
assert, _ = assert.New(t), require.New(t)
logger = log.TestingLogger()
privKey = crypto.GenPrivKeyEd25519()
sc = NewSocketClient(
logger,
"127.0.0.1:0",
&clientPrivKey,
&privKey,
)
)
defer sc.Stop()
SocketClientTimeout(time.Millisecond)(sc)
SocketClientConnWait(time.Millisecond)(sc)
assert.EqualError(sc.Start(), ErrDialRetryMax.Error())
assert.EqualError(sc.Start(), ErrConnWaitTimeout.Error())
}
func testSetupSocketPair(t *testing.T, chainID string) (*SocketClient, *PrivValidatorSocketServer) {
func TestRemoteSignerRetry(t *testing.T) {
var (
assert, _ = assert.New(t), require.New(t)
privKey = crypto.GenPrivKeyEd25519()
rs = NewRemoteSigner(
log.TestingLogger(),
cmn.RandStr(12),
"127.0.0.1:0",
NewTestPrivValidator(types.GenSigner()),
&privKey,
)
)
defer rs.Stop()
RemoteSignerConnDeadline(time.Millisecond)(rs)
RemoteSignerConnRetries(2)(rs)
assert.EqualError(rs.Start(), ErrDialRetryMax.Error())
}
func testSetupSocketPair(
t *testing.T,
chainID string,
) (*SocketClient, *RemoteSigner) {
var (
assert, require = assert.New(t), require.New(t)
logger = log.TestingLogger()
signer = types.GenSigner()
clientPrivKey = crypto.GenPrivKeyEd25519()
serverPrivKey = crypto.GenPrivKeyEd25519()
remotePrivKey = crypto.GenPrivKeyEd25519()
privVal = NewTestPrivValidator(signer)
pvss = NewPrivValidatorSocketServer(
readyc = make(chan struct{})
sc = NewSocketClient(
logger,
chainID,
"127.0.0.1:0",
1,
privVal,
&serverPrivKey,
&clientPrivKey,
)
)
err := pvss.Start()
require.NoError(err)
assert.True(pvss.IsRunning())
require.NoError(sc.listen())
sc := NewSocketClient(
go func(sc *SocketClient) {
require.NoError(sc.Start())
assert.True(sc.IsRunning())
readyc <- struct{}{}
}(sc)
rs := NewRemoteSigner(
logger,
pvss.listener.Addr().String(),
&clientPrivKey,
chainID,
sc.listener.Addr().String(),
privVal,
&remotePrivKey,
)
require.NoError(rs.Start())
assert.True(rs.IsRunning())
err = sc.Start()
require.NoError(err)
assert.True(sc.IsRunning())
<-readyc
return sc, pvss
return sc, rs
}

Loading…
Cancel
Save