Browse Source

p2p: backport changes in ping/pong tolerances (#8009)

pull/8012/head
Sam Kleinman 2 years ago
committed by GitHub
parent
commit
a0321633b0
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 117 additions and 106 deletions
  1. +49
    -39
      internal/p2p/conn/connection.go
  2. +68
    -67
      internal/p2p/conn/connection_test.go

+ 49
- 39
internal/p2p/conn/connection.go View File

@ -9,6 +9,7 @@ import (
"net"
"reflect"
"runtime/debug"
"sync"
"sync/atomic"
"time"
@ -45,7 +46,7 @@ const (
defaultRecvRate = int64(512000) // 500KB/s
defaultSendTimeout = 10 * time.Second
defaultPingInterval = 60 * time.Second
defaultPongTimeout = 45 * time.Second
defaultPongTimeout = 90 * time.Second
)
type receiveCbFunc func(chID byte, msgBytes []byte)
@ -108,8 +109,10 @@ type MConnection struct {
pingTimer *time.Ticker // send pings periodically
// close conn if pong is not received in pongTimeout
pongTimer *time.Timer
pongTimeoutCh chan bool // true - timeout, false - peer sent pong
lastMsgRecv struct {
sync.Mutex
at time.Time
}
chStatsTimer *time.Ticker // update channel stats periodically
@ -171,10 +174,6 @@ func NewMConnectionWithConfig(
onError errorCbFunc,
config MConnConfig,
) *MConnection {
if config.PongTimeout >= config.PingInterval {
panic("pongTimeout must be less than pingInterval (otherwise, next ping will reset pong timer)")
}
mconn := &MConnection{
conn: conn,
bufConnReader: bufio.NewReaderSize(conn, minReadBufferSize),
@ -223,16 +222,28 @@ func (c *MConnection) OnStart() error {
}
c.flushTimer = timer.NewThrottleTimer("flush", c.config.FlushThrottle)
c.pingTimer = time.NewTicker(c.config.PingInterval)
c.pongTimeoutCh = make(chan bool, 1)
c.chStatsTimer = time.NewTicker(updateStats)
c.quitSendRoutine = make(chan struct{})
c.doneSendRoutine = make(chan struct{})
c.quitRecvRoutine = make(chan struct{})
c.setRecvLastMsgAt(time.Now())
go c.sendRoutine()
go c.recvRoutine()
return nil
}
func (c *MConnection) setRecvLastMsgAt(t time.Time) {
c.lastMsgRecv.Lock()
defer c.lastMsgRecv.Unlock()
c.lastMsgRecv.at = t
}
func (c *MConnection) getLastMessageAt() time.Time {
c.lastMsgRecv.Lock()
defer c.lastMsgRecv.Unlock()
return c.lastMsgRecv.at
}
// stopServices stops the BaseService and timers and closes the quitSendRoutine.
// if the quitSendRoutine was already closed, it returns true, otherwise it returns false.
// It uses the stopMtx to ensure only one of FlushStop and OnStop can do this at a time.
@ -423,6 +434,8 @@ func (c *MConnection) sendRoutine() {
defer c._recover()
protoWriter := protoio.NewDelimitedWriter(c.bufConnWriter)
pongTimeout := time.NewTicker(c.config.PongTimeout)
defer pongTimeout.Stop()
FOR_LOOP:
for {
var _n int
@ -445,21 +458,7 @@ FOR_LOOP:
break SELECTION
}
c.sendMonitor.Update(_n)
c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout)
c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() {
select {
case c.pongTimeoutCh <- true:
default:
}
})
c.flush()
case timeout := <-c.pongTimeoutCh:
if timeout {
c.Logger.Debug("Pong timeout")
err = errors.New("pong timeout")
} else {
c.stopPongTimer()
}
case <-c.pong:
c.Logger.Debug("Send Pong")
_n, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
@ -471,6 +470,14 @@ FOR_LOOP:
c.flush()
case <-c.quitSendRoutine:
break FOR_LOOP
case <-pongTimeout.C:
// the point of the pong timer is to check to
// see if we've seen a message recently, so we
// want to make sure that we escape this
// select statement on an interval to ensure
// that we avoid hanging on to dead
// connections for too long.
break SELECTION
case <-c.send:
// Send some PacketMsgs
eof := c.sendSomePacketMsgs()
@ -483,18 +490,21 @@ FOR_LOOP:
}
}
if !c.IsRunning() {
break FOR_LOOP
if time.Since(c.getLastMessageAt()) > c.config.PongTimeout {
err = errors.New("pong timeout")
}
if err != nil {
c.Logger.Error("Connection failed @ sendRoutine", "conn", c, "err", err)
c.stopForError(err)
break FOR_LOOP
}
if !c.IsRunning() {
break FOR_LOOP
}
}
// Cleanup
c.stopPongTimer()
close(c.doneSendRoutine)
}
@ -563,6 +573,14 @@ func (c *MConnection) recvRoutine() {
FOR_LOOP:
for {
select {
case <-c.quitRecvRoutine:
break FOR_LOOP
case <-c.doneSendRoutine:
break FOR_LOOP
default:
}
// Block until .recvMonitor says we can read.
c.recvMonitor.Limit(c._maxPacketMsgSize, atomic.LoadInt64(&c.config.RecvRate), true)
@ -605,6 +623,9 @@ FOR_LOOP:
break FOR_LOOP
}
// record for pong/heartbeat
c.setRecvLastMsgAt(time.Now())
// Read more depending on packet type.
switch pkt := packet.Sum.(type) {
case *tmp2p.Packet_PacketPing:
@ -617,12 +638,9 @@ FOR_LOOP:
// never block
}
case *tmp2p.Packet_PacketPong:
c.Logger.Debug("Receive Pong")
select {
case c.pongTimeoutCh <- false:
default:
// never block
}
// do nothing, we updated the "last message
// received" timestamp above, so we can ignore
// this message
case *tmp2p.Packet_PacketMsg:
channelID := byte(pkt.PacketMsg.ChannelID)
channel, ok := c.channelsIdx[channelID]
@ -661,14 +679,6 @@ FOR_LOOP:
}
}
// not goroutine-safe
func (c *MConnection) stopPongTimer() {
if c.pongTimer != nil {
_ = c.pongTimer.Stop()
c.pongTimer = nil
}
}
// maxPacketMsgSize returns a maximum size of PacketMsg
func (c *MConnection) maxPacketMsgSize() int {
bz, err := proto.Marshal(mustWrapPacket(&tmp2p.PacketMsg{


+ 68
- 67
internal/p2p/conn/connection_test.go View File

@ -1,7 +1,9 @@
package conn
import (
"context"
"encoding/hex"
"io"
"net"
"testing"
"time"
@ -35,8 +37,8 @@ func createMConnectionWithCallbacks(
onError func(r interface{}),
) *MConnection {
cfg := DefaultMConnConfig()
cfg.PingInterval = 90 * time.Millisecond
cfg.PongTimeout = 45 * time.Millisecond
cfg.PingInterval = 250 * time.Millisecond
cfg.PongTimeout = 500 * time.Millisecond
chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg)
c.SetLogger(log.TestingLogger())
@ -159,41 +161,43 @@ func TestMConnectionStatus(t *testing.T) {
assert.Zero(t, status.Channels[0].SendQueueSize)
}
func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
func TestMConnectionWillEventuallyTimeout(t *testing.T) {
server, client := net.Pipe()
t.Cleanup(closeAll(t, client, server))
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(chID byte, msgBytes []byte) {
receivedCh <- msgBytes
}
onError := func(r interface{}) {
errorsCh <- r
}
mconn := createMConnectionWithCallbacks(client, onReceive, onError)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconn := createMConnectionWithCallbacks(client, nil, nil)
err := mconn.Start()
require.Nil(t, err)
t.Cleanup(stopAll(t, mconn))
require.NoError(t, err)
require.True(t, mconn.IsRunning())
serverGotPing := make(chan struct{})
go func() {
// read ping
var pkt tmp2p.Packet
_, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&pkt)
require.NoError(t, err)
serverGotPing <- struct{}{}
// read the send buffer so that the send receive
// doesn't get blocked.
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
_, _ = io.ReadAll(server)
case <-ctx.Done():
return
}
}
}()
<-serverGotPing
pongTimerExpired := mconn.config.PongTimeout + 200*time.Millisecond
// wait for the send routine to die because it doesn't
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected error, but got %v", msgBytes)
case err := <-errorsCh:
assert.NotNil(t, err)
case <-time.After(pongTimerExpired):
t.Fatalf("Expected to receive error after %v", pongTimerExpired)
case <-mconn.doneSendRoutine:
require.True(t, time.Since(mconn.getLastMessageAt()) > mconn.config.PongTimeout,
"the connection state reflects that we've passed the pong timeout")
// since we hit the timeout, things should be shutdown
require.False(t, mconn.IsRunning())
case <-time.After(2 * mconn.config.PongTimeout):
t.Fatal("connection did not hit timeout", mconn.config.PongTimeout)
}
}
@ -226,19 +230,14 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
serverGotPing := make(chan struct{})
go func() {
// read ping (one byte)
var packet tmp2p.Packet
_, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
require.NoError(t, err)
serverGotPing <- struct{}{}
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
}()
<-serverGotPing
// read ping (one byte)
var packet tmp2p.Packet
_, err = protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
require.NoError(t, err)
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
select {
@ -299,52 +298,54 @@ func TestMConnectionPingPongs(t *testing.T) {
// check that we are not leaking any go-routines
t.Cleanup(leaktest.CheckTimeout(t, 10*time.Second))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server, client := net.Pipe()
t.Cleanup(closeAll(t, client, server))
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(chID byte, msgBytes []byte) {
receivedCh <- msgBytes
select {
case <-ctx.Done():
case receivedCh <- msgBytes:
}
}
onError := func(r interface{}) {
errorsCh <- r
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
mconn := createMConnectionWithCallbacks(client, onReceive, onError)
err := mconn.Start()
require.Nil(t, err)
t.Cleanup(stopAll(t, mconn))
serverGotPing := make(chan struct{})
go func() {
protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
protoWriter := protoio.NewDelimitedWriter(server)
var pkt tmp2p.PacketPing
protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
protoWriter := protoio.NewDelimitedWriter(server)
var pkt tmp2p.PacketPing
// read ping
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
serverGotPing <- struct{}{}
// read ping
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
time.Sleep(mconn.config.PingInterval)
time.Sleep(mconn.config.PingInterval)
// read ping
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
serverGotPing <- struct{}{}
// read ping
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
}()
<-serverGotPing
<-serverGotPing
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2
pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 4
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected no data, but got %v", msgBytes)


Loading…
Cancel
Save