diff --git a/consensus/mempool_test.go b/consensus/mempool_test.go index 2a063be75..b2246ce54 100644 --- a/consensus/mempool_test.go +++ b/consensus/mempool_test.go @@ -2,7 +2,6 @@ package consensus import ( "encoding/binary" - // "math/rand" "testing" "time" @@ -52,6 +51,67 @@ func TestTxConcurrentWithCommit(t *testing.T) { } } +func TestRmBadTx(t *testing.T) { + state, privVals := randGenesisState(1, false, 10) + app := NewCounterApplication() + cs := newConsensusState(state, privVals[0], app) + + // increment the counter by 1 + txBytes := make([]byte, 8) + binary.BigEndian.PutUint64(txBytes, uint64(0)) + app.AppendTx(txBytes) + app.Commit() + + ch := make(chan struct{}) + cbCh := make(chan struct{}) + go func() { + // Try to send the tx through the mempool. + // CheckTx should not err, but the app should return a bad tmsp code + // and the tx should get removed from the pool + err := cs.mempool.CheckTx(txBytes, func(r *tmsp.Response) { + if r.GetCheckTx().Code != tmsp.CodeType_BadNonce { + t.Fatalf("expected checktx to return bad nonce, got %v", r) + } + cbCh <- struct{}{} + }) + if err != nil { + t.Fatal("Error after CheckTx: %v", err) + } + + // check for the tx + for { + time.Sleep(time.Second) + select { + case <-ch: + default: + txs := cs.mempool.Reap(1) + if len(txs) == 0 { + ch <- struct{}{} + } + + } + } + }() + + // Wait until the tx returns + ticker := time.After(time.Second * 5) + select { + case <-cbCh: + // success + case <-ticker: + t.Fatalf("Timed out waiting for tx to return") + } + + // Wait until the tx is removed + ticker = time.After(time.Second * 5) + select { + case <-ch: + // success + case <-ticker: + t.Fatalf("Timed out waiting for tx to be removed") + } +} + // CounterApplication that maintains a mempool state and resets it upon commit type CounterApplication struct { txCount int @@ -84,11 +144,7 @@ func runTx(tx []byte, countPtr *int) tmsp.Result { copy(tx8[len(tx8)-len(tx):], tx) txValue := binary.BigEndian.Uint64(tx8) if txValue != uint64(count) { - return tmsp.Result{ - Code: tmsp.CodeType_BadNonce, - Data: nil, - Log: Fmt("Invalid nonce. Expected %v, got %v", count, txValue), - } + return tmsp.ErrBadNonce.AppendLog(Fmt("Invalid nonce. Expected %v, got %v", count, txValue)) } *countPtr += 1 return tmsp.OK diff --git a/mempool/mempool.go b/mempool/mempool.go index acc52ae6e..2b7be1ddb 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -59,8 +59,7 @@ type Mempool struct { // Keep a cache of already-seen txs. // This reduces the pressure on the proxyApp. - cacheMap map[string]struct{} - cacheList *list.List // to remove oldest tx when cache gets too big + cache *txCache } func NewMempool(config cfg.Config, proxyAppConn proxy.AppConnMempool) *Mempool { @@ -74,8 +73,7 @@ func NewMempool(config cfg.Config, proxyAppConn proxy.AppConnMempool) *Mempool { recheckCursor: nil, recheckEnd: nil, - cacheMap: make(map[string]struct{}, cacheSize), - cacheList: list.New(), + cache: newTxCache(cacheSize), } proxyAppConn.SetResponseCallback(mempool.resCb) return mempool @@ -100,8 +98,7 @@ func (mem *Mempool) Flush() { mem.proxyMtx.Lock() defer mem.proxyMtx.Unlock() - mem.cacheMap = make(map[string]struct{}, cacheSize) - mem.cacheList.Init() + mem.cache.Reset() for e := mem.txs.Front(); e != nil; e = e.Next() { mem.txs.Remove(e) @@ -125,7 +122,7 @@ func (mem *Mempool) CheckTx(tx types.Tx, cb func(*tmsp.Response)) (err error) { defer mem.proxyMtx.Unlock() // CACHE - if _, exists := mem.cacheMap[string(tx)]; exists { + if mem.cache.Exists(tx) { if cb != nil { cb(&tmsp.Response{ Value: &tmsp.Response_CheckTx{ @@ -138,16 +135,7 @@ func (mem *Mempool) CheckTx(tx types.Tx, cb func(*tmsp.Response)) (err error) { } return nil } - if mem.cacheList.Len() >= cacheSize { - popped := mem.cacheList.Front() - poppedTx := popped.Value.(types.Tx) - // NOTE: the tx may have already been removed from the map - // but deleting a non-existant element is fine - delete(mem.cacheMap, string(poppedTx)) - mem.cacheList.Remove(popped) - } - mem.cacheMap[string(tx)] = struct{}{} - mem.cacheList.PushBack(tx) + mem.cache.Push(tx) // END CACHE // NOTE: proxyAppConn may error if tx buffer is full @@ -162,13 +150,6 @@ func (mem *Mempool) CheckTx(tx types.Tx, cb func(*tmsp.Response)) (err error) { return nil } -func (mem *Mempool) removeTxFromCacheMap(tx []byte) { - mem.proxyMtx.Lock() - // NOTE tx not removed from cacheList - delete(mem.cacheMap, string(tx)) - mem.proxyMtx.Unlock() -} - // TMSP callback function func (mem *Mempool) resCb(req *tmsp.Request, res *tmsp.Response) { if mem.recheckCursor == nil { @@ -194,9 +175,7 @@ func (mem *Mempool) resCbNormal(req *tmsp.Request, res *tmsp.Response) { log.Info("Bad Transaction", "res", r) // remove from cache (it might be good later) - // note this is an async callback, - // so we need to grab the lock in removeTxFromCacheMap - mem.removeTxFromCacheMap(req.GetCheckTx().Tx) + mem.cache.Remove(req.GetCheckTx().Tx) // TODO: handle other retcodes } @@ -221,7 +200,7 @@ func (mem *Mempool) resCbRecheck(req *tmsp.Request, res *tmsp.Response) { mem.recheckCursor.DetachPrev() // remove from cache (it might be good later) - mem.removeTxFromCacheMap(req.GetCheckTx().Tx) + mem.cache.Remove(req.GetCheckTx().Tx) } if mem.recheckCursor == mem.recheckEnd { mem.recheckCursor = nil @@ -348,3 +327,62 @@ type mempoolTx struct { func (memTx *mempoolTx) Height() int { return int(atomic.LoadInt64(&memTx.height)) } + +//-------------------------------------------------------------------------------- + +type txCache struct { + mtx sync.Mutex + size int + map_ map[string]struct{} + list *list.List // to remove oldest tx when cache gets too big +} + +func newTxCache(cacheSize int) *txCache { + return &txCache{ + size: cacheSize, + map_: make(map[string]struct{}, cacheSize), + list: list.New(), + } +} + +func (cache *txCache) Reset() { + cache.mtx.Lock() + cache.map_ = make(map[string]struct{}, cacheSize) + cache.list.Init() + cache.mtx.Unlock() +} + +func (cache *txCache) Exists(tx types.Tx) bool { + cache.mtx.Lock() + _, exists := cache.map_[string(tx)] + cache.mtx.Unlock() + return exists +} + +// Returns false if tx is in cache. +func (cache *txCache) Push(tx types.Tx) bool { + cache.mtx.Lock() + defer cache.mtx.Unlock() + + if _, exists := cache.map_[string(tx)]; exists { + return false + } + + if cache.list.Len() >= cache.size { + popped := cache.list.Front() + poppedTx := popped.Value.(types.Tx) + // NOTE: the tx may have already been removed from the map + // but deleting a non-existant element is fine + delete(cache.map_, string(poppedTx)) + cache.list.Remove(popped) + } + cache.map_[string(tx)] = struct{}{} + cache.list.PushBack(tx) + return true +} + +func (cache *txCache) Remove(tx types.Tx) { + cache.mtx.Lock() + delete(cache.map_, string(tx)) + cache.mtx.Unlock() +} diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index d5bd6b130..4755bf096 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -2,12 +2,11 @@ package mempool import ( "encoding/binary" - "sync" "testing" "github.com/tendermint/tendermint/config/tendermint_test" + "github.com/tendermint/tendermint/proxy" "github.com/tendermint/tendermint/types" - tmspcli "github.com/tendermint/tmsp/client" "github.com/tendermint/tmsp/example/counter" ) @@ -16,9 +15,9 @@ func TestSerialReap(t *testing.T) { app := counter.NewCounterApplication(true) app.SetOption("serial", "on") - mtx := new(sync.Mutex) - appConnMem := tmspcli.NewLocalClient(mtx, app) - appConnCon := tmspcli.NewLocalClient(mtx, app) + cc := proxy.NewLocalClientCreator(app) + appConnMem, _ := cc.NewTMSPClient() + appConnCon, _ := cc.NewTMSPClient() mempool := NewMempool(config, appConnMem) appendTxsRange := func(start, end int) { @@ -66,13 +65,13 @@ func TestSerialReap(t *testing.T) { for i := start; i < end; i++ { txBytes := make([]byte, 8) binary.BigEndian.PutUint64(txBytes, uint64(i)) - res := appConnCon.AppendTx(txBytes) + res := appConnCon.AppendTxSync(txBytes) if !res.IsOK() { t.Errorf("Error committing tx. Code:%v result:%X log:%v", res.Code, res.Data, res.Log) } } - res := appConnCon.Commit() + res := appConnCon.CommitSync() if len(res.Data) != 8 { t.Errorf("Error committing. Hash:%X log:%v", res.Data, res.Log) }