Browse Source

math: remove panics in safe math ops (#7962)

* math: remove panics in safe math ops

* fix docs

* fix lint
pull/7962/merge
Sam Kleinman 3 years ago
committed by GitHub
parent
commit
912751cf93
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 26 deletions
  1. +13
    -2
      internal/consensus/state.go
  2. +5
    -1
      internal/consensus/types/height_vote_set.go
  3. +5
    -1
      internal/state/store.go
  4. +17
    -22
      libs/math/safemath.go

+ 13
- 2
internal/consensus/state.go View File

@ -1086,6 +1086,8 @@ func (cs *State) handleTxsAvailable(ctx context.Context) {
// Enter: +2/3 prevotes any or +2/3 precommits for block or any from (height, round) // Enter: +2/3 prevotes any or +2/3 precommits for block or any from (height, round)
// NOTE: cs.StartTime was already set for height. // NOTE: cs.StartTime was already set for height.
func (cs *State) enterNewRound(ctx context.Context, height int64, round int32) { func (cs *State) enterNewRound(ctx context.Context, height int64, round int32) {
// TODO: remove panics in this function and return an error
logger := cs.logger.With("height", height, "round", round) logger := cs.logger.With("height", height, "round", round)
if cs.Height != height || round < cs.Round || (cs.Round == round && cs.Step != cstypes.RoundStepNewHeight) { if cs.Height != height || round < cs.Round || (cs.Round == round && cs.Step != cstypes.RoundStepNewHeight) {
@ -1106,7 +1108,11 @@ func (cs *State) enterNewRound(ctx context.Context, height int64, round int32) {
validators := cs.Validators validators := cs.Validators
if cs.Round < round { if cs.Round < round {
validators = validators.Copy() validators = validators.Copy()
validators.IncrementProposerPriority(tmmath.SafeSubInt32(round, cs.Round))
r, err := tmmath.SafeSubInt32(round, cs.Round)
if err != nil {
panic(err)
}
validators.IncrementProposerPriority(r)
} }
// Setup new round // Setup new round
@ -1126,7 +1132,12 @@ func (cs *State) enterNewRound(ctx context.Context, height int64, round int32) {
cs.ProposalBlockParts = nil cs.ProposalBlockParts = nil
} }
cs.Votes.SetRound(tmmath.SafeAddInt32(round, 1)) // also track next round (round+1) to allow round-skipping
r, err := tmmath.SafeAddInt32(round, 1)
if err != nil {
panic(err)
}
cs.Votes.SetRound(r) // also track next round (round+1) to allow round-skipping
cs.TriggeredTimeoutPrecommit = false cs.TriggeredTimeoutPrecommit = false
if err := cs.eventBus.PublishEventNewRound(ctx, cs.NewRoundEvent()); err != nil { if err := cs.eventBus.PublishEventNewRound(ctx, cs.NewRoundEvent()); err != nil {


+ 5
- 1
internal/consensus/types/height_vote_set.go View File

@ -85,7 +85,11 @@ func (hvs *HeightVoteSet) Round() int32 {
func (hvs *HeightVoteSet) SetRound(round int32) { func (hvs *HeightVoteSet) SetRound(round int32) {
hvs.mtx.Lock() hvs.mtx.Lock()
defer hvs.mtx.Unlock() defer hvs.mtx.Unlock()
newRound := tmmath.SafeSubInt32(hvs.round, 1)
newRound, err := tmmath.SafeSubInt32(hvs.round, 1)
if err != nil {
panic(err)
}
if hvs.round != 0 && (round < newRound) { if hvs.round != 0 && (round < newRound) {
panic("SetRound() must increment hvs.round") panic("SetRound() must increment hvs.round")
} }


+ 5
- 1
internal/state/store.go View File

@ -504,8 +504,12 @@ func (store dbStore) LoadValidators(height int64) (*types.ValidatorSet, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
h, err := tmmath.SafeConvertInt32(height - lastStoredHeight)
if err != nil {
return nil, err
}
vs.IncrementProposerPriority(tmmath.SafeConvertInt32(height - lastStoredHeight)) // mutate
vs.IncrementProposerPriority(h) // mutate
vi2, err := vs.ToProto() vi2, err := vs.ToProto()
if err != nil { if err != nil {
return nil, err return nil, err


+ 17
- 22
libs/math/safemath.go View File

@ -9,41 +9,37 @@ var ErrOverflowInt32 = errors.New("int32 overflow")
var ErrOverflowUint8 = errors.New("uint8 overflow") var ErrOverflowUint8 = errors.New("uint8 overflow")
var ErrOverflowInt8 = errors.New("int8 overflow") var ErrOverflowInt8 = errors.New("int8 overflow")
// SafeAddInt32 adds two int32 integers
// If there is an overflow this will panic
func SafeAddInt32(a, b int32) int32 {
// SafeAddInt32 adds two int32 integers.
func SafeAddInt32(a, b int32) (int32, error) {
if b > 0 && (a > math.MaxInt32-b) { if b > 0 && (a > math.MaxInt32-b) {
panic(ErrOverflowInt32)
return 0, ErrOverflowInt32
} else if b < 0 && (a < math.MinInt32-b) { } else if b < 0 && (a < math.MinInt32-b) {
panic(ErrOverflowInt32)
return 0, ErrOverflowInt32
} }
return a + b
return a + b, nil
} }
// SafeSubInt32 subtracts two int32 integers
// If there is an overflow this will panic
func SafeSubInt32(a, b int32) int32 {
// SafeSubInt32 subtracts two int32 integers.
func SafeSubInt32(a, b int32) (int32, error) {
if b > 0 && (a < math.MinInt32+b) { if b > 0 && (a < math.MinInt32+b) {
panic(ErrOverflowInt32)
return 0, ErrOverflowInt32
} else if b < 0 && (a > math.MaxInt32+b) { } else if b < 0 && (a > math.MaxInt32+b) {
panic(ErrOverflowInt32)
return 0, ErrOverflowInt32
} }
return a - b
return a - b, nil
} }
// SafeConvertInt32 takes a int and checks if it overflows
// If there is an overflow this will panic
func SafeConvertInt32(a int64) int32 {
// SafeConvertInt32 takes a int and checks if it overflows.
func SafeConvertInt32(a int64) (int32, error) {
if a > math.MaxInt32 { if a > math.MaxInt32 {
panic(ErrOverflowInt32)
return 0, ErrOverflowInt32
} else if a < math.MinInt32 { } else if a < math.MinInt32 {
panic(ErrOverflowInt32)
return 0, ErrOverflowInt32
} }
return int32(a)
return int32(a), nil
} }
// SafeConvertUint8 takes an int64 and checks if it overflows
// If there is an overflow it returns an error
// SafeConvertUint8 takes an int64 and checks if it overflows.
func SafeConvertUint8(a int64) (uint8, error) { func SafeConvertUint8(a int64) (uint8, error) {
if a > math.MaxUint8 { if a > math.MaxUint8 {
return 0, ErrOverflowUint8 return 0, ErrOverflowUint8
@ -53,8 +49,7 @@ func SafeConvertUint8(a int64) (uint8, error) {
return uint8(a), nil return uint8(a), nil
} }
// SafeConvertInt8 takes an int64 and checks if it overflows
// If there is an overflow it returns an error
// SafeConvertInt8 takes an int64 and checks if it overflows.
func SafeConvertInt8(a int64) (int8, error) { func SafeConvertInt8(a int64) (int8, error) {
if a > math.MaxInt8 { if a > math.MaxInt8 {
return 0, ErrOverflowInt8 return 0, ErrOverflowInt8


Loading…
Cancel
Save