diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 2a0b58f60..450079072 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -45,3 +45,4 @@ timeoutPrecommit before starting next round - [evidence] \#2515 fix db iter leak (@goolAdapter) - [common/bit_array] Fixed a bug in the `Or` function - [common/bit_array] Fixed a bug in the `Sub` function (@bradyjoestar) +- [common] \#2534 make bit array's PickRandom choose uniformly from true bits diff --git a/libs/common/bit_array.go b/libs/common/bit_array.go index 161f21fce..ebd6cc4a0 100644 --- a/libs/common/bit_array.go +++ b/libs/common/bit_array.go @@ -234,49 +234,53 @@ func (bA *BitArray) IsFull() bool { return (lastElem+1)&((uint64(1)< 0 { - randBitStart := RandIntn(64) - for j := 0; j < 64; j++ { - bitIdx := ((j + randBitStart) % 64) - if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 { - return 64*elemIdx + bitIdx, true - } - } - PanicSanity("should not happen") - } - } else { - // Special case for last elem, to ignore straggler bits - elemBits := bA.Bits % 64 - if elemBits == 0 { - elemBits = 64 - } - randBitStart := RandIntn(elemBits) - for j := 0; j < elemBits; j++ { - bitIdx := ((j + randBitStart) % elemBits) - if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 { - return 64*elemIdx + bitIdx, true - } + + return trueIndices[RandIntn(len(trueIndices))], true +} + +func (bA *BitArray) getTrueIndices() []int { + trueIndices := make([]int, 0, bA.Bits) + curBit := 0 + numElems := len(bA.Elems) + // set all true indices + for i := 0; i < numElems-1; i++ { + elem := bA.Elems[i] + if elem == 0 { + curBit += 64 + continue + } + for j := 0; j < 64; j++ { + if (elem & (uint64(1) << uint64(j))) > 0 { + trueIndices = append(trueIndices, curBit) } + curBit++ + } + } + // handle last element + lastElem := bA.Elems[numElems-1] + numFinalBits := bA.Bits - curBit + for i := 0; i < numFinalBits; i++ { + if (lastElem & (uint64(1) << uint64(i))) > 0 { + trueIndices = append(trueIndices, curBit) } + curBit++ } - return 0, false + return trueIndices } // String returns a string representation of BitArray: BA{}, diff --git a/libs/common/bit_array_test.go b/libs/common/bit_array_test.go index bc117b2a0..09ec8af25 100644 --- a/libs/common/bit_array_test.go +++ b/libs/common/bit_array_test.go @@ -107,16 +107,29 @@ func TestSub(t *testing.T) { } func TestPickRandom(t *testing.T) { - for idx := 0; idx < 123; idx++ { - bA1 := NewBitArray(123) - bA1.SetIndex(idx, true) - index, ok := bA1.PickRandom() - if !ok { - t.Fatal("Expected to pick element but got none") - } - if index != idx { - t.Fatalf("Expected to pick element at %v but got wrong index", idx) - } + empty16Bits := "________________" + empty64Bits := empty16Bits + empty16Bits + empty16Bits + empty16Bits + testCases := []struct { + bA string + ok bool + }{ + {`null`, false}, + {`"x"`, true}, + {`"` + empty16Bits + `"`, false}, + {`"x` + empty16Bits + `"`, true}, + {`"` + empty16Bits + `x"`, true}, + {`"x` + empty16Bits + `x"`, true}, + {`"` + empty64Bits + `"`, false}, + {`"x` + empty64Bits + `"`, true}, + {`"` + empty64Bits + `x"`, true}, + {`"x` + empty64Bits + `x"`, true}, + } + for _, tc := range testCases { + var bitArr *BitArray + err := json.Unmarshal([]byte(tc.bA), &bitArr) + require.NoError(t, err) + _, ok := bitArr.PickRandom() + require.Equal(t, tc.ok, ok, "PickRandom got an unexpected result on input %s", tc.bA) } }