Browse Source

fix race by sending signal instead of stopping pongTimer

pull/1095/head
Anton Kaliaev 7 years ago
parent
commit
45750e1b29
No known key found for this signature in database GPG Key ID: 7B6881D965918214
1 changed files with 19 additions and 9 deletions
  1. +19
    -9
      p2p/conn/connection.go

+ 19
- 9
p2p/conn/connection.go View File

@ -89,7 +89,7 @@ type MConnection struct {
// close conn if pong is not received in pongTimeout // close conn if pong is not received in pongTimeout
pongTimer *time.Timer pongTimer *time.Timer
pongTimeoutCh chan struct{}
pongTimeoutCh chan bool // true - timeout, false - peer sent pong
chStatsTimer *cmn.RepeatTimer // update channel stats periodically chStatsTimer *cmn.RepeatTimer // update channel stats periodically
@ -191,7 +191,7 @@ func (c *MConnection) OnStart() error {
c.quit = make(chan struct{}) c.quit = make(chan struct{})
c.flushTimer = cmn.NewThrottleTimer("flush", c.config.FlushThrottle) c.flushTimer = cmn.NewThrottleTimer("flush", c.config.FlushThrottle)
c.pingTimer = cmn.NewRepeatTimer("ping", c.config.PingInterval) c.pingTimer = cmn.NewRepeatTimer("ping", c.config.PingInterval)
c.pongTimeoutCh = make(chan struct{})
c.pongTimeoutCh = make(chan bool)
c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateStats) c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateStats)
go c.sendRoutine() go c.sendRoutine()
go c.recvRoutine() go c.recvRoutine()
@ -339,19 +339,22 @@ FOR_LOOP:
c.sendMonitor.Update(int(n)) c.sendMonitor.Update(int(n))
c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout) c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout)
c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() { c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() {
c.pongTimeoutCh <- struct{}{}
c.pongTimeoutCh <- true
}) })
c.flush() c.flush()
case <-c.pongTimeoutCh:
c.Logger.Debug("Pong timeout")
err = errors.New("pong timeout")
case timeout := <-c.pongTimeoutCh:
if timeout {
c.Logger.Debug("Pong timeout")
err = errors.New("pong timeout")
} else {
c.stopPongTimer()
}
case <-c.pong: case <-c.pong:
c.Logger.Debug("Send Pong") c.Logger.Debug("Send Pong")
wire.WriteByte(packetTypePong, c.bufWriter, &n, &err) wire.WriteByte(packetTypePong, c.bufWriter, &n, &err)
c.sendMonitor.Update(int(n)) c.sendMonitor.Update(int(n))
c.flush() c.flush()
case <-c.quit: case <-c.quit:
c.stopPongTimer()
break FOR_LOOP break FOR_LOOP
case <-c.send: case <-c.send:
// Send some msgPackets // Send some msgPackets
@ -376,6 +379,7 @@ FOR_LOOP:
} }
// Cleanup // Cleanup
c.stopPongTimer()
} }
// Returns true if messages from channels were exhausted. // Returns true if messages from channels were exhausted.
@ -486,7 +490,11 @@ FOR_LOOP:
} }
case packetTypePong: case packetTypePong:
c.Logger.Debug("Receive Pong") c.Logger.Debug("Receive Pong")
c.stopPongTimer()
select {
case c.pongTimeoutCh <- false:
case <-c.quit:
break FOR_LOOP
}
case packetTypeMsg: case packetTypeMsg:
pkt, n, err := msgPacket{}, int(0), error(nil) pkt, n, err := msgPacket{}, int(0), error(nil)
wire.ReadBinaryPtr(&pkt, c.bufReader, c.config.maxMsgPacketTotalSize(), &n, &err) wire.ReadBinaryPtr(&pkt, c.bufReader, c.config.maxMsgPacketTotalSize(), &n, &err)
@ -534,12 +542,14 @@ FOR_LOOP:
} }
} }
// not goroutine-safe
func (c *MConnection) stopPongTimer() { func (c *MConnection) stopPongTimer() {
if c.pongTimer != nil { if c.pongTimer != nil {
if !c.pongTimer.Stop() { if !c.pongTimer.Stop() {
<-c.pongTimer.C <-c.pongTimer.C
} }
drain(c.pongTimeoutCh) drain(c.pongTimeoutCh)
c.pongTimer = nil
} }
} }
@ -771,7 +781,7 @@ func (p msgPacket) String() string {
return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF) return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF)
} }
func drain(ch <-chan struct{}) {
func drain(ch <-chan bool) {
for { for {
select { select {
case <-ch: case <-ch:


Loading…
Cancel
Save