Browse Source

add safeAdd & safeSub plus quickcheck tests

pull/997/head
Anton Kaliaev 7 years ago
parent
commit
69c3a7640b
No known key found for this signature in database GPG Key ID: 7B6881D965918214
2 changed files with 59 additions and 15 deletions
  1. +31
    -8
      types/validator_set.go
  2. +28
    -7
      types/validator_set_test.go

+ 31
- 8
types/validator_set.go View File

@ -52,10 +52,15 @@ func (valSet *ValidatorSet) IncrementAccum(times int) {
// Add VotingPower * times to each validator and order into heap.
validatorsHeap := cmn.NewHeap()
for _, val := range valSet.Validators {
res, overflow := signedMulWithOverflowCheck(val.VotingPower, int64(times))
// check for overflow both multiplication and sum
if !overflow && val.Accum <= mostPositive-res {
val.Accum += res
res, overflow := safeMul(val.VotingPower, int64(times))
if !overflow {
res2, overflow2 := safeAdd(val.Accum, res)
if !overflow2 {
val.Accum = res2
} else {
val.Accum = mostPositive
}
} else {
val.Accum = mostPositive
}
@ -70,8 +75,9 @@ func (valSet *ValidatorSet) IncrementAccum(times int) {
}
// mind underflow
if mostest.Accum >= mostNegative+valSet.TotalVotingPower() {
mostest.Accum -= valSet.TotalVotingPower()
res, underflow := safeSub(mostest.Accum, valSet.TotalVotingPower())
if !underflow {
mostest.Accum = res
} else {
mostest.Accum = mostNegative
}
@ -129,8 +135,9 @@ func (valSet *ValidatorSet) TotalVotingPower() int64 {
if valSet.totalVotingPower == 0 {
for _, val := range valSet.Validators {
// mind overflow
if valSet.totalVotingPower <= mostPositive-val.VotingPower {
valSet.totalVotingPower += val.VotingPower
res, overflow := safeAdd(valSet.totalVotingPower, val.VotingPower)
if !overflow {
valSet.totalVotingPower = res
} else {
valSet.totalVotingPower = mostPositive
return valSet.totalVotingPower
@ -443,10 +450,13 @@ func RandValidatorSet(numValidators int, votingPower int64) (*ValidatorSet, []*P
return valSet, privValidators
}
///////////////////////////////////////////////////////////////////////////////
// Safe multiplication and addition/subtraction
const mostNegative int64 = -mostPositive - 1
const mostPositive int64 = 1<<63 - 1
func signedMulWithOverflowCheck(a, b int64) (int64, bool) {
func safeMul(a, b int64) (int64, bool) {
if a == 0 || b == 0 {
return 0, false
}
@ -462,3 +472,16 @@ func signedMulWithOverflowCheck(a, b int64) (int64, bool) {
c := a * b
return c, c/b != a
}
func safeAdd(a, b int64) (int64, bool) {
if b > 0 && a > mostPositive-b {
return -1, true
} else if b < 0 && a < mostNegative-b {
return -1, true
}
return a + b, false
}
func safeSub(a, b int64) (int64, bool) {
return safeAdd(a, -b)
}

+ 28
- 7
types/validator_set_test.go View File

@ -4,6 +4,7 @@ import (
"bytes"
"strings"
"testing"
"testing/quick"
"github.com/stretchr/testify/assert"
crypto "github.com/tendermint/go-crypto"
@ -191,6 +192,16 @@ func TestProposerSelection3(t *testing.T) {
}
}
func TestValidatorSetTotalVotingPowerOverflows(t *testing.T) {
vset := NewValidatorSet([]*Validator{
{Address: []byte("a"), VotingPower: mostPositive, Accum: 0},
{Address: []byte("b"), VotingPower: mostPositive, Accum: 0},
{Address: []byte("c"), VotingPower: mostPositive, Accum: 0},
})
assert.Equal(t, mostPositive, vset.TotalVotingPower())
}
func TestValidatorSetIncrementAccumOverflows(t *testing.T) {
// NewValidatorSet calls IncrementAccum(1)
vset := NewValidatorSet([]*Validator{
@ -220,14 +231,24 @@ func TestValidatorSetIncrementAccumUnderflows(t *testing.T) {
assert.Equal(t, mostNegative, vset.Validators[1].Accum, "1")
}
func TestValidatorSetTotalVotingPowerOverflows(t *testing.T) {
vset := NewValidatorSet([]*Validator{
{Address: []byte("a"), VotingPower: mostPositive, Accum: 0},
{Address: []byte("b"), VotingPower: mostPositive, Accum: 0},
{Address: []byte("c"), VotingPower: mostPositive, Accum: 0},
})
func TestSafeMul(t *testing.T) {
f := func(a, b int64) bool {
c, overflow := safeMul(a, b)
return overflow || (!overflow && c == a*b)
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
assert.Equal(t, mostPositive, vset.TotalVotingPower())
func TestSafeAdd(t *testing.T) {
f := func(a, b int64) bool {
c, overflow := safeAdd(a, b)
return overflow || (!overflow && c == a+b)
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
func BenchmarkValidatorSetCopy(b *testing.B) {


Loading…
Cancel
Save