diff --git a/p2p/conn/connection.go b/p2p/conn/connection.go index 6fbb425e7..938c3eb2c 100644 --- a/p2p/conn/connection.go +++ b/p2p/conn/connection.go @@ -89,7 +89,7 @@ type MConnection struct { // close conn if pong is not received in pongTimeout pongTimer *time.Timer - pongTimeoutCh chan struct{} + pongTimeoutCh chan bool // true - timeout, false - peer sent pong chStatsTimer *cmn.RepeatTimer // update channel stats periodically @@ -191,7 +191,7 @@ func (c *MConnection) OnStart() error { c.quit = make(chan struct{}) c.flushTimer = cmn.NewThrottleTimer("flush", c.config.FlushThrottle) c.pingTimer = cmn.NewRepeatTimer("ping", c.config.PingInterval) - c.pongTimeoutCh = make(chan struct{}) + c.pongTimeoutCh = make(chan bool) c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateStats) go c.sendRoutine() go c.recvRoutine() @@ -339,19 +339,22 @@ FOR_LOOP: c.sendMonitor.Update(int(n)) c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout) c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() { - c.pongTimeoutCh <- struct{}{} + c.pongTimeoutCh <- true }) c.flush() - case <-c.pongTimeoutCh: - c.Logger.Debug("Pong timeout") - err = errors.New("pong timeout") + case timeout := <-c.pongTimeoutCh: + if timeout { + c.Logger.Debug("Pong timeout") + err = errors.New("pong timeout") + } else { + c.stopPongTimer() + } case <-c.pong: c.Logger.Debug("Send Pong") wire.WriteByte(packetTypePong, c.bufWriter, &n, &err) c.sendMonitor.Update(int(n)) c.flush() case <-c.quit: - c.stopPongTimer() break FOR_LOOP case <-c.send: // Send some msgPackets @@ -376,6 +379,7 @@ FOR_LOOP: } // Cleanup + c.stopPongTimer() } // Returns true if messages from channels were exhausted. @@ -486,7 +490,11 @@ FOR_LOOP: } case packetTypePong: c.Logger.Debug("Receive Pong") - c.stopPongTimer() + select { + case c.pongTimeoutCh <- false: + case <-c.quit: + break FOR_LOOP + } case packetTypeMsg: pkt, n, err := msgPacket{}, int(0), error(nil) wire.ReadBinaryPtr(&pkt, c.bufReader, c.config.maxMsgPacketTotalSize(), &n, &err) @@ -534,12 +542,14 @@ FOR_LOOP: } } +// not goroutine-safe func (c *MConnection) stopPongTimer() { if c.pongTimer != nil { if !c.pongTimer.Stop() { <-c.pongTimer.C } drain(c.pongTimeoutCh) + c.pongTimer = nil } } @@ -771,7 +781,7 @@ func (p msgPacket) String() string { return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF) } -func drain(ch <-chan struct{}) { +func drain(ch <-chan bool) { for { select { case <-ch: