diff --git a/internal/state/execution.go b/internal/state/execution.go index c1ecc3602..547b74ce4 100644 --- a/internal/state/execution.go +++ b/internal/state/execution.go @@ -141,25 +141,27 @@ func (blockExec *BlockExecutor) CreateProposalBlock( panic(err) } - if err := rpp.Validate(maxDataBytes, block.Txs.ToSliceOfBytes()); err != nil { - return nil, err - } - if !rpp.ModifiedTx { return block, nil } + txrSet := types.NewTxRecordSet(rpp.TxRecords) + + if err := txrSet.Validate(maxDataBytes, block.Txs); err != nil { + return nil, err + } - for _, rtx := range rpp.RemovedTxs() { - if err := blockExec.mempool.RemoveTxByKey(types.Tx(rtx.Tx).Key()); err != nil { - blockExec.logger.Debug("error removing transaction from the mempool", "error", err) + for _, rtx := range txrSet.GetRemovedTxs() { + if err := blockExec.mempool.RemoveTxByKey(rtx.Key()); err != nil { + blockExec.logger.Debug("error removing transaction from the mempool", "error", err, "tx hash", rtx.Hash()) } } - for _, rtx := range rpp.AddedTxs() { - if err := blockExec.mempool.CheckTx(ctx, rtx.Tx, nil, mempool.TxInfo{}); err != nil { - blockExec.logger.Error("error adding tx to the mempool", "error", err) + for _, atx := range txrSet.GetAddedTxs() { + if err := blockExec.mempool.CheckTx(ctx, atx, nil, mempool.TxInfo{}); err != nil { + blockExec.logger.Error("error adding tx to the mempool", "error", err, "tx hash", atx.Hash()) } } - return state.MakeBlock(height, types.TxRecordsToTxs(rpp.IncludedTxs()), commit, evidence, proposerAddr), nil + itxs := txrSet.GetIncludedTxs() + return state.MakeBlock(height, itxs, commit, evidence, proposerAddr), nil } func (blockExec *BlockExecutor) ProcessProposal( diff --git a/types/tx.go b/types/tx.go index d142fb82a..9b86a181e 100644 --- a/types/tx.go +++ b/types/tx.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "errors" "fmt" + "sort" abci "github.com/tendermint/tendermint/abci/types" "github.com/tendermint/tendermint/crypto/merkle" @@ -62,6 +63,12 @@ func (txs Txs) IndexByHash(hash []byte) int { return -1 } +func (txs Txs) Len() int { return len(txs) } +func (txs Txs) Swap(i, j int) { txs[i], txs[j] = txs[j], txs[i] } +func (txs Txs) Less(i, j int) bool { + return bytes.Compare(txs[i], txs[j]) == -1 +} + // ToSliceOfBytes converts a Txs to slice of byte slices. // // NOTE: This method should become obsolete once Txs is switched to [][]byte. @@ -94,6 +101,139 @@ func TxRecordsToTxs(trs []*abci.TxRecord) Txs { return txs } +// TxRecordSet contains indexes into an underlying set of transactions. +// These indexes are useful for validating and working with a list of TxRecords +// from the PrepareProposal response. +type TxRecordSet struct { + txs Txs + + added Txs + unmodified Txs + included Txs + removed Txs + unknown Txs +} + +func NewTxRecordSet(trs []*abci.TxRecord) TxRecordSet { + txrSet := TxRecordSet{} + txrSet.txs = make([]Tx, len(trs)) + for i, tr := range trs { + txrSet.txs[i] = Tx(tr.Tx) + switch tr.GetAction() { + case abci.TxRecord_UNKNOWN: + txrSet.unknown = append(txrSet.unknown, txrSet.txs[i]) + case abci.TxRecord_UNMODIFIED: + txrSet.unmodified = append(txrSet.unmodified, txrSet.txs[i]) + txrSet.included = append(txrSet.included, txrSet.txs[i]) + case abci.TxRecord_ADDED: + txrSet.added = append(txrSet.added, txrSet.txs[i]) + txrSet.included = append(txrSet.included, txrSet.txs[i]) + case abci.TxRecord_REMOVED: + txrSet.removed = append(txrSet.removed, txrSet.txs[i]) + } + } + return txrSet +} + +// GetAddedTxs returns the transactions marked for inclusion in a block. +func (t TxRecordSet) GetIncludedTxs() []Tx { + return t.included +} + +// GetAddedTxs returns the transactions added by the application. +func (t TxRecordSet) GetAddedTxs() []Tx { + return t.added +} + +// GetRemovedTxs returns the transactions marked for removal by the application. +func (t TxRecordSet) GetRemovedTxs() []Tx { + return t.removed +} + +// Validate checks that the record set was correctly constructed from the original +// list of transactions. +func (t TxRecordSet) Validate(maxSizeBytes int64, otxs Txs) error { + if len(t.unknown) > 0 { + return fmt.Errorf("transaction incorrectly marked as unknown, transaction hash: %x", t.unknown[0].Hash()) + } + + var size int64 + cp := make([]Tx, len(t.txs)) + copy(cp, t.txs) + sort.Sort(Txs(cp)) + + for i := 0; i < len(cp); i++ { + size += int64(len(cp[i])) + if size > maxSizeBytes { + return fmt.Errorf("transaction data size %d exceeds maximum %d", size, maxSizeBytes) + } + if i < len(cp)-1 && bytes.Equal(cp[i], cp[i+1]) { + return fmt.Errorf("TxRecords contains duplicate transaction, transaction hash: %x", cp[i].Hash()) + } + } + + addedCopy := make([]Tx, len(t.added)) + copy(addedCopy, t.added) + removedCopy := make([]Tx, len(t.removed)) + copy(removedCopy, t.removed) + unmodifiedCopy := make([]Tx, len(t.unmodified)) + copy(unmodifiedCopy, t.unmodified) + + sort.Sort(otxs) + sort.Sort(Txs(addedCopy)) + sort.Sort(Txs(removedCopy)) + sort.Sort(Txs(unmodifiedCopy)) + unmodifiedIdx, addedIdx, removedIdx := 0, 0, 0 + for i := 0; i < len(otxs); i++ { + if addedIdx == len(addedCopy) && + removedIdx == len(removedCopy) && + unmodifiedIdx == len(unmodifiedCopy) { + break + } + + LOOP: + for addedIdx < len(addedCopy) { + switch bytes.Compare(addedCopy[addedIdx], otxs[i]) { + case 0: + return fmt.Errorf("existing transaction incorrectly marked as added, transaction hash: %x", otxs[i].Hash()) + case -1: + addedIdx++ + case 1: + break LOOP + } + } + if removedIdx < len(removedCopy) { + switch bytes.Compare(removedCopy[removedIdx], otxs[i]) { + case 0: + removedIdx++ + case -1: + return fmt.Errorf("new transaction incorrectly marked as removed, transaction hash: %x", removedCopy[i].Hash()) + } + } + if unmodifiedIdx < len(unmodifiedCopy) { + switch bytes.Compare(unmodifiedCopy[unmodifiedIdx], otxs[i]) { + case 0: + unmodifiedIdx++ + case -1: + return fmt.Errorf("new transaction incorrectly marked as unmodified, transaction hash: %x", removedCopy[i].Hash()) + } + } + } + + if unmodifiedIdx != len(unmodifiedCopy) { + return fmt.Errorf("new transaction incorrectly marked as unmodified, transaction hash: %x", unmodifiedCopy[unmodifiedIdx].Hash()) + } + if removedIdx != len(removedCopy) { + return fmt.Errorf("new transaction incorrectly marked as removed, transaction hash: %x", removedCopy[removedIdx].Hash()) + } + + return nil +} + +func (t TxRecordSet) GetTxs() []Tx { + return t.txs +} + // TxsToTxRecords converts from a list of Txs to a list of TxRecords. All of the // resulting TxRecords are returned with the status TxRecord_UNMODIFIED. func TxsToTxRecords(txs []Tx) []*abci.TxRecord { diff --git a/types/tx_test.go b/types/tx_test.go index d77ba00e8..c50e51f61 100644 --- a/types/tx_test.go +++ b/types/tx_test.go @@ -4,7 +4,9 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + abci "github.com/tendermint/tendermint/abci/types" tmrand "github.com/tendermint/tendermint/libs/rand" ) @@ -41,3 +43,119 @@ func TestTxIndexByHash(t *testing.T) { assert.Equal(t, -1, txs.IndexByHash(Tx("foodnwkf").Hash())) } } + +func TestValidateTxRecordSet(t *testing.T) { + t.Run("should error on total transaction size exceeding max data size", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{6, 7, 8, 9, 10}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(9, []Tx{}) + require.Error(t, err) + }) + t.Run("should error on duplicate transactions with the same action", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{100}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{200}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{}) + require.Error(t, err) + }) + t.Run("should error on duplicate transactions with mixed actions", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{100}), + }, + { + Action: abci.TxRecord_REMOVED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{200}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{}) + require.Error(t, err) + }) + t.Run("should error on new transactions marked UNMODIFIED", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_UNMODIFIED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{}) + require.Error(t, err) + }) + t.Run("should error on new transactions marked REMOVED", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_REMOVED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{}) + require.Error(t, err) + }) + t.Run("should error on existing transaction marked as ADDED", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{5, 4, 3, 2, 1}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{6}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{{0}, {1, 2, 3, 4, 5}}) + require.Error(t, err) + }) + t.Run("should error if any transaction marked as UNKNOWN", func(t *testing.T) { + trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_UNKNOWN, + Tx: Tx([]byte{1, 2, 3, 4, 5}), + }, + } + txrSet := NewTxRecordSet(trs) + err := txrSet.Validate(100, []Tx{}) + require.Error(t, err) + }) +}