diff --git a/types/tx.go b/types/tx.go index 6df08526a..f43a4bb18 100644 --- a/types/tx.go +++ b/types/tx.go @@ -127,91 +127,96 @@ func NewTxRecordSet(trs []*abci.TxRecord) TxRecordSet { } return txrSet } + func (t TxRecordSet) GetAddedTxs() []Tx { - return nil + return t.added } - func (t TxRecordSet) GetRemovedTxs() []Tx { - return nil + return t.removed } - func (t TxRecordSet) GetUnknownTxs() []Tx { - return nil + return t.unknown } func (t TxRecordSet) GetUnmodifiedTxs() []Tx { - return nil + return t.unmodified } func (t TxRecordSet) Validate(maxSizeBytes int64, otxs Txs) error { if len(t.unknown) > 0 { return fmt.Errorf("transaction incorrectly marked as unknown, transaction hash: %x", t.unknown[0].Hash()) } + var size int64 cp := make([]Tx, len(t.txs)) - for i := range t.txs { - size += int64(len(t.txs[i])) - if size > maxSizeBytes { - return fmt.Errorf("transaction data size %d exceeds maximum %d", size, maxSizeBytes) - } - cp[i] = t.txs[i] - } + copy(cp, t.txs) sort.Sort(Txs(cp)) // duplicate validation - for i := 0; i < len(cp)-1; i++ { - if bytes.Equal(cp[i], cp[i+1]) { + for i := 0; i < len(cp); i++ { + size += int64(len(cp[i])) + if size > maxSizeBytes { + return fmt.Errorf("transaction data size %d exceeds maximum %d", size, maxSizeBytes) + } + if i < len(cp)-1 && bytes.Equal(cp[i], cp[i+1]) { return fmt.Errorf("TxRecords contains duplicate transaction, transaction hash: %x", cp[i].Hash()) } } + + addedCopy := make([]Tx, len(t.added)) + copy(addedCopy, t.added) + removedCopy := make([]Tx, len(t.removed)) + copy(removedCopy, t.removed) + unmodifiedCopy := make([]Tx, len(t.unmodified)) + copy(unmodifiedCopy, t.unmodified) + sort.Sort(otxs) - sort.Sort(Txs(t.added)) - sort.Sort(Txs(t.removed)) - sort.Sort(Txs(t.unmodified)) - - for i, j := 0, 0; i < len(t.added) && j < len(otxs); { - switch bytes.Compare(t.added[i], otxs[j]) { - case 0: - return fmt.Errorf("existing transaction incorrectly marked as added, transaction hash: %x", otxs[j].Hash()) - case -1: - i++ - case 1: - j++ + sort.Sort(Txs(addedCopy)) + sort.Sort(Txs(removedCopy)) + sort.Sort(Txs(unmodifiedCopy)) + unmodifiedIdx, addedIdx, removedIdx := 0, 0, 0 + for i := 0; i < len(otxs); i++ { + if addedIdx == len(addedCopy) && + removedIdx == len(removedCopy) && + unmodifiedIdx == len(unmodifiedCopy) { + break } - } - for i, j := 0, 0; i < len(t.removed); { - if j >= len(otxs) { - // we reached the end of the original txs without finding a match for - // all of the removed elements - return fmt.Errorf("new transaction incorrectly marked as removed, transaction hash: %x", t.removed[i].Hash()) + LOOP: + for addedIdx < len(addedCopy) { + switch bytes.Compare(addedCopy[addedIdx], otxs[i]) { + case 0: + return fmt.Errorf("existing transaction incorrectly marked as added, transaction hash: %x", otxs[i].Hash()) + case -1: + addedIdx++ + case 1: + break LOOP + } } - switch bytes.Compare(t.added[i], otxs[j]) { - case 0: - i++ - j++ - case -1: - return fmt.Errorf("new transaction incorrectly marked as removed, transaction hash: %x", t.removed[i].Hash()) - case 1: - j++ + if removedIdx < len(removedCopy) { + switch bytes.Compare(removedCopy[removedIdx], otxs[i]) { + case 0: + removedIdx++ + case -1: + return fmt.Errorf("new transaction incorrectly marked as removed, transaction hash: %x", removedCopy[i].Hash()) + } } - } - for i, j := 0, 0; i < len(t.unmodified); { - if j >= len(otxs) { - // we reached the end of the original txs without finding a match for - // all of the unmodified elements - return fmt.Errorf("new transaction incorrectly marked as unmodified, transaction hash: %x", t.unmodified[i].Hash()) - } - switch bytes.Compare(t.unmodified[i], otxs[j]) { - case 0: - i++ - j++ - case -1: - return fmt.Errorf("new transaction incorrectly marked as unmodified, transaction hash: %x", t.unmodified[i].Hash()) - case 1: - j++ + if unmodifiedIdx < len(unmodifiedCopy) { + switch bytes.Compare(unmodifiedCopy[unmodifiedIdx], otxs[i]) { + case 0: + unmodifiedIdx++ + case -1: + return fmt.Errorf("new transaction incorrectly marked as unmodified, transaction hash: %x", removedCopy[i].Hash()) + } } } + if unmodifiedIdx != len(unmodifiedCopy) { + return fmt.Errorf("new transaction incorrectly marked as unmodified, transaction hash: %x", unmodifiedCopy[unmodifiedIdx].Hash()) + } + if removedIdx != len(removedCopy) { + return fmt.Errorf("new transaction incorrectly marked as removed, transaction hash: %x", removedCopy[removedIdx].Hash()) + } + return nil } diff --git a/types/tx_test.go b/types/tx_test.go index f65d37551..c50e51f61 100644 --- a/types/tx_test.go +++ b/types/tx_test.go @@ -130,13 +130,21 @@ func TestValidateTxRecordSet(t *testing.T) { }) t.Run("should error on existing transaction marked as ADDED", func(t *testing.T) { trs := []*abci.TxRecord{ + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{5, 4, 3, 2, 1}), + }, + { + Action: abci.TxRecord_ADDED, + Tx: Tx([]byte{6}), + }, { Action: abci.TxRecord_ADDED, Tx: Tx([]byte{1, 2, 3, 4, 5}), }, } txrSet := NewTxRecordSet(trs) - err := txrSet.Validate(100, []Tx{{1, 2, 3, 4, 5}}) + err := txrSet.Validate(100, []Tx{{0}, {1, 2, 3, 4, 5}}) require.Error(t, err) }) t.Run("should error if any transaction marked as UNKNOWN", func(t *testing.T) {