diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 7cf3ab4e5..bebc3e6a6 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -20,6 +20,8 @@ ### IMPROVEMENTS: +- [mempool] \#2778 No longer send txs back to peers who sent it to you + ### BUG FIXES: - [blockchain] \#2699 update the maxHeight when a peer is removed diff --git a/docs/spec/reactors/mempool/reactor.md b/docs/spec/reactors/mempool/reactor.md index fa25eeb3e..d0b19f7ca 100644 --- a/docs/spec/reactors/mempool/reactor.md +++ b/docs/spec/reactors/mempool/reactor.md @@ -12,3 +12,5 @@ for details. Sending incorrectly encoded data or data exceeding `maxMsgSize` will result in stopping the peer. + +The mempool will not send a tx back to any peer which it received it from. \ No newline at end of file diff --git a/mempool/bench_test.go b/mempool/bench_test.go index 8936f8dfb..0cd394cd6 100644 --- a/mempool/bench_test.go +++ b/mempool/bench_test.go @@ -26,6 +26,19 @@ func BenchmarkReap(b *testing.B) { } } +func BenchmarkCheckTx(b *testing.B) { + app := kvstore.NewKVStoreApplication() + cc := proxy.NewLocalClientCreator(app) + mempool, cleanup := newMempoolWithApp(cc) + defer cleanup() + + for i := 0; i < b.N; i++ { + tx := make([]byte, 8) + binary.BigEndian.PutUint64(tx, uint64(i)) + mempool.CheckTx(tx, nil) + } +} + func BenchmarkCacheInsertTime(b *testing.B) { cache := newMapTxCache(b.N) txs := make([][]byte, b.N) diff --git a/mempool/cache_test.go b/mempool/cache_test.go new file mode 100644 index 000000000..26e560b6e --- /dev/null +++ b/mempool/cache_test.go @@ -0,0 +1,101 @@ +package mempool + +import ( + "crypto/rand" + "crypto/sha256" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/tendermint/tendermint/abci/example/kvstore" + "github.com/tendermint/tendermint/proxy" + "github.com/tendermint/tendermint/types" +) + +func TestCacheRemove(t *testing.T) { + cache := newMapTxCache(100) + numTxs := 10 + txs := make([][]byte, numTxs) + for i := 0; i < numTxs; i++ { + // probability of collision is 2**-256 + txBytes := make([]byte, 32) + rand.Read(txBytes) + txs[i] = txBytes + cache.Push(txBytes) + // make sure its added to both the linked list and the map + require.Equal(t, i+1, len(cache.map_)) + require.Equal(t, i+1, cache.list.Len()) + } + for i := 0; i < numTxs; i++ { + cache.Remove(txs[i]) + // make sure its removed from both the map and the linked list + require.Equal(t, numTxs-(i+1), len(cache.map_)) + require.Equal(t, numTxs-(i+1), cache.list.Len()) + } +} + +func TestCacheAfterUpdate(t *testing.T) { + app := kvstore.NewKVStoreApplication() + cc := proxy.NewLocalClientCreator(app) + mempool, cleanup := newMempoolWithApp(cc) + defer cleanup() + + // reAddIndices & txsInCache can have elements > numTxsToCreate + // also assumes max index is 255 for convenience + // txs in cache also checks order of elements + tests := []struct { + numTxsToCreate int + updateIndices []int + reAddIndices []int + txsInCache []int + }{ + {1, []int{}, []int{1}, []int{1, 0}}, // adding new txs works + {2, []int{1}, []int{}, []int{1, 0}}, // update doesn't remove tx from cache + {2, []int{2}, []int{}, []int{2, 1, 0}}, // update adds new tx to cache + {2, []int{1}, []int{1}, []int{1, 0}}, // re-adding after update doesn't make dupe + } + for tcIndex, tc := range tests { + for i := 0; i < tc.numTxsToCreate; i++ { + tx := types.Tx{byte(i)} + err := mempool.CheckTx(tx, nil) + require.NoError(t, err) + } + + updateTxs := []types.Tx{} + for _, v := range tc.updateIndices { + tx := types.Tx{byte(v)} + updateTxs = append(updateTxs, tx) + } + mempool.Update(int64(tcIndex), updateTxs, nil, nil) + + for _, v := range tc.reAddIndices { + tx := types.Tx{byte(v)} + _ = mempool.CheckTx(tx, nil) + } + + cache := mempool.cache.(*mapTxCache) + node := cache.list.Front() + counter := 0 + for node != nil { + require.NotEqual(t, len(tc.txsInCache), counter, + "cache larger than expected on testcase %d", tcIndex) + + nodeVal := node.Value.([sha256.Size]byte) + expectedBz := sha256.Sum256([]byte{byte(tc.txsInCache[len(tc.txsInCache)-counter-1])}) + // Reference for reading the errors: + // >>> sha256('\x00').hexdigest() + // '6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d' + // >>> sha256('\x01').hexdigest() + // '4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a' + // >>> sha256('\x02').hexdigest() + // 'dbc1b4c900ffe48d575b5da5c638040125f65db0fe3e24494b76ea986457d986' + + require.Equal(t, expectedBz, nodeVal, "Equality failed on index %d, tc %d", counter, tcIndex) + counter++ + node = node.Next() + } + require.Equal(t, len(tc.txsInCache), counter, + "cache smaller than expected on testcase %d", tcIndex) + mempool.Flush() + } +} diff --git a/mempool/mempool.go b/mempool/mempool.go index 41ee59cb4..2064b7bce 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -31,6 +31,14 @@ type PreCheckFunc func(types.Tx) error // transaction doesn't require more gas than available for the block. type PostCheckFunc func(types.Tx, *abci.ResponseCheckTx) error +// TxInfo are parameters that get passed when attempting to add a tx to the +// mempool. +type TxInfo struct { + // We don't use p2p.ID here because it's too big. The gain is to store max 2 + // bytes with each tx to identify the sender rather than 20 bytes. + PeerID uint16 +} + /* The mempool pushes new txs onto the proxyAppConn. @@ -148,9 +156,12 @@ func TxID(tx []byte) string { type Mempool struct { config *cfg.MempoolConfig - proxyMtx sync.Mutex - proxyAppConn proxy.AppConnMempool - txs *clist.CList // concurrent linked-list of good txs + proxyMtx sync.Mutex + proxyAppConn proxy.AppConnMempool + txs *clist.CList // concurrent linked-list of good txs + // map for quick access to txs + // Used in CheckTx to record the tx sender. + txsMap map[[sha256.Size]byte]*clist.CElement height int64 // the last block Update()'d to rechecking int32 // for re-checking filtered txs on Update() recheckCursor *clist.CElement // next expected response @@ -161,7 +172,10 @@ type Mempool struct { postCheck PostCheckFunc // Atomic integers - txsBytes int64 // see TxsBytes + + // Used to check if the mempool size is bigger than the allowed limit. + // See TxsBytes + txsBytes int64 // Keep a cache of already-seen txs. // This reduces the pressure on the proxyApp. @@ -189,6 +203,7 @@ func NewMempool( config: config, proxyAppConn: proxyAppConn, txs: clist.New(), + txsMap: make(map[[sha256.Size]byte]*clist.CElement), height: height, rechecking: 0, recheckCursor: nil, @@ -286,8 +301,8 @@ func (mem *Mempool) TxsBytes() int64 { return atomic.LoadInt64(&mem.txsBytes) } -// FlushAppConn flushes the mempool connection to ensure async resCb calls are -// done e.g. from CheckTx. +// FlushAppConn flushes the mempool connection to ensure async reqResCb calls are +// done. E.g. from CheckTx. func (mem *Mempool) FlushAppConn() error { return mem.proxyAppConn.FlushSync() } @@ -304,6 +319,7 @@ func (mem *Mempool) Flush() { e.DetachPrev() } + mem.txsMap = make(map[[sha256.Size]byte]*clist.CElement) _ = atomic.SwapInt64(&mem.txsBytes, 0) } @@ -327,6 +343,13 @@ func (mem *Mempool) TxsWaitChan() <-chan struct{} { // It gets called from another goroutine. // CONTRACT: Either cb will get called, or err returned. func (mem *Mempool) CheckTx(tx types.Tx, cb func(*abci.Response)) (err error) { + return mem.CheckTxWithInfo(tx, cb, TxInfo{PeerID: UnknownPeerID}) +} + +// CheckTxWithInfo performs the same operation as CheckTx, but with extra meta data about the tx. +// Currently this metadata is the peer who sent it, +// used to prevent the tx from being gossiped back to them. +func (mem *Mempool) CheckTxWithInfo(tx types.Tx, cb func(*abci.Response), txInfo TxInfo) (err error) { mem.proxyMtx.Lock() // use defer to unlock mutex because application (*local client*) might panic defer mem.proxyMtx.Unlock() @@ -357,6 +380,17 @@ func (mem *Mempool) CheckTx(tx types.Tx, cb func(*abci.Response)) (err error) { // CACHE if !mem.cache.Push(tx) { + // record the sender + e, ok := mem.txsMap[sha256.Sum256(tx)] + if ok { // tx may be in cache, but not in the mempool + memTx := e.Value.(*mempoolTx) + if _, loaded := memTx.senders.LoadOrStore(txInfo.PeerID, true); loaded { + // TODO: consider punishing peer for dups, + // its non-trivial since invalid txs can become valid, + // but they can spam the same tx with little cost to them atm. + } + } + return ErrTxInCache } // END CACHE @@ -381,27 +415,77 @@ func (mem *Mempool) CheckTx(tx types.Tx, cb func(*abci.Response)) (err error) { } reqRes := mem.proxyAppConn.CheckTxAsync(tx) if cb != nil { - reqRes.SetCallback(cb) + composedCallback := func(res *abci.Response) { + mem.reqResCb(tx, txInfo.PeerID)(res) + cb(res) + } + reqRes.SetCallback(composedCallback) + } else { + reqRes.SetCallback(mem.reqResCb(tx, txInfo.PeerID)) } return nil } -// ABCI callback function +// Global callback, which is called in the absence of the specific callback. +// +// In recheckTxs because no reqResCb (specific) callback is set, this callback +// will be called. func (mem *Mempool) resCb(req *abci.Request, res *abci.Response) { if mem.recheckCursor == nil { - mem.resCbNormal(req, res) - } else { - mem.metrics.RecheckTimes.Add(1) - mem.resCbRecheck(req, res) + return } + + mem.metrics.RecheckTimes.Add(1) + mem.resCbRecheck(req, res) + + // update metrics mem.metrics.Size.Set(float64(mem.Size())) } -func (mem *Mempool) resCbNormal(req *abci.Request, res *abci.Response) { +// Specific callback, which allows us to incorporate local information, like +// the peer that sent us this tx, so we can avoid sending it back to the same +// peer. +// +// Used in CheckTxWithInfo to record PeerID who sent us the tx. +func (mem *Mempool) reqResCb(tx []byte, peerID uint16) func(res *abci.Response) { + return func(res *abci.Response) { + if mem.recheckCursor != nil { + return + } + + mem.resCbFirstTime(tx, peerID, res) + + // update metrics + mem.metrics.Size.Set(float64(mem.Size())) + } +} + +func (mem *Mempool) addTx(memTx *mempoolTx) { + e := mem.txs.PushBack(memTx) + mem.txsMap[sha256.Sum256(memTx.tx)] = e + atomic.AddInt64(&mem.txsBytes, int64(len(memTx.tx))) + mem.metrics.TxSizeBytes.Observe(float64(len(memTx.tx))) +} + +func (mem *Mempool) removeTx(tx types.Tx, elem *clist.CElement, removeFromCache bool) { + mem.txs.Remove(elem) + elem.DetachPrev() + delete(mem.txsMap, sha256.Sum256(tx)) + atomic.AddInt64(&mem.txsBytes, int64(-len(tx))) + + if removeFromCache { + mem.cache.Remove(tx) + } +} + +// callback, which is called after the app checked the tx for the first time. +// +// The case where the app checks the tx for the second and subsequent times is +// handled by the resCbRecheck callback. +func (mem *Mempool) resCbFirstTime(tx []byte, peerID uint16, res *abci.Response) { switch r := res.Value.(type) { case *abci.Response_CheckTx: - tx := req.GetCheckTx().Tx var postCheckErr error if mem.postCheck != nil { postCheckErr = mem.postCheck(tx, r.CheckTx) @@ -412,15 +496,14 @@ func (mem *Mempool) resCbNormal(req *abci.Request, res *abci.Response) { gasWanted: r.CheckTx.GasWanted, tx: tx, } - mem.txs.PushBack(memTx) - atomic.AddInt64(&mem.txsBytes, int64(len(tx))) + memTx.senders.Store(peerID, true) + mem.addTx(memTx) mem.logger.Info("Added good transaction", "tx", TxID(tx), "res", r, "height", memTx.height, "total", mem.Size(), ) - mem.metrics.TxSizeBytes.Observe(float64(len(tx))) mem.notifyTxsAvailable() } else { // ignore bad transaction @@ -434,6 +517,10 @@ func (mem *Mempool) resCbNormal(req *abci.Request, res *abci.Response) { } } +// callback, which is called after the app rechecked the tx. +// +// The case where the app checks the tx for the first time is handled by the +// resCbFirstTime callback. func (mem *Mempool) resCbRecheck(req *abci.Request, res *abci.Response) { switch r := res.Value.(type) { case *abci.Response_CheckTx: @@ -454,12 +541,8 @@ func (mem *Mempool) resCbRecheck(req *abci.Request, res *abci.Response) { } else { // Tx became invalidated due to newly committed block. mem.logger.Info("Tx is no longer valid", "tx", TxID(tx), "res", r, "err", postCheckErr) - mem.txs.Remove(mem.recheckCursor) - atomic.AddInt64(&mem.txsBytes, int64(-len(tx))) - mem.recheckCursor.DetachPrev() - - // remove from cache (it might be good later) - mem.cache.Remove(tx) + // NOTE: we remove tx from the cache because it might be good later + mem.removeTx(tx, mem.recheckCursor, true) } if mem.recheckCursor == mem.recheckEnd { mem.recheckCursor = nil @@ -627,12 +710,9 @@ func (mem *Mempool) removeTxs(txs types.Txs) []types.Tx { memTx := e.Value.(*mempoolTx) // Remove the tx if it's already in a block. if _, ok := txsMap[string(memTx.tx)]; ok { - // remove from clist - mem.txs.Remove(e) - atomic.AddInt64(&mem.txsBytes, int64(-len(memTx.tx))) - e.DetachPrev() - // NOTE: we don't remove committed txs from the cache. + mem.removeTx(memTx.tx, e, false) + continue } txsLeft = append(txsLeft, memTx.tx) @@ -650,7 +730,7 @@ func (mem *Mempool) recheckTxs(txs []types.Tx) { mem.recheckEnd = mem.txs.Back() // Push txs to proxyAppConn - // NOTE: resCb() may be called concurrently. + // NOTE: reqResCb may be called concurrently. for _, tx := range txs { mem.proxyAppConn.CheckTxAsync(tx) } @@ -663,6 +743,7 @@ func (mem *Mempool) recheckTxs(txs []types.Tx) { type mempoolTx struct { height int64 // height that this tx had been validated in gasWanted int64 // amount of gas this tx states it will require + senders sync.Map // ids of peers who've sent us this tx (as a map for quick lookups) tx types.Tx // } @@ -679,13 +760,13 @@ type txCache interface { Remove(tx types.Tx) } -// mapTxCache maintains a cache of transactions. This only stores -// the hash of the tx, due to memory concerns. +// mapTxCache maintains a LRU cache of transactions. This only stores the hash +// of the tx, due to memory concerns. type mapTxCache struct { mtx sync.Mutex size int map_ map[[sha256.Size]byte]*list.Element - list *list.List // to remove oldest tx when cache gets too big + list *list.List } var _ txCache = (*mapTxCache)(nil) @@ -707,8 +788,8 @@ func (cache *mapTxCache) Reset() { cache.mtx.Unlock() } -// Push adds the given tx to the cache and returns true. It returns false if tx -// is already in the cache. +// Push adds the given tx to the cache and returns true. It returns +// false if tx is already in the cache. func (cache *mapTxCache) Push(tx types.Tx) bool { cache.mtx.Lock() defer cache.mtx.Unlock() @@ -728,8 +809,8 @@ func (cache *mapTxCache) Push(tx types.Tx) bool { cache.list.Remove(popped) } } - cache.list.PushBack(txHash) - cache.map_[txHash] = cache.list.Back() + e := cache.list.PushBack(txHash) + cache.map_[txHash] = e return true } diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 5928fbc56..dc7d595af 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -12,9 +12,10 @@ import ( "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + amino "github.com/tendermint/go-amino" + "github.com/tendermint/tendermint/abci/example/counter" "github.com/tendermint/tendermint/abci/example/kvstore" abci "github.com/tendermint/tendermint/abci/types" @@ -63,8 +64,9 @@ func ensureFire(t *testing.T, ch <-chan struct{}, timeoutMS int) { } } -func checkTxs(t *testing.T, mempool *Mempool, count int) types.Txs { +func checkTxs(t *testing.T, mempool *Mempool, count int, peerID uint16) types.Txs { txs := make(types.Txs, count) + txInfo := TxInfo{PeerID: peerID} for i := 0; i < count; i++ { txBytes := make([]byte, 20) txs[i] = txBytes @@ -72,7 +74,7 @@ func checkTxs(t *testing.T, mempool *Mempool, count int) types.Txs { if err != nil { t.Error(err) } - if err := mempool.CheckTx(txBytes, nil); err != nil { + if err := mempool.CheckTxWithInfo(txBytes, nil, txInfo); err != nil { // Skip invalid txs. // TestMempoolFilters will fail otherwise. It asserts a number of txs // returned. @@ -92,7 +94,7 @@ func TestReapMaxBytesMaxGas(t *testing.T) { defer cleanup() // Ensure gas calculation behaves as expected - checkTxs(t, mempool, 1) + checkTxs(t, mempool, 1, UnknownPeerID) tx0 := mempool.TxsFront().Value.(*mempoolTx) // assert that kv store has gas wanted = 1. require.Equal(t, app.CheckTx(tx0.tx).GasWanted, int64(1), "KVStore had a gas value neq to 1") @@ -126,7 +128,7 @@ func TestReapMaxBytesMaxGas(t *testing.T) { {20, 20000, 30, 20}, } for tcIndex, tt := range tests { - checkTxs(t, mempool, tt.numTxsToCreate) + checkTxs(t, mempool, tt.numTxsToCreate, UnknownPeerID) got := mempool.ReapMaxBytesMaxGas(tt.maxBytes, tt.maxGas) assert.Equal(t, tt.expectedNumTxs, len(got), "Got %d txs, expected %d, tc #%d", len(got), tt.expectedNumTxs, tcIndex) @@ -167,7 +169,7 @@ func TestMempoolFilters(t *testing.T) { } for tcIndex, tt := range tests { mempool.Update(1, emptyTxArr, tt.preFilter, tt.postFilter) - checkTxs(t, mempool, tt.numTxsToCreate) + checkTxs(t, mempool, tt.numTxsToCreate, UnknownPeerID) require.Equal(t, tt.expectedNumTxs, mempool.Size(), "mempool had the incorrect size, on test case %d", tcIndex) mempool.Flush() } @@ -198,7 +200,7 @@ func TestTxsAvailable(t *testing.T) { ensureNoFire(t, mempool.TxsAvailable(), timeoutMS) // send a bunch of txs, it should only fire once - txs := checkTxs(t, mempool, 100) + txs := checkTxs(t, mempool, 100, UnknownPeerID) ensureFire(t, mempool.TxsAvailable(), timeoutMS) ensureNoFire(t, mempool.TxsAvailable(), timeoutMS) @@ -213,7 +215,7 @@ func TestTxsAvailable(t *testing.T) { ensureNoFire(t, mempool.TxsAvailable(), timeoutMS) // send a bunch more txs. we already fired for this height so it shouldnt fire again - moreTxs := checkTxs(t, mempool, 50) + moreTxs := checkTxs(t, mempool, 50, UnknownPeerID) ensureNoFire(t, mempool.TxsAvailable(), timeoutMS) // now call update with all the txs. it should not fire as there are no txs left @@ -224,7 +226,7 @@ func TestTxsAvailable(t *testing.T) { ensureNoFire(t, mempool.TxsAvailable(), timeoutMS) // send a bunch more txs, it should only fire once - checkTxs(t, mempool, 100) + checkTxs(t, mempool, 100, UnknownPeerID) ensureFire(t, mempool.TxsAvailable(), timeoutMS) ensureNoFire(t, mempool.TxsAvailable(), timeoutMS) } @@ -340,28 +342,6 @@ func TestSerialReap(t *testing.T) { reapCheck(600) } -func TestCacheRemove(t *testing.T) { - cache := newMapTxCache(100) - numTxs := 10 - txs := make([][]byte, numTxs) - for i := 0; i < numTxs; i++ { - // probability of collision is 2**-256 - txBytes := make([]byte, 32) - rand.Read(txBytes) - txs[i] = txBytes - cache.Push(txBytes) - // make sure its added to both the linked list and the map - require.Equal(t, i+1, len(cache.map_)) - require.Equal(t, i+1, cache.list.Len()) - } - for i := 0; i < numTxs; i++ { - cache.Remove(txs[i]) - // make sure its removed from both the map and the linked list - require.Equal(t, numTxs-(i+1), len(cache.map_)) - require.Equal(t, numTxs-(i+1), cache.list.Len()) - } -} - func TestMempoolCloseWAL(t *testing.T) { // 1. Create the temporary directory for mempool and WAL testing. rootDir, err := ioutil.TempDir("", "mempool-test") diff --git a/mempool/reactor.go b/mempool/reactor.go index ff87f0506..555f38b8b 100644 --- a/mempool/reactor.go +++ b/mempool/reactor.go @@ -3,13 +3,14 @@ package mempool import ( "fmt" "reflect" + "sync" "time" amino "github.com/tendermint/go-amino" - "github.com/tendermint/tendermint/libs/clist" - "github.com/tendermint/tendermint/libs/log" cfg "github.com/tendermint/tendermint/config" + "github.com/tendermint/tendermint/libs/clist" + "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/p2p" "github.com/tendermint/tendermint/types" ) @@ -21,13 +22,70 @@ const ( maxTxSize = maxMsgSize - 8 // account for amino overhead of TxMessage peerCatchupSleepIntervalMS = 100 // If peer is behind, sleep this amount + + // UnknownPeerID is the peer ID to use when running CheckTx when there is + // no peer (e.g. RPC) + UnknownPeerID uint16 = 0 ) // MempoolReactor handles mempool tx broadcasting amongst peers. +// It maintains a map from peer ID to counter, to prevent gossiping txs to the +// peers you received it from. type MempoolReactor struct { p2p.BaseReactor config *cfg.MempoolConfig Mempool *Mempool + ids *mempoolIDs +} + +type mempoolIDs struct { + mtx sync.RWMutex + peerMap map[p2p.ID]uint16 + nextID uint16 // assumes that a node will never have over 65536 active peers + activeIDs map[uint16]struct{} // used to check if a given peerID key is used, the value doesn't matter +} + +// Reserve searches for the next unused ID and assignes it to the peer. +func (ids *mempoolIDs) ReserveForPeer(peer p2p.Peer) { + ids.mtx.Lock() + defer ids.mtx.Unlock() + + curID := ids.nextPeerID() + ids.peerMap[peer.ID()] = curID + ids.activeIDs[curID] = struct{}{} +} + +// nextPeerID returns the next unused peer ID to use. +// This assumes that ids's mutex is already locked. +func (ids *mempoolIDs) nextPeerID() uint16 { + _, idExists := ids.activeIDs[ids.nextID] + for idExists { + ids.nextID++ + _, idExists = ids.activeIDs[ids.nextID] + } + curID := ids.nextID + ids.nextID++ + return curID +} + +// Reclaim returns the ID reserved for the peer back to unused pool. +func (ids *mempoolIDs) Reclaim(peer p2p.Peer) { + ids.mtx.Lock() + defer ids.mtx.Unlock() + + removedID, ok := ids.peerMap[peer.ID()] + if ok { + delete(ids.activeIDs, removedID) + delete(ids.peerMap, peer.ID()) + } +} + +// GetForPeer returns an ID reserved for the peer. +func (ids *mempoolIDs) GetForPeer(peer p2p.Peer) uint16 { + ids.mtx.RLock() + defer ids.mtx.RUnlock() + + return ids.peerMap[peer.ID()] } // NewMempoolReactor returns a new MempoolReactor with the given config and mempool. @@ -35,6 +93,11 @@ func NewMempoolReactor(config *cfg.MempoolConfig, mempool *Mempool) *MempoolReac memR := &MempoolReactor{ config: config, Mempool: mempool, + ids: &mempoolIDs{ + peerMap: make(map[p2p.ID]uint16), + activeIDs: map[uint16]struct{}{0: {}}, + nextID: 1, // reserve unknownPeerID(0) for mempoolReactor.BroadcastTx + }, } memR.BaseReactor = *p2p.NewBaseReactor("MempoolReactor", memR) return memR @@ -68,11 +131,13 @@ func (memR *MempoolReactor) GetChannels() []*p2p.ChannelDescriptor { // AddPeer implements Reactor. // It starts a broadcast routine ensuring all txs are forwarded to the given peer. func (memR *MempoolReactor) AddPeer(peer p2p.Peer) { + memR.ids.ReserveForPeer(peer) go memR.broadcastTxRoutine(peer) } // RemovePeer implements Reactor. func (memR *MempoolReactor) RemovePeer(peer p2p.Peer, reason interface{}) { + memR.ids.Reclaim(peer) // broadcast routine checks if peer is gone and returns } @@ -89,7 +154,8 @@ func (memR *MempoolReactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { switch msg := msg.(type) { case *TxMessage: - err := memR.Mempool.CheckTx(msg.Tx, nil) + peerID := memR.ids.GetForPeer(src) + err := memR.Mempool.CheckTxWithInfo(msg.Tx, nil, TxInfo{PeerID: peerID}) if err != nil { memR.Logger.Info("Could not check tx", "tx", TxID(msg.Tx), "err", err) } @@ -110,6 +176,7 @@ func (memR *MempoolReactor) broadcastTxRoutine(peer p2p.Peer) { return } + peerID := memR.ids.GetForPeer(peer) var next *clist.CElement for { // This happens because the CElement we were looking at got garbage @@ -146,12 +213,15 @@ func (memR *MempoolReactor) broadcastTxRoutine(peer p2p.Peer) { continue } - // send memTx - msg := &TxMessage{Tx: memTx.tx} - success := peer.Send(MempoolChannel, cdc.MustMarshalBinaryBare(msg)) - if !success { - time.Sleep(peerCatchupSleepIntervalMS * time.Millisecond) - continue + // ensure peer hasn't already sent us this tx + if _, ok := memTx.senders.Load(peerID); !ok { + // send memTx + msg := &TxMessage{Tx: memTx.tx} + success := peer.Send(MempoolChannel, cdc.MustMarshalBinaryBare(msg)) + if !success { + time.Sleep(peerCatchupSleepIntervalMS * time.Millisecond) + continue + } } select { diff --git a/mempool/reactor_test.go b/mempool/reactor_test.go index 51d130187..f16f84479 100644 --- a/mempool/reactor_test.go +++ b/mempool/reactor_test.go @@ -7,15 +7,13 @@ import ( "time" "github.com/fortytw2/leaktest" + "github.com/go-kit/kit/log/term" "github.com/pkg/errors" "github.com/stretchr/testify/assert" - "github.com/go-kit/kit/log/term" - "github.com/tendermint/tendermint/abci/example/kvstore" - "github.com/tendermint/tendermint/libs/log" - cfg "github.com/tendermint/tendermint/config" + "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/p2p" "github.com/tendermint/tendermint/proxy" "github.com/tendermint/tendermint/types" @@ -102,6 +100,12 @@ func _waitForTxs(t *testing.T, wg *sync.WaitGroup, txs types.Txs, reactorIdx int wg.Done() } +// ensure no txs on reactor after some timeout +func ensureNoTxs(t *testing.T, reactor *MempoolReactor, timeout time.Duration) { + time.Sleep(timeout) // wait for the txs in all mempools + assert.Zero(t, reactor.Mempool.Size()) +} + const ( NUM_TXS = 1000 TIMEOUT = 120 * time.Second // ridiculously high because CircleCI is slow @@ -124,10 +128,26 @@ func TestReactorBroadcastTxMessage(t *testing.T) { // send a bunch of txs to the first reactor's mempool // and wait for them all to be received in the others - txs := checkTxs(t, reactors[0].Mempool, NUM_TXS) + txs := checkTxs(t, reactors[0].Mempool, NUM_TXS, UnknownPeerID) waitForTxs(t, txs, reactors) } +func TestReactorNoBroadcastToSender(t *testing.T) { + config := cfg.TestConfig() + const N = 2 + reactors := makeAndConnectMempoolReactors(config, N) + defer func() { + for _, r := range reactors { + r.Stop() + } + }() + + // send a bunch of txs to the first reactor's mempool, claiming it came from peer + // ensure peer gets no txs + checkTxs(t, reactors[0].Mempool, NUM_TXS, 1) + ensureNoTxs(t, reactors[1], 100*time.Millisecond) +} + func TestBroadcastTxForPeerStopsWhenPeerStops(t *testing.T) { if testing.Short() { t.Skip("skipping test in short mode.") diff --git a/state/services.go b/state/services.go index 02c3aa7d1..07d12c5a1 100644 --- a/state/services.go +++ b/state/services.go @@ -23,6 +23,7 @@ type Mempool interface { Size() int CheckTx(types.Tx, func(*abci.Response)) error + CheckTxWithInfo(types.Tx, func(*abci.Response), mempool.TxInfo) error ReapMaxBytesMaxGas(maxBytes, maxGas int64) types.Txs Update(int64, types.Txs, mempool.PreCheckFunc, mempool.PostCheckFunc) error Flush() @@ -37,11 +38,17 @@ type MockMempool struct{} var _ Mempool = MockMempool{} -func (MockMempool) Lock() {} -func (MockMempool) Unlock() {} -func (MockMempool) Size() int { return 0 } -func (MockMempool) CheckTx(_ types.Tx, _ func(*abci.Response)) error { return nil } -func (MockMempool) ReapMaxBytesMaxGas(_, _ int64) types.Txs { return types.Txs{} } +func (MockMempool) Lock() {} +func (MockMempool) Unlock() {} +func (MockMempool) Size() int { return 0 } +func (MockMempool) CheckTx(_ types.Tx, _ func(*abci.Response)) error { + return nil +} +func (MockMempool) CheckTxWithInfo(_ types.Tx, _ func(*abci.Response), + _ mempool.TxInfo) error { + return nil +} +func (MockMempool) ReapMaxBytesMaxGas(_, _ int64) types.Txs { return types.Txs{} } func (MockMempool) Update( _ int64, _ types.Txs,