diff --git a/cmd/priv_val_server/main.go b/cmd/priv_val_server/main.go index 651707e83..b494a442c 100644 --- a/cmd/priv_val_server/main.go +++ b/cmd/priv_val_server/main.go @@ -2,7 +2,6 @@ package main import ( "flag" - "fmt" "os" cmn "github.com/tendermint/tmlibs/common" @@ -11,20 +10,27 @@ import ( priv_val "github.com/tendermint/tendermint/types/priv_validator" ) -var chainID = flag.String("chain-id", "mychain", "chain id") -var privValPath = flag.String("priv", "", "priv val file path") - func main() { + var ( + chainID = flag.String("chain-id", "mychain", "chain id") + privValPath = flag.String("priv", "", "priv val file path") + socketAddr = flag.String("socket.addr", ":46659", "socket bind addr") + + logger = log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("module", "priv_val") + ) flag.Parse() - logger := log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("module", "priv_val") - socketAddr := "localhost:46659" - fmt.Println(*chainID) - fmt.Println(*privValPath) + logger.Info("args", "chainID", *chainID, "privPath", *privValPath) + privVal := priv_val.LoadPrivValidatorJSON(*privValPath) - pvss := priv_val.NewPrivValidatorSocketServer(logger, socketAddr, *chainID, privVal) - pvss.Start() + pvss := priv_val.NewPrivValidatorSocketServer( + logger, + *socketAddr, + *chainID, + privVal, + ) + // pvss.Start() cmn.TrapSignal(func() { pvss.Stop() diff --git a/types/priv_validator/socket.go b/types/priv_validator/socket.go index 2ed4c46f2..c22e2698d 100644 --- a/types/priv_validator/socket.go +++ b/types/priv_validator/socket.go @@ -3,6 +3,7 @@ package types import ( "bytes" "fmt" + "io" "net" "time" @@ -11,6 +12,7 @@ import ( "github.com/tendermint/go-wire/data" cmn "github.com/tendermint/tmlibs/common" "github.com/tendermint/tmlibs/log" + "golang.org/x/net/netutil" "github.com/tendermint/tendermint/types" ) @@ -82,16 +84,16 @@ func (pvsc *PrivValidatorSocketClient) Address() data.Bytes { // PubKey implements PrivValidator. func (pvsc *PrivValidatorSocketClient) PubKey() crypto.PubKey { - res, err := readWrite(pvsc.conn, PubKeyMsg{}) + res, err := readWrite(pvsc.conn, &PubKeyMsg{}) if err != nil { panic(err) } - return res.(PubKeyMsg).PubKey + return res.(*PubKeyMsg).PubKey } // SignVote implements PrivValidator. func (pvsc *PrivValidatorSocketClient) SignVote(chainID string, vote *types.Vote) error { - res, err := readWrite(pvsc.conn, SignVoteMsg{Vote: vote}) + res, err := readWrite(pvsc.conn, &SignVoteMsg{Vote: vote}) if err != nil { return err } @@ -101,7 +103,7 @@ func (pvsc *PrivValidatorSocketClient) SignVote(chainID string, vote *types.Vote // SignProposal implements PrivValidator. func (pvsc *PrivValidatorSocketClient) SignProposal(chainID string, proposal *types.Proposal) error { - res, err := readWrite(pvsc.conn, SignProposalMsg{Proposal: proposal}) + res, err := readWrite(pvsc.conn, &SignProposalMsg{Proposal: proposal}) if err != nil { return err } @@ -111,7 +113,7 @@ func (pvsc *PrivValidatorSocketClient) SignProposal(chainID string, proposal *ty // SignHeartbeat implements PrivValidator. func (pvsc *PrivValidatorSocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat) error { - res, err := readWrite(pvsc.conn, SignHeartbeatMsg{Heartbeat: heartbeat}) + res, err := readWrite(pvsc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat}) if err != nil { return err } @@ -126,8 +128,10 @@ func (pvsc *PrivValidatorSocketClient) SignHeartbeat(chainID string, heartbeat * type PrivValidatorSocketServer struct { cmn.BaseService - proto, addr string - listener net.Listener + proto, addr string + listener net.Listener + maxConnections int + privKey crypto.PrivKeyEd25519 privVal PrivValidator chainID string @@ -135,13 +139,21 @@ type PrivValidatorSocketServer struct { // NewPrivValidatorSocketServer returns an instance of // PrivValidatorSocketServer. -func NewPrivValidatorSocketServer(logger log.Logger, socketAddr, chainID string, privVal PrivValidator) *PrivValidatorSocketServer { +func NewPrivValidatorSocketServer( + logger log.Logger, + socketAddr, chainID string, + privVal PrivValidator, + privKey crypto.PrivKeyEd25519, + maxConnections int, +) *PrivValidatorSocketServer { proto, addr := cmn.ProtocolAndAddress(socketAddr) pvss := &PrivValidatorSocketServer{ - proto: proto, - addr: addr, - privVal: privVal, - chainID: chainID, + proto: proto, + addr: addr, + maxConnections: maxConnections, + privKey: privKey, + privVal: privVal, + chainID: chainID, } pvss.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketServer", pvss) return pvss @@ -149,22 +161,20 @@ func NewPrivValidatorSocketServer(logger log.Logger, socketAddr, chainID string, // OnStart implements cmn.Service. func (pvss *PrivValidatorSocketServer) OnStart() error { - if err := pvss.BaseService.OnStart(); err != nil { - return err - } ln, err := net.Listen(pvss.proto, pvss.addr) if err != nil { return err } - pvss.listener = ln + + pvss.listener = netutil.LimitListener(ln, pvss.maxConnections) + go pvss.acceptConnectionsRoutine() + return nil } // OnStop implements cmn.Service. func (pvss *PrivValidatorSocketServer) OnStop() { - pvss.BaseService.OnStop() - if pvss.listener == nil { return } @@ -176,9 +186,6 @@ func (pvss *PrivValidatorSocketServer) OnStop() { func (pvss *PrivValidatorSocketServer) acceptConnectionsRoutine() { for { - // Accept a connection - pvss.Logger.Info("Waiting for new connection...") - conn, err := pvss.listener.Accept() if err != nil { if !pvss.IsRunning() { @@ -188,47 +195,46 @@ func (pvss *PrivValidatorSocketServer) acceptConnectionsRoutine() { continue } - pvss.Logger.Info("Accepted a new connection") + go pvss.handleConnection(conn) + } +} - // read/write - for { - if !pvss.IsRunning() { - return // Ignore error from listener closing. - } +func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) { + defer conn.Close() - var n int - var err error - b := wire.ReadByteSlice(conn, 0, &n, &err) //XXX: no max - if err != nil { - panic(err) - } + for { + if !pvss.IsRunning() { + return // Ignore error from listener closing. + } - req, err := decodeMsg(b) - if err != nil { - panic(err) - } - var res PrivValidatorSocketMsg - switch r := req.(type) { - case PubKeyMsg: - res = PubKeyMsg{pvss.privVal.PubKey()} - case SignVoteMsg: - pvss.privVal.SignVote(pvss.chainID, r.Vote) - res = SignVoteMsg{r.Vote} - case SignProposalMsg: - pvss.privVal.SignProposal(pvss.chainID, r.Proposal) - res = SignProposalMsg{r.Proposal} - case SignHeartbeatMsg: - pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat) - res = SignHeartbeatMsg{r.Heartbeat} - default: - panic(fmt.Sprintf("unknown msg: %v", r)) - } + req, err := readMsg(conn) + if err != nil { + pvss.Logger.Error("readMsg", "err", err) + return + } - b = wire.BinaryBytes(res) - _, err = conn.Write(b) - if err != nil { - panic(err) - } + var res PrivValidatorSocketMsg + + switch r := req.(type) { + case *PubKeyMsg: + res = &PubKeyMsg{pvss.privVal.PubKey()} + case *SignVoteMsg: + pvss.privVal.SignVote(pvss.chainID, r.Vote) + res = &SignVoteMsg{r.Vote} + case *SignProposalMsg: + pvss.privVal.SignProposal(pvss.chainID, r.Proposal) + res = &SignProposalMsg{r.Proposal} + case *SignHeartbeatMsg: + pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat) + res = &SignHeartbeatMsg{r.Heartbeat} + default: + panic(fmt.Sprintf("unknown msg: %v", r)) + } + + b := wire.BinaryBytes(res) + _, err = conn.Write(b) + if err != nil { + panic(err) } } } @@ -254,23 +260,47 @@ var _ = wire.RegisterInterface( wire.ConcreteType{&SignHeartbeatMsg{}, msgTypeSignHeartbeat}, ) +func readMsg(r io.Reader) (PrivValidatorSocketMsg, error) { + var ( + n int + err error + ) + + read := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, &n, &err) + if err != nil { + return nil, err + } + + w, ok := read.(struct{ PrivValidatorSocketMsg }) + if !ok { + return nil, fmt.Errorf("unknown type") + } + + return w.PrivValidatorSocketMsg, nil +} + func readWrite(conn net.Conn, req PrivValidatorSocketMsg) (res PrivValidatorSocketMsg, err error) { b := wire.BinaryBytes(req) + _, err = conn.Write(b) if err != nil { return nil, err } - var n int - b = wire.ReadByteSlice(conn, 0, &n, &err) //XXX: no max - return decodeMsg(b) + return readMsg(conn) } func decodeMsg(bz []byte) (msg PrivValidatorSocketMsg, err error) { - n := new(int) - r := bytes.NewReader(bz) - msgI := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, n, &err) + var ( + r = bytes.NewReader(bz) + + n int + ) + + msgI := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, &n, &err) + msg = msgI.(struct{ PrivValidatorSocketMsg }).PrivValidatorSocketMsg + return msg, err } diff --git a/types/priv_validator/socket_test.go b/types/priv_validator/socket_test.go new file mode 100644 index 000000000..971e5308b --- /dev/null +++ b/types/priv_validator/socket_test.go @@ -0,0 +1,44 @@ +package types + +import ( + "reflect" + "testing" + + crypto "github.com/tendermint/go-crypto" + "github.com/tendermint/tendermint/types" + "github.com/tendermint/tmlibs/log" +) + +func TestPrivValidatorSocketServer(t *testing.T) { + var ( + chainID = "test-chain" + logger = log.TestingLogger() + signer = types.GenSigner() + privKey = crypto.GenPrivKeyEd25519() + privVal = NewTestPrivValidator(signer) + pvss = NewPrivValidatorSocketServer( + logger, + "127.0.0.1:0", + chainID, + privVal, + privKey, + 1, + ) + ) + + err := pvss.Start() + if err != nil { + t.Fatal(err) + } + + c := NewPrivValidatorSocketClient(logger, pvss.listener.Addr().String()) + + err = c.Start() + if err != nil { + t.Fatal(err) + } + + if have, want := c.PubKey(), pvss.privVal.PubKey(); !reflect.DeepEqual(have, want) { + t.Errorf("have %v, want %v", have, want) + } +}