Browse Source

rewrite pong timer to use time.AfterFunc

pull/1095/head
Anton Kaliaev 7 years ago
parent
commit
f4ff66de30
No known key found for this signature in database GPG Key ID: 7B6881D965918214
2 changed files with 126 additions and 23 deletions
  1. +23
    -19
      p2p/conn/connection.go
  2. +103
    -4
      p2p/conn/connection_test.go

+ 23
- 19
p2p/conn/connection.go View File

@ -83,11 +83,15 @@ type MConnection struct {
errored uint32 errored uint32
config *MConnConfig config *MConnConfig
quit chan struct{}
flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled.
pingTimer *cmn.RepeatTimer // send pings periodically
pongTimer *time.Timer // close conn if pong is not received in pongTimeout
chStatsTimer *cmn.RepeatTimer // update channel stats periodically
quit chan struct{}
flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled.
pingTimer *cmn.RepeatTimer // send pings periodically
// close conn if pong is not received in pongTimeout
pongTimer *time.Timer
pongTimeoutCh chan struct{}
chStatsTimer *cmn.RepeatTimer // update channel stats periodically
created time.Time // time of creation created time.Time // time of creation
} }
@ -187,10 +191,7 @@ func (c *MConnection) OnStart() error {
c.quit = make(chan struct{}) c.quit = make(chan struct{})
c.flushTimer = cmn.NewThrottleTimer("flush", c.config.FlushThrottle) c.flushTimer = cmn.NewThrottleTimer("flush", c.config.FlushThrottle)
c.pingTimer = cmn.NewRepeatTimer("ping", c.config.PingInterval) c.pingTimer = cmn.NewRepeatTimer("ping", c.config.PingInterval)
c.pongTimer = time.NewTimer(c.config.PongTimeout)
// we start timer once we've send ping; needed here because we use start
// listening in recvRoutine
_ = c.pongTimer.Stop()
c.pongTimeoutCh = make(chan struct{})
c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateStats) c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateStats)
go c.sendRoutine() go c.sendRoutine()
go c.recvRoutine() go c.recvRoutine()
@ -200,13 +201,12 @@ func (c *MConnection) OnStart() error {
// OnStop implements BaseService // OnStop implements BaseService
func (c *MConnection) OnStop() { func (c *MConnection) OnStop() {
c.BaseService.OnStop() c.BaseService.OnStop()
c.flushTimer.Stop()
c.pingTimer.Stop()
_ = c.pongTimer.Stop()
c.chStatsTimer.Stop()
if c.quit != nil { if c.quit != nil {
close(c.quit) close(c.quit)
} }
c.flushTimer.Stop()
c.pingTimer.Stop()
c.chStatsTimer.Stop()
c.conn.Close() // nolint: errcheck c.conn.Close() // nolint: errcheck
// We can't close pong safely here because // We can't close pong safely here because
@ -337,12 +337,13 @@ FOR_LOOP:
c.Logger.Debug("Send Ping") c.Logger.Debug("Send Ping")
wire.WriteByte(packetTypePing, c.bufWriter, &n, &err) wire.WriteByte(packetTypePing, c.bufWriter, &n, &err)
c.sendMonitor.Update(int(n)) c.sendMonitor.Update(int(n))
c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout)
c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() {
c.pongTimeoutCh <- struct{}{}
})
c.flush() c.flush()
c.Logger.Debug("Starting pong timer")
c.pongTimer.Reset(c.config.PongTimeout)
case <-c.pongTimer.C:
case <-c.pongTimeoutCh:
c.Logger.Debug("Pong timeout") c.Logger.Debug("Pong timeout")
// XXX: should we decrease peer score instead of closing connection?
err = errors.New("pong timeout") err = errors.New("pong timeout")
case <-c.pong: case <-c.pong:
c.Logger.Debug("Send Pong") c.Logger.Debug("Send Pong")
@ -350,6 +351,9 @@ FOR_LOOP:
c.sendMonitor.Update(int(n)) c.sendMonitor.Update(int(n))
c.flush() c.flush()
case <-c.quit: case <-c.quit:
if c.pongTimer != nil {
_ = c.pongTimer.Stop()
}
break FOR_LOOP break FOR_LOOP
case <-c.send: case <-c.send:
// Send some msgPackets // Send some msgPackets
@ -483,8 +487,8 @@ FOR_LOOP:
} }
case packetTypePong: case packetTypePong:
c.Logger.Debug("Receive Pong") c.Logger.Debug("Receive Pong")
if !c.pongTimer.Stop() {
<-c.pongTimer.C
if c.pongTimer != nil {
_ = c.pongTimer.Stop()
} }
case packetTypeMsg: case packetTypeMsg:
pkt, n, err := msgPacket{}, int(0), error(nil) pkt, n, err := msgPacket{}, int(0), error(nil)


