Browse Source

Correct server protocol

pull/1204/head
Alexander Simmerl 7 years ago
committed by Ethan Buchman
parent
commit
fec541373d
3 changed files with 154 additions and 74 deletions
  1. +16
    -10
      cmd/priv_val_server/main.go
  2. +94
    -64
      types/priv_validator/socket.go
  3. +44
    -0
      types/priv_validator/socket_test.go

+ 16
- 10
cmd/priv_val_server/main.go View File

@ -2,7 +2,6 @@ package main
import (
"flag"
"fmt"
"os"
cmn "github.com/tendermint/tmlibs/common"
@ -11,20 +10,27 @@ import (
priv_val "github.com/tendermint/tendermint/types/priv_validator"
)
var chainID = flag.String("chain-id", "mychain", "chain id")
var privValPath = flag.String("priv", "", "priv val file path")
func main() {
var (
chainID = flag.String("chain-id", "mychain", "chain id")
privValPath = flag.String("priv", "", "priv val file path")
socketAddr = flag.String("socket.addr", ":46659", "socket bind addr")
logger = log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("module", "priv_val")
)
flag.Parse()
logger := log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("module", "priv_val")
socketAddr := "localhost:46659"
fmt.Println(*chainID)
fmt.Println(*privValPath)
logger.Info("args", "chainID", *chainID, "privPath", *privValPath)
privVal := priv_val.LoadPrivValidatorJSON(*privValPath)
pvss := priv_val.NewPrivValidatorSocketServer(logger, socketAddr, *chainID, privVal)
pvss.Start()
pvss := priv_val.NewPrivValidatorSocketServer(
logger,
*socketAddr,
*chainID,
privVal,
)
// pvss.Start()
cmn.TrapSignal(func() {
pvss.Stop()


+ 94
- 64
types/priv_validator/socket.go View File

@ -3,6 +3,7 @@ package types
import (
"bytes"
"fmt"
"io"
"net"
"time"
@ -11,6 +12,7 @@ import (
"github.com/tendermint/go-wire/data"
cmn "github.com/tendermint/tmlibs/common"
"github.com/tendermint/tmlibs/log"
"golang.org/x/net/netutil"
"github.com/tendermint/tendermint/types"
)
@ -82,16 +84,16 @@ func (pvsc *PrivValidatorSocketClient) Address() data.Bytes {
// PubKey implements PrivValidator.
func (pvsc *PrivValidatorSocketClient) PubKey() crypto.PubKey {
res, err := readWrite(pvsc.conn, PubKeyMsg{})
res, err := readWrite(pvsc.conn, &PubKeyMsg{})
if err != nil {
panic(err)
}
return res.(PubKeyMsg).PubKey
return res.(*PubKeyMsg).PubKey
}
// SignVote implements PrivValidator.
func (pvsc *PrivValidatorSocketClient) SignVote(chainID string, vote *types.Vote) error {
res, err := readWrite(pvsc.conn, SignVoteMsg{Vote: vote})
res, err := readWrite(pvsc.conn, &SignVoteMsg{Vote: vote})
if err != nil {
return err
}
@ -101,7 +103,7 @@ func (pvsc *PrivValidatorSocketClient) SignVote(chainID string, vote *types.Vote
// SignProposal implements PrivValidator.
func (pvsc *PrivValidatorSocketClient) SignProposal(chainID string, proposal *types.Proposal) error {
res, err := readWrite(pvsc.conn, SignProposalMsg{Proposal: proposal})
res, err := readWrite(pvsc.conn, &SignProposalMsg{Proposal: proposal})
if err != nil {
return err
}
@ -111,7 +113,7 @@ func (pvsc *PrivValidatorSocketClient) SignProposal(chainID string, proposal *ty
// SignHeartbeat implements PrivValidator.
func (pvsc *PrivValidatorSocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat) error {
res, err := readWrite(pvsc.conn, SignHeartbeatMsg{Heartbeat: heartbeat})
res, err := readWrite(pvsc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat})
if err != nil {
return err
}
@ -126,8 +128,10 @@ func (pvsc *PrivValidatorSocketClient) SignHeartbeat(chainID string, heartbeat *
type PrivValidatorSocketServer struct {
cmn.BaseService
proto, addr string
listener net.Listener
proto, addr string
listener net.Listener
maxConnections int
privKey crypto.PrivKeyEd25519
privVal PrivValidator
chainID string
@ -135,13 +139,21 @@ type PrivValidatorSocketServer struct {
// NewPrivValidatorSocketServer returns an instance of
// PrivValidatorSocketServer.
func NewPrivValidatorSocketServer(logger log.Logger, socketAddr, chainID string, privVal PrivValidator) *PrivValidatorSocketServer {
func NewPrivValidatorSocketServer(
logger log.Logger,
socketAddr, chainID string,
privVal PrivValidator,
privKey crypto.PrivKeyEd25519,
maxConnections int,
) *PrivValidatorSocketServer {
proto, addr := cmn.ProtocolAndAddress(socketAddr)
pvss := &PrivValidatorSocketServer{
proto: proto,
addr: addr,
privVal: privVal,
chainID: chainID,
proto: proto,
addr: addr,
maxConnections: maxConnections,
privKey: privKey,
privVal: privVal,
chainID: chainID,
}
pvss.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketServer", pvss)
return pvss
@ -149,22 +161,20 @@ func NewPrivValidatorSocketServer(logger log.Logger, socketAddr, chainID string,
// OnStart implements cmn.Service.
func (pvss *PrivValidatorSocketServer) OnStart() error {
if err := pvss.BaseService.OnStart(); err != nil {
return err
}
ln, err := net.Listen(pvss.proto, pvss.addr)
if err != nil {
return err
}
pvss.listener = ln
pvss.listener = netutil.LimitListener(ln, pvss.maxConnections)
go pvss.acceptConnectionsRoutine()
return nil
}
// OnStop implements cmn.Service.
func (pvss *PrivValidatorSocketServer) OnStop() {
pvss.BaseService.OnStop()
if pvss.listener == nil {
return
}
@ -176,9 +186,6 @@ func (pvss *PrivValidatorSocketServer) OnStop() {
func (pvss *PrivValidatorSocketServer) acceptConnectionsRoutine() {
for {
// Accept a connection
pvss.Logger.Info("Waiting for new connection...")
conn, err := pvss.listener.Accept()
if err != nil {
if !pvss.IsRunning() {
@ -188,47 +195,46 @@ func (pvss *PrivValidatorSocketServer) acceptConnectionsRoutine() {
continue
}
pvss.Logger.Info("Accepted a new connection")
go pvss.handleConnection(conn)
}
}
// read/write
for {
if !pvss.IsRunning() {
return // Ignore error from listener closing.
}
func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) {
defer conn.Close()
var n int
var err error
b := wire.ReadByteSlice(conn, 0, &n, &err) //XXX: no max
if err != nil {
panic(err)
}
for {
if !pvss.IsRunning() {
return // Ignore error from listener closing.
}
req, err := decodeMsg(b)
if err != nil {
panic(err)
}
var res PrivValidatorSocketMsg
switch r := req.(type) {
case PubKeyMsg:
res = PubKeyMsg{pvss.privVal.PubKey()}
case SignVoteMsg:
pvss.privVal.SignVote(pvss.chainID, r.Vote)
res = SignVoteMsg{r.Vote}
case SignProposalMsg:
pvss.privVal.SignProposal(pvss.chainID, r.Proposal)
res = SignProposalMsg{r.Proposal}
case SignHeartbeatMsg:
pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat)
res = SignHeartbeatMsg{r.Heartbeat}
default:
panic(fmt.Sprintf("unknown msg: %v", r))
}
req, err := readMsg(conn)
if err != nil {
pvss.Logger.Error("readMsg", "err", err)
return
}
b = wire.BinaryBytes(res)
_, err = conn.Write(b)
if err != nil {
panic(err)
}
var res PrivValidatorSocketMsg
switch r := req.(type) {
case *PubKeyMsg:
res = &PubKeyMsg{pvss.privVal.PubKey()}
case *SignVoteMsg:
pvss.privVal.SignVote(pvss.chainID, r.Vote)
res = &SignVoteMsg{r.Vote}
case *SignProposalMsg:
pvss.privVal.SignProposal(pvss.chainID, r.Proposal)
res = &SignProposalMsg{r.Proposal}
case *SignHeartbeatMsg:
pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat)
res = &SignHeartbeatMsg{r.Heartbeat}
default:
panic(fmt.Sprintf("unknown msg: %v", r))
}
b := wire.BinaryBytes(res)
_, err = conn.Write(b)
if err != nil {
panic(err)
}
}
}
@ -254,23 +260,47 @@ var _ = wire.RegisterInterface(
wire.ConcreteType{&SignHeartbeatMsg{}, msgTypeSignHeartbeat},
)
func readMsg(r io.Reader) (PrivValidatorSocketMsg, error) {
var (
n int
err error
)
read := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, &n, &err)
if err != nil {
return nil, err
}
w, ok := read.(struct{ PrivValidatorSocketMsg })
if !ok {
return nil, fmt.Errorf("unknown type")
}
return w.PrivValidatorSocketMsg, nil
}
func readWrite(conn net.Conn, req PrivValidatorSocketMsg) (res PrivValidatorSocketMsg, err error) {
b := wire.BinaryBytes(req)
_, err = conn.Write(b)
if err != nil {
return nil, err
}
var n int
b = wire.ReadByteSlice(conn, 0, &n, &err) //XXX: no max
return decodeMsg(b)
return readMsg(conn)
}
func decodeMsg(bz []byte) (msg PrivValidatorSocketMsg, err error) {
n := new(int)
r := bytes.NewReader(bz)
msgI := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, n, &err)
var (
r = bytes.NewReader(bz)
n int
)
msgI := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, &n, &err)
msg = msgI.(struct{ PrivValidatorSocketMsg }).PrivValidatorSocketMsg
return msg, err
}


+ 44
- 0
types/priv_validator/socket_test.go View File

@ -0,0 +1,44 @@
package types
import (
"reflect"
"testing"
crypto "github.com/tendermint/go-crypto"
"github.com/tendermint/tendermint/types"
"github.com/tendermint/tmlibs/log"
)
func TestPrivValidatorSocketServer(t *testing.T) {
var (
chainID = "test-chain"
logger = log.TestingLogger()
signer = types.GenSigner()
privKey = crypto.GenPrivKeyEd25519()
privVal = NewTestPrivValidator(signer)
pvss = NewPrivValidatorSocketServer(
logger,
"127.0.0.1:0",
chainID,
privVal,
privKey,
1,
)
)
err := pvss.Start()
if err != nil {
t.Fatal(err)
}
c := NewPrivValidatorSocketClient(logger, pvss.listener.Addr().String())
err = c.Start()
if err != nil {
t.Fatal(err)
}
if have, want := c.PubKey(), pvss.privVal.PubKey(); !reflect.DeepEqual(have, want) {
t.Errorf("have %v, want %v", have, want)
}
}

Loading…
Cancel
Save