diff --git a/types/priv_validator/socket.go b/types/priv_validator/socket.go index c22e2698d..35cfe4efd 100644 --- a/types/priv_validator/socket.go +++ b/types/priv_validator/socket.go @@ -12,8 +12,8 @@ import ( "github.com/tendermint/go-wire/data" cmn "github.com/tendermint/tmlibs/common" "github.com/tendermint/tmlibs/log" - "golang.org/x/net/netutil" + "github.com/tendermint/tendermint/p2p" "github.com/tendermint/tendermint/types" ) @@ -26,7 +26,8 @@ var _ types.PrivValidator = (*PrivValidatorSocketClient)(nil) type PrivValidatorSocketClient struct { cmn.BaseService - conn net.Conn + conn net.Conn + privKey *crypto.PrivKeyEd25519 ID types.ValidatorID SocketAddress string @@ -38,11 +39,18 @@ const ( // NewPrivValidatorSocketClient returns an instance of // PrivValidatorSocketClient. -func NewPrivValidatorSocketClient(logger log.Logger, socketAddr string) *PrivValidatorSocketClient { +func NewPrivValidatorSocketClient( + logger log.Logger, + socketAddr string, + privKey *crypto.PrivKeyEd25519, +) *PrivValidatorSocketClient { pvsc := &PrivValidatorSocketClient{ SocketAddress: socketAddr, + privKey: privKey, } + pvsc.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketClient", pvsc) + return pvsc } @@ -54,6 +62,7 @@ func (pvsc *PrivValidatorSocketClient) OnStart() error { var err error var conn net.Conn + RETRY_LOOP: for { conn, err = cmn.Connect(pvsc.SocketAddress) @@ -62,6 +71,15 @@ RETRY_LOOP: time.Sleep(time.Second * dialRetryIntervalSeconds) continue RETRY_LOOP } + + if pvsc.privKey != nil { + conn, err = p2p.MakeSecretConnection(conn, *pvsc.privKey) + if err != nil { + pvsc.Logger.Error("failed to encrypt connection: " + err.Error()) + continue RETRY_LOOP + } + } + pvsc.conn = conn return nil } @@ -84,40 +102,67 @@ func (pvsc *PrivValidatorSocketClient) Address() data.Bytes { // PubKey implements PrivValidator. func (pvsc *PrivValidatorSocketClient) PubKey() crypto.PubKey { - res, err := readWrite(pvsc.conn, &PubKeyMsg{}) + err := writeMsg(pvsc.conn, &PubKeyMsg{}) if err != nil { panic(err) } + + res, err := readMsg(pvsc.conn) + if err != nil { + panic(err) + } + return res.(*PubKeyMsg).PubKey } // SignVote implements PrivValidator. func (pvsc *PrivValidatorSocketClient) SignVote(chainID string, vote *types.Vote) error { - res, err := readWrite(pvsc.conn, &SignVoteMsg{Vote: vote}) + err := writeMsg(pvsc.conn, &SignVoteMsg{Vote: vote}) if err != nil { return err } + + res, err := readMsg(pvsc.conn) + if err != nil { + return err + } + *vote = *res.(SignVoteMsg).Vote + return nil } // SignProposal implements PrivValidator. func (pvsc *PrivValidatorSocketClient) SignProposal(chainID string, proposal *types.Proposal) error { - res, err := readWrite(pvsc.conn, &SignProposalMsg{Proposal: proposal}) + err := writeMsg(pvsc.conn, &SignProposalMsg{Proposal: proposal}) if err != nil { return err } + + res, err := readMsg(pvsc.conn) + if err != nil { + return err + } + *proposal = *res.(SignProposalMsg).Proposal + return nil } // SignHeartbeat implements PrivValidator. func (pvsc *PrivValidatorSocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat) error { - res, err := readWrite(pvsc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat}) + err := writeMsg(pvsc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat}) if err != nil { return err } + + res, err := readMsg(pvsc.conn) + if err != nil { + return err + } + *heartbeat = *res.(SignHeartbeatMsg).Heartbeat + return nil } @@ -131,7 +176,7 @@ type PrivValidatorSocketServer struct { proto, addr string listener net.Listener maxConnections int - privKey crypto.PrivKeyEd25519 + privKey *crypto.PrivKeyEd25519 privVal PrivValidator chainID string @@ -143,7 +188,7 @@ func NewPrivValidatorSocketServer( logger log.Logger, socketAddr, chainID string, privVal PrivValidator, - privKey crypto.PrivKeyEd25519, + privKey *crypto.PrivKeyEd25519, maxConnections int, ) *PrivValidatorSocketServer { proto, addr := cmn.ProtocolAndAddress(socketAddr) @@ -166,7 +211,8 @@ func (pvss *PrivValidatorSocketServer) OnStart() error { return err } - pvss.listener = netutil.LimitListener(ln, pvss.maxConnections) + // pvss.listener = netutil.LimitListener(ln, pvss.maxConnections) + pvss.listener = ln go pvss.acceptConnectionsRoutine() @@ -195,6 +241,19 @@ func (pvss *PrivValidatorSocketServer) acceptConnectionsRoutine() { continue } + if err := conn.SetDeadline(time.Now().Add(time.Second)); err != nil { + pvss.Logger.Error("failed to set timeout for ocnnection: " + err.Error()) + continue + } + + if pvss.privKey != nil { + conn, err = p2p.MakeSecretConnection(conn, *pvss.privKey) + if err != nil { + pvss.Logger.Error("Failed to make secret connection: " + err.Error()) + continue + } + } + go pvss.handleConnection(conn) } } @@ -209,7 +268,9 @@ func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) { req, err := readMsg(conn) if err != nil { - pvss.Logger.Error("readMsg", "err", err) + if err != io.EOF { + pvss.Logger.Error("readMsg", "err", err) + } return } @@ -219,22 +280,38 @@ func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) { case *PubKeyMsg: res = &PubKeyMsg{pvss.privVal.PubKey()} case *SignVoteMsg: - pvss.privVal.SignVote(pvss.chainID, r.Vote) + err := pvss.privVal.SignVote(pvss.chainID, r.Vote) + if err != nil { + pvss.Logger.Error("handleConnection", "err", err) + return + } + res = &SignVoteMsg{r.Vote} case *SignProposalMsg: - pvss.privVal.SignProposal(pvss.chainID, r.Proposal) + err := pvss.privVal.SignProposal(pvss.chainID, r.Proposal) + if err != nil { + pvss.Logger.Error("handleConnection", "err", err) + return + } + res = &SignProposalMsg{r.Proposal} case *SignHeartbeatMsg: - pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat) + err := pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat) + if err != nil { + pvss.Logger.Error("handleConnection", "err", err) + return + } + res = &SignHeartbeatMsg{r.Heartbeat} default: - panic(fmt.Sprintf("unknown msg: %v", r)) + pvss.Logger.Error("handleConnection", "err", fmt.Sprintf("unknown msg: %v", r)) + return } - b := wire.BinaryBytes(res) - _, err = conn.Write(b) + err = writeMsg(conn, res) if err != nil { - panic(err) + pvss.Logger.Error("handleConnection", "err", err) + return } } } @@ -279,15 +356,15 @@ func readMsg(r io.Reader) (PrivValidatorSocketMsg, error) { return w.PrivValidatorSocketMsg, nil } -func readWrite(conn net.Conn, req PrivValidatorSocketMsg) (res PrivValidatorSocketMsg, err error) { - b := wire.BinaryBytes(req) +func writeMsg(w io.Writer, msg interface{}) error { + var ( + err error + n int + ) - _, err = conn.Write(b) - if err != nil { - return nil, err - } + wire.WriteBinary(struct{ PrivValidatorSocketMsg }{msg}, w, &n, &err) - return readMsg(conn) + return err } func decodeMsg(bz []byte) (msg PrivValidatorSocketMsg, err error) { diff --git a/types/priv_validator/socket_test.go b/types/priv_validator/socket_test.go index 971e5308b..16cc59c5f 100644 --- a/types/priv_validator/socket_test.go +++ b/types/priv_validator/socket_test.go @@ -4,6 +4,9 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + crypto "github.com/tendermint/go-crypto" "github.com/tendermint/tendermint/types" "github.com/tendermint/tmlibs/log" @@ -11,34 +14,42 @@ import ( func TestPrivValidatorSocketServer(t *testing.T) { var ( - chainID = "test-chain" - logger = log.TestingLogger() - signer = types.GenSigner() - privKey = crypto.GenPrivKeyEd25519() - privVal = NewTestPrivValidator(signer) - pvss = NewPrivValidatorSocketServer( + assert, require = assert.New(t), require.New(t) + chainID = "test-chain" + logger = log.TestingLogger() + signer = types.GenSigner() + clientPrivKey = crypto.GenPrivKeyEd25519() + serverPrivKey = crypto.GenPrivKeyEd25519() + privVal = NewTestPrivValidator(signer) + pvss = NewPrivValidatorSocketServer( logger, "127.0.0.1:0", chainID, privVal, - privKey, + &serverPrivKey, 1, ) ) err := pvss.Start() - if err != nil { - t.Fatal(err) - } + require.Nil(err) + defer pvss.Stop() - c := NewPrivValidatorSocketClient(logger, pvss.listener.Addr().String()) + assert.True(pvss.IsRunning()) - err = c.Start() - if err != nil { - t.Fatal(err) - } + pvsc := NewPrivValidatorSocketClient( + logger, + pvss.listener.Addr().String(), + &clientPrivKey, + ) + + err = pvsc.Start() + require.Nil(err) + defer pvsc.Stop() + + assert.True(pvsc.IsRunning()) - if have, want := c.PubKey(), pvss.privVal.PubKey(); !reflect.DeepEqual(have, want) { + if have, want := pvsc.PubKey(), pvss.privVal.PubKey(); !reflect.DeepEqual(have, want) { t.Errorf("have %v, want %v", have, want) } }