diff --git a/p2p/conn/connection.go b/p2p/conn/connection.go index b49a45db0..23c8edc9d 100644 --- a/p2p/conn/connection.go +++ b/p2p/conn/connection.go @@ -34,8 +34,8 @@ const ( defaultSendRate = int64(512000) // 500KB/s defaultRecvRate = int64(512000) // 500KB/s defaultSendTimeout = 10 * time.Second - defaultPingTimeout = 40 * time.Second - defaultPongTimeout = 60 * time.Second + defaultPingInterval = 40 * time.Second + defaultPongTimeout = 35 * time.Second ) type receiveCbFunc func(chID byte, msgBytes []byte) @@ -86,7 +86,7 @@ type MConnection struct { quit chan struct{} flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled. pingTimer *cmn.RepeatTimer // send pings periodically - pongTimer *cmn.ThrottleTimer // close conn if pong not recv in 1 min + pongTimer *time.Timer // close conn if pong is not received in pongTimeout chStatsTimer *cmn.RepeatTimer // update channel stats periodically created time.Time // time of creation @@ -101,8 +101,8 @@ type MConnConfig struct { FlushThrottle time.Duration - pingTimeout time.Duration - pongTimeout time.Duration + pingInterval time.Duration + pongTimeout time.Duration } func (cfg *MConnConfig) maxMsgPacketTotalSize() int { @@ -116,7 +116,7 @@ func DefaultMConnConfig() *MConnConfig { RecvRate: defaultRecvRate, MaxMsgPacketPayloadSize: defaultMaxMsgPacketPayloadSize, FlushThrottle: defaultFlushThrottle, - pingTimeout: defaultPingTimeout, + pingInterval: defaultPingInterval, pongTimeout: defaultPongTimeout, } } @@ -133,6 +133,10 @@ func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive recei // NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection { + if config.pongTimeout >= config.pingInterval { + panic("pongTimeout must be less than pingInterval") + } + mconn := &MConnection{ conn: conn, bufReader: bufio.NewReaderSize(conn, minReadBufferSize), @@ -176,9 +180,12 @@ func (c *MConnection) OnStart() error { return err } c.quit = make(chan struct{}) - c.flushTimer = cmn.NewThrottleTimer("flush", c.config.FlushThrottle) - c.pingTimer = cmn.NewRepeatTimer("ping", c.config.pingTimeout) - c.pongTimer = cmn.NewThrottleTimer("pong", c.config.pongTimeout) + c.flushTimer = cmn.NewThrottleTimer("flush", c.config.flushThrottle) + c.pingTimer = cmn.NewRepeatTimer("ping", c.config.pingInterval) + c.pongTimer = time.NewTimer(c.config.pongTimeout) + // we start timer once we've send ping; needed here because we use start + // listening in recvRoutine + _ = c.pongTimer.Stop() c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateStats) go c.sendRoutine() go c.recvRoutine() @@ -190,7 +197,7 @@ func (c *MConnection) OnStop() { c.BaseService.OnStop() c.flushTimer.Stop() c.pingTimer.Stop() - c.pongTimer.Stop() + _ = c.pongTimer.Stop() c.chStatsTimer.Stop() if c.quit != nil { close(c.quit) @@ -325,12 +332,12 @@ FOR_LOOP: c.Logger.Debug("Send Ping") wire.WriteByte(packetTypePing, c.bufWriter, &n, &err) c.sendMonitor.Update(int(n)) - // should be c.flush go c.flush() c.Logger.Debug("Starting pong timer") - c.pongTimer.Set() - case <-c.pongTimer.Ch: + c.pongTimer.Reset(c.config.pongTimeout) + case <-c.pongTimer.C: c.Logger.Debug("Pong timeout") + // XXX: should we decrease peer score instead of closing connection? err = errors.New("pong timeout") case <-c.pong: c.Logger.Debug("Send Pong") @@ -471,8 +478,9 @@ FOR_LOOP: } case packetTypePong: c.Logger.Debug("Receive Pong") - // Should we unset pongTimer if we get other packet? - c.pongTimer.Unset() + if !c.pongTimer.Stop() { + <-c.pongTimer.C + } case packetTypeMsg: pkt, n, err := msgPacket{}, int(0), error(nil) wire.ReadBinaryPtr(&pkt, c.bufReader, c.config.maxMsgPacketTotalSize(), &n, &err) diff --git a/p2p/conn/connection_test.go b/p2p/conn/connection_test.go index 5686af6a6..5570331e7 100644 --- a/p2p/conn/connection_test.go +++ b/p2p/conn/connection_test.go @@ -24,8 +24,8 @@ func createTestMConnection(conn net.Conn) *MConnection { func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *MConnection { chDescs := []*ChannelDescriptor{&ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} cfg := DefaultMConnConfig() - cfg.pingTimeout = 40 * time.Millisecond - cfg.pongTimeout = 60 * time.Millisecond + cfg.pingInterval = 40 * time.Millisecond + cfg.pongTimeout = 35 * time.Millisecond c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg) c.SetLogger(log.TestingLogger()) return c @@ -119,9 +119,7 @@ func TestMConnectionStatus(t *testing.T) { assert.Zero(status.Channels[0].SendQueueSize) } -func TestPingPongTimeout(t *testing.T) { - assert, require := assert.New(t), require.New(t) - +func TestPongTimeoutResultsInError(t *testing.T) { server, client := net.Pipe() defer server.Close() defer client.Close() @@ -135,18 +133,18 @@ func TestPingPongTimeout(t *testing.T) { errorsCh <- r } mconn := createMConnectionWithCallbacks(client, onReceive, onError) - _, err := mconn.Start() - require.Nil(err) + err := mconn.Start() + require.Nil(t, err) defer mconn.Stop() + expectErrorAfter := 10*time.Millisecond + mconn.config.pingInterval + mconn.config.pongTimeout select { - case receivedBytes := <-receivedCh: - t.Fatalf("Expected error, got %v", receivedBytes) + case msgBytes := <-receivedCh: + t.Fatalf("Expected error, but got %v", msgBytes) case err := <-errorsCh: - assert.NotNil(err) - assert.False(mconn.IsRunning()) - case <-time.After(10*time.Millisecond + mconn.config.pingTimeout + mconn.config.pongTimeout): - t.Fatal("Did not receive error in ~(pingTimeout + pongTimeout) seconds") + assert.NotNil(t, err) + case <-time.After(expectErrorAfter): + t.Fatalf("Expected to receive error after %v", expectErrorAfter) } }