Browse Source

Merge pull request #1286 from tendermint/feature/xla-priv-val-invert-dial

Invert privVal socket communication
pull/1239/merge
Ethan Buchman 7 years ago
committed by GitHub
parent
commit
cd2ba4aa7f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 330 additions and 171 deletions
  1. +15
    -10
      cmd/priv_val_server/main.go
  2. +208
    -116
      types/priv_validator/socket.go
  3. +107
    -45
      types/priv_validator/socket_test.go

+ 15
- 10
cmd/priv_val_server/main.go View File

@ -12,36 +12,41 @@ import (
func main() {
var (
addr = flag.String("addr", ":46659", "Address of client to connect to")
chainID = flag.String("chain-id", "mychain", "chain id")
listenAddr = flag.String("laddr", ":46659", "Validator listen address (0.0.0.0:0 means any interface, any port")
maxConn = flag.Int("clients", 3, "maximum of concurrent connections")
privValPath = flag.String("priv", "", "priv val file path")
logger = log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("module", "priv_val")
logger = log.NewTMLogger(
log.NewSyncWriter(os.Stdout),
).With("module", "priv_val")
)
flag.Parse()
logger.Info(
"Starting private validator",
"addr", *addr,
"chainID", *chainID,
"listenAddr", *listenAddr,
"maxConn", *maxConn,
"privPath", *privValPath,
)
privVal := priv_val.LoadPrivValidatorJSON(*privValPath)
pvss := priv_val.NewPrivValidatorSocketServer(
rs := priv_val.NewRemoteSigner(
logger,
*chainID,
*listenAddr,
*maxConn,
*addr,
privVal,
nil,
)
pvss.Start()
err := rs.Start()
if err != nil {
panic(err)
}
cmn.TrapSignal(func() {
pvss.Stop()
err := rs.Stop()
if err != nil {
panic(err)
}
})
}

+ 208
- 116
types/priv_validator/socket.go View File

@ -19,12 +19,16 @@ import (
const (
defaultConnDeadlineSeconds = 3
defaultDialRetryMax = 10
defaultConnWaitSeconds = 60
defaultDialRetries = 10
defaultSignersMax = 1
)
// Socket errors.
var (
ErrDialRetryMax = errors.New("Error max client retries")
ErrDialRetryMax = errors.New("Error max client retries")
ErrConnWaitTimeout = errors.New("Error waiting for external connection")
ErrConnTimeout = errors.New("Error connection timed out")
)
var (
@ -34,10 +38,16 @@ var (
// SocketClientOption sets an optional parameter on the SocketClient.
type SocketClientOption func(*SocketClient)
// SocketClientTimeout sets the timeout for connecting to the external socket
// address.
func SocketClientTimeout(timeout time.Duration) SocketClientOption {
return func(sc *SocketClient) { sc.connectTimeout = timeout }
// SocketClientConnDeadline sets the read and write deadline for connections
// from external signing processes.
func SocketClientConnDeadline(deadline time.Duration) SocketClientOption {
return func(sc *SocketClient) { sc.connDeadline = deadline }
}
// SocketClientConnWait sets the timeout duration before connection of external
// signing processes are considered to be unsuccessful.
func SocketClientConnWait(timeout time.Duration) SocketClientOption {
return func(sc *SocketClient) { sc.connWaitTimeout = timeout }
}
// SocketClient implements PrivValidator, it uses a socket to request signatures
@ -45,11 +55,13 @@ func SocketClientTimeout(timeout time.Duration) SocketClientOption {
type SocketClient struct {
cmn.BaseService
conn net.Conn
privKey *crypto.PrivKeyEd25519
addr string
connDeadline time.Duration
connWaitTimeout time.Duration
privKey *crypto.PrivKeyEd25519
addr string
connectTimeout time.Duration
conn net.Conn
listener net.Listener
}
// Check that SocketClient implements PrivValidator2.
@ -62,24 +74,37 @@ func NewSocketClient(
privKey *crypto.PrivKeyEd25519,
) *SocketClient {
sc := &SocketClient{
addr: socketAddr,
connectTimeout: time.Second * defaultConnDeadlineSeconds,
privKey: privKey,
addr: socketAddr,
connDeadline: time.Second * defaultConnDeadlineSeconds,
connWaitTimeout: time.Second * defaultConnWaitSeconds,
privKey: privKey,
}
sc.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketClient", sc)
sc.BaseService = *cmn.NewBaseService(logger, "SocketClient", sc)
return sc
}
// OnStart implements cmn.Service.
func (sc *SocketClient) OnStart() error {
if err := sc.BaseService.OnStart(); err != nil {
return err
if sc.listener == nil {
if err := sc.listen(); err != nil {
sc.Logger.Error(
"OnStart",
"err", errors.Wrap(err, "failed to listen"),
)
return err
}
}
conn, err := sc.connect()
conn, err := sc.waitConnection()
if err != nil {
sc.Logger.Error(
"OnStart",
"err", errors.Wrap(err, "failed to accept connection"),
)
return err
}
@ -93,7 +118,21 @@ func (sc *SocketClient) OnStop() {
sc.BaseService.OnStop()
if sc.conn != nil {
sc.conn.Close()
if err := sc.conn.Close(); err != nil {
sc.Logger.Error(
"OnStop",
"err", errors.Wrap(err, "failed to close connection"),
)
}
}
if sc.listener != nil {
if err := sc.listener.Close(); err != nil {
sc.Logger.Error(
"OnStop",
"err", errors.Wrap(err, "failed to close listener"),
)
}
}
}
@ -162,7 +201,10 @@ func (sc *SocketClient) SignVote(chainID string, vote *types.Vote) error {
}
// SignProposal implements PrivValidator2.
func (sc *SocketClient) SignProposal(chainID string, proposal *types.Proposal) error {
func (sc *SocketClient) SignProposal(
chainID string,
proposal *types.Proposal,
) error {
err := writeMsg(sc.conn, &SignProposalMsg{Proposal: proposal})
if err != nil {
return err
@ -179,7 +221,10 @@ func (sc *SocketClient) SignProposal(chainID string, proposal *types.Proposal) e
}
// SignHeartbeat implements PrivValidator2.
func (sc *SocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat) error {
func (sc *SocketClient) SignHeartbeat(
chainID string,
heartbeat *types.Heartbeat,
) error {
err := writeMsg(sc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat})
if err != nil {
return err
@ -195,166 +240,206 @@ func (sc *SocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat
return nil
}
func (sc *SocketClient) connect() (net.Conn, error) {
retries := defaultDialRetryMax
RETRY_LOOP:
for retries > 0 {
if retries != defaultDialRetryMax {
time.Sleep(sc.connectTimeout)
func (sc *SocketClient) acceptConnection() (net.Conn, error) {
conn, err := sc.listener.Accept()
if err != nil {
if !sc.IsRunning() {
return nil, nil // Ignore error from listener closing.
}
return nil, err
retries--
}
conn, err := cmn.Connect(sc.addr)
if err != nil {
sc.Logger.Error(
"sc connect",
"addr", sc.addr,
"err", errors.Wrap(err, "connection failed"),
)
if err := conn.SetDeadline(time.Now().Add(sc.connDeadline)); err != nil {
return nil, err
}
continue RETRY_LOOP
if sc.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap())
if err != nil {
return nil, err
}
}
if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil {
sc.Logger.Error(
"sc connect",
"err", errors.Wrap(err, "setting connection timeout failed"),
)
continue
}
return conn, nil
}
if sc.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap())
if err != nil {
sc.Logger.Error(
"sc connect",
"err", errors.Wrap(err, "encrypting connection failed"),
)
func (sc *SocketClient) listen() error {
ln, err := net.Listen(cmn.ProtocolAndAddress(sc.addr))
if err != nil {
return err
}
continue RETRY_LOOP
}
sc.listener = netutil.LimitListener(ln, defaultSignersMax)
return nil
}
// waitConnection uses the configured wait timeout to error if no external
// process connects in the time period.
func (sc *SocketClient) waitConnection() (net.Conn, error) {
var (
connc = make(chan net.Conn, 1)
errc = make(chan error, 1)
)
go func(connc chan<- net.Conn, errc chan<- error) {
conn, err := sc.acceptConnection()
if err != nil {
errc <- err
return
}
connc <- conn
}(connc, errc)
select {
case conn := <-connc:
return conn, nil
case err := <-errc:
return nil, err
case <-time.After(sc.connWaitTimeout):
return nil, ErrConnWaitTimeout
}
return nil, ErrDialRetryMax
}
//---------------------------------------------------------
// PrivValidatorSocketServer implements PrivValidator.
// RemoteSignerOption sets an optional parameter on the RemoteSigner.
type RemoteSignerOption func(*RemoteSigner)
// RemoteSignerConnDeadline sets the read and write deadline for connections
// from external signing processes.
func RemoteSignerConnDeadline(deadline time.Duration) RemoteSignerOption {
return func(ss *RemoteSigner) { ss.connDeadline = deadline }
}
// RemoteSignerConnRetries sets the amount of attempted retries to connect.
func RemoteSignerConnRetries(retries int) RemoteSignerOption {
return func(ss *RemoteSigner) { ss.connRetries = retries }
}
// RemoteSigner implements PrivValidator.
// It responds to requests over a socket
type PrivValidatorSocketServer struct {
type RemoteSigner struct {
cmn.BaseService
proto, addr string
listener net.Listener
maxConnections int
privKey *crypto.PrivKeyEd25519
addr string
chainID string
connDeadline time.Duration
connRetries int
privKey *crypto.PrivKeyEd25519
privVal PrivValidator
privVal PrivValidator
chainID string
conn net.Conn
}
// NewPrivValidatorSocketServer returns an instance of
// PrivValidatorSocketServer.
func NewPrivValidatorSocketServer(
// NewRemoteSigner returns an instance of
// RemoteSigner.
func NewRemoteSigner(
logger log.Logger,
chainID, socketAddr string,
maxConnections int,
privVal PrivValidator,
privKey *crypto.PrivKeyEd25519,
) *PrivValidatorSocketServer {
proto, addr := cmn.ProtocolAndAddress(socketAddr)
pvss := &PrivValidatorSocketServer{
proto: proto,
addr: addr,
maxConnections: maxConnections,
privKey: privKey,
privVal: privVal,
chainID: chainID,
) *RemoteSigner {
rs := &RemoteSigner{
addr: socketAddr,
chainID: chainID,
connDeadline: time.Second * defaultConnDeadlineSeconds,
connRetries: defaultDialRetries,
privKey: privKey,
privVal: privVal,
}
pvss.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketServer", pvss)
return pvss
rs.BaseService = *cmn.NewBaseService(logger, "RemoteSigner", rs)
return rs
}
// OnStart implements cmn.Service.
func (pvss *PrivValidatorSocketServer) OnStart() error {
ln, err := net.Listen(pvss.proto, pvss.addr)
func (rs *RemoteSigner) OnStart() error {
conn, err := rs.connect()
if err != nil {
rs.Logger.Error("OnStart", "err", errors.Wrap(err, "connect"))
return err
}
pvss.listener = netutil.LimitListener(ln, pvss.maxConnections)
go pvss.acceptConnections()
go rs.handleConnection(conn)
return nil
}
// OnStop implements cmn.Service.
func (pvss *PrivValidatorSocketServer) OnStop() {
if pvss.listener == nil {
func (rs *RemoteSigner) OnStop() {
if rs.conn == nil {
return
}
if err := pvss.listener.Close(); err != nil {
pvss.Logger.Error("OnStop", "err", errors.Wrap(err, "closing listener failed"))
if err := rs.conn.Close(); err != nil {
rs.Logger.Error("OnStop", "err", errors.Wrap(err, "closing listener failed"))
}
}
func (pvss *PrivValidatorSocketServer) acceptConnections() {
for {
conn, err := pvss.listener.Accept()
func (rs *RemoteSigner) connect() (net.Conn, error) {
retries := defaultDialRetries
RETRY_LOOP:
for retries > 0 {
// Don't sleep if it is the first retry.
if retries != defaultDialRetries {
time.Sleep(rs.connDeadline)
}
retries--
conn, err := cmn.Connect(rs.addr)
if err != nil {
if !pvss.IsRunning() {
return // Ignore error from listener closing.
}
pvss.Logger.Error(
"acceptConnections",
"err", errors.Wrap(err, "failed to accept connection"),
rs.Logger.Error(
"connect",
"addr", rs.addr,
"err", errors.Wrap(err, "connection failed"),
)
continue
continue RETRY_LOOP
}
if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil {
pvss.Logger.Error(
"acceptConnetions",
rs.Logger.Error(
"connect",
"err", errors.Wrap(err, "setting connection timeout failed"),
)
continue
}
if pvss.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, pvss.privKey.Wrap())
if rs.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey.Wrap())
if err != nil {
pvss.Logger.Error(
"acceptConnections",
"err", errors.Wrap(err, "secret connection failed"),
rs.Logger.Error(
"sc connect",
"err", errors.Wrap(err, "encrypting connection failed"),
)
continue
continue RETRY_LOOP
}
}
go pvss.handleConnection(conn)
return conn, nil
}
}
func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) {
defer conn.Close()
return nil, ErrDialRetryMax
}
func (rs *RemoteSigner) handleConnection(conn net.Conn) {
for {
if !pvss.IsRunning() {
if !rs.IsRunning() {
return // Ignore error from listener closing.
}
req, err := readMsg(conn)
if err != nil {
if err != io.EOF {
pvss.Logger.Error("handleConnection", "err", err)
rs.Logger.Error("handleConnection", "err", err)
}
return
}
@ -365,29 +450,29 @@ func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) {
case *PubKeyMsg:
var p crypto.PubKey
p, err = pvss.privVal.PubKey()
p, err = rs.privVal.PubKey()
res = &PubKeyMsg{p}
case *SignVoteMsg:
err = pvss.privVal.SignVote(pvss.chainID, r.Vote)
err = rs.privVal.SignVote(rs.chainID, r.Vote)
res = &SignVoteMsg{r.Vote}
case *SignProposalMsg:
err = pvss.privVal.SignProposal(pvss.chainID, r.Proposal)
err = rs.privVal.SignProposal(rs.chainID, r.Proposal)
res = &SignProposalMsg{r.Proposal}
case *SignHeartbeatMsg:
err = pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat)
err = rs.privVal.SignHeartbeat(rs.chainID, r.Heartbeat)
res = &SignHeartbeatMsg{r.Heartbeat}
default:
err = fmt.Errorf("unknown msg: %v", r)
}
if err != nil {
pvss.Logger.Error("handleConnection", "err", err)
rs.Logger.Error("handleConnection", "err", err)
return
}
err = writeMsg(conn, res)
if err != nil {
pvss.Logger.Error("handleConnection", "err", err)
rs.Logger.Error("handleConnection", "err", err)
return
}
}
@ -442,6 +527,10 @@ func readMsg(r io.Reader) (PrivValidatorSocketMsg, error) {
read := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, &n, &err)
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
return nil, errors.Wrapf(ErrConnTimeout, opErr.Addr.String())
}
return nil, err
}
@ -461,6 +550,9 @@ func writeMsg(w io.Writer, msg interface{}) error {
// TODO(xla): This extra wrap should be gone with the sdk-2 update.
wire.WriteBinary(struct{ PrivValidatorSocketMsg }{msg}, w, &n, &err)
if opErr, ok := err.(*net.OpError); ok {
return errors.Wrapf(ErrConnTimeout, opErr.Addr.String())
}
return err
}

+ 107
- 45
types/priv_validator/socket_test.go View File

@ -4,10 +4,12 @@ import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
crypto "github.com/tendermint/go-crypto"
cmn "github.com/tendermint/tmlibs/common"
"github.com/tendermint/tmlibs/log"
"github.com/tendermint/tendermint/types"
@ -16,13 +18,13 @@ import (
func TestSocketClientAddress(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
serverAddr, err := pvss.privVal.Address()
serverAddr, err := rs.privVal.Address()
require.NoError(err)
clientAddr, err := sc.Address()
@ -38,16 +40,16 @@ func TestSocketClientAddress(t *testing.T) {
func TestSocketClientPubKey(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
clientKey, err := sc.PubKey()
require.NoError(err)
privKey, err := pvss.privVal.PubKey()
privKey, err := rs.privVal.PubKey()
require.NoError(err)
assert.Equal(privKey, clientKey)
@ -59,17 +61,17 @@ func TestSocketClientPubKey(t *testing.T) {
func TestSocketClientProposal(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
ts = time.Now()
privProposal = &types.Proposal{Timestamp: ts}
clientProposal = &types.Proposal{Timestamp: ts}
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
require.NoError(pvss.privVal.SignProposal(chainID, privProposal))
require.NoError(rs.privVal.SignProposal(chainID, privProposal))
require.NoError(sc.SignProposal(chainID, clientProposal))
assert.Equal(privProposal.Signature, clientProposal.Signature)
}
@ -77,8 +79,8 @@ func TestSocketClientProposal(t *testing.T) {
func TestSocketClientVote(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
ts = time.Now()
vType = types.VoteTypePrecommit
@ -86,9 +88,9 @@ func TestSocketClientVote(t *testing.T) {
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
require.NoError(pvss.privVal.SignVote(chainID, want))
require.NoError(rs.privVal.SignVote(chainID, want))
require.NoError(sc.SignVote(chainID, have))
assert.Equal(want.Signature, have.Signature)
}
@ -96,69 +98,129 @@ func TestSocketClientVote(t *testing.T) {
func TestSocketClientHeartbeat(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
want = &types.Heartbeat{}
have = &types.Heartbeat{}
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
require.NoError(pvss.privVal.SignHeartbeat(chainID, want))
require.NoError(rs.privVal.SignHeartbeat(chainID, want))
require.NoError(sc.SignHeartbeat(chainID, have))
assert.Equal(want.Signature, have.Signature)
}
func TestSocketClientConnectRetryMax(t *testing.T) {
func TestSocketClientDeadline(t *testing.T) {
var (
assert, _ = assert.New(t), require.New(t)
logger = log.TestingLogger()
clientPrivKey = crypto.GenPrivKeyEd25519()
sc = NewSocketClient(
assert, require = assert.New(t), require.New(t)
readyc = make(chan struct{})
sc = NewSocketClient(
log.TestingLogger(),
"127.0.0.1:0",
nil,
)
)
defer sc.Stop()
SocketClientConnDeadline(time.Millisecond)(sc)
require.NoError(sc.listen())
go func(sc *SocketClient) {
require.NoError(sc.Start())
assert.True(sc.IsRunning())
readyc <- struct{}{}
}(sc)
_, err := cmn.Connect(sc.listener.Addr().String())
require.NoError(err)
<-readyc
_, err = sc.PubKey()
assert.Equal(errors.Cause(err), ErrConnTimeout)
}
func TestSocketClientWait(t *testing.T) {
var (
assert, _ = assert.New(t), require.New(t)
logger = log.TestingLogger()
privKey = crypto.GenPrivKeyEd25519()
sc = NewSocketClient(
logger,
"127.0.0.1:0",
&clientPrivKey,
&privKey,
)
)
defer sc.Stop()
SocketClientTimeout(time.Millisecond)(sc)
SocketClientConnWait(time.Millisecond)(sc)
assert.EqualError(sc.Start(), ErrDialRetryMax.Error())
assert.EqualError(sc.Start(), ErrConnWaitTimeout.Error())
}
func testSetupSocketPair(t *testing.T, chainID string) (*SocketClient, *PrivValidatorSocketServer) {
func TestRemoteSignerRetry(t *testing.T) {
var (
assert, _ = assert.New(t), require.New(t)
privKey = crypto.GenPrivKeyEd25519()
rs = NewRemoteSigner(
log.TestingLogger(),
cmn.RandStr(12),
"127.0.0.1:0",
NewTestPrivValidator(types.GenSigner()),
&privKey,
)
)
defer rs.Stop()
RemoteSignerConnDeadline(time.Millisecond)(rs)
RemoteSignerConnRetries(2)(rs)
assert.EqualError(rs.Start(), ErrDialRetryMax.Error())
}
func testSetupSocketPair(
t *testing.T,
chainID string,
) (*SocketClient, *RemoteSigner) {
var (
assert, require = assert.New(t), require.New(t)
logger = log.TestingLogger()
signer = types.GenSigner()
clientPrivKey = crypto.GenPrivKeyEd25519()
serverPrivKey = crypto.GenPrivKeyEd25519()
remotePrivKey = crypto.GenPrivKeyEd25519()
privVal = NewTestPrivValidator(signer)
pvss = NewPrivValidatorSocketServer(
readyc = make(chan struct{})
sc = NewSocketClient(
logger,
chainID,
"127.0.0.1:0",
1,
privVal,
&serverPrivKey,
&clientPrivKey,
)
)
err := pvss.Start()
require.NoError(err)
assert.True(pvss.IsRunning())
require.NoError(sc.listen())
sc := NewSocketClient(
go func(sc *SocketClient) {
require.NoError(sc.Start())
assert.True(sc.IsRunning())
readyc <- struct{}{}
}(sc)
rs := NewRemoteSigner(
logger,
pvss.listener.Addr().String(),
&clientPrivKey,
chainID,
sc.listener.Addr().String(),
privVal,
&remotePrivKey,
)
require.NoError(rs.Start())
assert.True(rs.IsRunning())
err = sc.Start()
require.NoError(err)
assert.True(sc.IsRunning())
<-readyc
return sc, pvss
return sc, rs
}

Loading…
Cancel
Save