Browse Source

First stab: p2p/conn

pull/1347/head
Jae Kwon 7 years ago
parent
commit
6c345f9fa2
8 changed files with 408 additions and 371 deletions
  1. +47
    -27
      Gopkg.lock
  2. +4
    -5
      Gopkg.toml
  3. +123
    -116
      p2p/conn/connection.go
  4. +81
    -88
      p2p/conn/connection_test.go
  5. +61
    -39
      p2p/conn/secret_connection.go
  6. +78
    -36
      p2p/conn/secret_connection_test.go
  7. +14
    -0
      p2p/conn/wire.go
  8. +0
    -60
      wire/wire.go

+ 47
- 27
Gopkg.lock View File

@ -2,9 +2,10 @@
[[projects]] [[projects]]
branch = "master"
name = "github.com/btcsuite/btcd" name = "github.com/btcsuite/btcd"
packages = ["btcec"] packages = ["btcec"]
revision = "50de9da05b50eb15658bb350f6ea24368a111ab7"
revision = "2be2f12b358dc57d70b8f501b00be450192efbc3"
[[projects]] [[projects]]
name = "github.com/davecgh/go-spew" name = "github.com/davecgh/go-spew"
@ -19,10 +20,10 @@
revision = "95f809107225be108efcf10a3509e4ea6ceef3c4" revision = "95f809107225be108efcf10a3509e4ea6ceef3c4"
[[projects]] [[projects]]
branch = "master"
name = "github.com/fortytw2/leaktest" name = "github.com/fortytw2/leaktest"
packages = ["."] packages = ["."]
revision = "3b724c3d7b8729a35bf4e577f71653aec6e53513"
revision = "a5ef70473c97b71626b9abeda80ee92ba2a7de9e"
version = "v1.2.0"
[[projects]] [[projects]]
name = "github.com/fsnotify/fsnotify" name = "github.com/fsnotify/fsnotify"
@ -96,6 +97,7 @@
".", ".",
"hcl/ast", "hcl/ast",
"hcl/parser", "hcl/parser",
"hcl/printer",
"hcl/scanner", "hcl/scanner",
"hcl/strconv", "hcl/strconv",
"hcl/token", "hcl/token",
@ -103,7 +105,7 @@
"json/scanner", "json/scanner",
"json/token" "json/token"
] ]
revision = "23c074d0eceb2b8a5bfdbb271ab780cde70f05a8"
revision = "f40e974e75af4e271d97ce0fc917af5898ae7bda"
[[projects]] [[projects]]
name = "github.com/inconshreveable/mousetrap" name = "github.com/inconshreveable/mousetrap"
@ -126,12 +128,14 @@
[[projects]] [[projects]]
name = "github.com/magiconair/properties" name = "github.com/magiconair/properties"
packages = ["."] packages = ["."]
revision = "49d762b9817ba1c2e9d0c69183c2b4a8b8f1d934"
revision = "c3beff4c2358b44d0493c7dda585e7db7ff28ae6"
version = "v1.7.6"
[[projects]] [[projects]]
branch = "master"
name = "github.com/mitchellh/mapstructure" name = "github.com/mitchellh/mapstructure"
packages = ["."] packages = ["."]
revision = "b4575eea38cca1123ec2dc90c26529b5c5acfcff"
revision = "00c29f56e2386353d58c599509e8dc3801b0d716"
[[projects]] [[projects]]
name = "github.com/pelletier/go-toml" name = "github.com/pelletier/go-toml"
@ -169,8 +173,8 @@
[[projects]] [[projects]]
name = "github.com/spf13/cast" name = "github.com/spf13/cast"
packages = ["."] packages = ["."]
revision = "acbeb36b902d72a7a4c18e8f3241075e7ab763e4"
version = "v1.1.0"
revision = "8965335b8c7107321228e3e3702cab9832751bac"
version = "v1.2.0"
[[projects]] [[projects]]
name = "github.com/spf13/cobra" name = "github.com/spf13/cobra"
@ -193,8 +197,8 @@
[[projects]] [[projects]]
name = "github.com/spf13/viper" name = "github.com/spf13/viper"
packages = ["."] packages = ["."]
revision = "25b30aa063fc18e48662b86996252eabdcf2f0c7"
version = "v1.0.0"
revision = "b5e8006cbee93ec955a89ab31e0e3ce3204f3736"
version = "v1.0.2"
[[projects]] [[projects]]
name = "github.com/stretchr/testify" name = "github.com/stretchr/testify"
@ -206,6 +210,7 @@
version = "v1.2.1" version = "v1.2.1"
[[projects]] [[projects]]
branch = "master"
name = "github.com/syndtr/goleveldb" name = "github.com/syndtr/goleveldb"
packages = [ packages = [
"leveldb", "leveldb",
@ -221,7 +226,7 @@
"leveldb/table", "leveldb/table",
"leveldb/util" "leveldb/util"
] ]
revision = "34011bf325bce385408353a30b101fe5e923eb6e"
revision = "169b1b37be738edb2813dab48c97a549bcf99bb5"
[[projects]] [[projects]]
branch = "develop" branch = "develop"
@ -234,7 +239,7 @@
"server", "server",
"types" "types"
] ]
revision = "9e0e00bef42aebf6b402f66bf0f3dc607de8a6f3"
revision = "4e0218467649fecf17ebc5e8161f1c888fc8ff22"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -247,10 +252,16 @@
revision = "d8387025d2b9d158cf4efb07e7ebf814bcce2057" revision = "d8387025d2b9d158cf4efb07e7ebf814bcce2057"
[[projects]] [[projects]]
branch = "develop"
name = "github.com/tendermint/go-amino"
packages = ["."]
revision = "b1f32ee20e73716d8bfe695365c0a812b2bd8ef9"
[[projects]]
branch = "develop"
name = "github.com/tendermint/go-crypto" name = "github.com/tendermint/go-crypto"
packages = ["."] packages = ["."]
revision = "c3e19f3ea26f5c3357e0bcbb799b0761ef923755"
version = "v0.5.0"
revision = "a3800da0a15c8272cbd3c155e024bff28fe9692c"
[[projects]] [[projects]]
name = "github.com/tendermint/go-wire" name = "github.com/tendermint/go-wire"
@ -259,10 +270,10 @@
"data" "data"
] ]
revision = "fa721242b042ecd4c6ed1a934ee740db4f74e45c" revision = "fa721242b042ecd4c6ed1a934ee740db4f74e45c"
source = "github.com/tendermint/go-amino"
version = "v0.7.3" version = "v0.7.3"
[[projects]] [[projects]]
branch = "develop"
name = "github.com/tendermint/tmlibs" name = "github.com/tendermint/tmlibs"
packages = [ packages = [
"autofile", "autofile",
@ -278,10 +289,10 @@
"pubsub/query", "pubsub/query",
"test" "test"
] ]
revision = "1b9b5652a199ab0be2e781393fb275b66377309d"
version = "v0.7.0"
revision = "4e5c655944c9a636eaed549e6ad8fd8011fb4d42"
[[projects]] [[projects]]
branch = "master"
name = "golang.org/x/crypto" name = "golang.org/x/crypto"
packages = [ packages = [
"curve25519", "curve25519",
@ -293,7 +304,7 @@
"ripemd160", "ripemd160",
"salsa20/salsa" "salsa20/salsa"
] ]
revision = "1875d0a70c90e57f11972aefd42276df65e895b9"
revision = "c3a3ad6d03f7a915c0f7e194b7152974bb73d287"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -307,12 +318,13 @@
"lex/httplex", "lex/httplex",
"trace" "trace"
] ]
revision = "cbe0f9307d0156177f9dd5dc85da1a31abc5f2fb"
revision = "6078986fec03a1dcc236c34816c71b0e05018fda"
[[projects]] [[projects]]
branch = "master"
name = "golang.org/x/sys" name = "golang.org/x/sys"
packages = ["unix"] packages = ["unix"]
revision = "37707fdb30a5b38865cfb95e5aab41707daec7fd"
revision = "7ceb54c8418b8f9cdf0177b511d5cbb06e9fae39"
[[projects]] [[projects]]
name = "golang.org/x/text" name = "golang.org/x/text"
@ -332,21 +344,27 @@
"unicode/norm", "unicode/norm",
"unicode/rangetable" "unicode/rangetable"
] ]
revision = "e19ae1496984b1c655b8044a65c0300a3c878dd3"
revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0"
version = "v0.3.0"
[[projects]] [[projects]]
branch = "master"
name = "google.golang.org/genproto" name = "google.golang.org/genproto"
packages = ["googleapis/rpc/status"] packages = ["googleapis/rpc/status"]
revision = "4eb30f4778eed4c258ba66527a0d4f9ec8a36c45"
revision = "f8c8703595236ae70fdf8789ecb656ea0bcdcf46"
[[projects]] [[projects]]
name = "google.golang.org/grpc" name = "google.golang.org/grpc"
packages = [ packages = [
".", ".",
"balancer", "balancer",
"balancer/base",
"balancer/roundrobin",
"codes", "codes",
"connectivity", "connectivity",
"credentials", "credentials",
"encoding",
"encoding/proto",
"grpclb/grpc_lb_v1/messages", "grpclb/grpc_lb_v1/messages",
"grpclog", "grpclog",
"internal", "internal",
@ -355,23 +373,25 @@
"naming", "naming",
"peer", "peer",
"resolver", "resolver",
"resolver/dns",
"resolver/passthrough",
"stats", "stats",
"status", "status",
"tap", "tap",
"transport" "transport"
] ]
revision = "401e0e00e4bb830a10496d64cd95e068c5bf50de"
version = "v1.7.3"
revision = "8e4536a86ab602859c20df5ebfd0bd4228d08655"
version = "v1.10.0"
[[projects]] [[projects]]
name = "gopkg.in/yaml.v2" name = "gopkg.in/yaml.v2"
packages = ["."] packages = ["."]
revision = "d670f9405373e636a5a2765eea47fac0c9bc91a4"
version = "v2.0.0"
revision = "7f97868eec74b32b0982dd158a51a446d1da7eb5"
version = "v2.1.1"
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
analyzer-version = 1 analyzer-version = 1
inputs-digest = "ed9db0be72a900f4812675f683db20eff9d64ef4511dc00ad29a810da65909c2"
inputs-digest = "6da81f319b092e227b5d2c9de3b10296e9bb7287c02adb38fe547147e9e5e447"
solver-name = "gps-cdcl" solver-name = "gps-cdcl"
solver-version = 1 solver-version = 1

