Browse Source

Use secret connection

pull/1204/head
Alexander Simmerl 7 years ago
committed by Ethan Buchman
parent
commit
18f7e52562
2 changed files with 129 additions and 41 deletions
  1. +102
    -25
      types/priv_validator/socket.go
  2. +27
    -16
      types/priv_validator/socket_test.go

+ 102
- 25
types/priv_validator/socket.go View File

@ -12,8 +12,8 @@ import (
"github.com/tendermint/go-wire/data"
cmn "github.com/tendermint/tmlibs/common"
"github.com/tendermint/tmlibs/log"
"golang.org/x/net/netutil"
"github.com/tendermint/tendermint/p2p"
"github.com/tendermint/tendermint/types"
)
@ -26,7 +26,8 @@ var _ types.PrivValidator = (*PrivValidatorSocketClient)(nil)
type PrivValidatorSocketClient struct {
cmn.BaseService
conn net.Conn
conn net.Conn
privKey *crypto.PrivKeyEd25519
ID types.ValidatorID
SocketAddress string
@ -38,11 +39,18 @@ const (
// NewPrivValidatorSocketClient returns an instance of
// PrivValidatorSocketClient.
func NewPrivValidatorSocketClient(logger log.Logger, socketAddr string) *PrivValidatorSocketClient {
func NewPrivValidatorSocketClient(
logger log.Logger,
socketAddr string,
privKey *crypto.PrivKeyEd25519,
) *PrivValidatorSocketClient {
pvsc := &PrivValidatorSocketClient{
SocketAddress: socketAddr,
privKey: privKey,
}
pvsc.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketClient", pvsc)
return pvsc
}
@ -54,6 +62,7 @@ func (pvsc *PrivValidatorSocketClient) OnStart() error {
var err error
var conn net.Conn
RETRY_LOOP:
for {
conn, err = cmn.Connect(pvsc.SocketAddress)
@ -62,6 +71,15 @@ RETRY_LOOP:
time.Sleep(time.Second * dialRetryIntervalSeconds)
continue RETRY_LOOP
}
if pvsc.privKey != nil {
conn, err = p2p.MakeSecretConnection(conn, *pvsc.privKey)
if err != nil {
pvsc.Logger.Error("failed to encrypt connection: " + err.Error())
continue RETRY_LOOP
}
}
pvsc.conn = conn
return nil
}
@ -84,40 +102,67 @@ func (pvsc *PrivValidatorSocketClient) Address() data.Bytes {
// PubKey implements PrivValidator.
func (pvsc *PrivValidatorSocketClient) PubKey() crypto.PubKey {
res, err := readWrite(pvsc.conn, &PubKeyMsg{})
err := writeMsg(pvsc.conn, &PubKeyMsg{})
if err != nil {
panic(err)
}
res, err := readMsg(pvsc.conn)
if err != nil {
panic(err)
}
return res.(*PubKeyMsg).PubKey
}
// SignVote implements PrivValidator.
func (pvsc *PrivValidatorSocketClient) SignVote(chainID string, vote *types.Vote) error {
res, err := readWrite(pvsc.conn, &SignVoteMsg{Vote: vote})
err := writeMsg(pvsc.conn, &SignVoteMsg{Vote: vote})
if err != nil {
return err
}
res, err := readMsg(pvsc.conn)
if err != nil {
return err
}
*vote = *res.(SignVoteMsg).Vote
return nil
}
// SignProposal implements PrivValidator.
func (pvsc *PrivValidatorSocketClient) SignProposal(chainID string, proposal *types.Proposal) error {
res, err := readWrite(pvsc.conn, &SignProposalMsg{Proposal: proposal})
err := writeMsg(pvsc.conn, &SignProposalMsg{Proposal: proposal})
if err != nil {
return err
}
res, err := readMsg(pvsc.conn)
if err != nil {
return err
}
*proposal = *res.(SignProposalMsg).Proposal
return nil
}
// SignHeartbeat implements PrivValidator.
func (pvsc *PrivValidatorSocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat) error {
res, err := readWrite(pvsc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat})
err := writeMsg(pvsc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat})
if err != nil {
return err
}
res, err := readMsg(pvsc.conn)
if err != nil {
return err
}
*heartbeat = *res.(SignHeartbeatMsg).Heartbeat
return nil
}
@ -131,7 +176,7 @@ type PrivValidatorSocketServer struct {
proto, addr string
listener net.Listener
maxConnections int
privKey crypto.PrivKeyEd25519
privKey *crypto.PrivKeyEd25519
privVal PrivValidator
chainID string
@ -143,7 +188,7 @@ func NewPrivValidatorSocketServer(
logger log.Logger,
socketAddr, chainID string,
privVal PrivValidator,
privKey crypto.PrivKeyEd25519,
privKey *crypto.PrivKeyEd25519,
maxConnections int,
) *PrivValidatorSocketServer {
proto, addr := cmn.ProtocolAndAddress(socketAddr)
@ -166,7 +211,8 @@ func (pvss *PrivValidatorSocketServer) OnStart() error {
return err
}
pvss.listener = netutil.LimitListener(ln, pvss.maxConnections)
// pvss.listener = netutil.LimitListener(ln, pvss.maxConnections)
pvss.listener = ln
go pvss.acceptConnectionsRoutine()
@ -195,6 +241,19 @@ func (pvss *PrivValidatorSocketServer) acceptConnectionsRoutine() {
continue
}
if err := conn.SetDeadline(time.Now().Add(time.Second)); err != nil {
pvss.Logger.Error("failed to set timeout for ocnnection: " + err.Error())
continue
}
if pvss.privKey != nil {
conn, err = p2p.MakeSecretConnection(conn, *pvss.privKey)
if err != nil {
pvss.Logger.Error("Failed to make secret connection: " + err.Error())
continue
}
}
go pvss.handleConnection(conn)
}
}
@ -209,7 +268,9 @@ func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) {
req, err := readMsg(conn)
if err != nil {
pvss.Logger.Error("readMsg", "err", err)
if err != io.EOF {
pvss.Logger.Error("readMsg", "err", err)
}
return
}
@ -219,22 +280,38 @@ func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) {
case *PubKeyMsg:
res = &PubKeyMsg{pvss.privVal.PubKey()}
case *SignVoteMsg:
pvss.privVal.SignVote(pvss.chainID, r.Vote)
err := pvss.privVal.SignVote(pvss.chainID, r.Vote)
if err != nil {
pvss.Logger.Error("handleConnection", "err", err)
return
}
res = &SignVoteMsg{r.Vote}
case *SignProposalMsg:
pvss.privVal.SignProposal(pvss.chainID, r.Proposal)
err := pvss.privVal.SignProposal(pvss.chainID, r.Proposal)
if err != nil {
pvss.Logger.Error("handleConnection", "err", err)
return
}
res = &SignProposalMsg{r.Proposal}
case *SignHeartbeatMsg:
pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat)
err := pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat)
if err != nil {
pvss.Logger.Error("handleConnection", "err", err)
return
}
res = &SignHeartbeatMsg{r.Heartbeat}
default:
panic(fmt.Sprintf("unknown msg: %v", r))
pvss.Logger.Error("handleConnection", "err", fmt.Sprintf("unknown msg: %v", r))
return
}
b := wire.BinaryBytes(res)
_, err = conn.Write(b)
err = writeMsg(conn, res)
if err != nil {
panic(err)
pvss.Logger.Error("handleConnection", "err", err)
return
}
}
}
@ -279,15 +356,15 @@ func readMsg(r io.Reader) (PrivValidatorSocketMsg, error) {
return w.PrivValidatorSocketMsg, nil
}
func readWrite(conn net.Conn, req PrivValidatorSocketMsg) (res PrivValidatorSocketMsg, err error) {
b := wire.BinaryBytes(req)
func writeMsg(w io.Writer, msg interface{}) error {
var (
err error
n int
)
_, err = conn.Write(b)
if err != nil {
return nil, err
}
wire.WriteBinary(struct{ PrivValidatorSocketMsg }{msg}, w, &n, &err)
return readMsg(conn)
return err
}
func decodeMsg(bz []byte) (msg PrivValidatorSocketMsg, err error) {


+ 27
- 16
types/priv_validator/socket_test.go View File

@ -4,6 +4,9 @@ import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
crypto "github.com/tendermint/go-crypto"
"github.com/tendermint/tendermint/types"
"github.com/tendermint/tmlibs/log"
@ -11,34 +14,42 @@ import (
func TestPrivValidatorSocketServer(t *testing.T) {
var (
chainID = "test-chain"
logger = log.TestingLogger()
signer = types.GenSigner()
privKey = crypto.GenPrivKeyEd25519()
privVal = NewTestPrivValidator(signer)
pvss = NewPrivValidatorSocketServer(
assert, require = assert.New(t), require.New(t)
chainID = "test-chain"
logger = log.TestingLogger()
signer = types.GenSigner()
clientPrivKey = crypto.GenPrivKeyEd25519()
serverPrivKey = crypto.GenPrivKeyEd25519()
privVal = NewTestPrivValidator(signer)
pvss = NewPrivValidatorSocketServer(
logger,
"127.0.0.1:0",
chainID,
privVal,
privKey,
&serverPrivKey,
1,
)
)
err := pvss.Start()
if err != nil {
t.Fatal(err)
}
require.Nil(err)
defer pvss.Stop()
c := NewPrivValidatorSocketClient(logger, pvss.listener.Addr().String())
assert.True(pvss.IsRunning())
err = c.Start()
if err != nil {
t.Fatal(err)
}
pvsc := NewPrivValidatorSocketClient(
logger,
pvss.listener.Addr().String(),
&clientPrivKey,
)
err = pvsc.Start()
require.Nil(err)
defer pvsc.Stop()
assert.True(pvsc.IsRunning())
if have, want := c.PubKey(), pvss.privVal.PubKey(); !reflect.DeepEqual(have, want) {
if have, want := pvsc.PubKey(), pvss.privVal.PubKey(); !reflect.DeepEqual(have, want) {
t.Errorf("have %v, want %v", have, want)
}
}

Loading…
Cancel
Save