From 810aeb7bcb72a9d4ba6f7e1547c1347913ee0969 Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Sun, 12 Oct 2014 21:14:10 -0700 Subject: [PATCH] fix tests --- blocks/block_test.go | 12 +- blocks/tx.go | 28 ++-- consensus/state.go | 2 +- merkle/iavl_tree.go | 3 + state/account.go | 7 +- state/state.go | 253 +++++++++++++++++++++++++++--------- state/state_test.go | 22 ++-- state/validator_set.go | 23 +++- state/validator_set_test.go | 45 +++++++ 9 files changed, 288 insertions(+), 107 deletions(-) create mode 100644 state/validator_set_test.go diff --git a/blocks/block_test.go b/blocks/block_test.go index a50bde6f6..5c93897c8 100644 --- a/blocks/block_test.go +++ b/blocks/block_test.go @@ -12,7 +12,7 @@ func randSig() Signature { } func randBaseTx() BaseTx { - return BaseTx{0, randSig()} + return BaseTx{0, RandUInt64Exp(), randSig()} } func TestBlock(t *testing.T) { @@ -21,14 +21,12 @@ func TestBlock(t *testing.T) { sendTx := &SendTx{ BaseTx: randBaseTx(), - Fee: RandUInt64Exp(), To: RandUInt64Exp(), Amount: RandUInt64Exp(), } nameTx := &NameTx{ BaseTx: randBaseTx(), - Fee: RandUInt64Exp(), Name: string(RandBytes(12)), PubKey: RandBytes(32), } @@ -36,14 +34,12 @@ func TestBlock(t *testing.T) { // Validation Txs bondTx := &BondTx{ - BaseTx: randBaseTx(), - Fee: RandUInt64Exp(), - UnbondTo: RandUInt64Exp(), + BaseTx: randBaseTx(), + //UnbondTo: RandUInt64Exp(), } unbondTx := &UnbondTx{ BaseTx: randBaseTx(), - Fee: RandUInt64Exp(), } dupeoutTx := &DupeoutTx{ @@ -79,7 +75,7 @@ func TestBlock(t *testing.T) { Signatures: []Signature{randSig(), randSig()}, }, Data: Data{ - Txs: []Tx{sendTx, nameTx, bondTx, unbondTx, timeoutTx, dupeoutTx}, + Txs: []Tx{sendTx, nameTx, bondTx, unbondTx, dupeoutTx}, }, } diff --git a/blocks/tx.go b/blocks/tx.go index b0655de8f..a262d6854 100644 --- a/blocks/tx.go +++ b/blocks/tx.go @@ -20,6 +20,7 @@ Validation Txs: type Tx interface { Signable GetSequence() uint + GetFee() uint64 } const ( @@ -38,27 +39,23 @@ func ReadTx(r io.Reader, n *int64, err *error) Tx { case TxTypeSend: return &SendTx{ BaseTx: ReadBaseTx(r, n, err), - Fee: ReadUInt64(r, n, err), To: ReadUInt64(r, n, err), Amount: ReadUInt64(r, n, err), } case TxTypeName: return &NameTx{ BaseTx: ReadBaseTx(r, n, err), - Fee: ReadUInt64(r, n, err), Name: ReadString(r, n, err), PubKey: ReadByteSlice(r, n, err), } case TxTypeBond: return &BondTx{ - BaseTx: ReadBaseTx(r, n, err), - Fee: ReadUInt64(r, n, err), - UnbondTo: ReadUInt64(r, n, err), + BaseTx: ReadBaseTx(r, n, err), + //UnbondTo: ReadUInt64(r, n, err), } case TxTypeUnbond: return &UnbondTx{ BaseTx: ReadBaseTx(r, n, err), - Fee: ReadUInt64(r, n, err), } case TxTypeDupeout: return &DupeoutTx{ @@ -76,18 +73,21 @@ func ReadTx(r io.Reader, n *int64, err *error) Tx { type BaseTx struct { Sequence uint + Fee uint64 Signature } func ReadBaseTx(r io.Reader, n *int64, err *error) BaseTx { return BaseTx{ Sequence: ReadUVarInt(r, n, err), + Fee: ReadUInt64(r, n, err), Signature: ReadSignature(r, n, err), } } func (tx BaseTx) WriteTo(w io.Writer) (n int64, err error) { WriteUVarInt(w, tx.Sequence, &n, &err) + WriteUInt64(w, tx.Fee, &n, &err) WriteBinary(w, tx.Signature, &n, &err) return } @@ -100,6 +100,10 @@ func (tx *BaseTx) GetSignature() Signature { return tx.Signature } +func (tx *BaseTx) GetFee() uint64 { + return tx.Fee +} + func (tx *BaseTx) SetSignature(sig Signature) { tx.Signature = sig } @@ -108,7 +112,6 @@ func (tx *BaseTx) SetSignature(sig Signature) { type SendTx struct { BaseTx - Fee uint64 To uint64 Amount uint64 } @@ -116,7 +119,6 @@ type SendTx struct { func (tx *SendTx) WriteTo(w io.Writer) (n int64, err error) { WriteByte(w, TxTypeSend, &n, &err) WriteBinary(w, tx.BaseTx, &n, &err) - WriteUInt64(w, tx.Fee, &n, &err) WriteUInt64(w, tx.To, &n, &err) WriteUInt64(w, tx.Amount, &n, &err) return @@ -126,7 +128,6 @@ func (tx *SendTx) WriteTo(w io.Writer) (n int64, err error) { type NameTx struct { BaseTx - Fee uint64 Name string PubKey []byte } @@ -134,7 +135,6 @@ type NameTx struct { func (tx *NameTx) WriteTo(w io.Writer) (n int64, err error) { WriteByte(w, TxTypeName, &n, &err) WriteBinary(w, tx.BaseTx, &n, &err) - WriteUInt64(w, tx.Fee, &n, &err) WriteString(w, tx.Name, &n, &err) WriteByteSlice(w, tx.PubKey, &n, &err) return @@ -144,15 +144,13 @@ func (tx *NameTx) WriteTo(w io.Writer) (n int64, err error) { type BondTx struct { BaseTx - Fee uint64 - UnbondTo uint64 + //UnbondTo uint64 } func (tx *BondTx) WriteTo(w io.Writer) (n int64, err error) { WriteByte(w, TxTypeBond, &n, &err) WriteBinary(w, tx.BaseTx, &n, &err) - WriteUInt64(w, tx.Fee, &n, &err) - WriteUInt64(w, tx.UnbondTo, &n, &err) + //WriteUInt64(w, tx.UnbondTo, &n, &err) return } @@ -160,13 +158,11 @@ func (tx *BondTx) WriteTo(w io.Writer) (n int64, err error) { type UnbondTx struct { BaseTx - Fee uint64 } func (tx *UnbondTx) WriteTo(w io.Writer) (n int64, err error) { WriteByte(w, TxTypeUnbond, &n, &err) WriteBinary(w, tx.BaseTx, &n, &err) - WriteUInt64(w, tx.Fee, &n, &err) return } diff --git a/consensus/state.go b/consensus/state.go index f79548e4c..8eda7cb48 100644 --- a/consensus/state.go +++ b/consensus/state.go @@ -86,7 +86,7 @@ func (cs *ConsensusState) updateToState(state *State) { // Reset fields based on state. height := state.Height - validators := state.Validators + validators := state.BondedValidators cs.Height = height cs.Round = 0 cs.Step = RoundStepStart diff --git a/merkle/iavl_tree.go b/merkle/iavl_tree.go index 0c743542b..76cb31443 100644 --- a/merkle/iavl_tree.go +++ b/merkle/iavl_tree.go @@ -148,6 +148,9 @@ func (t *IAVLTree) Remove(key interface{}) (value interface{}, removed bool) { } func (t *IAVLTree) Iterate(fn func(key interface{}, value interface{}) bool) (stopped bool) { + if t.root == nil { + return false + } return t.root.traverse(t, func(node *IAVLNode) bool { if node.height == 0 { return fn(node.key, node.value) diff --git a/state/account.go b/state/account.go index f1e4fd89b..a4770aeaa 100644 --- a/state/account.go +++ b/state/account.go @@ -9,9 +9,10 @@ import ( ) const ( - AccountDetailStatusNominal = byte(0x00) - AccountDetailStatusBonded = byte(0x01) - AccountDetailStatusUnbonding = byte(0x02) + AccountStatusNominal = byte(0x00) + AccountStatusBonded = byte(0x01) + AccountStatusUnbonding = byte(0x02) + AccountStatusDupedOut = byte(0x03) ) type Account struct { diff --git a/state/state.go b/state/state.go index 04fb74448..f7b8d8719 100644 --- a/state/state.go +++ b/state/state.go @@ -3,39 +3,52 @@ package state import ( "bytes" "errors" + "fmt" "time" . "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/blocks" + . "github.com/tendermint/tendermint/common" . "github.com/tendermint/tendermint/db" "github.com/tendermint/tendermint/merkle" ) var ( - ErrStateInvalidAccountId = errors.New("Error State invalid account id") - ErrStateInvalidSignature = errors.New("Error State invalid signature") - ErrStateInvalidSequenceNumber = errors.New("Error State invalid sequence number") - ErrStateInvalidAccountState = errors.New("Error State invalid account state") - ErrStateInvalidValidationStateHash = errors.New("Error State invalid ValidationStateHash") - ErrStateInvalidAccountStateHash = errors.New("Error State invalid AccountStateHash") - ErrStateInsufficientFunds = errors.New("Error State insufficient funds") + ErrStateInvalidAccountId = errors.New("Error State invalid account id") + ErrStateInvalidSignature = errors.New("Error State invalid signature") + ErrStateInvalidSequenceNumber = errors.New("Error State invalid sequence number") + ErrStateInvalidAccountState = errors.New("Error State invalid account state") + ErrStateInsufficientFunds = errors.New("Error State insufficient funds") stateKey = []byte("stateKey") - minBondAmount = uint64(1) // TODO adjust - defaultAccountDetailsCacheCapacity = 1000 // TODO adjust + minBondAmount = uint64(1) // TODO adjust + defaultAccountDetailsCacheCapacity = 1000 // TODO adjust + unbondingPeriodBlocks = uint32(60 * 24 * 365) // TODO probably better to make it time based. + validatorTimeoutBlocks = uint32(10) // TODO adjust ) //----------------------------------------------------------------------------- +type InvalidTxError struct { + Tx Tx + Reason error +} + +func (txErr InvalidTxError) Error() string { + return fmt.Sprintf("Invalid tx: [%v] reason: [%v]", txErr.Tx, txErr.Reason) +} + +//----------------------------------------------------------------------------- + // NOTE: not goroutine-safe. type State struct { - DB DB - Height uint32 // Last known block height - BlockHash []byte // Last known block hash - CommitTime time.Time - AccountDetails merkle.Tree - BondedValidators *ValidatorSet - UnbondedValidators *ValidatorSet + DB DB + Height uint32 // Last known block height + BlockHash []byte // Last known block hash + CommitTime time.Time + AccountDetails merkle.Tree + BondedValidators *ValidatorSet + UnbondingValidators *ValidatorSet } func GenesisState(db DB, genesisTime time.Time, accDets []*AccountDetail) *State { @@ -46,7 +59,7 @@ func GenesisState(db DB, genesisTime time.Time, accDets []*AccountDetail) *State for _, accDet := range accDets { accountDetails.Set(accDet.Id, accDet) - if accDet.Status == AccountDetailStatusBonded { + if accDet.Status == AccountStatusBonded { validators = append(validators, &Validator{ Account: accDet.Account, BondHeight: 0, @@ -61,13 +74,13 @@ func GenesisState(db DB, genesisTime time.Time, accDets []*AccountDetail) *State } return &State{ - DB: db, - Height: 0, - BlockHash: nil, - CommitTime: genesisTime, - AccountDetails: accountDetails, - BondedValidators: NewValidatorSet(validators), - UnbondedValidators: NewValidatorSet(nil), + DB: db, + Height: 0, + BlockHash: nil, + CommitTime: genesisTime, + AccountDetails: accountDetails, + BondedValidators: NewValidatorSet(validators), + UnbondingValidators: NewValidatorSet(nil), } } @@ -87,7 +100,7 @@ func LoadState(db DB) *State { s.AccountDetails = merkle.NewIAVLTree(BasicCodec, AccountDetailCodec, defaultAccountDetailsCacheCapacity, db) s.AccountDetails.Load(accountDetailsHash) s.BondedValidators = ReadValidatorSet(reader, &n, &err) - s.UnbondedValidators = ReadValidatorSet(reader, &n, &err) + s.UnbondingValidators = ReadValidatorSet(reader, &n, &err) if err != nil { panic(err) } @@ -110,7 +123,7 @@ func (s *State) Save(commitTime time.Time) { WriteByteSlice(&buf, s.BlockHash, &n, &err) WriteByteSlice(&buf, s.AccountDetails.Hash(), &n, &err) WriteBinary(&buf, s.BondedValidators, &n, &err) - WriteBinary(&buf, s.UnbondedValidators, &n, &err) + WriteBinary(&buf, s.UnbondingValidators, &n, &err) if err != nil { panic(err) } @@ -119,13 +132,13 @@ func (s *State) Save(commitTime time.Time) { func (s *State) Copy() *State { return &State{ - DB: s.DB, - Height: s.Height, - CommitTime: s.CommitTime, - BlockHash: s.BlockHash, - AccountDetails: s.AccountDetails.Copy(), - BondedValidators: s.BondedValidators.Copy(), - UnbondedValidators: s.UnbondedValidators.Copy(), + DB: s.DB, + Height: s.Height, + CommitTime: s.CommitTime, + BlockHash: s.BlockHash, + AccountDetails: s.AccountDetails.Copy(), + BondedValidators: s.BondedValidators.Copy(), + UnbondingValidators: s.UnbondingValidators.Copy(), } } @@ -144,20 +157,26 @@ func (s *State) ExecTx(tx Tx) error { if tx.GetSequence() <= accDet.Sequence { return ErrStateInvalidSequenceNumber } + // Subtract fee from balance. + if accDet.Balance < tx.GetFee() { + return ErrStateInsufficientFunds + } else { + accDet.Balance -= tx.GetFee() + } // Exec tx switch tx.(type) { case *SendTx: stx := tx.(*SendTx) toAccDet := s.GetAccountDetail(stx.To) // Accounts must be nominal - if accDet.Status != AccountDetailStatusNominal { + if accDet.Status != AccountStatusNominal { return ErrStateInvalidAccountState } - if toAccDet.Status != AccountDetailStatusNominal { + if toAccDet.Status != AccountStatusNominal { return ErrStateInvalidAccountState } // Check account balance - if accDet.Balance < stx.Fee+stx.Amount { + if accDet.Balance < stx.Amount { return ErrStateInsufficientFunds } // Check existence of destination account @@ -165,27 +184,26 @@ func (s *State) ExecTx(tx Tx) error { return ErrStateInvalidAccountId } // Good! - accDet.Balance -= (stx.Fee + stx.Amount) - toAccDet.Balance += (stx.Amount) + accDet.Balance -= stx.Amount + toAccDet.Balance += stx.Amount s.SetAccountDetail(accDet) s.SetAccountDetail(toAccDet) + return nil //case *NameTx case *BondTx: - btx := tx.(*BondTx) + //btx := tx.(*BondTx) // Account must be nominal - if accDet.Status != AccountDetailStatusNominal { + if accDet.Status != AccountStatusNominal { return ErrStateInvalidAccountState } // Check account balance if accDet.Balance < minBondAmount { return ErrStateInsufficientFunds } - // TODO: max number of validators? // Good! - accDet.Balance -= btx.Fee // remaining balance are bonded coins. - accDet.Status = AccountDetailStatusBonded + accDet.Status = AccountStatusBonded s.SetAccountDetail(accDet) - added := s.BondednValidators.Add(&Validator{ + added := s.BondedValidators.Add(&Validator{ Account: accDet.Account, BondHeight: s.Height, VotingPower: accDet.Balance, @@ -194,29 +212,106 @@ func (s *State) ExecTx(tx Tx) error { if !added { panic("Failed to add validator") } + return nil case *UnbondTx: - utx := tx.(*UnbondTx) + //utx := tx.(*UnbondTx) // Account must be bonded. - if accDet.Status != AccountDetailStatusBonded { + if accDet.Status != AccountStatusBonded { return ErrStateInvalidAccountState } // Good! - accDet.Status = AccountDetailStatusUnbonding + s.unbondValidator(accDet.Id, accDet) s.SetAccountDetail(accDet) - val, removed := s.BondedValidators.Remove(accDet.Id) - if !removed { - panic("Failed to remove validator") + return nil + case *DupeoutTx: + { + // NOTE: accDet is the one who created this transaction. + // Subtract any fees, save, and forget. + s.SetAccountDetail(accDet) + accDet = nil } - val.UnbondHeight = s.Height - added := s.UnbondedValidators.Add(val) - if !added { - panic("Failed to add validator") + dtx := tx.(*DupeoutTx) + // Verify the signatures + if dtx.VoteA.SignerId != dtx.VoteB.SignerId { + return ErrStateInvalidSignature } - case *DupeoutTx: - // XXX + accused := s.GetAccountDetail(dtx.VoteA.SignerId) + if !accused.Verify(&dtx.VoteA) || !accused.Verify(&dtx.VoteB) { + return ErrStateInvalidSignature + } + // Verify equivocation + if dtx.VoteA.Height != dtx.VoteB.Height { + return errors.New("DupeoutTx height must be the same.") + } + if dtx.VoteA.Type == VoteTypeCommit && dtx.VoteA.Round < dtx.VoteB.Round { + // Check special case. + // Validators should not sign another vote after committing. + } else { + if dtx.VoteA.Round != dtx.VoteB.Round { + return errors.New("DupeoutTx rounds don't match") + } + if dtx.VoteA.Type != dtx.VoteB.Type { + return errors.New("DupeoutTx types don't match") + } + if bytes.Equal(dtx.VoteA.BlockHash, dtx.VoteB.BlockHash) { + return errors.New("DupeoutTx blockhash shouldn't match") + } + } + // Good! (Bad validator!) + if accused.Status == AccountStatusBonded { + _, removed := s.BondedValidators.Remove(accused.Id) + if !removed { + panic("Failed to remove accused validator") + } + } else if accused.Status == AccountStatusUnbonding { + _, removed := s.UnbondingValidators.Remove(accused.Id) + if !removed { + panic("Failed to remove accused validator") + } + } else { + panic("Couldn't find accused validator") + } + accused.Status = AccountStatusDupedOut + updated := s.SetAccountDetail(accused) + if !updated { + panic("Failed to update accused validator account") + } + return nil + default: + panic("Unknown Tx type") + } +} + +// accDet optional +func (s *State) unbondValidator(accountId uint64, accDet *AccountDetail) { + if accDet == nil { + accDet = s.GetAccountDetail(accountId) + } + accDet.Status = AccountStatusUnbonding + s.SetAccountDetail(accDet) + val, removed := s.BondedValidators.Remove(accDet.Id) + if !removed { + panic("Failed to remove validator") + } + val.UnbondHeight = s.Height + added := s.UnbondingValidators.Add(val) + if !added { + panic("Failed to add validator") + } +} + +func (s *State) releaseValidator(accountId uint64) { + accDet := s.GetAccountDetail(accountId) + if accDet.Status != AccountStatusUnbonding { + panic("Cannot release validator") + } + accDet.Status = AccountStatusNominal + // TODO: move balance to designated address, UnbondTo. + s.SetAccountDetail(accDet) + _, removed := s.UnbondingValidators.Remove(accountId) + if !removed { + panic("Couldn't release validator") } - panic("Implement ExecTx()") - return nil } // NOTE: If an error occurs during block execution, state will be left @@ -232,25 +327,61 @@ func (s *State) AppendBlock(b *Block) error { for _, tx := range b.Data.Txs { err := s.ExecTx(tx) if err != nil { - return err + return InvalidTxError{tx, err} + } + } + + // Update LastCommitHeight as necessary. + for _, sig := range b.Validation.Signatures { + _, val := s.BondedValidators.GetById(sig.SignerId) + if val == nil { + return ErrStateInvalidSignature + } + val.LastCommitHeight = b.Height + updated := s.BondedValidators.Update(val) + if !updated { + panic("Failed to update validator LastCommitHeight") } } // If any unbonding periods are over, // reward account with bonded coins. + toRelease := []*Validator{} + s.UnbondingValidators.Iterate(func(val *Validator) bool { + if val.UnbondHeight+unbondingPeriodBlocks < b.Height { + toRelease = append(toRelease, val) + } + return false + }) + for _, val := range toRelease { + s.releaseValidator(val.Id) + } // If any validators haven't signed in a while, // unbond them, they have timed out. + toTimeout := []*Validator{} + s.BondedValidators.Iterate(func(val *Validator) bool { + if val.LastCommitHeight+validatorTimeoutBlocks < b.Height { + toTimeout = append(toTimeout, val) + } + return false + }) + for _, val := range toTimeout { + s.unbondValidator(val.Id, nil) + } // Increment validator AccumPowers s.BondedValidators.IncrementAccum() // State hashes should match + // XXX include UnbondingValidators.Hash(). if !bytes.Equal(s.BondedValidators.Hash(), b.ValidationStateHash) { - return ErrStateInvalidValidationStateHash + return Errorf("Invalid ValidationStateHash. Got %X, block says %X", + s.BondedValidators.Hash(), b.ValidationStateHash) } if !bytes.Equal(s.AccountDetails.Hash(), b.AccountStateHash) { - return ErrStateInvalidAccountStateHash + return Errorf("Invalid AccountStateHash. Got %X, block says %X", + s.AccountDetails.Hash(), b.AccountStateHash) } s.Height = b.Height diff --git a/state/state_test.go b/state/state_test.go index ed20efcc4..66c43fc5d 100644 --- a/state/state_test.go +++ b/state/state_test.go @@ -29,9 +29,9 @@ func randGenesisState(numAccounts int, numValidators int) *State { accountDetails := make([]*AccountDetail, numAccounts) for i := 0; i < numAccounts; i++ { if i < numValidators { - accountDetails[i] = randAccountDetail(uint64(i), AccountDetailStatusNominal) + accountDetails[i] = randAccountDetail(uint64(i), AccountStatusNominal) } else { - accountDetails[i] = randAccountDetail(uint64(i), AccountDetailStatusBonded) + accountDetails[i] = randAccountDetail(uint64(i), AccountStatusBonded) } } s0 := GenesisState(db, time.Now(), accountDetails) @@ -43,8 +43,8 @@ func TestGenesisSaveLoad(t *testing.T) { // Generate a state, save & load it. s0 := randGenesisState(10, 5) // Figure out what the next state hashes should be. - s0.Validators.Hash() - s0ValsCopy := s0.Validators.Copy() + s0.BondedValidators.Hash() + s0ValsCopy := s0.BondedValidators.Copy() s0ValsCopy.IncrementAccum() nextValidationStateHash := s0ValsCopy.Hash() nextAccountStateHash := s0.AccountDetails.Hash() @@ -71,8 +71,8 @@ func TestGenesisSaveLoad(t *testing.T) { // Sanity check s0 //s0.DB.(*MemDB).Print() - if s0.Validators.TotalVotingPower() == 0 { - t.Error("s0 Validators TotalVotingPower should not be 0") + if s0.BondedValidators.TotalVotingPower() == 0 { + t.Error("s0 BondedValidators TotalVotingPower should not be 0") } if s0.Height != 1 { t.Error("s0 Height should be 1, got", s0.Height) @@ -92,12 +92,12 @@ func TestGenesisSaveLoad(t *testing.T) { if !bytes.Equal(s0.BlockHash, s1.BlockHash) { t.Error("BlockHash mismatch") } - // Compare Validators - if s0.Validators.Size() != s1.Validators.Size() { - t.Error("Validators Size mismatch") + // Compare BondedValidators + if s0.BondedValidators.Size() != s1.BondedValidators.Size() { + t.Error("BondedValidators Size mismatch") } - if s0.Validators.TotalVotingPower() != s1.Validators.TotalVotingPower() { - t.Error("Validators TotalVotingPower mismatch") + if s0.BondedValidators.TotalVotingPower() != s1.BondedValidators.TotalVotingPower() { + t.Error("BondedValidators TotalVotingPower mismatch") } if !bytes.Equal(s0.AccountDetails.Hash(), s1.AccountDetails.Hash()) { t.Error("AccountDetail mismatch") diff --git a/state/validator_set.go b/state/validator_set.go index 1cacbb1d3..9b5de2439 100644 --- a/state/validator_set.go +++ b/state/validator_set.go @@ -101,17 +101,26 @@ func (vset *ValidatorSet) Hash() []byte { } func (vset *ValidatorSet) Add(val *Validator) (added bool) { - if val.Accum != 0 { - panic("AddValidator only accepts validators with zero accumpower") - } if vset.validators.Has(val.Id) { return false } - updated := vset.validators.Set(val.Id, val) - return !updated + return !vset.validators.Set(val.Id, val) +} + +func (vset *ValidatorSet) Update(val *Validator) (updated bool) { + if !vset.validators.Has(val.Id) { + return false + } + return vset.validators.Set(val.Id, val) } func (vset *ValidatorSet) Remove(validatorId uint64) (val *Validator, removed bool) { - val, removed = vset.validators.Remove(validatorId) - return val.(*Validator), removed + val_, removed := vset.validators.Remove(validatorId) + return val_.(*Validator), removed +} + +func (vset *ValidatorSet) Iterate(fn func(val *Validator) bool) { + vset.validators.Iterate(func(key_ interface{}, val_ interface{}) bool { + return fn(val_.(*Validator)) + }) } diff --git a/state/validator_set_test.go b/state/validator_set_test.go new file mode 100644 index 000000000..51e936ca2 --- /dev/null +++ b/state/validator_set_test.go @@ -0,0 +1,45 @@ +package state + +import ( + . "github.com/tendermint/tendermint/common" + + "bytes" + "testing" +) + +func randValidator() *Validator { + return &Validator{ + Account: Account{ + Id: RandUInt64(), + PubKey: CRandBytes(32), + }, + BondHeight: RandUInt32(), + UnbondHeight: RandUInt32(), + LastCommitHeight: RandUInt32(), + VotingPower: RandUInt64(), + Accum: int64(RandUInt64()), + } +} + +func randValidatorSet(numValidators int) *ValidatorSet { + validators := make([]*Validator, numValidators) + for i := 0; i < numValidators; i++ { + validators[i] = randValidator() + } + return NewValidatorSet(validators) +} + +func TestCopy(t *testing.T) { + vset := randValidatorSet(10) + vsetHash := vset.Hash() + if len(vsetHash) == 0 { + t.Fatalf("ValidatorSet had unexpected zero hash") + } + + vsetCopy := vset.Copy() + vsetCopyHash := vsetCopy.Hash() + + if !bytes.Equal(vsetHash, vsetCopyHash) { + t.Fatalf("ValidatorSet copy had wrong hash. Orig: %X, Copy: %X", vsetHash, vsetCopyHash) + } +}