+ 4
- 5
Gopkg.toml View File

@ -75,16 +75,15 @@
[[constraint]] [[constraint]]
name = "github.com/tendermint/go-crypto" name = "github.com/tendermint/go-crypto"
version = "0.5.0"
branch = "develop"
[[constraint]] [[constraint]]
name = "github.com/tendermint/go-wire"
source = "github.com/tendermint/go-amino"
version = "0.7.3"
name = "github.com/tendermint/go-amino"
branch = "develop"
[[constraint]] [[constraint]]
name = "github.com/tendermint/tmlibs" name = "github.com/tendermint/tmlibs"
version = "0.7.0"
branch = "develop"
[[constraint]] [[constraint]]
name = "google.golang.org/grpc" name = "google.golang.org/grpc"


+ 123
- 116
p2p/conn/connection.go View File

@ -7,18 +7,22 @@ import (
"io" "io"
"math" "math"
"net" "net"
"reflect"
"runtime/debug" "runtime/debug"
"sync/atomic" "sync/atomic"
"time" "time"
wire "github.com/tendermint/go-wire"
amino "github.com/tendermint/go-amino"
cmn "github.com/tendermint/tmlibs/common" cmn "github.com/tendermint/tmlibs/common"
flow "github.com/tendermint/tmlibs/flowrate" flow "github.com/tendermint/tmlibs/flowrate"
"github.com/tendermint/tmlibs/log" "github.com/tendermint/tmlibs/log"
) )
const ( const (
numBatchMsgPackets = 10
defaultMaxPacketMsgPayloadSize = 1024
maxPacketMsgOverheadSize = 10 // It's actually lower but good enough
numBatchPacketMsgs = 10
minReadBufferSize = 1024 minReadBufferSize = 1024
minWriteBufferSize = 65536 minWriteBufferSize = 65536
updateStats = 2 * time.Second updateStats = 2 * time.Second
@ -58,8 +62,7 @@ There are two methods for sending messages:
`Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued `Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued
for the channel with the given id byte `chID`, or until the request times out. for the channel with the given id byte `chID`, or until the request times out.
The message `msg` is serialized using the `tendermint/wire` submodule's
`WriteBinary()` reflection routine.
The message `msg` is serialized using Go-Amino.
`TrySend(chID, msg)` is a nonblocking call that returns false if the channel's `TrySend(chID, msg)` is a nonblocking call that returns false if the channel's
queue is full. queue is full.
@ -69,19 +72,19 @@ Inbound message bytes are handled with an onReceive callback function.
type MConnection struct { type MConnection struct {
cmn.BaseService cmn.BaseService
conn net.Conn
bufReader *bufio.Reader
bufWriter *bufio.Writer
sendMonitor *flow.Monitor
recvMonitor *flow.Monitor
send chan struct{}
pong chan struct{}
channels []*Channel
channelsIdx map[byte]*Channel
onReceive receiveCbFunc
onError errorCbFunc
errored uint32
config *MConnConfig
conn net.Conn
bufConnReader *bufio.Reader
bufConnWriter *bufio.Writer
sendMonitor *flow.Monitor
recvMonitor *flow.Monitor
send chan struct{}
pong chan struct{}
channels []*Channel
channelsIdx map[byte]*Channel
onReceive receiveCbFunc
onError errorCbFunc
errored uint32
config *MConnConfig
quit chan struct{} quit chan struct{}
flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled. flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled.
@ -102,7 +105,7 @@ type MConnConfig struct {
RecvRate int64 `mapstructure:"recv_rate"` RecvRate int64 `mapstructure:"recv_rate"`
// Maximum payload size // Maximum payload size
MaxMsgPacketPayloadSize int `mapstructure:"max_msg_packet_payload_size"`
MaxPacketMsgPayloadSize int `mapstructure:"max_msg_packet_payload_size"`
// Interval to flush writes (throttled) // Interval to flush writes (throttled)
FlushThrottle time.Duration `mapstructure:"flush_throttle"` FlushThrottle time.Duration `mapstructure:"flush_throttle"`
@ -114,8 +117,8 @@ type MConnConfig struct {
PongTimeout time.Duration `mapstructure:"pong_timeout"` PongTimeout time.Duration `mapstructure:"pong_timeout"`
} }
func (cfg *MConnConfig) maxMsgPacketTotalSize() int {
return cfg.MaxMsgPacketPayloadSize + maxMsgPacketOverheadSize
func (cfg *MConnConfig) maxPacketMsgTotalSize() int {
return cfg.MaxPacketMsgPayloadSize + maxPacketMsgOverheadSize
} }
// DefaultMConnConfig returns the default config. // DefaultMConnConfig returns the default config.
@ -123,7 +126,7 @@ func DefaultMConnConfig() *MConnConfig {
return &MConnConfig{ return &MConnConfig{
SendRate: defaultSendRate, SendRate: defaultSendRate,
RecvRate: defaultRecvRate, RecvRate: defaultRecvRate,
MaxMsgPacketPayloadSize: defaultMaxMsgPacketPayloadSize,
MaxPacketMsgPayloadSize: defaultMaxPacketMsgPayloadSize,
FlushThrottle: defaultFlushThrottle, FlushThrottle: defaultFlushThrottle,
PingInterval: defaultPingInterval, PingInterval: defaultPingInterval,
PongTimeout: defaultPongTimeout, PongTimeout: defaultPongTimeout,
@ -147,16 +150,16 @@ func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onRec
} }
mconn := &MConnection{ mconn := &MConnection{
conn: conn,
bufReader: bufio.NewReaderSize(conn, minReadBufferSize),
bufWriter: bufio.NewWriterSize(conn, minWriteBufferSize),
sendMonitor: flow.New(0, 0),
recvMonitor: flow.New(0, 0),
send: make(chan struct{}, 1),
pong: make(chan struct{}, 1),
onReceive: onReceive,
onError: onError,
config: config,
conn: conn,
bufConnReader: bufio.NewReaderSize(conn, minReadBufferSize),
bufConnWriter: bufio.NewWriterSize(conn, minWriteBufferSize),
sendMonitor: flow.New(0, 0),
recvMonitor: flow.New(0, 0),
send: make(chan struct{}, 1),
pong: make(chan struct{}, 1),
onReceive: onReceive,
onError: onError,
config: config,
} }
// Create channels // Create channels
@ -221,7 +224,7 @@ func (c *MConnection) String() string {
func (c *MConnection) flush() { func (c *MConnection) flush() {
c.Logger.Debug("Flush", "conn", c) c.Logger.Debug("Flush", "conn", c)
err := c.bufWriter.Flush()
err := c.bufConnWriter.Flush()
if err != nil { if err != nil {
c.Logger.Error("MConnection flush failed", "err", err) c.Logger.Error("MConnection flush failed", "err", err)
} }
@ -251,7 +254,7 @@ func (c *MConnection) Send(chID byte, msg interface{}) bool {
return false return false
} }
c.Logger.Debug("Send", "channel", chID, "conn", c, "msg", msg) //, "bytes", wire.BinaryBytes(msg))
c.Logger.Debug("Send", "channel", chID, "conn", c, "msg", msg)
// Send message to channel. // Send message to channel.
channel, ok := c.channelsIdx[chID] channel, ok := c.channelsIdx[chID]
@ -260,7 +263,7 @@ func (c *MConnection) Send(chID byte, msg interface{}) bool {
return false return false
} }
success := channel.sendBytes(wire.BinaryBytes(msg))
success := channel.sendBytes(cdc.MustMarshalBinary(msg))
if success { if success {
// Wake up sendRoutine if necessary // Wake up sendRoutine if necessary
select { select {
@ -289,7 +292,7 @@ func (c *MConnection) TrySend(chID byte, msg interface{}) bool {
return false return false
} }
ok = channel.trySendBytes(wire.BinaryBytes(msg))
ok = channel.trySendBytes(cdc.MustMarshalBinary(msg))
if ok { if ok {
// Wake up sendRoutine if necessary // Wake up sendRoutine if necessary
select { select {
@ -322,12 +325,13 @@ func (c *MConnection) sendRoutine() {
FOR_LOOP: FOR_LOOP:
for { for {
var n int
var _n int64
var err error var err error
SELECTION:
select { select {
case <-c.flushTimer.Ch: case <-c.flushTimer.Ch:
// NOTE: flushTimer.Set() must be called every time // NOTE: flushTimer.Set() must be called every time
// something is written to .bufWriter.
// something is written to .bufConnWriter.
c.flush() c.flush()
case <-c.chStatsTimer.Chan(): case <-c.chStatsTimer.Chan():
for _, channel := range c.channels { for _, channel := range c.channels {
@ -335,8 +339,11 @@ FOR_LOOP:
} }
case <-c.pingTimer.Chan(): case <-c.pingTimer.Chan():
c.Logger.Debug("Send Ping") c.Logger.Debug("Send Ping")
wire.WriteByte(packetTypePing, c.bufWriter, &n, &err)
c.sendMonitor.Update(int(n))
_n, err = cdc.MarshalBinaryWriter(c.bufConnWriter, PacketPing{})
if err != nil {
break SELECTION
}
c.sendMonitor.Update(int(_n))
c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout) c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout)
c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() { c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() {
select { select {
@ -354,14 +361,17 @@ FOR_LOOP:
} }
case <-c.pong: case <-c.pong:
c.Logger.Debug("Send Pong") c.Logger.Debug("Send Pong")
wire.WriteByte(packetTypePong, c.bufWriter, &n, &err)
c.sendMonitor.Update(int(n))
_n, err = cdc.MarshalBinaryWriter(c.bufConnWriter, PacketPong{})
if err != nil {
break SELECTION
}
c.sendMonitor.Update(int(_n))
c.flush() c.flush()
case <-c.quit: case <-c.quit:
break FOR_LOOP break FOR_LOOP
case <-c.send: case <-c.send:
// Send some msgPackets
eof := c.sendSomeMsgPackets()
// Send some PacketMsgs
eof := c.sendSomePacketMsgs()
if !eof { if !eof {
// Keep sendRoutine awake. // Keep sendRoutine awake.
select { select {
@ -387,15 +397,15 @@ FOR_LOOP:
// Returns true if messages from channels were exhausted. // Returns true if messages from channels were exhausted.
// Blocks in accordance to .sendMonitor throttling. // Blocks in accordance to .sendMonitor throttling.
func (c *MConnection) sendSomeMsgPackets() bool {
func (c *MConnection) sendSomePacketMsgs() bool {
// Block until .sendMonitor says we can write. // Block until .sendMonitor says we can write.
// Once we're ready we send more than we asked for, // Once we're ready we send more than we asked for,
// but amortized it should even out. // but amortized it should even out.
c.sendMonitor.Limit(c.config.maxMsgPacketTotalSize(), atomic.LoadInt64(&c.config.SendRate), true)
c.sendMonitor.Limit(c.config.maxPacketMsgTotalSize(), atomic.LoadInt64(&c.config.SendRate), true)
// Now send some msgPackets.
for i := 0; i < numBatchMsgPackets; i++ {
if c.sendMsgPacket() {
// Now send some PacketMsgs.
for i := 0; i < numBatchPacketMsgs; i++ {
if c.sendPacketMsg() {
return true return true
} }
} }
@ -403,8 +413,8 @@ func (c *MConnection) sendSomeMsgPackets() bool {
} }
// Returns true if messages from channels were exhausted. // Returns true if messages from channels were exhausted.
func (c *MConnection) sendMsgPacket() bool {
// Choose a channel to create a msgPacket from.
func (c *MConnection) sendPacketMsg() bool {
// Choose a channel to create a PacketMsg from.
// The chosen channel will be the one whose recentlySent/priority is the least. // The chosen channel will be the one whose recentlySent/priority is the least.
var leastRatio float32 = math.MaxFloat32 var leastRatio float32 = math.MaxFloat32
var leastChannel *Channel var leastChannel *Channel
@ -425,22 +435,22 @@ func (c *MConnection) sendMsgPacket() bool {
if leastChannel == nil { if leastChannel == nil {
return true return true
} else { } else {
// c.Logger.Info("Found a msgPacket to send")
// c.Logger.Info("Found a PacketMsg to send")
} }
// Make & send a msgPacket from this channel
n, err := leastChannel.writeMsgPacketTo(c.bufWriter)
// Make & send a PacketMsg from this channel
_n, err := leastChannel.writePacketMsgTo(c.bufConnWriter)
if err != nil { if err != nil {
c.Logger.Error("Failed to write msgPacket", "err", err)
c.Logger.Error("Failed to write PacketMsg", "err", err)
c.stopForError(err) c.stopForError(err)
return true return true
} }
c.sendMonitor.Update(int(n))
c.sendMonitor.Update(int(_n))
c.flushTimer.Set() c.flushTimer.Set()
return false return false
} }
// recvRoutine reads msgPackets and reconstructs the message using the channels' "recving" buffer.
// recvRoutine reads PacketMsgs and reconstructs the message using the channels' "recving" buffer.
// After a whole message has been assembled, it's pushed to onReceive(). // After a whole message has been assembled, it's pushed to onReceive().
// Blocks depending on how the connection is throttled. // Blocks depending on how the connection is throttled.
// Otherwise, it never blocks. // Otherwise, it never blocks.
@ -450,13 +460,13 @@ func (c *MConnection) recvRoutine() {
FOR_LOOP: FOR_LOOP:
for { for {
// Block until .recvMonitor says we can read. // Block until .recvMonitor says we can read.
c.recvMonitor.Limit(c.config.maxMsgPacketTotalSize(), atomic.LoadInt64(&c.config.RecvRate), true)
c.recvMonitor.Limit(c.config.maxPacketMsgTotalSize(), atomic.LoadInt64(&c.config.RecvRate), true)
/* /*
// Peek into bufReader for debugging
if numBytes := c.bufReader.Buffered(); numBytes > 0 {
// Peek into bufConnReader for debugging
if numBytes := c.bufConnReader.Buffered(); numBytes > 0 {
log.Info("Peek connection buffer", "numBytes", numBytes, "bytes", log15.Lazy{func() []byte { log.Info("Peek connection buffer", "numBytes", numBytes, "bytes", log15.Lazy{func() []byte {
bytes, err := c.bufReader.Peek(cmn.MinInt(numBytes, 100))
bytes, err := c.bufConnReader.Peek(cmn.MinInt(numBytes, 100))
if err == nil { if err == nil {
return bytes return bytes
} else { } else {
@ -468,10 +478,11 @@ FOR_LOOP:
*/ */
// Read packet type // Read packet type
var n int
var packet Packet
var _n int64
var err error var err error
pktType := wire.ReadByte(c.bufReader, &n, &err)
c.recvMonitor.Update(int(n))
_n, err = cdc.UnmarshalBinaryReader(c.bufConnReader, &packet, int64(c.config.maxPacketMsgTotalSize()))
c.recvMonitor.Update(int(_n))
if err != nil { if err != nil {
if c.IsRunning() { if c.IsRunning() {
c.Logger.Error("Connection failed @ recvRoutine (reading byte)", "conn", c, "err", err) c.Logger.Error("Connection failed @ recvRoutine (reading byte)", "conn", c, "err", err)
@ -481,8 +492,8 @@ FOR_LOOP:
} }
// Read more depending on packet type. // Read more depending on packet type.
switch pktType {
case packetTypePing:
switch pkt := packet.(type) {
case PacketPing:
// TODO: prevent abuse, as they cause flush()'s. // TODO: prevent abuse, as they cause flush()'s.
// https://github.com/tendermint/tendermint/issues/1190 // https://github.com/tendermint/tendermint/issues/1190
c.Logger.Debug("Receive Ping") c.Logger.Debug("Receive Ping")
@ -491,24 +502,14 @@ FOR_LOOP:
default: default:
// never block // never block
} }
case packetTypePong:
case PacketPong:
c.Logger.Debug("Receive Pong") c.Logger.Debug("Receive Pong")
select { select {
case c.pongTimeoutCh <- false: case c.pongTimeoutCh <- false:
default: default:
// never block // never block
} }
case packetTypeMsg:
pkt, n, err := msgPacket{}, int(0), error(nil)
wire.ReadBinaryPtr(&pkt, c.bufReader, c.config.maxMsgPacketTotalSize(), &n, &err)
c.recvMonitor.Update(int(n))
if err != nil {
if c.IsRunning() {
c.Logger.Error("Connection failed @ recvRoutine", "conn", c, "err", err)
c.stopForError(err)
}
break FOR_LOOP
}
case PacketMsg:
channel, ok := c.channelsIdx[pkt.ChannelID] channel, ok := c.channelsIdx[pkt.ChannelID]
if !ok || channel == nil { if !ok || channel == nil {
err := fmt.Errorf("Unknown channel %X", pkt.ChannelID) err := fmt.Errorf("Unknown channel %X", pkt.ChannelID)
@ -517,7 +518,7 @@ FOR_LOOP:
break FOR_LOOP break FOR_LOOP
} }
msgBytes, err := channel.recvMsgPacket(pkt)
msgBytes, err := channel.recvPacketMsg(pkt)
if err != nil { if err != nil {
if c.IsRunning() { if c.IsRunning() {
c.Logger.Error("Connection failed @ recvRoutine", "conn", c, "err", err) c.Logger.Error("Connection failed @ recvRoutine", "conn", c, "err", err)
@ -531,7 +532,7 @@ FOR_LOOP:
c.onReceive(pkt.ChannelID, msgBytes) c.onReceive(pkt.ChannelID, msgBytes)
} }
default: default:
err := fmt.Errorf("Unknown message type %X", pktType)
err := fmt.Errorf("Unknown message type %v", reflect.TypeOf(packet))
c.Logger.Error("Connection failed @ recvRoutine", "conn", c, "err", err) c.Logger.Error("Connection failed @ recvRoutine", "conn", c, "err", err)
c.stopForError(err) c.stopForError(err)
break FOR_LOOP break FOR_LOOP
@ -623,7 +624,7 @@ type Channel struct {
sending []byte sending []byte
recentlySent int64 // exponential moving average recentlySent int64 // exponential moving average
maxMsgPacketPayloadSize int
maxPacketMsgPayloadSize int
Logger log.Logger Logger log.Logger
} }
@ -638,7 +639,7 @@ func newChannel(conn *MConnection, desc ChannelDescriptor) *Channel {
desc: desc, desc: desc,
sendQueue: make(chan []byte, desc.SendQueueCapacity), sendQueue: make(chan []byte, desc.SendQueueCapacity),
recving: make([]byte, 0, desc.RecvBufferCapacity), recving: make([]byte, 0, desc.RecvBufferCapacity),
maxMsgPacketPayloadSize: conn.config.MaxMsgPacketPayloadSize,
maxPacketMsgPayloadSize: conn.config.MaxPacketMsgPayloadSize,
} }
} }
@ -683,8 +684,8 @@ func (ch *Channel) canSend() bool {
return ch.loadSendQueueSize() < defaultSendQueueCapacity return ch.loadSendQueueSize() < defaultSendQueueCapacity
} }
// Returns true if any msgPackets are pending to be sent.
// Call before calling nextMsgPacket()
// Returns true if any PacketMsgs are pending to be sent.
// Call before calling nextPacketMsg()
// Goroutine-safe // Goroutine-safe
func (ch *Channel) isSendPending() bool { func (ch *Channel) isSendPending() bool {
if len(ch.sending) == 0 { if len(ch.sending) == 0 {
@ -696,12 +697,12 @@ func (ch *Channel) isSendPending() bool {
return true return true
} }
// Creates a new msgPacket to send.
// Creates a new PacketMsg to send.
// Not goroutine-safe // Not goroutine-safe
func (ch *Channel) nextMsgPacket() msgPacket {
packet := msgPacket{}
func (ch *Channel) nextPacketMsg() PacketMsg {
packet := PacketMsg{}
packet.ChannelID = byte(ch.desc.ID) packet.ChannelID = byte(ch.desc.ID)
maxSize := ch.maxMsgPacketPayloadSize
maxSize := ch.maxPacketMsgPayloadSize
packet.Bytes = ch.sending[:cmn.MinInt(maxSize, len(ch.sending))] packet.Bytes = ch.sending[:cmn.MinInt(maxSize, len(ch.sending))]
if len(ch.sending) <= maxSize { if len(ch.sending) <= maxSize {
packet.EOF = byte(0x01) packet.EOF = byte(0x01)
@ -714,30 +715,24 @@ func (ch *Channel) nextMsgPacket() msgPacket {
return packet return packet
} }
// Writes next msgPacket to w.
// Writes next PacketMsg to w and updates c.recentlySent.
// Not goroutine-safe // Not goroutine-safe
func (ch *Channel) writeMsgPacketTo(w io.Writer) (n int, err error) {
packet := ch.nextMsgPacket()
func (ch *Channel) writePacketMsgTo(w io.Writer) (n int64, err error) {
var packet = ch.nextPacketMsg()
ch.Logger.Debug("Write Msg Packet", "conn", ch.conn, "packet", packet) ch.Logger.Debug("Write Msg Packet", "conn", ch.conn, "packet", packet)
writeMsgPacketTo(packet, w, &n, &err)
if err == nil {
ch.recentlySent += int64(n)
}
n, err = cdc.MarshalBinaryWriter(w, packet)
ch.recentlySent += n
return return
} }
func writeMsgPacketTo(packet msgPacket, w io.Writer, n *int, err *error) {
wire.WriteByte(packetTypeMsg, w, n, err)
wire.WriteBinary(packet, w, n, err)
}
// Handles incoming msgPackets. It returns a message bytes if message is
// complete. NOTE message bytes may change on next call to recvMsgPacket.
// Handles incoming PacketMsgs. It returns a message bytes if message is
// complete. NOTE message bytes may change on next call to recvPacketMsg.
// Not goroutine-safe // Not goroutine-safe
func (ch *Channel) recvMsgPacket(packet msgPacket) ([]byte, error) {
func (ch *Channel) recvPacketMsg(packet PacketMsg) ([]byte, error) {
ch.Logger.Debug("Read Msg Packet", "conn", ch.conn, "packet", packet) ch.Logger.Debug("Read Msg Packet", "conn", ch.conn, "packet", packet)
if ch.desc.RecvMessageCapacity < len(ch.recving)+len(packet.Bytes) {
return nil, wire.ErrBinaryReadOverflow
var recvCap, recvReceived = ch.desc.RecvMessageCapacity, len(ch.recving) + len(packet.Bytes)
if recvCap < recvReceived {
return nil, fmt.Errorf("Received message exceeds available capacity: %v < %v", recvCap, recvReceived)
} }
ch.recving = append(ch.recving, packet.Bytes...) ch.recving = append(ch.recving, packet.Bytes...)
if packet.EOF == byte(0x01) { if packet.EOF == byte(0x01) {
@ -761,24 +756,36 @@ func (ch *Channel) updateStats() {
ch.recentlySent = int64(float64(ch.recentlySent) * 0.8) ch.recentlySent = int64(float64(ch.recentlySent) * 0.8)
} }
//-----------------------------------------------------------------------------
//----------------------------------------
// Packet
const (
defaultMaxMsgPacketPayloadSize = 1024
type Packet interface {
AssertIsPacket()
}
maxMsgPacketOverheadSize = 10 // It's actually lower but good enough
packetTypePing = byte(0x01)
packetTypePong = byte(0x02)
packetTypeMsg = byte(0x03)
)
func RegisterPacket(cdc *amino.Codec) {
cdc.RegisterInterface((*Packet)(nil), nil)
cdc.RegisterConcrete(PacketPing{}, "tendermint/p2p/PacketPing", nil)
cdc.RegisterConcrete(PacketPong{}, "tendermint/p2p/PacketPong", nil)
cdc.RegisterConcrete(PacketMsg{}, "tendermint/p2p/PacketMsg", nil)
}
func (_ PacketPing) AssertIsPacket() {}
func (_ PacketPong) AssertIsPacket() {}
func (_ PacketMsg) AssertIsPacket() {}
type PacketPing struct {
}
type PacketPong struct {
}
// Messages in channels are chopped into smaller msgPackets for multiplexing.
type msgPacket struct {
type PacketMsg struct {
ChannelID byte ChannelID byte
EOF byte // 1 means message ends here. EOF byte // 1 means message ends here.
Bytes []byte Bytes []byte
} }
func (p msgPacket) String() string {
return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF)
func (mp PacketMsg) String() string {
return fmt.Sprintf("PacketMsg{%X:%X T:%X}", mp.ChannelID, mp.Bytes, mp.EOF)
} }

+ 81
- 88
p2p/conn/connection_test.go View File

@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
wire "github.com/tendermint/go-wire"
"github.com/tendermint/go-amino"
"github.com/tendermint/tmlibs/log" "github.com/tendermint/tmlibs/log"
) )
@ -32,41 +32,37 @@ func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msg
} }
func TestMConnectionSend(t *testing.T) { func TestMConnectionSend(t *testing.T) {
assert, require := assert.New(t), require.New(t)
server, client := NetPipe() server, client := NetPipe()
defer server.Close() // nolint: errcheck defer server.Close() // nolint: errcheck
defer client.Close() // nolint: errcheck defer client.Close() // nolint: errcheck
mconn := createTestMConnection(client) mconn := createTestMConnection(client)
err := mconn.Start() err := mconn.Start()
require.Nil(err)
require.Nil(t, err)
defer mconn.Stop() defer mconn.Stop()
msg := "Ant-Man" msg := "Ant-Man"
assert.True(mconn.Send(0x01, msg))
assert.True(t, mconn.Send(0x01, msg))
// Note: subsequent Send/TrySend calls could pass because we are reading from // Note: subsequent Send/TrySend calls could pass because we are reading from
// the send queue in a separate goroutine. // the send queue in a separate goroutine.
_, err = server.Read(make([]byte, len(msg))) _, err = server.Read(make([]byte, len(msg)))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
assert.True(mconn.CanSend(0x01))
assert.True(t, mconn.CanSend(0x01))
msg = "Spider-Man" msg = "Spider-Man"
assert.True(mconn.TrySend(0x01, msg))
assert.True(t, mconn.TrySend(0x01, msg))
_, err = server.Read(make([]byte, len(msg))) _, err = server.Read(make([]byte, len(msg)))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
assert.False(mconn.CanSend(0x05), "CanSend should return false because channel is unknown")
assert.False(mconn.Send(0x05, "Absorbing Man"), "Send should return false because channel is unknown")
assert.False(t, mconn.CanSend(0x05), "CanSend should return false because channel is unknown")
assert.False(t, mconn.Send(0x05, "Absorbing Man"), "Send should return false because channel is unknown")
} }
func TestMConnectionReceive(t *testing.T) { func TestMConnectionReceive(t *testing.T) {
assert, require := assert.New(t), require.New(t)
server, client := NetPipe() server, client := NetPipe()
defer server.Close() // nolint: errcheck defer server.Close() // nolint: errcheck
defer client.Close() // nolint: errcheck defer client.Close() // nolint: errcheck
@ -81,20 +77,20 @@ func TestMConnectionReceive(t *testing.T) {
} }
mconn1 := createMConnectionWithCallbacks(client, onReceive, onError) mconn1 := createMConnectionWithCallbacks(client, onReceive, onError)
err := mconn1.Start() err := mconn1.Start()
require.Nil(err)
require.Nil(t, err)
defer mconn1.Stop() defer mconn1.Stop()
mconn2 := createTestMConnection(server) mconn2 := createTestMConnection(server)
err = mconn2.Start() err = mconn2.Start()
require.Nil(err)
require.Nil(t, err)
defer mconn2.Stop() defer mconn2.Stop()
msg := "Cyclops" msg := "Cyclops"
assert.True(mconn2.Send(0x01, msg))
assert.True(t, mconn2.Send(0x01, msg))
select { select {
case receivedBytes := <-receivedCh: case receivedBytes := <-receivedCh:
assert.Equal([]byte(msg), receivedBytes[2:]) // first 3 bytes are internal
assert.Equal(t, []byte(msg), receivedBytes[2:]) // first 3 bytes are internal
case err := <-errorsCh: case err := <-errorsCh:
t.Fatalf("Expected %s, got %+v", msg, err) t.Fatalf("Expected %s, got %+v", msg, err)
case <-time.After(500 * time.Millisecond): case <-time.After(500 * time.Millisecond):
@ -103,20 +99,18 @@ func TestMConnectionReceive(t *testing.T) {
} }
func TestMConnectionStatus(t *testing.T) { func TestMConnectionStatus(t *testing.T) {
assert, require := assert.New(t), require.New(t)
server, client := NetPipe() server, client := NetPipe()
defer server.Close() // nolint: errcheck defer server.Close() // nolint: errcheck
defer client.Close() // nolint: errcheck defer client.Close() // nolint: errcheck
mconn := createTestMConnection(client) mconn := createTestMConnection(client)
err := mconn.Start() err := mconn.Start()
require.Nil(err)
require.Nil(t, err)
defer mconn.Stop() defer mconn.Stop()
status := mconn.Status() status := mconn.Status()
assert.NotNil(status)
assert.Zero(status.Channels[0].SendQueueSize)
assert.NotNil(t, status)
assert.Zero(t, status.Channels[0].SendQueueSize)
} }
func TestMConnectionPongTimeoutResultsInError(t *testing.T) { func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
@ -140,7 +134,9 @@ func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
serverGotPing := make(chan struct{}) serverGotPing := make(chan struct{})
go func() { go func() {
// read ping // read ping
server.Read(make([]byte, 1))
var pkt PacketPing
_, err = cdc.UnmarshalBinaryReader(server, &pkt, 1024)
assert.Nil(t, err)
serverGotPing <- struct{}{} serverGotPing <- struct{}{}
}() }()
<-serverGotPing <-serverGotPing
@ -175,21 +171,22 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
defer mconn.Stop() defer mconn.Stop()
// sending 3 pongs in a row (abuse) // sending 3 pongs in a row (abuse)
_, err = server.Write([]byte{packetTypePong})
_, err = server.Write(cdc.MustMarshalBinary(PacketPong{}))
require.Nil(t, err) require.Nil(t, err)
_, err = server.Write([]byte{packetTypePong})
_, err = server.Write(cdc.MustMarshalBinary(PacketPong{}))
require.Nil(t, err) require.Nil(t, err)
_, err = server.Write([]byte{packetTypePong})
_, err = server.Write(cdc.MustMarshalBinary(PacketPong{}))
require.Nil(t, err) require.Nil(t, err)
serverGotPing := make(chan struct{}) serverGotPing := make(chan struct{})
go func() { go func() {
// read ping (one byte) // read ping (one byte)
_, err = server.Read(make([]byte, 1))
var packet, err = Packet(nil), error(nil)
_, err = cdc.UnmarshalBinaryReader(server, &packet, 1024)
require.Nil(t, err) require.Nil(t, err)
serverGotPing <- struct{}{} serverGotPing <- struct{}{}
// respond with pong // respond with pong
_, err = server.Write([]byte{packetTypePong})
_, err = server.Write(cdc.MustMarshalBinary(PacketPong{}))
require.Nil(t, err) require.Nil(t, err)
}() }()
<-serverGotPing <-serverGotPing
@ -225,17 +222,18 @@ func TestMConnectionMultiplePings(t *testing.T) {
// sending 3 pings in a row (abuse) // sending 3 pings in a row (abuse)
// see https://github.com/tendermint/tendermint/issues/1190 // see https://github.com/tendermint/tendermint/issues/1190
_, err = server.Write([]byte{packetTypePing})
_, err = server.Write(cdc.MustMarshalBinary(PacketPing{}))
require.Nil(t, err) require.Nil(t, err)
_, err = server.Read(make([]byte, 1))
var pkt PacketPong
_, err = cdc.UnmarshalBinaryReader(server, &pkt, 1024)
require.Nil(t, err) require.Nil(t, err)
_, err = server.Write([]byte{packetTypePing})
_, err = server.Write(cdc.MustMarshalBinary(PacketPing{}))
require.Nil(t, err) require.Nil(t, err)
_, err = server.Read(make([]byte, 1))
_, err = cdc.UnmarshalBinaryReader(server, &pkt, 1024)
require.Nil(t, err) require.Nil(t, err)
_, err = server.Write([]byte{packetTypePing})
_, err = server.Write(cdc.MustMarshalBinary(PacketPing{}))
require.Nil(t, err) require.Nil(t, err)
_, err = server.Read(make([]byte, 1))
_, err = cdc.UnmarshalBinaryReader(server, &pkt, 1024)
require.Nil(t, err) require.Nil(t, err)
assert.True(t, mconn.IsRunning()) assert.True(t, mconn.IsRunning())
@ -262,18 +260,21 @@ func TestMConnectionPingPongs(t *testing.T) {
serverGotPing := make(chan struct{}) serverGotPing := make(chan struct{})
go func() { go func() {
// read ping // read ping
server.Read(make([]byte, 1))
var pkt PacketPing
_, err = cdc.UnmarshalBinaryReader(server, &pkt, 1024)
require.Nil(t, err)
serverGotPing <- struct{}{} serverGotPing <- struct{}{}
// respond with pong // respond with pong
_, err = server.Write([]byte{packetTypePong})
_, err = server.Write(cdc.MustMarshalBinary(PacketPong{}))
require.Nil(t, err) require.Nil(t, err)
time.Sleep(mconn.config.PingInterval) time.Sleep(mconn.config.PingInterval)
// read ping // read ping
server.Read(make([]byte, 1))
_, err = cdc.UnmarshalBinaryReader(server, &pkt, 1024)
require.Nil(t, err)
// respond with pong // respond with pong
_, err = server.Write([]byte{packetTypePong})
_, err = server.Write(cdc.MustMarshalBinary(PacketPong{}))
require.Nil(t, err) require.Nil(t, err)
}() }()
<-serverGotPing <-serverGotPing
@ -290,8 +291,6 @@ func TestMConnectionPingPongs(t *testing.T) {
} }
func TestMConnectionStopsAndReturnsError(t *testing.T) { func TestMConnectionStopsAndReturnsError(t *testing.T) {
assert, require := assert.New(t), require.New(t)
server, client := NetPipe() server, client := NetPipe()
defer server.Close() // nolint: errcheck defer server.Close() // nolint: errcheck
defer client.Close() // nolint: errcheck defer client.Close() // nolint: errcheck
@ -306,7 +305,7 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) {
} }
mconn := createMConnectionWithCallbacks(client, onReceive, onError) mconn := createMConnectionWithCallbacks(client, onReceive, onError)
err := mconn.Start() err := mconn.Start()
require.Nil(err)
require.Nil(t, err)
defer mconn.Stop() defer mconn.Stop()
if err := client.Close(); err != nil { if err := client.Close(); err != nil {
@ -317,14 +316,14 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) {
case receivedBytes := <-receivedCh: case receivedBytes := <-receivedCh:
t.Fatalf("Expected error, got %v", receivedBytes) t.Fatalf("Expected error, got %v", receivedBytes)
case err := <-errorsCh: case err := <-errorsCh:
assert.NotNil(err)
assert.False(mconn.IsRunning())
assert.NotNil(t, err)
assert.False(t, mconn.IsRunning())
case <-time.After(500 * time.Millisecond): case <-time.After(500 * time.Millisecond):
t.Fatal("Did not receive error in 500ms") t.Fatal("Did not receive error in 500ms")
} }
} }
func newClientAndServerConnsForReadErrors(require *require.Assertions, chOnErr chan struct{}) (*MConnection, *MConnection) {
func newClientAndServerConnsForReadErrors(t *testing.T, chOnErr chan struct{}) (*MConnection, *MConnection) {
server, client := NetPipe() server, client := NetPipe()
onReceive := func(chID byte, msgBytes []byte) {} onReceive := func(chID byte, msgBytes []byte) {}
@ -338,7 +337,7 @@ func newClientAndServerConnsForReadErrors(require *require.Assertions, chOnErr c
mconnClient := NewMConnection(client, chDescs, onReceive, onError) mconnClient := NewMConnection(client, chDescs, onReceive, onError)
mconnClient.SetLogger(log.TestingLogger().With("module", "client")) mconnClient.SetLogger(log.TestingLogger().With("module", "client"))
err := mconnClient.Start() err := mconnClient.Start()
require.Nil(err)
require.Nil(t, err)
// create server conn with 1 channel // create server conn with 1 channel
// it fires on chOnErr when there's an error // it fires on chOnErr when there's an error
@ -349,7 +348,7 @@ func newClientAndServerConnsForReadErrors(require *require.Assertions, chOnErr c
mconnServer := createMConnectionWithCallbacks(server, onReceive, onError) mconnServer := createMConnectionWithCallbacks(server, onReceive, onError)
mconnServer.SetLogger(serverLogger) mconnServer.SetLogger(serverLogger)
err = mconnServer.Start() err = mconnServer.Start()
require.Nil(err)
require.Nil(t, err)
return mconnClient, mconnServer return mconnClient, mconnServer
} }
@ -364,50 +363,45 @@ func expectSend(ch chan struct{}) bool {
} }
func TestMConnectionReadErrorBadEncoding(t *testing.T) { func TestMConnectionReadErrorBadEncoding(t *testing.T) {
assert, require := assert.New(t), require.New(t)
chOnErr := make(chan struct{}) chOnErr := make(chan struct{})
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(require, chOnErr)
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
defer mconnClient.Stop() defer mconnClient.Stop()
defer mconnServer.Stop() defer mconnServer.Stop()
client := mconnClient.conn client := mconnClient.conn
msg := "Ant-Man"
// send badly encoded msgPacket // send badly encoded msgPacket
var n int
var err error
wire.WriteByte(packetTypeMsg, client, &n, &err)
wire.WriteByteSlice([]byte(msg), client, &n, &err)
assert.True(expectSend(chOnErr), "badly encoded msgPacket")
bz := cdc.MustMarshalBinary(PacketMsg{})
bz[4] += 0x01 // Invalid prefix bytes.
// Write it.
_, err := client.Write(bz)
assert.Nil(t, err)
assert.True(t, expectSend(chOnErr), "badly encoded msgPacket")
} }
func TestMConnectionReadErrorUnknownChannel(t *testing.T) { func TestMConnectionReadErrorUnknownChannel(t *testing.T) {
assert, require := assert.New(t), require.New(t)
chOnErr := make(chan struct{}) chOnErr := make(chan struct{})
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(require, chOnErr)
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
defer mconnClient.Stop() defer mconnClient.Stop()
defer mconnServer.Stop() defer mconnServer.Stop()
msg := "Ant-Man" msg := "Ant-Man"
// fail to send msg on channel unknown by client // fail to send msg on channel unknown by client
assert.False(mconnClient.Send(0x03, msg))
assert.False(t, mconnClient.Send(0x03, msg))
// send msg on channel unknown by the server. // send msg on channel unknown by the server.
// should cause an error // should cause an error
assert.True(mconnClient.Send(0x02, msg))
assert.True(expectSend(chOnErr), "unknown channel")
assert.True(t, mconnClient.Send(0x02, msg))
assert.True(t, expectSend(chOnErr), "unknown channel")
} }
func TestMConnectionReadErrorLongMessage(t *testing.T) { func TestMConnectionReadErrorLongMessage(t *testing.T) {
assert, require := assert.New(t), require.New(t)
chOnErr := make(chan struct{}) chOnErr := make(chan struct{})
chOnRcv := make(chan struct{}) chOnRcv := make(chan struct{})
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(require, chOnErr)
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
defer mconnClient.Stop() defer mconnClient.Stop()
defer mconnServer.Stop() defer mconnServer.Stop()
@ -418,65 +412,64 @@ func TestMConnectionReadErrorLongMessage(t *testing.T) {
client := mconnClient.conn client := mconnClient.conn
// send msg thats just right // send msg thats just right
var n int
var err error var err error
packet := msgPacket{
var packet = PacketMsg{
ChannelID: 0x01, ChannelID: 0x01,
Bytes: make([]byte, mconnClient.config.maxMsgPacketTotalSize()-5),
Bytes: make([]byte, mconnClient.config.maxPacketMsgTotalSize()-12),
EOF: 1, EOF: 1,
} }
writeMsgPacketTo(packet, client, &n, &err)
assert.True(expectSend(chOnRcv), "msg just right")
_, err = cdc.MarshalBinaryWriter(client, packet)
assert.Nil(t, err)
assert.True(t, expectSend(chOnRcv), "msg just right")
// send msg thats too long // send msg thats too long
packet = msgPacket{
packet = PacketMsg{
ChannelID: 0x01, ChannelID: 0x01,
Bytes: make([]byte, mconnClient.config.maxMsgPacketTotalSize()-4),
Bytes: make([]byte, mconnClient.config.maxPacketMsgTotalSize()-11),
EOF: 1, EOF: 1,
} }
writeMsgPacketTo(packet, client, &n, &err)
assert.True(expectSend(chOnErr), "msg too long")
_, err = cdc.MarshalBinaryWriter(client, packet)
assert.Nil(t, err)
assert.True(t, expectSend(chOnErr), "msg too long")
} }
func TestMConnectionReadErrorUnknownMsgType(t *testing.T) { func TestMConnectionReadErrorUnknownMsgType(t *testing.T) {
assert, require := assert.New(t), require.New(t)
chOnErr := make(chan struct{}) chOnErr := make(chan struct{})
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(require, chOnErr)
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
defer mconnClient.Stop() defer mconnClient.Stop()
defer mconnServer.Stop() defer mconnServer.Stop()
// send msg with unknown msg type // send msg with unknown msg type
var n int
var err error
wire.WriteByte(0x04, mconnClient.conn, &n, &err)
assert.True(expectSend(chOnErr), "unknown msg type")
err := error(nil)
err = amino.EncodeUvarint(mconnClient.conn, 4)
assert.Nil(t, err)
_, err = mconnClient.conn.Write([]byte{0xFF, 0xFF, 0xFF, 0xFF})
assert.Nil(t, err)
assert.True(t, expectSend(chOnErr), "unknown msg type")
} }
func TestMConnectionTrySend(t *testing.T) { func TestMConnectionTrySend(t *testing.T) {
assert, require := assert.New(t), require.New(t)
server, client := NetPipe() server, client := NetPipe()
defer server.Close() defer server.Close()
defer client.Close() defer client.Close()
mconn := createTestMConnection(client) mconn := createTestMConnection(client)
err := mconn.Start() err := mconn.Start()
require.Nil(err)
require.Nil(t, err)
defer mconn.Stop() defer mconn.Stop()
msg := "Semicolon-Woman" msg := "Semicolon-Woman"
resultCh := make(chan string, 2) resultCh := make(chan string, 2)
assert.True(mconn.TrySend(0x01, msg))
assert.True(t, mconn.TrySend(0x01, msg))
server.Read(make([]byte, len(msg))) server.Read(make([]byte, len(msg)))
assert.True(mconn.CanSend(0x01))
assert.True(mconn.TrySend(0x01, msg))
assert.False(mconn.CanSend(0x01))
assert.True(t, mconn.CanSend(0x01))
assert.True(t, mconn.TrySend(0x01, msg))
assert.False(t, mconn.CanSend(0x01))
go func() { go func() {
mconn.TrySend(0x01, msg) mconn.TrySend(0x01, msg)
resultCh <- "TrySend" resultCh <- "TrySend"
}() }()
assert.False(mconn.CanSend(0x01))
assert.False(mconn.TrySend(0x01, msg))
assert.Equal("TrySend", <-resultCh)
assert.False(t, mconn.CanSend(0x01))
assert.False(t, mconn.TrySend(0x01, msg))
assert.Equal(t, "TrySend", <-resultCh)
} }

+ 61
- 39
p2p/conn/secret_connection.go View File

@ -12,6 +12,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"time" "time"
@ -21,16 +22,14 @@ import (
"golang.org/x/crypto/ripemd160" "golang.org/x/crypto/ripemd160"
"github.com/tendermint/go-crypto" "github.com/tendermint/go-crypto"
"github.com/tendermint/go-wire"
cmn "github.com/tendermint/tmlibs/common" cmn "github.com/tendermint/tmlibs/common"
) )
// 2 + 1024 == 1026 total frame size
const dataLenSize = 2 // uint16 to describe the length, is <= dataMaxSize
// 4 + 1024 == 1028 total frame size
const dataLenSize = 4
const dataMaxSize = 1024 const dataMaxSize = 1024
const totalFrameSize = dataMaxSize + dataLenSize const totalFrameSize = dataMaxSize + dataLenSize
const sealedFrameSize = totalFrameSize + secretbox.Overhead const sealedFrameSize = totalFrameSize + secretbox.Overhead
const authSigMsgSize = (32 + 1) + (64 + 1) // fixed size (length prefixed) byte arrays
// Implements net.Conn // Implements net.Conn
type SecretConnection struct { type SecretConnection struct {
@ -92,6 +91,7 @@ func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKey) (*
// Share (in secret) each other's pubkey & challenge signature // Share (in secret) each other's pubkey & challenge signature
authSigMsg, err := shareAuthSignature(sc, locPubKey, locSignature) authSigMsg, err := shareAuthSignature(sc, locPubKey, locSignature)
if err != nil { if err != nil {
fmt.Println(">>>", err)
return nil, err return nil, err
} }
remPubKey, remSignature := authSigMsg.Key, authSigMsg.Sig remPubKey, remSignature := authSigMsg.Key, authSigMsg.Sig
@ -123,7 +123,7 @@ func (sc *SecretConnection) Write(data []byte) (n int, err error) {
data = nil data = nil
} }
chunkLength := len(chunk) chunkLength := len(chunk)
binary.BigEndian.PutUint16(frame, uint16(chunkLength))
binary.BigEndian.PutUint32(frame, uint32(chunkLength))
copy(frame[dataLenSize:], chunk) copy(frame[dataLenSize:], chunk)
// encrypt the frame // encrypt the frame
@ -167,7 +167,7 @@ func (sc *SecretConnection) Read(data []byte) (n int, err error) {
incr2Nonce(sc.recvNonce) incr2Nonce(sc.recvNonce)
// end decryption // end decryption
var chunkLength = binary.BigEndian.Uint16(frame) // read the first two bytes
var chunkLength = binary.BigEndian.Uint32(frame) // read the first two bytes
if chunkLength > dataMaxSize { if chunkLength > dataMaxSize {
return 0, errors.New("chunkLength is greater than dataMaxSize") return 0, errors.New("chunkLength is greater than dataMaxSize")
} }
@ -200,26 +200,41 @@ func genEphKeys() (ephPub, ephPriv *[32]byte) {
} }
func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) { func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) {
var err1, err2 error
cmn.Parallel(
func() {
_, err1 = conn.Write(locEphPub[:])
// Send our pubkey and receive theirs in tandem.
var trs, _ = cmn.Parallel(
func(_ int) (val interface{}, err error, abort bool) {
var _, err1 = cdc.MarshalBinaryWriter(conn, locEphPub)
if err1 != nil {
return nil, err1, true // abort
} else {
return nil, nil, false
}
}, },
func() {
remEphPub = new([32]byte)
_, err2 = io.ReadFull(conn, remEphPub[:])
func(_ int) (val interface{}, err error, abort bool) {
var _remEphPub [32]byte
var _, err2 = cdc.UnmarshalBinaryReader(conn, &_remEphPub, 1024*1024) // TODO
if err2 != nil {
return nil, err2, true // abort
} else {
return _remEphPub, nil, false
}
}, },
) )
if err1 != nil {
return nil, err1
}
if err2 != nil {
return nil, err2
// If error:
if trs.FirstError() != nil {
err = trs.FirstError()
return
} else if trs.FirstPanic() != nil {
err = fmt.Errorf("Panic: %v", trs.FirstPanic())
return
} }
return remEphPub, nil
// Otherwise:
var _remEphPub = trs.FirstValue().([32]byte)
return &_remEphPub, nil
} }
func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) { func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) {
@ -268,33 +283,40 @@ type authSigMessage struct {
Sig crypto.Signature Sig crypto.Signature
} }
func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKey, signature crypto.Signature) (*authSigMessage, error) {
var recvMsg authSigMessage
var err1, err2 error
func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKey, signature crypto.Signature) (recvMsg authSigMessage, err error) {
cmn.Parallel(
func() {
msgBytes := wire.BinaryBytes(authSigMessage{pubKey.Wrap(), signature.Wrap()})
_, err1 = sc.Write(msgBytes)
// Send our info and receive theirs in tandem.
var trs, _ = cmn.Parallel(
func(_ int) (val interface{}, err error, abort bool) {
var _, err1 = cdc.MarshalBinaryWriter(sc, authSigMessage{pubKey, signature})
if err1 != nil {
return nil, err1, true // abort
} else {
return nil, nil, false
}
}, },
func() {
readBuffer := make([]byte, authSigMsgSize)
_, err2 = io.ReadFull(sc, readBuffer)
func(_ int) (val interface{}, err error, abort bool) {
var _recvMsg authSigMessage
var _, err2 = cdc.UnmarshalBinaryReader(sc, &_recvMsg, 1024*1024) // TODO
if err2 != nil { if err2 != nil {
return
return nil, err2, true // abort
} else {
return _recvMsg, nil, false
} }
n := int(0) // not used.
recvMsg = wire.ReadBinary(authSigMessage{}, bytes.NewBuffer(readBuffer), authSigMsgSize, &n, &err2).(authSigMessage)
})
},
)
if err1 != nil {
return nil, err1
}
if err2 != nil {
return nil, err2
// If error:
if trs.FirstError() != nil {
err = trs.FirstError()
return
} else if trs.FirstPanic() != nil {
err = fmt.Errorf("Panic: %v", trs.FirstPanic())
return
} }
return &recvMsg, nil
var _recvMsg = trs.FirstValue().(authSigMessage)
return _recvMsg, nil
} }
//-------------------------------------------------------------------------------- //--------------------------------------------------------------------------------


+ 78
- 36
p2p/conn/secret_connection_test.go View File

@ -1,9 +1,12 @@
package conn package conn
import ( import (
"fmt"
"io" "io"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
crypto "github.com/tendermint/go-crypto" crypto "github.com/tendermint/go-crypto"
cmn "github.com/tendermint/tmlibs/common" cmn "github.com/tendermint/tmlibs/common"
) )
@ -30,39 +33,49 @@ func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
} }
func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) { func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
fooConn, barConn := makeKVStoreConnPair()
fooPrvKey := crypto.GenPrivKeyEd25519().Wrap()
fooPubKey := fooPrvKey.PubKey()
barPrvKey := crypto.GenPrivKeyEd25519().Wrap()
barPubKey := barPrvKey.PubKey()
cmn.Parallel(
func() {
var err error
var fooConn, barConn = makeKVStoreConnPair()
var fooPrvKey = crypto.GenPrivKeyEd25519()
var fooPubKey = fooPrvKey.PubKey()
var barPrvKey = crypto.GenPrivKeyEd25519()
var barPubKey = barPrvKey.PubKey()
// Make connections from both sides in parallel.
var trs, ok = cmn.Parallel(
func(_ int) (val interface{}, err error, abort bool) {
fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey) fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
if err != nil { if err != nil {
tb.Errorf("Failed to establish SecretConnection for foo: %v", err) tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
return
return nil, err, true
} }
remotePubBytes := fooSecConn.RemotePubKey() remotePubBytes := fooSecConn.RemotePubKey()
if !remotePubBytes.Equals(barPubKey) { if !remotePubBytes.Equals(barPubKey) {
tb.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
err = fmt.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
barPubKey, fooSecConn.RemotePubKey()) barPubKey, fooSecConn.RemotePubKey())
tb.Error(err)
return nil, err, false
} }
return nil, nil, false
}, },
func() {
var err error
func(_ int) (val interface{}, err error, abort bool) {
barSecConn, err = MakeSecretConnection(barConn, barPrvKey) barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
if barSecConn == nil { if barSecConn == nil {
tb.Errorf("Failed to establish SecretConnection for bar: %v", err) tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
return
return nil, err, true
} }
remotePubBytes := barSecConn.RemotePubKey() remotePubBytes := barSecConn.RemotePubKey()
if !remotePubBytes.Equals(fooPubKey) { if !remotePubBytes.Equals(fooPubKey) {
tb.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v",
err = fmt.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v",
fooPubKey, barSecConn.RemotePubKey()) fooPubKey, barSecConn.RemotePubKey())
tb.Error(err)
return nil, nil, false
} }
})
return nil, nil, false
},
)
require.Nil(tb, trs.FirstPanic())
require.Nil(tb, trs.FirstError())
require.True(tb, ok, "Unexpected task abortion")
return return
} }
@ -89,59 +102,80 @@ func TestSecretConnectionReadWrite(t *testing.T) {
} }
// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
genNodeRunner := func(nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) func() {
return func() {
// Node handskae
nodePrvKey := crypto.GenPrivKeyEd25519().Wrap()
genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) cmn.Task {
return func(_ int) (interface{}, error, bool) {
// Initiate cryptographic private key and secret connection trhough nodeConn.
nodePrvKey := crypto.GenPrivKeyEd25519()
nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey) nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
if err != nil { if err != nil {
t.Errorf("Failed to establish SecretConnection for node: %v", err) t.Errorf("Failed to establish SecretConnection for node: %v", err)
return
return nil, err, true
} }
// In parallel, handle reads and writes
cmn.Parallel(
func() {
// Node writes
// In parallel, handle some reads and writes.
var trs, ok = cmn.Parallel(
func(_ int) (interface{}, error, bool) {
// Node writes:
for _, nodeWrite := range nodeWrites { for _, nodeWrite := range nodeWrites {
n, err := nodeSecretConn.Write([]byte(nodeWrite)) n, err := nodeSecretConn.Write([]byte(nodeWrite))
if err != nil { if err != nil {
t.Errorf("Failed to write to nodeSecretConn: %v", err) t.Errorf("Failed to write to nodeSecretConn: %v", err)
return
return nil, err, true
} }
if n != len(nodeWrite) { if n != len(nodeWrite) {
t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
return
err = fmt.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
t.Error(err)
return nil, err, true
} }
} }
if err := nodeConn.PipeWriter.Close(); err != nil { if err := nodeConn.PipeWriter.Close(); err != nil {
t.Error(err) t.Error(err)
return nil, err, true
} }
return nil, nil, false
}, },
func() {
// Node reads
func(_ int) (interface{}, error, bool) {
// Node reads:
readBuffer := make([]byte, dataMaxSize) readBuffer := make([]byte, dataMaxSize)
for { for {
n, err := nodeSecretConn.Read(readBuffer) n, err := nodeSecretConn.Read(readBuffer)
if err == io.EOF { if err == io.EOF {
return
return nil, nil, false
} else if err != nil { } else if err != nil {
t.Errorf("Failed to read from nodeSecretConn: %v", err) t.Errorf("Failed to read from nodeSecretConn: %v", err)
return
return nil, err, true
} }
*nodeReads = append(*nodeReads, string(readBuffer[:n])) *nodeReads = append(*nodeReads, string(readBuffer[:n]))
} }
if err := nodeConn.PipeReader.Close(); err != nil { if err := nodeConn.PipeReader.Close(); err != nil {
t.Error(err) t.Error(err)
return nil, err, true
} }
})
return nil, nil, false
},
)
assert.True(t, ok, "Unexpected task abortion")
// If error:
if trs.FirstError() != nil {
return nil, trs.FirstError(), true
} else if trs.FirstPanic() != nil {
err = fmt.Errorf("Panic in task: %v", trs.FirstPanic())
return nil, err, true
}
// Otherwise:
return nil, nil, false
} }
} }
// Run foo & bar in parallel // Run foo & bar in parallel
cmn.Parallel(
genNodeRunner(fooConn, fooWrites, &fooReads),
genNodeRunner(barConn, barWrites, &barReads),
var trs, ok = cmn.Parallel(
genNodeRunner("foo", fooConn, fooWrites, &fooReads),
genNodeRunner("bar", barConn, barWrites, &barReads),
) )
require.Nil(t, trs.FirstPanic())
require.Nil(t, trs.FirstError())
require.True(t, ok, "unexpected task abortion")
// A helper to ensure that the writes and reads match. // A helper to ensure that the writes and reads match.
// Additionally, small writes (<= dataMaxSize) must be atomically read. // Additionally, small writes (<= dataMaxSize) must be atomically read.
@ -209,3 +243,11 @@ func BenchmarkSecretConnection(b *testing.B) {
} }
//barSecConn.Close() race condition //barSecConn.Close() race condition
} }
func fingerprint(bz []byte) []byte {
if len(bz) < 40 {
return bz
} else {
return bz[:40]
}
}

+ 14
- 0
p2p/conn/wire.go View File

@ -0,0 +1,14 @@
package conn
import (
"github.com/tendermint/go-amino"
"github.com/tendermint/go-crypto"
)
var cdc *amino.Codec
func init() {
cdc = amino.NewCodec()
crypto.RegisterAmino(cdc)
RegisterPacket(cdc)
}

+ 0
- 60
wire/wire.go View File

@ -1,60 +0,0 @@
package wire
import (
"github.com/tendermint/go-wire"
)
/*
// Expose access to a global wire codec
// TODO: maybe introduce some Context object
// containing logger, config, codec that can
// be threaded through everything to avoid this global
var cdc *wire.Codec
func init() {
cdc = wire.NewCodec()
crypto.RegisterWire(cdc)
}
*/
// Just a flow through to go-wire.
// To be used later for the global codec
func MarshalBinary(o interface{}) ([]byte, error) {
return wire.MarshalBinary(o)
}
func UnmarshalBinary(bz []byte, ptr interface{}) error {
return wire.UnmarshalBinary(bz, ptr)
}
func MarshalJSON(o interface{}) ([]byte, error) {
return wire.MarshalJSON(o)
}
func UnmarshalJSON(jsonBz []byte, ptr interface{}) error {
return wire.UnmarshalJSON(jsonBz, ptr)
}
type ConcreteType = wire.ConcreteType
func RegisterInterface(o interface{}, ctypes ...ConcreteType) *wire.TypeInfo {
return wire.RegisterInterface(o, ctypes...)
}
const RFC3339Millis = wire.RFC3339Millis
/*
func RegisterInterface(ptr interface{}, opts *wire.InterfaceOptions) {
cdc.RegisterInterface(ptr, opts)
}
func RegisterConcrete(o interface{}, name string, opts *wire.ConcreteOptions) {
cdc.RegisterConcrete(o, name, opts)
}
//-------------------------------
const RFC3339Millis = wire.RFC3339Millis
*/

Loading…
Cancel
Save