diff --git a/internal/state/execution.go b/internal/state/execution.go index c1ecc3602..173e3f8ad 100644 --- a/internal/state/execution.go +++ b/internal/state/execution.go @@ -141,25 +141,31 @@ func (blockExec *BlockExecutor) CreateProposalBlock( panic(err) } - if err := rpp.Validate(maxDataBytes, block.Txs.ToSliceOfBytes()); err != nil { - return nil, err - } - if !rpp.ModifiedTx { return block, nil } + txrSet := types.NewTxRecordSet(rpp.TxRecords) - for _, rtx := range rpp.RemovedTxs() { - if err := blockExec.mempool.RemoveTxByKey(types.Tx(rtx.Tx).Key()); err != nil { - blockExec.logger.Debug("error removing transaction from the mempool", "error", err) + if err := txrSet.Validate(maxDataBytes, block.Txs); err != nil { + return nil, err + } + + for _, rtx := range txrSet.GetRemovedTxs() { + if err := blockExec.mempool.RemoveTxByKey(rtx.Key()); err != nil { + blockExec.logger.Debug("error removing transaction from the mempool", "error", err, "tx hash", rtx.Hash()) } } - for _, rtx := range rpp.AddedTxs() { - if err := blockExec.mempool.CheckTx(ctx, rtx.Tx, nil, mempool.TxInfo{}); err != nil { - blockExec.logger.Error("error adding tx to the mempool", "error", err) + for _, atx := range txrSet.GetAddedTxs() { + if err := blockExec.mempool.CheckTx(ctx, *atx, nil, mempool.TxInfo{}); err != nil { + blockExec.logger.Error("error adding tx to the mempool", "error", err, "tx hash", atx.Hash()) } } - return state.MakeBlock(height, types.TxRecordsToTxs(rpp.IncludedTxs()), commit, evidence, proposerAddr), nil + itxs := append(txrSet.GetAddedTxs(), txrSet.GetUnmodifiedTxs()...) + txs = make([]types.Tx, len(itxs)) + for i, tx := range itxs { + txs[i] = *tx + } + return state.MakeBlock(height, txs, commit, evidence, proposerAddr), nil } func (blockExec *BlockExecutor) ProcessProposal( diff --git a/types/tx.go b/types/tx.go index d142fb82a..c9a55d90f 100644 --- a/types/tx.go +++ b/types/tx.go @@ -13,6 +13,14 @@ import ( tmproto "github.com/tendermint/tendermint/proto/tendermint/types" ) +type TxRecordSet struct { + txs []Tx + unknownIdx []*Tx + unmodifiedIdx []*Tx + addedIdx []*Tx + removedIdx []*Tx +} + // Tx is an arbitrary byte array. // NOTE: Tx has no types at this level, so when wire encoded it's just length-prefixed. // Might we want types here ? @@ -75,6 +83,14 @@ func (txs Txs) ToSliceOfBytes() [][]byte { return txBzs } +func (txs Txs) ToSet() map[string]struct{} { + m := make(map[string]struct{}, len(txs)) + for _, tx := range txs { + m[string(tx.Hash())] = struct{}{} + } + return m +} + // ToTxs converts a raw slice of byte slices into a Txs type. // TODO This function is to disappear when TxRecord is introduced func ToTxs(txs [][]byte) Txs { @@ -94,6 +110,88 @@ func TxRecordsToTxs(trs []*abci.TxRecord) Txs { return txs } +func NewTxRecordSet(trs []*abci.TxRecord) TxRecordSet { + txrSet := TxRecordSet{} + txrSet.txs = make([]Tx, len(trs)) + for i, tr := range trs { + txrSet.txs[i] = Tx(tr.Tx) + if tr.Action == abci.TxRecord_UNKNOWN { + txrSet.unknownIdx = append(txrSet.unknownIdx, &txrSet.txs[i]) + } + if tr.Action == abci.TxRecord_UNMODIFIED { + txrSet.unmodifiedIdx = append(txrSet.unmodifiedIdx, &txrSet.txs[i]) + } + if tr.Action == abci.TxRecord_ADDED { + txrSet.addedIdx = append(txrSet.addedIdx, &txrSet.txs[i]) + } + if tr.Action == abci.TxRecord_REMOVED { + txrSet.removedIdx = append(txrSet.removedIdx, &txrSet.txs[i]) + } + } + return txrSet +} + +func (t TxRecordSet) GetUnmodifiedTxs() []*Tx { + return t.unmodifiedIdx +} + +func (t TxRecordSet) GetAddedTxs() []*Tx { + return t.addedIdx +} + +func (t TxRecordSet) GetRemovedTxs() []*Tx { + return t.removedIdx +} + +func (t TxRecordSet) GetUnknownTxs() []*Tx { + return t.unknownIdx +} + +func (t TxRecordSet) Validate(maxSizeBytes int64, otxs Txs) error { + otxSet := otxs.ToSet() + ntxSet := map[string]struct{}{} + var size int64 + for _, tx := range t.GetAddedTxs() { + size += int64(len(*tx)) + if size > maxSizeBytes { + return fmt.Errorf("transaction data size %d exceeds maximum %d", size, maxSizeBytes) + } + hash := tx.Hash() + if _, ok := otxSet[string(hash)]; ok { + return fmt.Errorf("unmodified transaction incorrectly marked as %s, transaction hash: %x", abci.TxRecord_ADDED, hash) + } + if _, ok := ntxSet[string(hash)]; ok { + return fmt.Errorf("TxRecords contains duplicate transaction, transaction hash: %x", hash) + } + ntxSet[string(hash)] = struct{}{} + } + for _, tx := range t.GetUnmodifiedTxs() { + size += int64(len(*tx)) + if size > maxSizeBytes { + return fmt.Errorf("transaction data size %d exceeds maximum %d", size, maxSizeBytes) + } + hash := tx.Hash() + if _, ok := otxSet[string(hash)]; !ok { + return fmt.Errorf("new transaction incorrectly marked as %s, transaction hash: %x", abci.TxRecord_UNMODIFIED, hash) + } + } + for _, tx := range t.GetRemovedTxs() { + hash := tx.Hash() + if _, ok := otxSet[string(hash)]; !ok { + return fmt.Errorf("new transaction incorrectly marked as %s, transaction hash: %x", abci.TxRecord_REMOVED, hash) + } + } + if len(t.GetUnknownTxs()) > 0 { + utx := t.GetUnknownTxs()[0] + return fmt.Errorf("transaction incorrectly marked as %s, transaction hash: %x", utx, utx.Hash()) + } + return nil +} + +func (t TxRecordSet) GetTxs() []Tx { + return t.txs +} + // TxsToTxRecords converts from a list of Txs to a list of TxRecords. All of the // resulting TxRecords are returned with the status TxRecord_UNMODIFIED. func TxsToTxRecords(txs []Tx) []*abci.TxRecord { diff --git a/types/tx_test.go b/types/tx_test.go index d77ba00e8..5add513a2 100644 --- a/types/tx_test.go +++ b/types/tx_test.go @@ -4,7 +4,9 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + abci "github.com/tendermint/tendermint/abci/types" tmrand "github.com/tendermint/tendermint/libs/rand" ) @@ -41,3 +43,111 @@ func TestTxIndexByHash(t *testing.T) { assert.Equal(t, -1, txs.IndexByHash(Tx("foodnwkf").Hash())) } } + +func TestValidateResponsePrepareProposal(t *testing.T) { + t.Run("should error on total transaction size exceeding max data size", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{6, 7, 8, 9, 10}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(9, []Tx{}) + require.Error(t, err) + }) + t.Run("should error on duplicate transactions with the same action", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{100}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{200}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{}) + require.Error(t, err) + }) + t.Run("should error on duplicate transactions with mixed actions", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{100}), + }, + { + Action: abci.TxRecord_REMOVED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{200}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{}) + require.Error(t, err) + }) + t.Run("should error on new transactions marked UNMODIFIED", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_UNMODIFIED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{}) + require.Error(t, err) + }) + t.Run("should error on new transactions marked REMOVED", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_REMOVED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{}) + require.Error(t, err) + }) + t.Run("should error on existing transaction marked as ADDED", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{{1, 2, 3, 4, 5}}) + require.Error(t, err) + }) + t.Run("should error if any transaction marked as UNKNOWN", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_UNKNOWN, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{}) + require.Error(t, err) + }) +}