diff --git a/common/bit_array.go b/common/bit_array.go index 68201bad6..7cc84705e 100644 --- a/common/bit_array.go +++ b/common/bit_array.go @@ -99,8 +99,14 @@ func (bA *BitArray) copyBits(bits int) *BitArray { // Returns a BitArray of larger bits size. func (bA *BitArray) Or(o *BitArray) *BitArray { - if bA == nil { - o.Copy() + if bA == nil && o == nil { + return nil + } + if bA == nil && o != nil { + return o.Copy() + } + if o == nil { + return bA.Copy() } bA.mtx.Lock() defer bA.mtx.Unlock() @@ -113,7 +119,7 @@ func (bA *BitArray) Or(o *BitArray) *BitArray { // Returns a BitArray of smaller bit size. func (bA *BitArray) And(o *BitArray) *BitArray { - if bA == nil { + if bA == nil || o == nil { return nil } bA.mtx.Lock() @@ -143,7 +149,8 @@ func (bA *BitArray) Not() *BitArray { } func (bA *BitArray) Sub(o *BitArray) *BitArray { - if bA == nil { + if bA == nil || o == nil { + // TODO: Decide if we should do 1's complement here? return nil } bA.mtx.Lock() diff --git a/common/bit_array_test.go b/common/bit_array_test.go index e4ac8bf6f..94a312b7e 100644 --- a/common/bit_array_test.go +++ b/common/bit_array_test.go @@ -3,6 +3,8 @@ package common import ( "bytes" "testing" + + "github.com/stretchr/testify/require" ) func randBitArray(bits int) (*BitArray, []byte) { @@ -26,6 +28,11 @@ func TestAnd(t *testing.T) { bA2, _ := randBitArray(31) bA3 := bA1.And(bA2) + var bNil *BitArray + require.Equal(t, bNil.And(bA1), (*BitArray)(nil)) + require.Equal(t, bA1.And(nil), (*BitArray)(nil)) + require.Equal(t, bNil.And(nil), (*BitArray)(nil)) + if bA3.Bits != 31 { t.Error("Expected min bits", bA3.Bits) } @@ -46,6 +53,11 @@ func TestOr(t *testing.T) { bA2, _ := randBitArray(31) bA3 := bA1.Or(bA2) + bNil := (*BitArray)(nil) + require.Equal(t, bNil.Or(bA1), bA1) + require.Equal(t, bA1.Or(nil), bA1) + require.Equal(t, bNil.Or(nil), (*BitArray)(nil)) + if bA3.Bits != 51 { t.Error("Expected max bits") } @@ -66,6 +78,11 @@ func TestSub1(t *testing.T) { bA2, _ := randBitArray(51) bA3 := bA1.Sub(bA2) + bNil := (*BitArray)(nil) + require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil)) + require.Equal(t, bA1.Sub(nil), (*BitArray)(nil)) + require.Equal(t, bNil.Sub(nil), (*BitArray)(nil)) + if bA3.Bits != bA1.Bits { t.Error("Expected bA1 bits") } @@ -89,6 +106,11 @@ func TestSub2(t *testing.T) { bA2, _ := randBitArray(31) bA3 := bA1.Sub(bA2) + bNil := (*BitArray)(nil) + require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil)) + require.Equal(t, bA1.Sub(nil), (*BitArray)(nil)) + require.Equal(t, bNil.Sub(nil), (*BitArray)(nil)) + if bA3.Bits != bA1.Bits { t.Error("Expected bA1 bits") }