diff --git a/internal/state/execution.go b/internal/state/execution.go index 173e3f8ad..61691976a 100644 --- a/internal/state/execution.go +++ b/internal/state/execution.go @@ -156,16 +156,12 @@ func (blockExec *BlockExecutor) CreateProposalBlock( } } for _, atx := range txrSet.GetAddedTxs() { - if err := blockExec.mempool.CheckTx(ctx, *atx, nil, mempool.TxInfo{}); err != nil { + if err := blockExec.mempool.CheckTx(ctx, atx, nil, mempool.TxInfo{}); err != nil { blockExec.logger.Error("error adding tx to the mempool", "error", err, "tx hash", atx.Hash()) } } - itxs := append(txrSet.GetAddedTxs(), txrSet.GetUnmodifiedTxs()...) - txs = make([]types.Tx, len(itxs)) - for i, tx := range itxs { - txs[i] = *tx - } - return state.MakeBlock(height, txs, commit, evidence, proposerAddr), nil + ftxs := append(txrSet.GetAddedTxs(), txrSet.GetUnmodifiedTxs()...) + return state.MakeBlock(height, ftxs, commit, evidence, proposerAddr), nil } func (blockExec *BlockExecutor) ProcessProposal( diff --git a/types/tx.go b/types/tx.go index c9a55d90f..6df08526a 100644 --- a/types/tx.go +++ b/types/tx.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "errors" "fmt" + "sort" abci "github.com/tendermint/tendermint/abci/types" "github.com/tendermint/tendermint/crypto/merkle" @@ -14,11 +15,11 @@ import ( ) type TxRecordSet struct { - txs []Tx - unknownIdx []*Tx - unmodifiedIdx []*Tx - addedIdx []*Tx - removedIdx []*Tx + txs Txs + added Txs + unmodified Txs + removed Txs + unknown Txs } // Tx is an arbitrary byte array. @@ -70,6 +71,12 @@ func (txs Txs) IndexByHash(hash []byte) int { return -1 } +func (txs Txs) Len() int { return len(txs) } +func (txs Txs) Swap(i, j int) { txs[i], txs[j] = txs[j], txs[i] } +func (txs Txs) Less(i, j int) bool { + return bytes.Compare(txs[i], txs[j]) == -1 +} + // ToSliceOfBytes converts a Txs to slice of byte slices. // // NOTE: This method should become obsolete once Txs is switched to [][]byte. @@ -83,14 +90,6 @@ func (txs Txs) ToSliceOfBytes() [][]byte { return txBzs } -func (txs Txs) ToSet() map[string]struct{} { - m := make(map[string]struct{}, len(txs)) - for _, tx := range txs { - m[string(tx.Hash())] = struct{}{} - } - return m -} - // ToTxs converts a raw slice of byte slices into a Txs type. // TODO This function is to disappear when TxRecord is introduced func ToTxs(txs [][]byte) Txs { @@ -115,76 +114,104 @@ func NewTxRecordSet(trs []*abci.TxRecord) TxRecordSet { txrSet.txs = make([]Tx, len(trs)) for i, tr := range trs { txrSet.txs[i] = Tx(tr.Tx) - if tr.Action == abci.TxRecord_UNKNOWN { - txrSet.unknownIdx = append(txrSet.unknownIdx, &txrSet.txs[i]) - } - if tr.Action == abci.TxRecord_UNMODIFIED { - txrSet.unmodifiedIdx = append(txrSet.unmodifiedIdx, &txrSet.txs[i]) - } - if tr.Action == abci.TxRecord_ADDED { - txrSet.addedIdx = append(txrSet.addedIdx, &txrSet.txs[i]) - } - if tr.Action == abci.TxRecord_REMOVED { - txrSet.removedIdx = append(txrSet.removedIdx, &txrSet.txs[i]) + switch tr.GetAction() { + case abci.TxRecord_UNKNOWN: + txrSet.unknown = append(txrSet.unknown, txrSet.txs[i]) + case abci.TxRecord_UNMODIFIED: + txrSet.unmodified = append(txrSet.unmodified, txrSet.txs[i]) + case abci.TxRecord_ADDED: + txrSet.added = append(txrSet.added, txrSet.txs[i]) + case abci.TxRecord_REMOVED: + txrSet.removed = append(txrSet.removed, txrSet.txs[i]) } } return txrSet } - -func (t TxRecordSet) GetUnmodifiedTxs() []*Tx { - return t.unmodifiedIdx +func (t TxRecordSet) GetAddedTxs() []Tx { + return nil } -func (t TxRecordSet) GetAddedTxs() []*Tx { - return t.addedIdx +func (t TxRecordSet) GetRemovedTxs() []Tx { + return nil } -func (t TxRecordSet) GetRemovedTxs() []*Tx { - return t.removedIdx +func (t TxRecordSet) GetUnknownTxs() []Tx { + return nil } - -func (t TxRecordSet) GetUnknownTxs() []*Tx { - return t.unknownIdx +func (t TxRecordSet) GetUnmodifiedTxs() []Tx { + return nil } func (t TxRecordSet) Validate(maxSizeBytes int64, otxs Txs) error { - otxSet := otxs.ToSet() - ntxSet := map[string]struct{}{} + if len(t.unknown) > 0 { + return fmt.Errorf("transaction incorrectly marked as unknown, transaction hash: %x", t.unknown[0].Hash()) + } var size int64 - for _, tx := range t.GetAddedTxs() { - size += int64(len(*tx)) + cp := make([]Tx, len(t.txs)) + for i := range t.txs { + size += int64(len(t.txs[i])) if size > maxSizeBytes { return fmt.Errorf("transaction data size %d exceeds maximum %d", size, maxSizeBytes) } - hash := tx.Hash() - if _, ok := otxSet[string(hash)]; ok { - return fmt.Errorf("unmodified transaction incorrectly marked as %s, transaction hash: %x", abci.TxRecord_ADDED, hash) + cp[i] = t.txs[i] + } + sort.Sort(Txs(cp)) + + // duplicate validation + for i := 0; i < len(cp)-1; i++ { + if bytes.Equal(cp[i], cp[i+1]) { + return fmt.Errorf("TxRecords contains duplicate transaction, transaction hash: %x", cp[i].Hash()) } - if _, ok := ntxSet[string(hash)]; ok { - return fmt.Errorf("TxRecords contains duplicate transaction, transaction hash: %x", hash) + } + sort.Sort(otxs) + sort.Sort(Txs(t.added)) + sort.Sort(Txs(t.removed)) + sort.Sort(Txs(t.unmodified)) + + for i, j := 0, 0; i < len(t.added) && j < len(otxs); { + switch bytes.Compare(t.added[i], otxs[j]) { + case 0: + return fmt.Errorf("existing transaction incorrectly marked as added, transaction hash: %x", otxs[j].Hash()) + case -1: + i++ + case 1: + j++ } - ntxSet[string(hash)] = struct{}{} } - for _, tx := range t.GetUnmodifiedTxs() { - size += int64(len(*tx)) - if size > maxSizeBytes { - return fmt.Errorf("transaction data size %d exceeds maximum %d", size, maxSizeBytes) + + for i, j := 0, 0; i < len(t.removed); { + if j >= len(otxs) { + // we reached the end of the original txs without finding a match for + // all of the removed elements + return fmt.Errorf("new transaction incorrectly marked as removed, transaction hash: %x", t.removed[i].Hash()) } - hash := tx.Hash() - if _, ok := otxSet[string(hash)]; !ok { - return fmt.Errorf("new transaction incorrectly marked as %s, transaction hash: %x", abci.TxRecord_UNMODIFIED, hash) + switch bytes.Compare(t.added[i], otxs[j]) { + case 0: + i++ + j++ + case -1: + return fmt.Errorf("new transaction incorrectly marked as removed, transaction hash: %x", t.removed[i].Hash()) + case 1: + j++ } } - for _, tx := range t.GetRemovedTxs() { - hash := tx.Hash() - if _, ok := otxSet[string(hash)]; !ok { - return fmt.Errorf("new transaction incorrectly marked as %s, transaction hash: %x", abci.TxRecord_REMOVED, hash) + for i, j := 0, 0; i < len(t.unmodified); { + if j >= len(otxs) { + // we reached the end of the original txs without finding a match for + // all of the unmodified elements + return fmt.Errorf("new transaction incorrectly marked as unmodified, transaction hash: %x", t.unmodified[i].Hash()) + } + switch bytes.Compare(t.unmodified[i], otxs[j]) { + case 0: + i++ + j++ + case -1: + return fmt.Errorf("new transaction incorrectly marked as unmodified, transaction hash: %x", t.unmodified[i].Hash()) + case 1: + j++ } } - if len(t.GetUnknownTxs()) > 0 { - utx := t.GetUnknownTxs()[0] - return fmt.Errorf("transaction incorrectly marked as %s, transaction hash: %x", utx, utx.Hash()) - } + return nil }