From 0755a5203da8fd5aab74373f2d2d537c3f17bf8a Mon Sep 17 00:00:00 2001 From: ValarDragon Date: Tue, 2 Oct 2018 16:03:59 -0700 Subject: [PATCH] bit_array: Simplify subtraction also, fix potential bug in Or function --- CHANGELOG_PENDING.md | 3 +- libs/common/bit_array.go | 46 +++++++-------- libs/common/bit_array_test.go | 103 +++++++++------------------------- 3 files changed, 48 insertions(+), 104 deletions(-) diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 6d9813350..81380e7c6 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -11,7 +11,7 @@ BREAKING CHANGES: * [rpc] \#2298 `/abci_query` takes `prove` argument instead of `trusted` and switches the default behaviour to `prove=false` * [privval] \#2459 Split `SocketPVMsg`s implementations into Request and Response, where the Response may contain a error message (returned by the remote signer) - + * Apps * [abci] \#2298 ResponseQuery.Proof is now a structured merkle.Proof, not just arbitrary bytes @@ -40,3 +40,4 @@ BUG FIXES: - [autofile] \#2428 Group.RotateFile need call Flush() before rename (@goolAdapter) - [node] \#2434 Make node respond to signal interrupts while sleeping for genesis time - [evidence] \#2515 fix db iter leak (@goolAdapter) +- [common/bit_array] Fixed a bug in the `Or` function diff --git a/libs/common/bit_array.go b/libs/common/bit_array.go index aa470bbdb..161f21fce 100644 --- a/libs/common/bit_array.go +++ b/libs/common/bit_array.go @@ -119,14 +119,13 @@ func (bA *BitArray) Or(o *BitArray) *BitArray { } bA.mtx.Lock() o.mtx.Lock() - defer func() { - bA.mtx.Unlock() - o.mtx.Unlock() - }() c := bA.copyBits(MaxInt(bA.Bits, o.Bits)) - for i := 0; i < len(c.Elems); i++ { + smaller := MinInt(len(bA.Elems), len(o.Elems)) + for i := 0; i < smaller; i++ { c.Elems[i] |= o.Elems[i] } + bA.mtx.Unlock() + o.mtx.Unlock() return c } @@ -173,8 +172,9 @@ func (bA *BitArray) not() *BitArray { } // Sub subtracts the two bit-arrays bitwise, without carrying the bits. -// This is essentially bA.And(o.Not()). -// If bA is longer than o, o is right padded with zeroes. +// Note that carryless subtraction of a - b is (a and not b). +// The output is the same as bA, regardless of o's size. +// If bA is longer than o, o is right padded with zeroes func (bA *BitArray) Sub(o *BitArray) *BitArray { if bA == nil || o == nil { // TODO: Decide if we should do 1's complement here? @@ -182,24 +182,20 @@ func (bA *BitArray) Sub(o *BitArray) *BitArray { } bA.mtx.Lock() o.mtx.Lock() - defer func() { - bA.mtx.Unlock() - o.mtx.Unlock() - }() - if bA.Bits > o.Bits { - c := bA.copy() - for i := 0; i < len(o.Elems)-1; i++ { - c.Elems[i] &= ^o.Elems[i] - } - i := len(o.Elems) - 1 - if i >= 0 { - for idx := i * 64; idx < o.Bits; idx++ { - c.setIndex(idx, c.getIndex(idx) && !o.getIndex(idx)) - } - } - return c - } - return bA.and(o.not()) // Note degenerate case where o == nil + // output is the same size as bA + c := bA.copyBits(bA.Bits) + // Only iterate to the minimum size between the two. + // If o is longer, those bits are ignored. + // If bA is longer, then skipping those iterations is equivalent + // to right padding with 0's + smaller := MinInt(len(bA.Elems), len(o.Elems)) + for i := 0; i < smaller; i++ { + // &^ is and not in golang + c.Elems[i] &^= o.Elems[i] + } + bA.mtx.Unlock() + o.mtx.Unlock() + return c } // IsEmpty returns true iff all bits in the bit array are 0 diff --git a/libs/common/bit_array_test.go b/libs/common/bit_array_test.go index 3e2f17ce1..bc117b2a0 100644 --- a/libs/common/bit_array_test.go +++ b/libs/common/bit_array_test.go @@ -75,87 +75,34 @@ func TestOr(t *testing.T) { } } -func TestSub1(t *testing.T) { - - bA1, _ := randBitArray(31) - bA2, _ := randBitArray(51) - bA3 := bA1.Sub(bA2) - - bNil := (*BitArray)(nil) - require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil)) - require.Equal(t, bA1.Sub(nil), (*BitArray)(nil)) - require.Equal(t, bNil.Sub(nil), (*BitArray)(nil)) - - if bA3.Bits != bA1.Bits { - t.Error("Expected bA1 bits") - } - if len(bA3.Elems) != len(bA1.Elems) { - t.Error("Expected bA1 elems length") - } - for i := 0; i < bA3.Bits; i++ { - expected := bA1.GetIndex(i) - if bA2.GetIndex(i) { - expected = false - } - if bA3.GetIndex(i) != expected { - t.Error("Wrong bit from bA3", i, bA1.GetIndex(i), bA2.GetIndex(i), bA3.GetIndex(i)) - } - } -} - -func TestSub2(t *testing.T) { - - bA1, _ := randBitArray(51) - bA2, _ := randBitArray(31) - bA3 := bA1.Sub(bA2) - - bNil := (*BitArray)(nil) - require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil)) - require.Equal(t, bA1.Sub(nil), (*BitArray)(nil)) - require.Equal(t, bNil.Sub(nil), (*BitArray)(nil)) - - if bA3.Bits != bA1.Bits { - t.Error("Expected bA1 bits") - } - if len(bA3.Elems) != len(bA1.Elems) { - t.Error("Expected bA1 elems length") - } - for i := 0; i < bA3.Bits; i++ { - expected := bA1.GetIndex(i) - if i < bA2.Bits && bA2.GetIndex(i) { - expected = false - } - if bA3.GetIndex(i) != expected { - t.Error("Wrong bit from bA3") - } +func TestSub(t *testing.T) { + testCases := []struct { + initBA string + subtractingBA string + expectedBA string + }{ + {`null`, `null`, `null`}, + {`"x"`, `null`, `null`}, + {`null`, `"x"`, `null`}, + {`"x"`, `"x"`, `"_"`}, + {`"xxxxxx"`, `"x_x_x_"`, `"_x_x_x"`}, + {`"x_x_x_"`, `"xxxxxx"`, `"______"`}, + {`"xxxxxx"`, `"x_x_x_xxxx"`, `"_x_x_x"`}, + {`"x_x_x_xxxx"`, `"xxxxxx"`, `"______xxxx"`}, + {`"xxxxxxxxxx"`, `"x_x_x_"`, `"_x_x_xxxxx"`}, + {`"x_x_x_"`, `"xxxxxxxxxx"`, `"______"`}, } -} - -func TestSub3(t *testing.T) { - - bA1, _ := randBitArray(231) - bA2, _ := randBitArray(81) - bA3 := bA1.Sub(bA2) + for _, tc := range testCases { + var bA *BitArray + err := json.Unmarshal([]byte(tc.initBA), &bA) + require.Nil(t, err) - bNil := (*BitArray)(nil) - require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil)) - require.Equal(t, bA1.Sub(nil), (*BitArray)(nil)) - require.Equal(t, bNil.Sub(nil), (*BitArray)(nil)) + var o *BitArray + err = json.Unmarshal([]byte(tc.subtractingBA), &o) + require.Nil(t, err) - if bA3.Bits != bA1.Bits { - t.Error("Expected bA1 bits") - } - if len(bA3.Elems) != len(bA1.Elems) { - t.Error("Expected bA1 elems length") - } - for i := 0; i < bA3.Bits; i++ { - expected := bA1.GetIndex(i) - if i < bA2.Bits && bA2.GetIndex(i){ - expected = false - } - if bA3.GetIndex(i) != expected { - t.Error("Wrong bit from bA3") - } + got, _ := json.Marshal(bA.Sub(o)) + require.Equal(t, tc.expectedBA, string(got), "%s minus %s doesn't equal %s", tc.initBA, tc.subtractingBA, tc.expectedBA) } }