From 75bad132fc71e08b13dbf3b1b15b6fa8026d7adf Mon Sep 17 00:00:00 2001 From: Ethan Buchman Date: Thu, 20 Apr 2017 17:29:43 -0400 Subject: [PATCH] msgCountByPeer is a CMap --- pex_reactor.go | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/pex_reactor.go b/pex_reactor.go index 0244416d1..4b6129762 100644 --- a/pex_reactor.go +++ b/pex_reactor.go @@ -49,7 +49,7 @@ type PEXReactor struct { ensurePeersPeriod time.Duration // tracks message count by peer, so we can prevent abuse - msgCountByPeer map[string]uint16 + msgCountByPeer *cmn.CMap maxMsgCountByPeer uint16 } @@ -58,7 +58,7 @@ func NewPEXReactor(b *AddrBook) *PEXReactor { r := &PEXReactor{ book: b, ensurePeersPeriod: defaultEnsurePeersPeriod, - msgCountByPeer: make(map[string]uint16), + msgCountByPeer: cmn.NewCMap(), maxMsgCountByPeer: defaultMaxMsgCountByPeer, } r.BaseReactor = *NewBaseReactor(log, "PEXReactor", r) @@ -122,7 +122,8 @@ func (r *PEXReactor) RemovePeer(p *Peer, reason interface{}) { func (r *PEXReactor) Receive(chID byte, src *Peer, msgBytes []byte) { srcAddr := src.Connection().RemoteAddress srcAddrStr := srcAddr.String() - r.msgCountByPeer[srcAddrStr]++ + + r.IncrementMsgCountForPeer(srcAddrStr) if r.ReachedMaxMsgCountForPeer(srcAddrStr) { log.Warn("Maximum number of messages reached for peer", "peer", srcAddrStr) // TODO remove src from peers? @@ -175,8 +176,20 @@ func (r *PEXReactor) SetMaxMsgCountByPeer(v uint16) { // ReachedMaxMsgCountForPeer returns true if we received too many // messages from peer with address `addr`. +// NOTE: assumes the value in the CMap is non-nil func (r *PEXReactor) ReachedMaxMsgCountForPeer(addr string) bool { - return r.msgCountByPeer[addr] >= r.maxMsgCountByPeer + return r.msgCountByPeer.Get(addr).(uint16) >= r.maxMsgCountByPeer +} + +// Increment or initialize the msg count for the peer in the CMap +func (r *PEXReactor) IncrementMsgCountForPeer(addr string) { + var count uint16 + countI := r.msgCountByPeer.Get(addr) + if countI != nil { + count = countI.(uint16) + } + count++ + r.msgCountByPeer.Set(addr, count) } // Ensures that sufficient peers are connected. (continuous) @@ -288,7 +301,7 @@ func (r *PEXReactor) flushMsgCountByPeer() { for { select { case <-ticker.C: - r.msgCountByPeer = make(map[string]uint16) + r.msgCountByPeer.Clear() case <-r.Quit: ticker.Stop() return