diff --git a/types/validator_set.go b/types/validator_set.go index a9b986590..3876c19d7 100644 --- a/types/validator_set.go +++ b/types/validator_set.go @@ -3,6 +3,7 @@ package types import ( "bytes" "fmt" + "math" "sort" "strings" @@ -53,17 +54,7 @@ func (valSet *ValidatorSet) IncrementAccum(times int) { validatorsHeap := cmn.NewHeap() for _, val := range valSet.Validators { // check for overflow both multiplication and sum - res, overflow := safeMul(val.VotingPower, int64(times)) - if !overflow { - res2, overflow2 := safeAdd(val.Accum, res) - if !overflow2 { - val.Accum = res2 - } else { - val.Accum = mostPositive - } - } else { - val.Accum = mostPositive - } + val.Accum = safeAddClip(val.Accum, safeMulClip(val.VotingPower, int64(times))) validatorsHeap.Push(val, accumComparable{val}) } @@ -75,12 +66,7 @@ func (valSet *ValidatorSet) IncrementAccum(times int) { } // mind underflow - res, underflow := safeSub(mostest.Accum, valSet.TotalVotingPower()) - if !underflow { - mostest.Accum = res - } else { - mostest.Accum = mostNegative - } + mostest.Accum = safeSubClip(mostest.Accum, valSet.TotalVotingPower()) validatorsHeap.Update(mostest, accumComparable{mostest}) } } @@ -135,13 +121,7 @@ func (valSet *ValidatorSet) TotalVotingPower() int64 { if valSet.totalVotingPower == 0 { for _, val := range valSet.Validators { // mind overflow - res, overflow := safeAdd(valSet.totalVotingPower, val.VotingPower) - if !overflow { - valSet.totalVotingPower = res - } else { - valSet.totalVotingPower = mostPositive - return valSet.totalVotingPower - } + valSet.totalVotingPower = safeAddClip(valSet.totalVotingPower, val.VotingPower) } } return valSet.totalVotingPower @@ -453,9 +433,6 @@ func RandValidatorSet(numValidators int, votingPower int64) (*ValidatorSet, []*P /////////////////////////////////////////////////////////////////////////////// // Safe multiplication and addition/subtraction -const mostNegative int64 = -mostPositive - 1 -const mostPositive int64 = 1<<63 - 1 - func safeMul(a, b int64) (int64, bool) { if a == 0 || b == 0 { return 0, false @@ -466,7 +443,7 @@ func safeMul(a, b int64) (int64, bool) { if b == 1 { return a, false } - if a == mostNegative || b == mostNegative { + if a == math.MinInt64 || b == math.MinInt64 { return -1, true } c := a * b @@ -474,14 +451,55 @@ func safeMul(a, b int64) (int64, bool) { } func safeAdd(a, b int64) (int64, bool) { - if b > 0 && a > mostPositive-b { + if b > 0 && a > math.MaxInt64-b { return -1, true - } else if b < 0 && a < mostNegative-b { + } else if b < 0 && a < math.MinInt64-b { return -1, true } return a + b, false } func safeSub(a, b int64) (int64, bool) { - return safeAdd(a, -b) + if b > 0 && a < math.MinInt64+b { + return -1, true + } else if b < 0 && a > math.MaxInt64+b { + return -1, true + } + return a - b, false +} + +func safeMulClip(a, b int64) int64 { + c, overflow := safeMul(a, b) + if overflow { + if (a < 0 || b < 0) && !(a < 0 && b < 0) { + return math.MinInt64 + } else { + return math.MaxInt64 + } + } + return c +} + +func safeAddClip(a, b int64) int64 { + c, overflow := safeAdd(a, b) + if overflow { + if b < 0 { + return math.MinInt64 + } else { + return math.MaxInt64 + } + } + return c +} + +func safeSubClip(a, b int64) int64 { + c, overflow := safeSub(a, b) + if overflow { + if b > 0 { + return math.MinInt64 + } else { + return math.MaxInt64 + } + } + return c } diff --git a/types/validator_set_test.go b/types/validator_set_test.go index dd2a59999..9c7512378 100644 --- a/types/validator_set_test.go +++ b/types/validator_set_test.go @@ -2,6 +2,7 @@ package types import ( "bytes" + "math" "strings" "testing" "testing/quick" @@ -194,41 +195,41 @@ func TestProposerSelection3(t *testing.T) { func TestValidatorSetTotalVotingPowerOverflows(t *testing.T) { vset := NewValidatorSet([]*Validator{ - {Address: []byte("a"), VotingPower: mostPositive, Accum: 0}, - {Address: []byte("b"), VotingPower: mostPositive, Accum: 0}, - {Address: []byte("c"), VotingPower: mostPositive, Accum: 0}, + {Address: []byte("a"), VotingPower: math.MaxInt64, Accum: 0}, + {Address: []byte("b"), VotingPower: math.MaxInt64, Accum: 0}, + {Address: []byte("c"), VotingPower: math.MaxInt64, Accum: 0}, }) - assert.Equal(t, mostPositive, vset.TotalVotingPower()) + assert.EqualValues(t, math.MaxInt64, vset.TotalVotingPower()) } func TestValidatorSetIncrementAccumOverflows(t *testing.T) { // NewValidatorSet calls IncrementAccum(1) vset := NewValidatorSet([]*Validator{ // too much voting power - 0: {Address: []byte("a"), VotingPower: mostPositive, Accum: 0}, + 0: {Address: []byte("a"), VotingPower: math.MaxInt64, Accum: 0}, // too big accum - 1: {Address: []byte("b"), VotingPower: 10, Accum: mostPositive}, + 1: {Address: []byte("b"), VotingPower: 10, Accum: math.MaxInt64}, // almost too big accum - 2: {Address: []byte("c"), VotingPower: 10, Accum: mostPositive - 5}, + 2: {Address: []byte("c"), VotingPower: 10, Accum: math.MaxInt64 - 5}, }) assert.Equal(t, int64(0), vset.Validators[0].Accum, "0") // because we decrement val with most voting power - assert.Equal(t, mostPositive, vset.Validators[1].Accum, "1") - assert.Equal(t, mostPositive, vset.Validators[2].Accum, "2") + assert.EqualValues(t, math.MaxInt64, vset.Validators[1].Accum, "1") + assert.EqualValues(t, math.MaxInt64, vset.Validators[2].Accum, "2") } func TestValidatorSetIncrementAccumUnderflows(t *testing.T) { // NewValidatorSet calls IncrementAccum(1) vset := NewValidatorSet([]*Validator{ - 0: {Address: []byte("a"), VotingPower: mostPositive, Accum: mostNegative}, - 1: {Address: []byte("b"), VotingPower: 1, Accum: mostNegative}, + 0: {Address: []byte("a"), VotingPower: math.MaxInt64, Accum: math.MinInt64}, + 1: {Address: []byte("b"), VotingPower: 1, Accum: math.MinInt64}, }) vset.IncrementAccum(5) - assert.Equal(t, mostNegative, vset.Validators[0].Accum, "0") - assert.Equal(t, mostNegative, vset.Validators[1].Accum, "1") + assert.EqualValues(t, math.MinInt64, vset.Validators[0].Accum, "0") + assert.EqualValues(t, math.MinInt64, vset.Validators[1].Accum, "1") } func TestSafeMul(t *testing.T) { @@ -251,6 +252,26 @@ func TestSafeAdd(t *testing.T) { } } +func TestSafeMulClip(t *testing.T) { + assert.EqualValues(t, math.MaxInt64, safeMulClip(math.MinInt64, math.MinInt64)) + assert.EqualValues(t, math.MinInt64, safeMulClip(math.MaxInt64, math.MinInt64)) + assert.EqualValues(t, math.MinInt64, safeMulClip(math.MinInt64, math.MaxInt64)) + assert.EqualValues(t, math.MaxInt64, safeMulClip(math.MaxInt64, 2)) +} + +func TestSafeAddClip(t *testing.T) { + assert.EqualValues(t, math.MaxInt64, safeAddClip(math.MaxInt64, 10)) + assert.EqualValues(t, math.MaxInt64, safeAddClip(math.MaxInt64, math.MaxInt64)) + assert.EqualValues(t, math.MinInt64, safeAddClip(math.MinInt64, -10)) +} + +func TestSafeSubClip(t *testing.T) { + assert.EqualValues(t, math.MinInt64, safeSubClip(math.MinInt64, 10)) + assert.EqualValues(t, 0, safeSubClip(math.MinInt64, math.MinInt64)) + assert.EqualValues(t, math.MinInt64, safeSubClip(math.MinInt64, math.MaxInt64)) + assert.EqualValues(t, math.MaxInt64, safeSubClip(math.MaxInt64, -10)) +} + func BenchmarkValidatorSetCopy(b *testing.B) { b.StopTimer() vset := NewValidatorSet([]*Validator{})