diff --git a/mempool/mempool.go b/mempool/mempool.go index acc52ae6e..2b7be1ddb 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -59,8 +59,7 @@ type Mempool struct { // Keep a cache of already-seen txs. // This reduces the pressure on the proxyApp. - cacheMap map[string]struct{} - cacheList *list.List // to remove oldest tx when cache gets too big + cache *txCache } func NewMempool(config cfg.Config, proxyAppConn proxy.AppConnMempool) *Mempool { @@ -74,8 +73,7 @@ func NewMempool(config cfg.Config, proxyAppConn proxy.AppConnMempool) *Mempool { recheckCursor: nil, recheckEnd: nil, - cacheMap: make(map[string]struct{}, cacheSize), - cacheList: list.New(), + cache: newTxCache(cacheSize), } proxyAppConn.SetResponseCallback(mempool.resCb) return mempool @@ -100,8 +98,7 @@ func (mem *Mempool) Flush() { mem.proxyMtx.Lock() defer mem.proxyMtx.Unlock() - mem.cacheMap = make(map[string]struct{}, cacheSize) - mem.cacheList.Init() + mem.cache.Reset() for e := mem.txs.Front(); e != nil; e = e.Next() { mem.txs.Remove(e) @@ -125,7 +122,7 @@ func (mem *Mempool) CheckTx(tx types.Tx, cb func(*tmsp.Response)) (err error) { defer mem.proxyMtx.Unlock() // CACHE - if _, exists := mem.cacheMap[string(tx)]; exists { + if mem.cache.Exists(tx) { if cb != nil { cb(&tmsp.Response{ Value: &tmsp.Response_CheckTx{ @@ -138,16 +135,7 @@ func (mem *Mempool) CheckTx(tx types.Tx, cb func(*tmsp.Response)) (err error) { } return nil } - if mem.cacheList.Len() >= cacheSize { - popped := mem.cacheList.Front() - poppedTx := popped.Value.(types.Tx) - // NOTE: the tx may have already been removed from the map - // but deleting a non-existant element is fine - delete(mem.cacheMap, string(poppedTx)) - mem.cacheList.Remove(popped) - } - mem.cacheMap[string(tx)] = struct{}{} - mem.cacheList.PushBack(tx) + mem.cache.Push(tx) // END CACHE // NOTE: proxyAppConn may error if tx buffer is full @@ -162,13 +150,6 @@ func (mem *Mempool) CheckTx(tx types.Tx, cb func(*tmsp.Response)) (err error) { return nil } -func (mem *Mempool) removeTxFromCacheMap(tx []byte) { - mem.proxyMtx.Lock() - // NOTE tx not removed from cacheList - delete(mem.cacheMap, string(tx)) - mem.proxyMtx.Unlock() -} - // TMSP callback function func (mem *Mempool) resCb(req *tmsp.Request, res *tmsp.Response) { if mem.recheckCursor == nil { @@ -194,9 +175,7 @@ func (mem *Mempool) resCbNormal(req *tmsp.Request, res *tmsp.Response) { log.Info("Bad Transaction", "res", r) // remove from cache (it might be good later) - // note this is an async callback, - // so we need to grab the lock in removeTxFromCacheMap - mem.removeTxFromCacheMap(req.GetCheckTx().Tx) + mem.cache.Remove(req.GetCheckTx().Tx) // TODO: handle other retcodes } @@ -221,7 +200,7 @@ func (mem *Mempool) resCbRecheck(req *tmsp.Request, res *tmsp.Response) { mem.recheckCursor.DetachPrev() // remove from cache (it might be good later) - mem.removeTxFromCacheMap(req.GetCheckTx().Tx) + mem.cache.Remove(req.GetCheckTx().Tx) } if mem.recheckCursor == mem.recheckEnd { mem.recheckCursor = nil @@ -348,3 +327,62 @@ type mempoolTx struct { func (memTx *mempoolTx) Height() int { return int(atomic.LoadInt64(&memTx.height)) } + +//-------------------------------------------------------------------------------- + +type txCache struct { + mtx sync.Mutex + size int + map_ map[string]struct{} + list *list.List // to remove oldest tx when cache gets too big +} + +func newTxCache(cacheSize int) *txCache { + return &txCache{ + size: cacheSize, + map_: make(map[string]struct{}, cacheSize), + list: list.New(), + } +} + +func (cache *txCache) Reset() { + cache.mtx.Lock() + cache.map_ = make(map[string]struct{}, cacheSize) + cache.list.Init() + cache.mtx.Unlock() +} + +func (cache *txCache) Exists(tx types.Tx) bool { + cache.mtx.Lock() + _, exists := cache.map_[string(tx)] + cache.mtx.Unlock() + return exists +} + +// Returns false if tx is in cache. +func (cache *txCache) Push(tx types.Tx) bool { + cache.mtx.Lock() + defer cache.mtx.Unlock() + + if _, exists := cache.map_[string(tx)]; exists { + return false + } + + if cache.list.Len() >= cache.size { + popped := cache.list.Front() + poppedTx := popped.Value.(types.Tx) + // NOTE: the tx may have already been removed from the map + // but deleting a non-existant element is fine + delete(cache.map_, string(poppedTx)) + cache.list.Remove(popped) + } + cache.map_[string(tx)] = struct{}{} + cache.list.PushBack(tx) + return true +} + +func (cache *txCache) Remove(tx types.Tx) { + cache.mtx.Lock() + delete(cache.map_, string(tx)) + cache.mtx.Unlock() +}