+ 103
- 4
p2p/conn/connection_test.go View File

@ -24,7 +24,7 @@ func createTestMConnection(conn net.Conn) *MConnection {
func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *MConnection { func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *MConnection {
chDescs := []*ChannelDescriptor{&ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} chDescs := []*ChannelDescriptor{&ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
cfg := DefaultMConnConfig() cfg := DefaultMConnConfig()
cfg.PingInterval = 60 * time.Millisecond
cfg.PingInterval = 90 * time.Millisecond
cfg.PongTimeout = 45 * time.Millisecond cfg.PongTimeout = 45 * time.Millisecond
c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg) c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg)
c.SetLogger(log.TestingLogger()) c.SetLogger(log.TestingLogger())
@ -137,19 +137,118 @@ func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
defer mconn.Stop() defer mconn.Stop()
serverGotPing := make(chan struct{})
go func() { go func() {
// read ping // read ping
server.Read(make([]byte, 1)) server.Read(make([]byte, 1))
serverGotPing <- struct{}{}
}() }()
<-serverGotPing
expectErrorAfter := (mconn.config.PingInterval + mconn.config.PongTimeout) * 2
pongTimerExpired := mconn.config.PongTimeout + 10*time.Millisecond
select { select {
case msgBytes := <-receivedCh: case msgBytes := <-receivedCh:
t.Fatalf("Expected error, but got %v", msgBytes) t.Fatalf("Expected error, but got %v", msgBytes)
case err := <-errorsCh: case err := <-errorsCh:
assert.NotNil(t, err) assert.NotNil(t, err)
case <-time.After(expectErrorAfter):
t.Fatalf("Expected to receive error after %v", expectErrorAfter)
case <-time.After(pongTimerExpired):
t.Fatalf("Expected to receive error after %v", pongTimerExpired)
}
}
func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
defer client.Close()
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(chID byte, msgBytes []byte) {
receivedCh <- msgBytes
}
onError := func(r interface{}) {
errorsCh <- r
}
mconn := createMConnectionWithCallbacks(client, onReceive, onError)
err := mconn.Start()
require.Nil(t, err)
defer mconn.Stop()
// sending 3 pongs in a row
_, err = server.Write([]byte{packetTypePong})
require.Nil(t, err)
_, err = server.Write([]byte{packetTypePong})
require.Nil(t, err)
_, err = server.Write([]byte{packetTypePong})
require.Nil(t, err)
serverGotPing := make(chan struct{})
go func() {
// read ping
server.Read(make([]byte, 1))
serverGotPing <- struct{}{}
// respond with pong
_, err = server.Write([]byte{packetTypePong})
require.Nil(t, err)
}()
<-serverGotPing
pongTimerExpired := mconn.config.PongTimeout + 10*time.Millisecond
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected no data, but got %v", msgBytes)
case err := <-errorsCh:
t.Fatalf("Expected no error, but got %v", err)
case <-time.After(pongTimerExpired):
assert.True(t, mconn.IsRunning())
}
}
func TestMConnectionPingPongs(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
defer client.Close()
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(chID byte, msgBytes []byte) {
receivedCh <- msgBytes
}
onError := func(r interface{}) {
errorsCh <- r
}
mconn := createMConnectionWithCallbacks(client, onReceive, onError)
err := mconn.Start()
require.Nil(t, err)
defer mconn.Stop()
serverGotPing := make(chan struct{})
go func() {
// read ping
server.Read(make([]byte, 1))
serverGotPing <- struct{}{}
// respond with pong
_, err = server.Write([]byte{packetTypePong})
require.Nil(t, err)
time.Sleep(mconn.config.PingInterval)
// read ping
server.Read(make([]byte, 1))
// respond with pong
_, err = server.Write([]byte{packetTypePong})
require.Nil(t, err)
}()
<-serverGotPing
pongTimerExpired := (mconn.config.PongTimeout + 10*time.Millisecond) * 2
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected no data, but got %v", msgBytes)
case err := <-errorsCh:
t.Fatalf("Expected no error, but got %v", err)
case <-time.After(2 * pongTimerExpired):
assert.True(t, mconn.IsRunning())
} }
} }


Loading…
Cancel
Save