From fa9b029a265b859f85d9bf9b7847bfd3e294b25a Mon Sep 17 00:00:00 2001 From: William Banfield Date: Mon, 14 Mar 2022 18:01:09 -0400 Subject: [PATCH] defensive copy + layout change --- types/tx.go | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/types/tx.go b/types/tx.go index 0d4dda75e..14e11aff3 100644 --- a/types/tx.go +++ b/types/tx.go @@ -211,17 +211,24 @@ func (t TxRecordSet) Validate(maxSizeBytes int64, otxs Txs) error { // indexes can be preserved. addedCopy := make([]Tx, len(t.added)) copy(addedCopy, t.added) + sort.Sort(Txs(addedCopy)) + removedCopy := make([]Tx, len(t.removed)) copy(removedCopy, t.removed) + sort.Sort(Txs(removedCopy)) + unmodifiedCopy := make([]Tx, len(t.unmodified)) copy(unmodifiedCopy, t.unmodified) - - sort.Sort(otxs) - sort.Sort(Txs(addedCopy)) - sort.Sort(Txs(removedCopy)) sort.Sort(Txs(unmodifiedCopy)) + + // make a defensive copy of otxs so that the order of + // the caller's data is not altered. + otxsCopy := make([]Tx, len(otxs)) + copy(otxsCopy, otxs) + sort.Sort(Txs(otxsCopy)) + unmodifiedIdx, addedIdx, removedIdx := 0, 0, 0 - for i := 0; i < len(otxs); i++ { + for i := 0; i < len(otxsCopy); i++ { if addedIdx == len(addedCopy) && removedIdx == len(removedCopy) && unmodifiedIdx == len(unmodifiedCopy) { @@ -234,9 +241,9 @@ func (t TxRecordSet) Validate(maxSizeBytes int64, otxs Txs) error { // iterate over the sorted addedIndex until we reach a value that sorts // higher than the value we are examining in the original list. for addedIdx < len(addedCopy) { - switch bytes.Compare(addedCopy[addedIdx], otxs[i]) { + switch bytes.Compare(addedCopy[addedIdx], otxsCopy[i]) { case 0: - return fmt.Errorf("existing transaction incorrectly marked as added, transaction hash: %x", otxs[i].Hash()) + return fmt.Errorf("existing transaction incorrectly marked as added, transaction hash: %x", otxsCopy[i].Hash()) case -1: addedIdx++ case 1: @@ -258,7 +265,7 @@ func (t TxRecordSet) Validate(maxSizeBytes int64, otxs Txs) error { // The same logic applies for the unmodified check. if removedIdx < len(removedCopy) { - switch bytes.Compare(removedCopy[removedIdx], otxs[i]) { + switch bytes.Compare(removedCopy[removedIdx], otxsCopy[i]) { case 0: removedIdx++ case -1: @@ -266,7 +273,7 @@ func (t TxRecordSet) Validate(maxSizeBytes int64, otxs Txs) error { } } if unmodifiedIdx < len(unmodifiedCopy) { - switch bytes.Compare(unmodifiedCopy[unmodifiedIdx], otxs[i]) { + switch bytes.Compare(unmodifiedCopy[unmodifiedIdx], otxsCopy[i]) { case 0: unmodifiedIdx++ case -1: