diff --git a/libs/common/bit_array.go b/libs/common/bit_array.go index abf6110d8..aa470bbdb 100644 --- a/libs/common/bit_array.go +++ b/libs/common/bit_array.go @@ -189,7 +189,7 @@ func (bA *BitArray) Sub(o *BitArray) *BitArray { if bA.Bits > o.Bits { c := bA.copy() for i := 0; i < len(o.Elems)-1; i++ { - c.Elems[i] &= ^c.Elems[i] + c.Elems[i] &= ^o.Elems[i] } i := len(o.Elems) - 1 if i >= 0 { diff --git a/libs/common/bit_array_test.go b/libs/common/bit_array_test.go index b1efd3f62..3e2f17ce1 100644 --- a/libs/common/bit_array_test.go +++ b/libs/common/bit_array_test.go @@ -131,6 +131,34 @@ func TestSub2(t *testing.T) { } } +func TestSub3(t *testing.T) { + + bA1, _ := randBitArray(231) + bA2, _ := randBitArray(81) + bA3 := bA1.Sub(bA2) + + bNil := (*BitArray)(nil) + require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil)) + require.Equal(t, bA1.Sub(nil), (*BitArray)(nil)) + require.Equal(t, bNil.Sub(nil), (*BitArray)(nil)) + + if bA3.Bits != bA1.Bits { + t.Error("Expected bA1 bits") + } + if len(bA3.Elems) != len(bA1.Elems) { + t.Error("Expected bA1 elems length") + } + for i := 0; i < bA3.Bits; i++ { + expected := bA1.GetIndex(i) + if i < bA2.Bits && bA2.GetIndex(i){ + expected = false + } + if bA3.GetIndex(i) != expected { + t.Error("Wrong bit from bA3") + } + } +} + func TestPickRandom(t *testing.T) { for idx := 0; idx < 123; idx++ { bA1 := NewBitArray(123)