From 7ec123c9687cbb7a6450f7853c30152d5af12049 Mon Sep 17 00:00:00 2001 From: Aleksandr Bezobchuk Date: Wed, 2 Jun 2021 09:53:57 -0400 Subject: [PATCH] improvement: update TxInfo (#6529) Remove `Context` from the `TxInfo` type and instead require the caller to pass a `Context` to `CheckTx` which is idiomatic. closes: #6497 --- CHANGELOG_PENDING.md | 1 + consensus/mempool_test.go | 5 ++-- consensus/reactor_test.go | 12 +++++++-- consensus/replay_stubs.go | 4 ++- consensus/replay_test.go | 14 +++++----- mempool/mempool.go | 3 ++- mempool/mock/mempool.go | 4 ++- mempool/tx.go | 4 --- mempool/v0/bench_test.go | 11 ++++---- mempool/v0/cache_test.go | 5 ++-- mempool/v0/clist_mempool.go | 13 +++++++--- mempool/v0/clist_mempool_test.go | 44 ++++++++++++++++++-------------- mempool/v0/reactor.go | 3 ++- mempool/v0/reactor_test.go | 12 +++++++-- mempool/v1/mempool.go | 9 +++++-- mempool/v1/mempool_bench_test.go | 3 ++- mempool/v1/mempool_test.go | 13 +++++----- mempool/v1/reactor.go | 3 ++- node/node_test.go | 8 +++--- rpc/client/rpc_test.go | 9 +++---- rpc/core/mempool.go | 26 +++++++++++++------ test/fuzz/mempool/checktx.go | 4 ++- 22 files changed, 131 insertions(+), 79 deletions(-) diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 38b0954e6..1323307c7 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -32,6 +32,7 @@ Friendly reminder: We have a [bug bounty program](https://hackerone.com/tendermi - P2P Protocol - Go API + - [mempool] \#6529 The `Context` field has been removed from the `TxInfo` type. `CheckTx` now requires a `Context` argument. (@alexanderbez) - [abci/client, proxy] \#5673 `Async` funcs return an error, `Sync` and `Async` funcs accept `context.Context` (@melekes) - [p2p] Removed unused function `MakePoWTarget`. (@erikgrinaker) - [libs/bits] \#5720 Validate `BitArray` in `FromProto`, which now returns an error (@melekes) diff --git a/consensus/mempool_test.go b/consensus/mempool_test.go index ec301af53..ca13507b9 100644 --- a/consensus/mempool_test.go +++ b/consensus/mempool_test.go @@ -1,6 +1,7 @@ package consensus import ( + "context" "encoding/binary" "fmt" "os" @@ -111,7 +112,7 @@ func deliverTxsRange(cs *State, start, end int) { for i := start; i < end; i++ { txBytes := make([]byte, 8) binary.BigEndian.PutUint64(txBytes, uint64(i)) - err := assertMempool(cs.txNotifier).CheckTx(txBytes, nil, mempl.TxInfo{}) + err := assertMempool(cs.txNotifier).CheckTx(context.Background(), txBytes, nil, mempl.TxInfo{}) if err != nil { panic(fmt.Sprintf("Error after CheckTx: %v", err)) } @@ -171,7 +172,7 @@ func TestMempoolRmBadTx(t *testing.T) { // Try to send the tx through the mempool. // CheckTx should not err, but the app should return a bad abci code // and the tx should get removed from the pool - err := assertMempool(cs.txNotifier).CheckTx(txBytes, func(r *abci.Response) { + err := assertMempool(cs.txNotifier).CheckTx(context.Background(), txBytes, func(r *abci.Response) { if r.GetCheckTx().Code != code.CodeTypeBadNonce { t.Errorf("expected checktx to return bad nonce, got %v", r) return diff --git a/consensus/reactor_test.go b/consensus/reactor_test.go index f88a43c45..b4051c9a9 100644 --- a/consensus/reactor_test.go +++ b/consensus/reactor_test.go @@ -155,7 +155,7 @@ func waitForAndValidateBlock( require.NoError(t, validateBlock(newBlock, activeVals)) for _, tx := range txs { - require.NoError(t, assertMempool(states[j].txNotifier).CheckTx(tx, nil, mempool.TxInfo{})) + require.NoError(t, assertMempool(states[j].txNotifier).CheckTx(context.Background(), tx, nil, mempool.TxInfo{})) } } @@ -401,7 +401,15 @@ func TestReactorCreatesBlockWhenEmptyBlocksFalse(t *testing.T) { } // send a tx - require.NoError(t, assertMempool(states[3].txNotifier).CheckTx([]byte{1, 2, 3}, nil, mempool.TxInfo{})) + require.NoError( + t, + assertMempool(states[3].txNotifier).CheckTx( + context.Background(), + []byte{1, 2, 3}, + nil, + mempool.TxInfo{}, + ), + ) var wg sync.WaitGroup for _, sub := range rts.subs { diff --git a/consensus/replay_stubs.go b/consensus/replay_stubs.go index aa5b7eeae..aad99553d 100644 --- a/consensus/replay_stubs.go +++ b/consensus/replay_stubs.go @@ -1,6 +1,8 @@ package consensus import ( + "context" + abci "github.com/tendermint/tendermint/abci/types" "github.com/tendermint/tendermint/internal/libs/clist" mempl "github.com/tendermint/tendermint/mempool" @@ -18,7 +20,7 @@ var _ mempl.Mempool = emptyMempool{} func (emptyMempool) Lock() {} func (emptyMempool) Unlock() {} func (emptyMempool) Size() int { return 0 } -func (emptyMempool) CheckTx(_ types.Tx, _ func(*abci.Response), _ mempl.TxInfo) error { +func (emptyMempool) CheckTx(_ context.Context, _ types.Tx, _ func(*abci.Response), _ mempl.TxInfo) error { return nil } func (emptyMempool) ReapMaxBytesMaxGas(_, _ int64) types.Txs { return types.Txs{} } diff --git a/consensus/replay_test.go b/consensus/replay_test.go index ec328730c..e7290d6cb 100644 --- a/consensus/replay_test.go +++ b/consensus/replay_test.go @@ -98,7 +98,7 @@ func sendTxs(ctx context.Context, cs *State) { return default: tx := []byte{byte(i)} - if err := assertMempool(cs.txNotifier).CheckTx(tx, nil, mempl.TxInfo{}); err != nil { + if err := assertMempool(cs.txNotifier).CheckTx(context.Background(), tx, nil, mempl.TxInfo{}); err != nil { panic(err) } i++ @@ -358,7 +358,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { valPubKey1ABCI, err := cryptoenc.PubKeyToProto(newValidatorPubKey1) require.NoError(t, err) newValidatorTx1 := kvstore.MakeValSetChangeTx(valPubKey1ABCI, testMinPower) - err = assertMempool(css[0].txNotifier).CheckTx(newValidatorTx1, nil, mempl.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), newValidatorTx1, nil, mempl.TxInfo{}) assert.Nil(t, err) propBlock, _ := css[0].createProposalBlock() // changeProposer(t, cs1, vs2) propBlockParts := propBlock.MakePartSet(partSize) @@ -390,7 +390,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { updatePubKey1ABCI, err := cryptoenc.PubKeyToProto(updateValidatorPubKey1) require.NoError(t, err) updateValidatorTx1 := kvstore.MakeValSetChangeTx(updatePubKey1ABCI, 25) - err = assertMempool(css[0].txNotifier).CheckTx(updateValidatorTx1, nil, mempl.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), updateValidatorTx1, nil, mempl.TxInfo{}) assert.Nil(t, err) propBlock, _ = css[0].createProposalBlock() // changeProposer(t, cs1, vs2) propBlockParts = propBlock.MakePartSet(partSize) @@ -422,14 +422,14 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { newVal2ABCI, err := cryptoenc.PubKeyToProto(newValidatorPubKey2) require.NoError(t, err) newValidatorTx2 := kvstore.MakeValSetChangeTx(newVal2ABCI, testMinPower) - err = assertMempool(css[0].txNotifier).CheckTx(newValidatorTx2, nil, mempl.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), newValidatorTx2, nil, mempl.TxInfo{}) assert.Nil(t, err) newValidatorPubKey3, err := css[nVals+2].privValidator.GetPubKey(context.Background()) require.NoError(t, err) newVal3ABCI, err := cryptoenc.PubKeyToProto(newValidatorPubKey3) require.NoError(t, err) newValidatorTx3 := kvstore.MakeValSetChangeTx(newVal3ABCI, testMinPower) - err = assertMempool(css[0].txNotifier).CheckTx(newValidatorTx3, nil, mempl.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), newValidatorTx3, nil, mempl.TxInfo{}) assert.Nil(t, err) propBlock, _ = css[0].createProposalBlock() // changeProposer(t, cs1, vs2) propBlockParts = propBlock.MakePartSet(partSize) @@ -469,7 +469,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { ensureNewProposal(proposalCh, height, round) removeValidatorTx2 := kvstore.MakeValSetChangeTx(newVal2ABCI, 0) - err = assertMempool(css[0].txNotifier).CheckTx(removeValidatorTx2, nil, mempl.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), removeValidatorTx2, nil, mempl.TxInfo{}) assert.Nil(t, err) rs = css[0].GetRoundState() @@ -508,7 +508,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { height++ incrementHeight(vss...) removeValidatorTx3 := kvstore.MakeValSetChangeTx(newVal3ABCI, 0) - err = assertMempool(css[0].txNotifier).CheckTx(removeValidatorTx3, nil, mempl.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), removeValidatorTx3, nil, mempl.TxInfo{}) assert.Nil(t, err) propBlock, _ = css[0].createProposalBlock() // changeProposer(t, cs1, vs2) propBlockParts = propBlock.MakePartSet(partSize) diff --git a/mempool/mempool.go b/mempool/mempool.go index ba9b7a138..ae6252f13 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -1,6 +1,7 @@ package mempool import ( + "context" "fmt" "math" @@ -29,7 +30,7 @@ const ( type Mempool interface { // CheckTx executes a new transaction against the application to determine // its validity and whether it should be added to the mempool. - CheckTx(tx types.Tx, callback func(*abci.Response), txInfo TxInfo) error + CheckTx(ctx context.Context, tx types.Tx, callback func(*abci.Response), txInfo TxInfo) error // ReapMaxBytesMaxGas reaps transactions from the mempool up to maxBytes // bytes total with the condition that the total gasWanted must be less than diff --git a/mempool/mock/mempool.go b/mempool/mock/mempool.go index 723ce791a..9cec2f757 100644 --- a/mempool/mock/mempool.go +++ b/mempool/mock/mempool.go @@ -1,6 +1,8 @@ package mock import ( + "context" + abci "github.com/tendermint/tendermint/abci/types" "github.com/tendermint/tendermint/internal/libs/clist" mempl "github.com/tendermint/tendermint/mempool" @@ -15,7 +17,7 @@ var _ mempl.Mempool = Mempool{} func (Mempool) Lock() {} func (Mempool) Unlock() {} func (Mempool) Size() int { return 0 } -func (Mempool) CheckTx(_ types.Tx, _ func(*abci.Response), _ mempl.TxInfo) error { +func (Mempool) CheckTx(_ context.Context, _ types.Tx, _ func(*abci.Response), _ mempl.TxInfo) error { return nil } func (Mempool) ReapMaxBytesMaxGas(_, _ int64) types.Txs { return types.Txs{} } diff --git a/mempool/tx.go b/mempool/tx.go index 8bdc82294..a040de5d9 100644 --- a/mempool/tx.go +++ b/mempool/tx.go @@ -1,7 +1,6 @@ package mempool import ( - "context" "crypto/sha256" "github.com/tendermint/tendermint/p2p" @@ -31,7 +30,4 @@ type TxInfo struct { // SenderNodeID is the actual p2p.NodeID of the sender. SenderNodeID p2p.NodeID - - // Context is the optional context to cancel CheckTx - Context context.Context } diff --git a/mempool/v0/bench_test.go b/mempool/v0/bench_test.go index 43fc44c32..40ffcee44 100644 --- a/mempool/v0/bench_test.go +++ b/mempool/v0/bench_test.go @@ -1,6 +1,7 @@ package v0 import ( + "context" "encoding/binary" "sync/atomic" "testing" @@ -21,7 +22,7 @@ func BenchmarkReap(b *testing.B) { for i := 0; i < size; i++ { tx := make([]byte, 8) binary.BigEndian.PutUint64(tx, uint64(i)) - if err := mp.CheckTx(tx, nil, mempool.TxInfo{}); err != nil { + if err := mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}); err != nil { b.Fatal(err) } } @@ -47,7 +48,7 @@ func BenchmarkCheckTx(b *testing.B) { binary.BigEndian.PutUint64(tx, uint64(i)) b.StartTimer() - if err := mp.CheckTx(tx, nil, mempool.TxInfo{}); err != nil { + if err := mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}); err != nil { b.Fatal(err) } } @@ -71,7 +72,7 @@ func BenchmarkParallelCheckTx(b *testing.B) { for pb.Next() { tx := make([]byte, 8) binary.BigEndian.PutUint64(tx, next()) - if err := mp.CheckTx(tx, nil, mempool.TxInfo{}); err != nil { + if err := mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}); err != nil { b.Fatal(err) } } @@ -89,11 +90,11 @@ func BenchmarkCheckDuplicateTx(b *testing.B) { for i := 0; i < b.N; i++ { tx := make([]byte, 8) binary.BigEndian.PutUint64(tx, uint64(i)) - if err := mp.CheckTx(tx, nil, mempool.TxInfo{}); err != nil { + if err := mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}); err != nil { b.Fatal(err) } - if err := mp.CheckTx(tx, nil, mempool.TxInfo{}); err == nil { + if err := mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}); err == nil { b.Fatal("tx should be duplicate") } } diff --git a/mempool/v0/cache_test.go b/mempool/v0/cache_test.go index fab6a6011..c393f60b6 100644 --- a/mempool/v0/cache_test.go +++ b/mempool/v0/cache_test.go @@ -1,6 +1,7 @@ package v0 import ( + "context" "crypto/sha256" "testing" @@ -36,7 +37,7 @@ func TestCacheAfterUpdate(t *testing.T) { for tcIndex, tc := range tests { for i := 0; i < tc.numTxsToCreate; i++ { tx := types.Tx{byte(i)} - err := mp.CheckTx(tx, nil, mempool.TxInfo{}) + err := mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}) require.NoError(t, err) } @@ -50,7 +51,7 @@ func TestCacheAfterUpdate(t *testing.T) { for _, v := range tc.reAddIndices { tx := types.Tx{byte(v)} - _ = mp.CheckTx(tx, nil, mempool.TxInfo{}) + _ = mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}) } cache := mp.cache.(*mempool.LRUTxCache) diff --git a/mempool/v0/clist_mempool.go b/mempool/v0/clist_mempool.go index 8b415310c..13606fd5d 100644 --- a/mempool/v0/clist_mempool.go +++ b/mempool/v0/clist_mempool.go @@ -199,7 +199,13 @@ func (mem *CListMempool) TxsWaitChan() <-chan struct{} { // CONTRACT: Either cb will get called, or err returned. // // Safe for concurrent use by multiple goroutines. -func (mem *CListMempool) CheckTx(tx types.Tx, cb func(*abci.Response), txInfo mempool.TxInfo) error { +func (mem *CListMempool) CheckTx( + ctx context.Context, + tx types.Tx, + cb func(*abci.Response), + txInfo mempool.TxInfo, +) error { + mem.updateMtx.RLock() // use defer to unlock mutex because application (*local client*) might panic defer mem.updateMtx.RUnlock() @@ -250,9 +256,8 @@ func (mem *CListMempool) CheckTx(tx types.Tx, cb func(*abci.Response), txInfo me return nil } - ctx := context.Background() - if txInfo.Context != nil { - ctx = txInfo.Context + if ctx == nil { + ctx = context.Background() } reqRes, err := mem.proxyAppConn.CheckTxAsync(ctx, abci.RequestCheckTx{Tx: tx}) diff --git a/mempool/v0/clist_mempool_test.go b/mempool/v0/clist_mempool_test.go index 0fd39103a..ae839b506 100644 --- a/mempool/v0/clist_mempool_test.go +++ b/mempool/v0/clist_mempool_test.go @@ -78,7 +78,7 @@ func checkTxs(t *testing.T, mp mempool.Mempool, count int, peerID uint16) types. if err != nil { t.Error(err) } - if err := mp.CheckTx(txBytes, nil, txInfo); err != nil { + if err := mp.CheckTx(context.Background(), txBytes, nil, txInfo); err != nil { // Skip invalid txs. // TestMempoolFilters will fail otherwise. It asserts a number of txs // returned. @@ -189,13 +189,13 @@ func TestMempoolUpdate(t *testing.T) { { err := mp.Update(1, []types.Tx{[]byte{0x01}}, abciResponses(1, abci.CodeTypeOK), nil, nil) require.NoError(t, err) - err = mp.CheckTx([]byte{0x01}, nil, mempool.TxInfo{}) + err = mp.CheckTx(context.Background(), []byte{0x01}, nil, mempool.TxInfo{}) require.NoError(t, err) } // 2. Removes valid txs from the mempool { - err := mp.CheckTx([]byte{0x02}, nil, mempool.TxInfo{}) + err := mp.CheckTx(context.Background(), []byte{0x02}, nil, mempool.TxInfo{}) require.NoError(t, err) err = mp.Update(1, []types.Tx{[]byte{0x02}}, abciResponses(1, abci.CodeTypeOK), nil, nil) require.NoError(t, err) @@ -204,13 +204,13 @@ func TestMempoolUpdate(t *testing.T) { // 3. Removes invalid transactions from the cache and the mempool (if present) { - err := mp.CheckTx([]byte{0x03}, nil, mempool.TxInfo{}) + err := mp.CheckTx(context.Background(), []byte{0x03}, nil, mempool.TxInfo{}) require.NoError(t, err) err = mp.Update(1, []types.Tx{[]byte{0x03}}, abciResponses(1, 1), nil, nil) require.NoError(t, err) assert.Zero(t, mp.Size()) - err = mp.CheckTx([]byte{0x03}, nil, mempool.TxInfo{}) + err = mp.CheckTx(context.Background(), []byte{0x03}, nil, mempool.TxInfo{}) require.NoError(t, err) } } @@ -231,7 +231,7 @@ func TestMempool_KeepInvalidTxsInCache(t *testing.T) { b := make([]byte, 8) binary.BigEndian.PutUint64(b, 1) - err := mp.CheckTx(b, nil, mempool.TxInfo{}) + err := mp.CheckTx(context.Background(), b, nil, mempool.TxInfo{}) require.NoError(t, err) // simulate new block @@ -242,11 +242,11 @@ func TestMempool_KeepInvalidTxsInCache(t *testing.T) { require.NoError(t, err) // a must be added to the cache - err = mp.CheckTx(a, nil, mempool.TxInfo{}) + err = mp.CheckTx(context.Background(), a, nil, mempool.TxInfo{}) require.NoError(t, err) // b must remain in the cache - err = mp.CheckTx(b, nil, mempool.TxInfo{}) + err = mp.CheckTx(context.Background(), b, nil, mempool.TxInfo{}) require.NoError(t, err) } @@ -258,7 +258,7 @@ func TestMempool_KeepInvalidTxsInCache(t *testing.T) { // remove a from the cache to test (2) mp.cache.Remove(a) - err := mp.CheckTx(a, nil, mempool.TxInfo{}) + err := mp.CheckTx(context.Background(), a, nil, mempool.TxInfo{}) require.NoError(t, err) } } @@ -327,7 +327,7 @@ func TestSerialReap(t *testing.T) { // This will succeed txBytes := make([]byte, 8) binary.BigEndian.PutUint64(txBytes, uint64(i)) - err := mp.CheckTx(txBytes, nil, mempool.TxInfo{}) + err := mp.CheckTx(context.Background(), txBytes, nil, mempool.TxInfo{}) _, cached := cacheMap[string(txBytes)] if cached { require.NotNil(t, err, "expected error for cached tx") @@ -337,7 +337,7 @@ func TestSerialReap(t *testing.T) { cacheMap[string(txBytes)] = struct{}{} // Duplicates are cached and should return error - err = mp.CheckTx(txBytes, nil, mempool.TxInfo{}) + err = mp.CheckTx(context.Background(), txBytes, nil, mempool.TxInfo{}) require.NotNil(t, err, "Expected error after CheckTx on duplicated tx") } } @@ -446,7 +446,7 @@ func TestMempool_CheckTxChecksTxSize(t *testing.T) { tx := tmrand.Bytes(testCase.len) - err := mempl.CheckTx(tx, nil, mempool.TxInfo{}) + err := mempl.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}) bv := gogotypes.BytesValue{Value: tx} bz, err2 := bv.Marshal() require.NoError(t, err2) @@ -475,7 +475,7 @@ func TestMempoolTxsBytes(t *testing.T) { assert.EqualValues(t, 0, mp.SizeBytes()) // 2. len(tx) after CheckTx - err := mp.CheckTx([]byte{0x01}, nil, mempool.TxInfo{}) + err := mp.CheckTx(context.Background(), []byte{0x01}, nil, mempool.TxInfo{}) require.NoError(t, err) assert.EqualValues(t, 1, mp.SizeBytes()) @@ -485,7 +485,7 @@ func TestMempoolTxsBytes(t *testing.T) { assert.EqualValues(t, 0, mp.SizeBytes()) // 4. zero after Flush - err = mp.CheckTx([]byte{0x02, 0x03}, nil, mempool.TxInfo{}) + err = mp.CheckTx(context.Background(), []byte{0x02, 0x03}, nil, mempool.TxInfo{}) require.NoError(t, err) assert.EqualValues(t, 2, mp.SizeBytes()) @@ -493,9 +493,15 @@ func TestMempoolTxsBytes(t *testing.T) { assert.EqualValues(t, 0, mp.SizeBytes()) // 5. ErrMempoolIsFull is returned when/if MaxTxsBytes limit is reached. - err = mp.CheckTx([]byte{0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04}, nil, mempool.TxInfo{}) + err = mp.CheckTx( + context.Background(), + []byte{0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04}, + nil, + mempool.TxInfo{}, + ) require.NoError(t, err) - err = mp.CheckTx([]byte{0x05}, nil, mempool.TxInfo{}) + + err = mp.CheckTx(context.Background(), []byte{0x05}, nil, mempool.TxInfo{}) if assert.Error(t, err) { assert.IsType(t, mempool.ErrMempoolIsFull{}, err) } @@ -509,7 +515,7 @@ func TestMempoolTxsBytes(t *testing.T) { txBytes := make([]byte, 8) binary.BigEndian.PutUint64(txBytes, uint64(0)) - err = mp.CheckTx(txBytes, nil, mempool.TxInfo{}) + err = mp.CheckTx(context.Background(), txBytes, nil, mempool.TxInfo{}) require.NoError(t, err) assert.EqualValues(t, 8, mp.SizeBytes()) @@ -536,7 +542,7 @@ func TestMempoolTxsBytes(t *testing.T) { assert.EqualValues(t, 0, mp.SizeBytes()) // 7. Test RemoveTxByKey function - err = mp.CheckTx([]byte{0x06}, nil, mempool.TxInfo{}) + err = mp.CheckTx(context.Background(), []byte{0x06}, nil, mempool.TxInfo{}) require.NoError(t, err) assert.EqualValues(t, 1, mp.SizeBytes()) mp.RemoveTxByKey(mempool.TxKey([]byte{0x07}), true) @@ -580,7 +586,7 @@ func TestMempoolRemoteAppConcurrency(t *testing.T) { tx := txs[txNum] // this will err with ErrTxInCache many times ... - mp.CheckTx(tx, nil, mempool.TxInfo{SenderID: uint16(peerID)}) //nolint: errcheck // will error + mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{SenderID: uint16(peerID)}) //nolint: errcheck // will error } err := mp.FlushAppConn() require.NoError(t, err) diff --git a/mempool/v0/reactor.go b/mempool/v0/reactor.go index 33d101475..27fc36cfe 100644 --- a/mempool/v0/reactor.go +++ b/mempool/v0/reactor.go @@ -1,6 +1,7 @@ package v0 import ( + "context" "errors" "fmt" "sync" @@ -168,7 +169,7 @@ func (r *Reactor) handleMempoolMessage(envelope p2p.Envelope) error { } for _, tx := range protoTxs { - if err := r.mempool.CheckTx(types.Tx(tx), nil, txInfo); err != nil { + if err := r.mempool.CheckTx(context.Background(), types.Tx(tx), nil, txInfo); err != nil { logger.Error("checktx failed for tx", "tx", fmt.Sprintf("%X", mempool.TxHashFromBytes(tx)), "err", err) } } diff --git a/mempool/v0/reactor_test.go b/mempool/v0/reactor_test.go index 2f1cae3a2..cc292e335 100644 --- a/mempool/v0/reactor_test.go +++ b/mempool/v0/reactor_test.go @@ -1,6 +1,7 @@ package v0 import ( + "context" "sync" "testing" "time" @@ -276,7 +277,14 @@ func TestReactor_MaxTxBytes(t *testing.T) { // Broadcast a tx, which has the max size and ensure it's received by the // second reactor. tx1 := tmrand.Bytes(config.Mempool.MaxTxBytes) - err := rts.reactors[primary].mempool.CheckTx(tx1, nil, mempool.TxInfo{SenderID: mempool.UnknownPeerID}) + err := rts.reactors[primary].mempool.CheckTx( + context.Background(), + tx1, + nil, + mempool.TxInfo{ + SenderID: mempool.UnknownPeerID, + }, + ) require.NoError(t, err) rts.start(t) @@ -290,7 +298,7 @@ func TestReactor_MaxTxBytes(t *testing.T) { // broadcast a tx, which is beyond the max size and ensure it's not sent tx2 := tmrand.Bytes(config.Mempool.MaxTxBytes + 1) - err = rts.mempools[primary].CheckTx(tx2, nil, mempool.TxInfo{SenderID: mempool.UnknownPeerID}) + err = rts.mempools[primary].CheckTx(context.Background(), tx2, nil, mempool.TxInfo{SenderID: mempool.UnknownPeerID}) require.Error(t, err) rts.assertMempoolChannelsDrained(t) diff --git a/mempool/v1/mempool.go b/mempool/v1/mempool.go index f49c182dd..f9c5ce9e8 100644 --- a/mempool/v1/mempool.go +++ b/mempool/v1/mempool.go @@ -212,7 +212,13 @@ func (txmp *TxMempool) TxsAvailable() <-chan struct{} { // NOTE: // - The applications' CheckTx implementation may panic. // - The caller is not to explicitly require any locks for executing CheckTx. -func (txmp *TxMempool) CheckTx(tx types.Tx, cb func(*abci.Response), txInfo mempool.TxInfo) error { +func (txmp *TxMempool) CheckTx( + ctx context.Context, + tx types.Tx, + cb func(*abci.Response), + txInfo mempool.TxInfo, +) error { + txmp.mtx.RLock() defer txmp.mtx.RUnlock() @@ -253,7 +259,6 @@ func (txmp *TxMempool) CheckTx(tx types.Tx, cb func(*abci.Response), txInfo memp return nil } - ctx := txInfo.Context if ctx == nil { ctx = context.Background() } diff --git a/mempool/v1/mempool_bench_test.go b/mempool/v1/mempool_bench_test.go index bad8ec8ab..b3239d13f 100644 --- a/mempool/v1/mempool_bench_test.go +++ b/mempool/v1/mempool_bench_test.go @@ -1,6 +1,7 @@ package v1 import ( + "context" "fmt" "math/rand" "testing" @@ -26,6 +27,6 @@ func BenchmarkTxMempool_CheckTx(b *testing.B) { tx := []byte(fmt.Sprintf("%X=%d", prefix, priority)) b.StartTimer() - require.NoError(b, txmp.CheckTx(tx, nil, mempool.TxInfo{})) + require.NoError(b, txmp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{})) } } diff --git a/mempool/v1/mempool_test.go b/mempool/v1/mempool_test.go index 6f7b149c2..2864e55c4 100644 --- a/mempool/v1/mempool_test.go +++ b/mempool/v1/mempool_test.go @@ -2,6 +2,7 @@ package v1 import ( "bytes" + "context" "fmt" "math/rand" "os" @@ -111,7 +112,7 @@ func checkTxs(t *testing.T, txmp *TxMempool, numTxs int, peerID uint16) []testTx tx: []byte(fmt.Sprintf("sender-%d=%X=%d", i, prefix, priority)), priority: priority, } - require.NoError(t, txmp.CheckTx(txs[i].tx, nil, txInfo)) + require.NoError(t, txmp.CheckTx(context.Background(), txs[i].tx, nil, txInfo)) } return txs @@ -327,7 +328,7 @@ func TestTxMempool_CheckTxExceedsMaxSize(t *testing.T) { _, err := rng.Read(tx) require.NoError(t, err) - require.Error(t, txmp.CheckTx(tx, nil, mempool.TxInfo{SenderID: 0})) + require.Error(t, txmp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{SenderID: 0})) } func TestTxMempool_CheckTxSamePeer(t *testing.T) { @@ -341,8 +342,8 @@ func TestTxMempool_CheckTxSamePeer(t *testing.T) { tx := []byte(fmt.Sprintf("sender-0=%X=%d", prefix, 50)) - require.NoError(t, txmp.CheckTx(tx, nil, mempool.TxInfo{SenderID: peerID})) - require.Error(t, txmp.CheckTx(tx, nil, mempool.TxInfo{SenderID: peerID})) + require.NoError(t, txmp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{SenderID: peerID})) + require.Error(t, txmp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{SenderID: peerID})) } func TestTxMempool_CheckTxSameSender(t *testing.T) { @@ -361,9 +362,9 @@ func TestTxMempool_CheckTxSameSender(t *testing.T) { tx1 := []byte(fmt.Sprintf("sender-0=%X=%d", prefix1, 50)) tx2 := []byte(fmt.Sprintf("sender-0=%X=%d", prefix2, 50)) - require.NoError(t, txmp.CheckTx(tx1, nil, mempool.TxInfo{SenderID: peerID})) + require.NoError(t, txmp.CheckTx(context.Background(), tx1, nil, mempool.TxInfo{SenderID: peerID})) require.Equal(t, 1, txmp.Size()) - require.NoError(t, txmp.CheckTx(tx2, nil, mempool.TxInfo{SenderID: peerID})) + require.NoError(t, txmp.CheckTx(context.Background(), tx2, nil, mempool.TxInfo{SenderID: peerID})) require.Equal(t, 1, txmp.Size()) } diff --git a/mempool/v1/reactor.go b/mempool/v1/reactor.go index cb9df868d..83443031b 100644 --- a/mempool/v1/reactor.go +++ b/mempool/v1/reactor.go @@ -1,6 +1,7 @@ package v1 import ( + "context" "errors" "fmt" "sync" @@ -167,7 +168,7 @@ func (r *Reactor) handleMempoolMessage(envelope p2p.Envelope) error { } for _, tx := range protoTxs { - if err := r.mempool.CheckTx(types.Tx(tx), nil, txInfo); err != nil { + if err := r.mempool.CheckTx(context.Background(), types.Tx(tx), nil, txInfo); err != nil { logger.Error("checktx failed for tx", "tx", fmt.Sprintf("%X", mempool.TxHashFromBytes(tx)), "err", err) } } diff --git a/node/node_test.go b/node/node_test.go index 3679fbb17..14d914352 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -262,7 +262,7 @@ func TestCreateProposalBlock(t *testing.T) { txLength := 100 for i := 0; i <= maxBytes/txLength; i++ { tx := tmrand.Bytes(txLength) - err := mp.CheckTx(tx, nil, mempool.TxInfo{}) + err := mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}) assert.NoError(t, err) } @@ -330,7 +330,7 @@ func TestMaxTxsProposalBlockSize(t *testing.T) { // fill the mempool with one txs just below the maximum size txLength := int(types.MaxDataBytesNoEvidence(maxBytes, 1)) tx := tmrand.Bytes(txLength - 4) // to account for the varint - err = mp.CheckTx(tx, nil, mempool.TxInfo{}) + err = mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}) assert.NoError(t, err) blockExec := sm.NewBlockExecutor( @@ -388,13 +388,13 @@ func TestMaxProposalBlockSize(t *testing.T) { // fill the mempool with one txs just below the maximum size txLength := int(types.MaxDataBytesNoEvidence(maxBytes, types.MaxVotesCount)) tx := tmrand.Bytes(txLength - 6) // to account for the varint - err = mp.CheckTx(tx, nil, mempool.TxInfo{}) + err = mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}) assert.NoError(t, err) // now produce more txs than what a normal block can hold with 10 smaller txs // At the end of the test, only the single big tx should be added for i := 0; i < 10; i++ { tx := tmrand.Bytes(10) - err = mp.CheckTx(tx, nil, mempool.TxInfo{}) + err = mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}) assert.NoError(t, err) } diff --git a/rpc/client/rpc_test.go b/rpc/client/rpc_test.go index 72ef54b32..8779d8a77 100644 --- a/rpc/client/rpc_test.go +++ b/rpc/client/rpc_test.go @@ -411,12 +411,11 @@ func TestBroadcastTxCommit(t *testing.T) { func TestUnconfirmedTxs(t *testing.T) { _, _, tx := MakeTxKV() - ch := make(chan *abci.Response, 1) - n := NodeSuite(t) mempool := n.Mempool() - err := mempool.CheckTx(tx, func(resp *abci.Response) { ch <- resp }, mempl.TxInfo{}) + + err := mempool.CheckTx(context.Background(), tx, func(resp *abci.Response) { ch <- resp }, mempl.TxInfo{}) require.NoError(t, err) // wait for tx to arrive in mempoool. @@ -443,11 +442,11 @@ func TestUnconfirmedTxs(t *testing.T) { func TestNumUnconfirmedTxs(t *testing.T) { _, _, tx := MakeTxKV() - n := NodeSuite(t) ch := make(chan *abci.Response, 1) mempool := n.Mempool() - err := mempool.CheckTx(tx, func(resp *abci.Response) { ch <- resp }, mempl.TxInfo{}) + + err := mempool.CheckTx(context.Background(), tx, func(resp *abci.Response) { ch <- resp }, mempl.TxInfo{}) require.NoError(t, err) // wait for tx to arrive in mempoool. diff --git a/rpc/core/mempool.go b/rpc/core/mempool.go index 9e40429f8..07743365c 100644 --- a/rpc/core/mempool.go +++ b/rpc/core/mempool.go @@ -20,11 +20,11 @@ import ( // CheckTx nor DeliverTx results. // More: https://docs.tendermint.com/master/rpc/#/Tx/broadcast_tx_async func (env *Environment) BroadcastTxAsync(ctx *rpctypes.Context, tx types.Tx) (*ctypes.ResultBroadcastTx, error) { - err := env.Mempool.CheckTx(tx, nil, mempl.TxInfo{Context: ctx.Context()}) - + err := env.Mempool.CheckTx(ctx.Context(), tx, nil, mempl.TxInfo{}) if err != nil { return nil, err } + return &ctypes.ResultBroadcastTx{Hash: tx.Hash()}, nil } @@ -33,14 +33,19 @@ func (env *Environment) BroadcastTxAsync(ctx *rpctypes.Context, tx types.Tx) (*c // More: https://docs.tendermint.com/master/rpc/#/Tx/broadcast_tx_sync func (env *Environment) BroadcastTxSync(ctx *rpctypes.Context, tx types.Tx) (*ctypes.ResultBroadcastTx, error) { resCh := make(chan *abci.Response, 1) - err := env.Mempool.CheckTx(tx, func(res *abci.Response) { - resCh <- res - }, mempl.TxInfo{Context: ctx.Context()}) + err := env.Mempool.CheckTx( + ctx.Context(), + tx, + func(res *abci.Response) { resCh <- res }, + mempl.TxInfo{}, + ) if err != nil { return nil, err } + res := <-resCh r := res.GetCheckTx() + return &ctypes.ResultBroadcastTx{ Code: r.Code, Data: r.Data, @@ -79,15 +84,20 @@ func (env *Environment) BroadcastTxCommit(ctx *rpctypes.Context, tx types.Tx) (* // Broadcast tx and wait for CheckTx result checkTxResCh := make(chan *abci.Response, 1) - err = env.Mempool.CheckTx(tx, func(res *abci.Response) { - checkTxResCh <- res - }, mempl.TxInfo{Context: ctx.Context()}) + err = env.Mempool.CheckTx( + ctx.Context(), + tx, + func(res *abci.Response) { checkTxResCh <- res }, + mempl.TxInfo{}, + ) if err != nil { env.Logger.Error("Error on broadcastTxCommit", "err", err) return nil, fmt.Errorf("error on broadcastTxCommit: %v", err) } + checkTxResMsg := <-checkTxResCh checkTxRes := checkTxResMsg.GetCheckTx() + if checkTxRes.Code != abci.CodeTypeOK { return &ctypes.ResultBroadcastTxCommit{ CheckTx: *checkTxRes, diff --git a/test/fuzz/mempool/checktx.go b/test/fuzz/mempool/checktx.go index e72b077c0..6af446a10 100644 --- a/test/fuzz/mempool/checktx.go +++ b/test/fuzz/mempool/checktx.go @@ -1,6 +1,8 @@ package checktx import ( + "context" + "github.com/tendermint/tendermint/abci/example/kvstore" "github.com/tendermint/tendermint/config" "github.com/tendermint/tendermint/mempool" @@ -26,7 +28,7 @@ func init() { } func Fuzz(data []byte) int { - err := mp.CheckTx(data, nil, mempool.TxInfo{}) + err := mp.CheckTx(context.Background(), data, nil, mempool.TxInfo{}) if err != nil { return 0 }