diff --git a/types/validator_set.go b/types/validator_set.go index 134e4e06e..3876c19d7 100644 --- a/types/validator_set.go +++ b/types/validator_set.go @@ -3,6 +3,7 @@ package types import ( "bytes" "fmt" + "math" "sort" "strings" @@ -48,12 +49,12 @@ func NewValidatorSet(vals []*Validator) *ValidatorSet { } // incrementAccum and update the proposer -// TODO: mind the overflow when times and votingPower shares too large. func (valSet *ValidatorSet) IncrementAccum(times int) { // Add VotingPower * times to each validator and order into heap. validatorsHeap := cmn.NewHeap() for _, val := range valSet.Validators { - val.Accum += val.VotingPower * int64(times) // TODO: mind overflow + // check for overflow both multiplication and sum + val.Accum = safeAddClip(val.Accum, safeMulClip(val.VotingPower, int64(times))) validatorsHeap.Push(val, accumComparable{val}) } @@ -63,7 +64,9 @@ func (valSet *ValidatorSet) IncrementAccum(times int) { if i == times-1 { valSet.Proposer = mostest } - mostest.Accum -= int64(valSet.TotalVotingPower()) + + // mind underflow + mostest.Accum = safeSubClip(mostest.Accum, valSet.TotalVotingPower()) validatorsHeap.Update(mostest, accumComparable{mostest}) } } @@ -117,7 +120,8 @@ func (valSet *ValidatorSet) Size() int { func (valSet *ValidatorSet) TotalVotingPower() int64 { if valSet.totalVotingPower == 0 { for _, val := range valSet.Validators { - valSet.totalVotingPower += val.VotingPower + // mind overflow + valSet.totalVotingPower = safeAddClip(valSet.totalVotingPower, val.VotingPower) } } return valSet.totalVotingPower @@ -425,3 +429,77 @@ func RandValidatorSet(numValidators int, votingPower int64) (*ValidatorSet, []*P sort.Sort(PrivValidatorsByAddress(privValidators)) return valSet, privValidators } + +/////////////////////////////////////////////////////////////////////////////// +// Safe multiplication and addition/subtraction + +func safeMul(a, b int64) (int64, bool) { + if a == 0 || b == 0 { + return 0, false + } + if a == 1 { + return b, false + } + if b == 1 { + return a, false + } + if a == math.MinInt64 || b == math.MinInt64 { + return -1, true + } + c := a * b + return c, c/b != a +} + +func safeAdd(a, b int64) (int64, bool) { + if b > 0 && a > math.MaxInt64-b { + return -1, true + } else if b < 0 && a < math.MinInt64-b { + return -1, true + } + return a + b, false +} + +func safeSub(a, b int64) (int64, bool) { + if b > 0 && a < math.MinInt64+b { + return -1, true + } else if b < 0 && a > math.MaxInt64+b { + return -1, true + } + return a - b, false +} + +func safeMulClip(a, b int64) int64 { + c, overflow := safeMul(a, b) + if overflow { + if (a < 0 || b < 0) && !(a < 0 && b < 0) { + return math.MinInt64 + } else { + return math.MaxInt64 + } + } + return c +} + +func safeAddClip(a, b int64) int64 { + c, overflow := safeAdd(a, b) + if overflow { + if b < 0 { + return math.MinInt64 + } else { + return math.MaxInt64 + } + } + return c +} + +func safeSubClip(a, b int64) int64 { + c, overflow := safeSub(a, b) + if overflow { + if b > 0 { + return math.MinInt64 + } else { + return math.MaxInt64 + } + } + return c +} diff --git a/types/validator_set_test.go b/types/validator_set_test.go index 572b7b007..9c7512378 100644 --- a/types/validator_set_test.go +++ b/types/validator_set_test.go @@ -2,11 +2,14 @@ package types import ( "bytes" + "math" "strings" "testing" + "testing/quick" - "github.com/tendermint/go-crypto" - "github.com/tendermint/go-wire" + "github.com/stretchr/testify/assert" + crypto "github.com/tendermint/go-crypto" + wire "github.com/tendermint/go-wire" cmn "github.com/tendermint/tmlibs/common" ) @@ -190,6 +193,85 @@ func TestProposerSelection3(t *testing.T) { } } +func TestValidatorSetTotalVotingPowerOverflows(t *testing.T) { + vset := NewValidatorSet([]*Validator{ + {Address: []byte("a"), VotingPower: math.MaxInt64, Accum: 0}, + {Address: []byte("b"), VotingPower: math.MaxInt64, Accum: 0}, + {Address: []byte("c"), VotingPower: math.MaxInt64, Accum: 0}, + }) + + assert.EqualValues(t, math.MaxInt64, vset.TotalVotingPower()) +} + +func TestValidatorSetIncrementAccumOverflows(t *testing.T) { + // NewValidatorSet calls IncrementAccum(1) + vset := NewValidatorSet([]*Validator{ + // too much voting power + 0: {Address: []byte("a"), VotingPower: math.MaxInt64, Accum: 0}, + // too big accum + 1: {Address: []byte("b"), VotingPower: 10, Accum: math.MaxInt64}, + // almost too big accum + 2: {Address: []byte("c"), VotingPower: 10, Accum: math.MaxInt64 - 5}, + }) + + assert.Equal(t, int64(0), vset.Validators[0].Accum, "0") // because we decrement val with most voting power + assert.EqualValues(t, math.MaxInt64, vset.Validators[1].Accum, "1") + assert.EqualValues(t, math.MaxInt64, vset.Validators[2].Accum, "2") +} + +func TestValidatorSetIncrementAccumUnderflows(t *testing.T) { + // NewValidatorSet calls IncrementAccum(1) + vset := NewValidatorSet([]*Validator{ + 0: {Address: []byte("a"), VotingPower: math.MaxInt64, Accum: math.MinInt64}, + 1: {Address: []byte("b"), VotingPower: 1, Accum: math.MinInt64}, + }) + + vset.IncrementAccum(5) + + assert.EqualValues(t, math.MinInt64, vset.Validators[0].Accum, "0") + assert.EqualValues(t, math.MinInt64, vset.Validators[1].Accum, "1") +} + +func TestSafeMul(t *testing.T) { + f := func(a, b int64) bool { + c, overflow := safeMul(a, b) + return overflow || (!overflow && c == a*b) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} + +func TestSafeAdd(t *testing.T) { + f := func(a, b int64) bool { + c, overflow := safeAdd(a, b) + return overflow || (!overflow && c == a+b) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} + +func TestSafeMulClip(t *testing.T) { + assert.EqualValues(t, math.MaxInt64, safeMulClip(math.MinInt64, math.MinInt64)) + assert.EqualValues(t, math.MinInt64, safeMulClip(math.MaxInt64, math.MinInt64)) + assert.EqualValues(t, math.MinInt64, safeMulClip(math.MinInt64, math.MaxInt64)) + assert.EqualValues(t, math.MaxInt64, safeMulClip(math.MaxInt64, 2)) +} + +func TestSafeAddClip(t *testing.T) { + assert.EqualValues(t, math.MaxInt64, safeAddClip(math.MaxInt64, 10)) + assert.EqualValues(t, math.MaxInt64, safeAddClip(math.MaxInt64, math.MaxInt64)) + assert.EqualValues(t, math.MinInt64, safeAddClip(math.MinInt64, -10)) +} + +func TestSafeSubClip(t *testing.T) { + assert.EqualValues(t, math.MinInt64, safeSubClip(math.MinInt64, 10)) + assert.EqualValues(t, 0, safeSubClip(math.MinInt64, math.MinInt64)) + assert.EqualValues(t, math.MinInt64, safeSubClip(math.MinInt64, math.MaxInt64)) + assert.EqualValues(t, math.MaxInt64, safeSubClip(math.MaxInt64, -10)) +} + func BenchmarkValidatorSetCopy(b *testing.B) { b.StopTimer() vset := NewValidatorSet([]*Validator{})