Browse Source

p2p: mconn track last message for pongs (#7995)

* p2p: mconn track last message for pongs

* fix spell

* cr feedback

* test fix part one

* cleanup tests

* fix comment

Co-authored-by: M. J. Fromberger <fromberger@interchain.io>
pull/8007/head
Sam Kleinman 2 years ago
committed by GitHub
parent
commit
c85e3e4ba8
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 106 deletions
  1. +47
    -36
      internal/p2p/conn/connection.go
  2. +52
    -70
      internal/p2p/conn/connection_test.go

+ 47
- 36
internal/p2p/conn/connection.go View File

@ -108,8 +108,10 @@ type MConnection struct {
pingTimer *time.Ticker // send pings periodically
// close conn if pong is not received in pongTimeout
pongTimer *time.Timer
pongTimeoutCh chan bool // true - timeout, false - peer sent pong
lastMsgRecv struct {
sync.Mutex
at time.Time
}
chStatsTimer *time.Ticker // update channel stats periodically
@ -161,10 +163,6 @@ func NewMConnection(
onError errorCbFunc,
config MConnConfig,
) *MConnection {
if config.PongTimeout >= config.PingInterval {
panic("pongTimeout must be less than pingInterval (otherwise, next ping will reset pong timer)")
}
mconn := &MConnection{
logger: logger,
conn: conn,
@ -205,16 +203,28 @@ func NewMConnection(
func (c *MConnection) OnStart(ctx context.Context) error {
c.flushTimer = timer.NewThrottleTimer("flush", c.config.FlushThrottle)
c.pingTimer = time.NewTicker(c.config.PingInterval)
c.pongTimeoutCh = make(chan bool, 1)
c.chStatsTimer = time.NewTicker(updateStats)
c.quitSendRoutine = make(chan struct{})
c.doneSendRoutine = make(chan struct{})
c.quitRecvRoutine = make(chan struct{})
c.setRecvLastMsgAt(time.Now())
go c.sendRoutine(ctx)
go c.recvRoutine(ctx)
return nil
}
func (c *MConnection) setRecvLastMsgAt(t time.Time) {
c.lastMsgRecv.Lock()
defer c.lastMsgRecv.Unlock()
c.lastMsgRecv.at = t
}
func (c *MConnection) getLastMessageAt() time.Time {
c.lastMsgRecv.Lock()
defer c.lastMsgRecv.Unlock()
return c.lastMsgRecv.at
}
// stopServices stops the BaseService and timers and closes the quitSendRoutine.
// if the quitSendRoutine was already closed, it returns true, otherwise it returns false.
// It uses the stopMtx to ensure only one of FlushStop and OnStop can do this at a time.
@ -323,6 +333,8 @@ func (c *MConnection) sendRoutine(ctx context.Context) {
defer c._recover(ctx)
protoWriter := protoio.NewDelimitedWriter(c.bufConnWriter)
pongTimeout := time.NewTicker(c.config.PongTimeout)
defer pongTimeout.Stop()
FOR_LOOP:
for {
var _n int
@ -344,20 +356,7 @@ FOR_LOOP:
break SELECTION
}
c.sendMonitor.Update(_n)
c.logger.Debug("Starting pong timer", "dur", c.config.PongTimeout)
c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() {
select {
case c.pongTimeoutCh <- true:
default:
}
})
c.flush()
case timeout := <-c.pongTimeoutCh:
if timeout {
err = errors.New("pong timeout")
} else {
c.stopPongTimer()
}
case <-c.pong:
_n, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
if err != nil {
@ -370,6 +369,14 @@ FOR_LOOP:
break FOR_LOOP
case <-c.quitSendRoutine:
break FOR_LOOP
case <-pongTimeout.C:
// the point of the pong timer is to check to
// see if we've seen a message recently, so we
// want to make sure that we escape this
// select statement on an interval to ensure
// that we avoid hanging on to dead
// connections for too long.
break SELECTION
case <-c.send:
// Send some PacketMsgs
eof := c.sendSomePacketMsgs(ctx)
@ -382,18 +389,21 @@ FOR_LOOP:
}
}
if !c.IsRunning() {
break FOR_LOOP
if time.Since(c.getLastMessageAt()) > c.config.PongTimeout {
err = errors.New("pong timeout")
}
if err != nil {
c.logger.Error("Connection failed @ sendRoutine", "conn", c, "err", err)
c.stopForError(ctx, err)
break FOR_LOOP
}
if !c.IsRunning() {
break FOR_LOOP
}
}
// Cleanup
c.stopPongTimer()
close(c.doneSendRoutine)
}
@ -462,6 +472,14 @@ func (c *MConnection) recvRoutine(ctx context.Context) {
FOR_LOOP:
for {
select {
case <-ctx.Done():
break FOR_LOOP
case <-c.doneSendRoutine:
break FOR_LOOP
default:
}
// Block until .recvMonitor says we can read.
c.recvMonitor.Limit(c._maxPacketMsgSize, atomic.LoadInt64(&c.config.RecvRate), true)
@ -505,6 +523,9 @@ FOR_LOOP:
break FOR_LOOP
}
// record for pong/heartbeat
c.setRecvLastMsgAt(time.Now())
// Read more depending on packet type.
switch pkt := packet.Sum.(type) {
case *tmp2p.Packet_PacketPing:
@ -516,11 +537,9 @@ FOR_LOOP:
// never block
}
case *tmp2p.Packet_PacketPong:
select {
case c.pongTimeoutCh <- false:
default:
// never block
}
// do nothing, we updated the "last message
// received" timestamp above, so we can ignore
// this message
case *tmp2p.Packet_PacketMsg:
channelID := ChannelID(pkt.PacketMsg.ChannelID)
channel, ok := c.channelsIdx[channelID]
@ -559,14 +578,6 @@ FOR_LOOP:
}
}
// not goroutine-safe
func (c *MConnection) stopPongTimer() {
if c.pongTimer != nil {
_ = c.pongTimer.Stop()
c.pongTimer = nil
}
}
// maxPacketMsgSize returns a maximum size of PacketMsg
func (c *MConnection) maxPacketMsgSize() int {
bz, err := proto.Marshal(mustWrapPacket(&tmp2p.PacketMsg{


+ 52
- 70
internal/p2p/conn/connection_test.go View File

@ -3,6 +3,7 @@ package conn
import (
"context"
"encoding/hex"
"io"
"net"
"sync"
"testing"
@ -39,8 +40,8 @@ func createMConnectionWithCallbacks(
onError func(ctx context.Context, r interface{}),
) *MConnection {
cfg := DefaultMConnConfig()
cfg.PingInterval = 90 * time.Millisecond
cfg.PongTimeout = 45 * time.Millisecond
cfg.PingInterval = 250 * time.Millisecond
cfg.PongTimeout = 500 * time.Millisecond
chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
c := NewMConnection(logger, conn, chDescs, onReceive, onError, cfg)
return c
@ -160,51 +161,44 @@ func TestMConnectionReceive(t *testing.T) {
}
}
func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
func TestMConnectionWillEventuallyTimeout(t *testing.T) {
server, client := net.Pipe()
t.Cleanup(closeAll(t, client, server))
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, nil, nil)
err := mconn.Start(ctx)
require.NoError(t, err)
t.Cleanup(waitAll(mconn))
require.True(t, mconn.IsRunning())
serverGotPing := make(chan struct{})
go func() {
// read ping
var pkt tmp2p.Packet
_, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&pkt)
require.NoError(t, err)
serverGotPing <- struct{}{}
// read the send buffer so that the send receive
// doesn't get blocked.
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
_, _ = io.ReadAll(server)
case <-ctx.Done():
return
}
}
}()
<-serverGotPing
pongTimerExpired := mconn.config.PongTimeout + 200*time.Millisecond
// wait for the send routine to die because it doesn't
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected error, but got %v", msgBytes)
case err := <-errorsCh:
assert.NotNil(t, err)
case <-time.After(pongTimerExpired):
t.Fatalf("Expected to receive error after %v", pongTimerExpired)
case <-mconn.doneSendRoutine:
require.True(t, time.Since(mconn.getLastMessageAt()) > mconn.config.PongTimeout,
"the connection state reflects that we've passed the pong timeout")
// since we hit the timeout, things should be shutdown
require.False(t, mconn.IsRunning())
case <-time.After(2 * mconn.config.PongTimeout):
t.Fatal("connection did not hit timeout", mconn.config.PongTimeout)
}
}
@ -247,19 +241,14 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
serverGotPing := make(chan struct{})
go func() {
// read ping (one byte)
var packet tmp2p.Packet
_, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
require.NoError(t, err)
serverGotPing <- struct{}{}
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
}()
<-serverGotPing
// read ping (one byte)
var packet tmp2p.Packet
_, err = protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
require.NoError(t, err)
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
select {
@ -355,36 +344,29 @@ func TestMConnectionPingPongs(t *testing.T) {
require.NoError(t, err)
t.Cleanup(waitAll(mconn))
serverGotPing := make(chan struct{})
go func() {
protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
protoWriter := protoio.NewDelimitedWriter(server)
var pkt tmp2p.PacketPing
protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
protoWriter := protoio.NewDelimitedWriter(server)
var pkt tmp2p.PacketPing
// read ping
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
serverGotPing <- struct{}{}
// read ping
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
time.Sleep(mconn.config.PingInterval)
time.Sleep(mconn.config.PingInterval)
// read ping
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
serverGotPing <- struct{}{}
// read ping
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
}()
<-serverGotPing
<-serverGotPing
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2
pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 4
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected no data, but got %v", msgBytes)


Loading…
Cancel
Save