diff --git a/blocks/block_test.go b/blocks/block_test.go index 26fe5eba1..972ae8547 100644 --- a/blocks/block_test.go +++ b/blocks/block_test.go @@ -39,13 +39,11 @@ func TestBlock(t *testing.T) { BaseTx: randBaseTx(), Fee: RandUInt64Exp(), UnbondTo: RandUInt64Exp(), - Amount: RandUInt64Exp(), } unbondTx := &UnbondTx{ BaseTx: randBaseTx(), Fee: RandUInt64Exp(), - Amount: RandUInt64Exp(), } timeoutTx := &TimeoutTx{ diff --git a/blocks/tx.go b/blocks/tx.go index e43f3b8b9..e7d248e95 100644 --- a/blocks/tx.go +++ b/blocks/tx.go @@ -20,7 +20,7 @@ Validation Txs: type Tx interface { Signable - GetSequence() uint64 + GetSequence() uint } const ( @@ -83,7 +83,7 @@ func ReadTx(r io.Reader, n *int64, err *error) Tx { //----------------------------------------------------------------------------- type BaseTx struct { - Sequence uint64 + Sequence uint Signature } @@ -100,7 +100,7 @@ func (tx BaseTx) WriteTo(w io.Writer) (n int64, err error) { return } -func (tx *BaseTx) GetSequence() uint64 { +func (tx *BaseTx) GetSequence() uint { return tx.Sequence } diff --git a/common/random.go b/common/random.go index 2588db623..3b46ef67e 100644 --- a/common/random.go +++ b/common/random.go @@ -58,6 +58,10 @@ func RandUInt64() uint64 { return uint64(rand.Uint32())<<32 + uint64(rand.Uint32()) } +func RandUInt() uint { + return uint(rand.Int()) +} + // Distributed pseudo-exponentially to test for various cases func RandUInt16Exp() uint16 { bits := rand.Uint32() % 16 diff --git a/consensus/consensus.go b/consensus/consensus.go index 4e616ad5c..de7db1440 100644 --- a/consensus/consensus.go +++ b/consensus/consensus.go @@ -229,8 +229,8 @@ func (conR *ConsensusReactor) Receive(chId byte, peer *p2p.Peer, msgBytes []byte if vote.Height != rs.Height || vote.Height != ps.Height { return } - index, ok := rs.Validators.GetIndexById(vote.SignerId) - if !ok { + index, val := rs.Validators.GetById(vote.SignerId) + if val == nil { log.Warning("Peer gave us an invalid vote.") return } @@ -348,8 +348,8 @@ OUTER_LOOP: if prs.Step <= RoundStepVote { index, ok := rs.Votes.BitArray().Sub(prs.Votes).PickRandom() if ok { - valId, ok := rs.Validators.GetIdByIndex(uint32(index)) - if ok { + valId, val := rs.Validators.GetByIndex(uint32(index)) + if val != nil { vote := rs.Votes.GetVote(valId) msg := p2p.TypedMessage{msgTypeVote, vote} peer.Send(VoteCh, msg) @@ -365,8 +365,8 @@ OUTER_LOOP: if prs.Step <= RoundStepPrecommit { index, ok := rs.Precommits.BitArray().Sub(prs.Precommits).PickRandom() if ok { - valId, ok := rs.Validators.GetIdByIndex(uint32(index)) - if ok { + valId, val := rs.Validators.GetByIndex(uint32(index)) + if val != nil { vote := rs.Precommits.GetVote(valId) msg := p2p.TypedMessage{msgTypeVote, vote} peer.Send(VoteCh, msg) @@ -381,8 +381,8 @@ OUTER_LOOP: // If there are any commits to send... index, ok := rs.Commits.BitArray().Sub(prs.Commits).PickRandom() if ok { - valId, ok := rs.Validators.GetIdByIndex(uint32(index)) - if ok { + valId, val := rs.Validators.GetByIndex(uint32(index)) + if val != nil { vote := rs.Commits.GetVote(valId) msg := p2p.TypedMessage{msgTypeVote, vote} peer.Send(VoteCh, msg) diff --git a/consensus/pol.go b/consensus/pol.go index 1308b3744..d1eb888e8 100644 --- a/consensus/pol.go +++ b/consensus/pol.go @@ -55,17 +55,17 @@ func (pol *POL) Verify(vset *ValidatorSet) error { if _, seen := seenValidators[sig.SignerId]; seen { return Errorf("Duplicate validator for vote %v for POL %v", sig, pol) } - validator := vset.GetById(sig.SignerId) - if validator == nil { + _, val := vset.GetById(sig.SignerId) + if val == nil { return Errorf("Invalid validator for vote %v for POL %v", sig, pol) } - if !validator.VerifyBytes(voteDoc, sig) { + if !val.VerifyBytes(voteDoc, sig) { return Errorf("Invalid signature for vote %v for POL %v", sig, pol) } // Tally - seenValidators[validator.Id] = struct{}{} - talliedVotingPower += validator.VotingPower + seenValidators[val.Id] = struct{}{} + talliedVotingPower += val.VotingPower } for i, sig := range pol.Commits { @@ -75,20 +75,20 @@ func (pol *POL) Verify(vset *ValidatorSet) error { if _, seen := seenValidators[sig.SignerId]; seen { return Errorf("Duplicate validator for commit %v for POL %v", sig, pol) } - validator := vset.GetById(sig.SignerId) - if validator == nil { + _, val := vset.GetById(sig.SignerId) + if val == nil { return Errorf("Invalid validator for commit %v for POL %v", sig, pol) } commitDoc := BinaryBytes(&Vote{Height: pol.Height, Round: round, Type: VoteTypeCommit, BlockHash: pol.BlockHash}) // TODO cache - if !validator.VerifyBytes(commitDoc, sig) { + if !val.VerifyBytes(commitDoc, sig) { return Errorf("Invalid signature for commit %v for POL %v", sig, pol) } // Tally - seenValidators[validator.Id] = struct{}{} - talliedVotingPower += validator.VotingPower + seenValidators[val.Id] = struct{}{} + talliedVotingPower += val.VotingPower } if talliedVotingPower > vset.TotalVotingPower()*2/3 { diff --git a/consensus/state.go b/consensus/state.go index 74b25218b..f79548e4c 100644 --- a/consensus/state.go +++ b/consensus/state.go @@ -92,7 +92,7 @@ func (cs *ConsensusState) updateToState(state *State) { cs.Step = RoundStepStart cs.StartTime = state.CommitTime.Add(newBlockWaitDuration) cs.Validators = validators - cs.Proposer = validators.GetProposer() + cs.Proposer = validators.Proposer() cs.Proposal = nil cs.ProposalBlock = nil cs.ProposalBlockPartSet = nil @@ -135,7 +135,7 @@ func (cs *ConsensusState) setupRound(round uint16) { cs.Round = round cs.Step = RoundStepStart cs.Validators = validators - cs.Proposer = validators.GetProposer() + cs.Proposer = validators.Proposer() cs.Proposal = nil cs.ProposalBlock = nil cs.ProposalBlockPartSet = nil diff --git a/consensus/vote_set.go b/consensus/vote_set.go index 851b15ec8..8dd2b8855 100644 --- a/consensus/vote_set.go +++ b/consensus/vote_set.go @@ -63,7 +63,7 @@ func (vs *VoteSet) AddVote(vote *Vote) (bool, error) { } // Ensure that signer is a validator. - val := vs.vset.GetById(vote.SignerId) + _, val := vs.vset.GetById(vote.SignerId) if val == nil { return false, ErrVoteInvalidAccount } @@ -89,12 +89,11 @@ func (vs *VoteSet) addVote(vote *Vote) (bool, error) { // Add vote. vs.votes[vote.SignerId] = vote - voterIndex, ok := vs.vset.GetIndexById(vote.SignerId) - if !ok { + voterIndex, val := vs.vset.GetById(vote.SignerId) + if val == nil { return false, ErrVoteInvalidAccount } vs.votesBitArray.SetIndex(uint(voterIndex), true) - val := vs.vset.GetById(vote.SignerId) totalBlockHashVotes := vs.votesByBlockHash[string(vote.BlockHash)] + val.VotingPower vs.votesByBlockHash[string(vote.BlockHash)] = totalBlockHashVotes vs.totalVotes += val.VotingPower diff --git a/merkle/iavl_tree.go b/merkle/iavl_tree.go index 514170259..0c743542b 100644 --- a/merkle/iavl_tree.go +++ b/merkle/iavl_tree.go @@ -39,7 +39,7 @@ func NewIAVLTree(keyCodec, valueCodec Codec, cacheSize int, db DB) *IAVLTree { // The returned tree and the original tree are goroutine independent. // That is, they can each run in their own goroutine. -func (t *IAVLTree) Copy() *IAVLTree { +func (t *IAVLTree) Copy() Tree { if t.ndb != nil && !t.root.persisted { panic("It is unsafe to Copy() an unpersisted tree.") // Saving a tree finalizes all the nodes. diff --git a/merkle/types.go b/merkle/types.go index a3fa9835a..bba0eee11 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -11,8 +11,8 @@ type Tree interface { HashWithCount() (hash []byte, count uint64) Hash() (hash []byte) Save() (hash []byte) - Checkpoint() (checkpoint interface{}) - Restore(checkpoint interface{}) + Load(hash []byte) + Copy() Tree Iterate(func(key interface{}, value interface{}) (stop bool)) (stopped bool) } diff --git a/state/account.go b/state/account.go index 91405209f..f1e4fd89b 100644 --- a/state/account.go +++ b/state/account.go @@ -57,7 +57,7 @@ func (account Account) Verify(o Signable) bool { type AccountDetail struct { Account - Sequence uint64 + Sequence uint Balance uint64 Status byte } @@ -65,7 +65,7 @@ type AccountDetail struct { func ReadAccountDetail(r io.Reader, n *int64, err *error) *AccountDetail { return &AccountDetail{ Account: ReadAccount(r, n, err), - Sequence: ReadUInt64(r, n, err), + Sequence: ReadUVarInt(r, n, err), Balance: ReadUInt64(r, n, err), Status: ReadByte(r, n, err), } @@ -73,12 +73,30 @@ func ReadAccountDetail(r io.Reader, n *int64, err *error) *AccountDetail { func (accDet AccountDetail) WriteTo(w io.Writer) (n int64, err error) { WriteBinary(w, accDet.Account, &n, &err) - WriteUInt64(w, accDet.Sequence, &n, &err) + WriteUVarInt(w, accDet.Sequence, &n, &err) WriteUInt64(w, accDet.Balance, &n, &err) WriteByte(w, accDet.Status, &n, &err) return } +//------------------------------------- + +var AccountDetailCodec = accountDetailCodec{} + +type accountDetailCodec struct{} + +func (abc accountDetailCodec) Encode(accDet interface{}, w io.Writer, n *int64, err *error) { + WriteBinary(w, accDet.(*AccountDetail), n, err) +} + +func (abc accountDetailCodec) Decode(r io.Reader, n *int64, err *error) interface{} { + return ReadAccountDetail(r, n, err) +} + +func (abc accountDetailCodec) Compare(o1 interface{}, o2 interface{}) int { + panic("AccountDetailCodec.Compare not implemented") +} + //----------------------------------------------------------------------------- type PrivAccount struct { diff --git a/state/account_test.go b/state/account_test.go index a9af1761e..20ce6021f 100644 --- a/state/account_test.go +++ b/state/account_test.go @@ -15,14 +15,14 @@ func TestSignAndValidate(t *testing.T) { t.Logf("msg: %X, sig: %X", msg, sig) // Test the signature - if !account.Verify(msg, sig) { + if !account.VerifyBytes(msg, sig) { t.Errorf("Account message signature verification failed") } // Mutate the signature, just one bit. sig.Bytes[0] ^= byte(0x01) - if account.Verify(msg, sig) { + if account.VerifyBytes(msg, sig) { t.Errorf("Account message signature verification should have failed but passed instead") } } diff --git a/state/state.go b/state/state.go index b5d26d613..d7568797f 100644 --- a/state/state.go +++ b/state/state.go @@ -20,23 +20,11 @@ var ( ErrStateInvalidAccountStateHash = errors.New("Error State invalid AccountStateHash") ErrStateInsufficientFunds = errors.New("Error State insufficient funds") - stateKey = []byte("stateKey") - minBondAmount = uint64(1) // TODO adjust + stateKey = []byte("stateKey") + minBondAmount = uint64(1) // TODO adjust + defaultAccountDetailsCacheCapacity = 1000 // TODO adjust ) -type accountDetailCodec struct{} - -func (abc accountDetailCodec) Write(accDet interface{}) (accDetBytes []byte, err error) { - w := new(bytes.Buffer) - _, err = accDet.(*AccountDetail).WriteTo(w) - return w.Bytes(), err -} - -func (abc accountDetailCodec) Read(accDetBytes []byte) (interface{}, error) { - n, err, r := new(int64), new(error), bytes.NewBuffer(accDetBytes) - return ReadAccountDetail(r, n, err), *err -} - //----------------------------------------------------------------------------- // NOTE: not goroutine-safe. @@ -45,25 +33,31 @@ type State struct { Height uint32 // Last known block height BlockHash []byte // Last known block hash CommitTime time.Time - AccountDetails *merkle.TypedTree + AccountDetails merkle.Tree Validators *ValidatorSet } func GenesisState(db DB, genesisTime time.Time, accDets []*AccountDetail) *State { // TODO: Use "uint64Codec" instead of BasicCodec - accountDetails := merkle.NewTypedTree(merkle.NewIAVLTree(db), BasicCodec, accountDetailCodec{}) - validators := map[uint64]*Validator{} + accountDetails := merkle.NewIAVLTree(BasicCodec, AccountDetailCodec, defaultAccountDetailsCacheCapacity, db) + validators := []*Validator{} for _, accDet := range accDets { accountDetails.Set(accDet.Id, accDet) - validators[accDet.Id] = &Validator{ - Account: accDet.Account, - BondHeight: 0, - VotingPower: accDet.Balance, - Accum: 0, + if accDet.Status == AccountDetailStatusBonded { + validators = append(validators, &Validator{ + Account: accDet.Account, + BondHeight: 0, + VotingPower: accDet.Balance, + Accum: 0, + }) } } + + if len(validators) == 0 { + panic("Must have some validators") + } validatorSet := NewValidatorSet(validators) return &State{ @@ -89,16 +83,13 @@ func LoadState(db DB) *State { s.CommitTime = ReadTime(reader, &n, &err) s.BlockHash = ReadByteSlice(reader, &n, &err) accountDetailsHash := ReadByteSlice(reader, &n, &err) - s.AccountDetails = merkle.NewTypedTree(merkle.LoadIAVLTreeFromHash(db, accountDetailsHash), BasicCodec, accountDetailCodec{}) - var validators = map[uint64]*Validator{} - for reader.Len() > 0 { - validator := ReadValidator(reader, &n, &err) - validators[validator.Id] = validator - } - s.Validators = NewValidatorSet(validators) + s.AccountDetails = merkle.NewIAVLTree(BasicCodec, AccountDetailCodec, defaultAccountDetailsCacheCapacity, db) + s.AccountDetails.Load(accountDetailsHash) + s.Validators = ReadValidatorSet(reader, &n, &err) if err != nil { panic(err) } + // TODO: ensure that buf is completely read. } return s } @@ -108,17 +99,15 @@ func LoadState(db DB) *State { // is saved here. func (s *State) Save(commitTime time.Time) { s.CommitTime = commitTime - s.AccountDetails.Tree.Save() + s.AccountDetails.Save() var buf bytes.Buffer var n int64 var err error WriteUInt32(&buf, s.Height, &n, &err) WriteTime(&buf, commitTime, &n, &err) WriteByteSlice(&buf, s.BlockHash, &n, &err) - WriteByteSlice(&buf, s.AccountDetails.Tree.Hash(), &n, &err) - for _, validator := range s.Validators.Map() { - WriteBinary(&buf, validator, &n, &err) - } + WriteByteSlice(&buf, s.AccountDetails.Hash(), &n, &err) + WriteBinary(&buf, s.Validators, &n, &err) if err != nil { panic(err) } @@ -225,7 +214,7 @@ func (s *State) AppendBlock(b *Block) error { if !bytes.Equal(s.Validators.Hash(), b.ValidationStateHash) { return ErrStateInvalidValidationStateHash } - if !bytes.Equal(s.AccountDetails.Tree.Hash(), b.AccountStateHash) { + if !bytes.Equal(s.AccountDetails.Hash(), b.AccountStateHash) { return ErrStateInvalidAccountStateHash } @@ -235,7 +224,7 @@ func (s *State) AppendBlock(b *Block) error { } func (s *State) GetAccountDetail(accountId uint64) *AccountDetail { - accDet := s.AccountDetails.Get(accountId) + _, accDet := s.AccountDetails.Get(accountId) if accDet == nil { return nil } diff --git a/state/state_test.go b/state/state_test.go index d45d01396..ed20efcc4 100644 --- a/state/state_test.go +++ b/state/state_test.go @@ -11,29 +11,30 @@ import ( "time" ) -func randAccountBalance(id uint64, status byte) *AccountBalance { - return &AccountBalance{ +func randAccountDetail(id uint64, status byte) *AccountDetail { + return &AccountDetail{ Account: Account{ Id: id, PubKey: CRandBytes(32), }, - Balance: RandUInt64(), - Status: status, + Sequence: RandUInt(), + Balance: RandUInt64(), + Status: status, } } // The first numValidators accounts are validators. func randGenesisState(numAccounts int, numValidators int) *State { db := NewMemDB() - accountBalances := make([]*AccountBalance, numAccounts) + accountDetails := make([]*AccountDetail, numAccounts) for i := 0; i < numAccounts; i++ { if i < numValidators { - accountBalances[i] = randAccountBalance(uint64(i), AccountBalanceStatusNominal) + accountDetails[i] = randAccountDetail(uint64(i), AccountDetailStatusNominal) } else { - accountBalances[i] = randAccountBalance(uint64(i), AccountBalanceStatusBonded) + accountDetails[i] = randAccountDetail(uint64(i), AccountDetailStatusBonded) } } - s0 := GenesisState(db, time.Now(), accountBalances) + s0 := GenesisState(db, time.Now(), accountDetails) return s0 } @@ -42,10 +43,11 @@ func TestGenesisSaveLoad(t *testing.T) { // Generate a state, save & load it. s0 := randGenesisState(10, 5) // Figure out what the next state hashes should be. + s0.Validators.Hash() s0ValsCopy := s0.Validators.Copy() s0ValsCopy.IncrementAccum() nextValidationStateHash := s0ValsCopy.Hash() - nextAccountStateHash := s0.AccountBalances.Tree.Hash() + nextAccountStateHash := s0.AccountDetails.Hash() // Mutate the state to append one empty block. block := &Block{ Header: Header{ @@ -97,7 +99,7 @@ func TestGenesisSaveLoad(t *testing.T) { if s0.Validators.TotalVotingPower() != s1.Validators.TotalVotingPower() { t.Error("Validators TotalVotingPower mismatch") } - if !bytes.Equal(s0.AccountBalances.Tree.Hash(), s1.AccountBalances.Tree.Hash()) { - t.Error("AccountBalance mismatch") + if !bytes.Equal(s0.AccountDetails.Hash(), s1.AccountDetails.Hash()) { + t.Error("AccountDetail mismatch") } } diff --git a/state/validator.go b/state/validator.go index 0995b75c1..f57b3aaf6 100644 --- a/state/validator.go +++ b/state/validator.go @@ -4,8 +4,6 @@ import ( "io" . "github.com/tendermint/tendermint/binary" - . "github.com/tendermint/tendermint/common" - "github.com/tendermint/tendermint/merkle" ) // Holds state for a Validator at a given height+round. @@ -47,126 +45,40 @@ func (v *Validator) WriteTo(w io.Writer) (n int64, err error) { return } -//----------------------------------------------------------------------------- - -// Not goroutine-safe. -type ValidatorSet struct { - validators map[uint64]*Validator - indexToId map[uint32]uint64 // bitarray index to validator id - idToIndex map[uint64]uint32 // validator id to bitarray index - totalVotingPower uint64 -} - -func NewValidatorSet(validators map[uint64]*Validator) *ValidatorSet { - if validators == nil { - validators = make(map[uint64]*Validator) - } - ids := []uint64{} - indexToId := map[uint32]uint64{} - idToIndex := map[uint64]uint32{} - totalVotingPower := uint64(0) - for id, val := range validators { - ids = append(ids, id) - totalVotingPower += val.VotingPower +// Returns the one with higher Accum. +func (v *Validator) CompareAccum(other *Validator) *Validator { + if v == nil { + return other } - UInt64Slice(ids).Sort() - for i, id := range ids { - indexToId[uint32(i)] = id - idToIndex[id] = uint32(i) - } - return &ValidatorSet{ - validators: validators, - indexToId: indexToId, - idToIndex: idToIndex, - totalVotingPower: totalVotingPower, - } -} - -func (vset *ValidatorSet) IncrementAccum() { - totalDelta := int64(0) - for _, validator := range vset.validators { - validator.Accum += int64(validator.VotingPower) - totalDelta += int64(validator.VotingPower) - } - proposer := vset.GetProposer() - proposer.Accum -= totalDelta - // NOTE: sum(v) here should be zero. - if true { - totalAccum := int64(0) - for _, validator := range vset.validators { - totalAccum += validator.Accum - } - if totalAccum != 0 { - Panicf("Total Accum of validators did not equal 0. Got: ", totalAccum) + if v.Accum > other.Accum { + return v + } else if v.Accum < other.Accum { + return other + } else { + if v.Id < other.Id { + return v + } else if v.Id > other.Id { + return other + } else { + panic("Cannot compare identical validators") } } } -func (vset *ValidatorSet) Copy() *ValidatorSet { - validators := map[uint64]*Validator{} - for id, val := range vset.validators { - validators[id] = val.Copy() - } - return &ValidatorSet{ - validators: validators, - indexToId: vset.indexToId, - idToIndex: vset.idToIndex, - totalVotingPower: vset.totalVotingPower, - } -} - -func (vset *ValidatorSet) GetById(id uint64) *Validator { - return vset.validators[id] -} +//------------------------------------- -func (vset *ValidatorSet) GetIndexById(id uint64) (uint32, bool) { - index, ok := vset.idToIndex[id] - return index, ok -} +var ValidatorCodec = validatorCodec{} -func (vset *ValidatorSet) GetIdByIndex(index uint32) (uint64, bool) { - id, ok := vset.indexToId[index] - return id, ok -} +type validatorCodec struct{} -func (vset *ValidatorSet) Map() map[uint64]*Validator { - return vset.validators +func (vc validatorCodec) Encode(o interface{}, w io.Writer, n *int64, err *error) { + WriteBinary(w, o.(*Validator), n, err) } -func (vset *ValidatorSet) Size() uint { - return uint(len(vset.validators)) +func (vc validatorCodec) Decode(r io.Reader, n *int64, err *error) interface{} { + return ReadValidator(r, n, err) } -func (vset *ValidatorSet) TotalVotingPower() uint64 { - return vset.totalVotingPower -} - -// TODO: cache proposer. invalidate upon increment. -func (vset *ValidatorSet) GetProposer() (proposer *Validator) { - highestAccum := int64(0) - for _, validator := range vset.validators { - if validator.Accum > highestAccum { - highestAccum = validator.Accum - proposer = validator - } else if validator.Accum == highestAccum { - if validator.Id < proposer.Id { // Seniority - proposer = validator - } - } - } - return -} - -// Should uniquely determine the state of the ValidatorSet. -func (vset *ValidatorSet) Hash() []byte { - ids := []uint64{} - for id, _ := range vset.validators { - ids = append(ids, id) - } - UInt64Slice(ids).Sort() - sortedValidators := make([]Binary, len(ids)) - for i, id := range ids { - sortedValidators[i] = vset.validators[id] - } - return merkle.HashFromBinaries(sortedValidators) +func (vc validatorCodec) Compare(o1 interface{}, o2 interface{}) int { + panic("ValidatorCodec.Compare not implemented") } diff --git a/state/validator_set.go b/state/validator_set.go new file mode 100644 index 000000000..38b9a4e0e --- /dev/null +++ b/state/validator_set.go @@ -0,0 +1,117 @@ +package state + +import ( + "io" + + . "github.com/tendermint/tendermint/binary" + "github.com/tendermint/tendermint/merkle" +) + +// Not goroutine-safe. +type ValidatorSet struct { + validators merkle.Tree + proposer *Validator // Whoever has the highest Accum. + totalVotingPower uint64 +} + +func NewValidatorSet(vals []*Validator) *ValidatorSet { + validators := merkle.NewIAVLTree(BasicCodec, ValidatorCodec, 0, nil) // In memory + var proposer *Validator + totalVotingPower := uint64(0) + for _, val := range vals { + validators.Set(val.Id, val) + proposer = proposer.CompareAccum(val) + totalVotingPower += val.VotingPower + } + return &ValidatorSet{ + validators: validators, + proposer: proposer, + totalVotingPower: totalVotingPower, + } +} + +func ReadValidatorSet(r io.Reader, n *int64, err *error) *ValidatorSet { + size := ReadUInt64(r, n, err) + validators := []*Validator{} + for i := uint64(0); i < size; i++ { + validator := ReadValidator(r, n, err) + validators = append(validators, validator) + } + return NewValidatorSet(validators) +} + +func (vset *ValidatorSet) WriteTo(w io.Writer) (n int64, err error) { + WriteUInt64(w, uint64(vset.validators.Size()), &n, &err) + vset.validators.Iterate(func(key_ interface{}, val_ interface{}) bool { + val := val_.(*Validator) + WriteBinary(w, val, &n, &err) + return false + }) + return +} + +func (vset *ValidatorSet) IncrementAccum() { + // Decrement from previous proposer + vset.proposer.Accum -= int64(vset.totalVotingPower) + var proposer *Validator + // Increment accum and find proposer + vset.validators.Iterate(func(key_ interface{}, val_ interface{}) bool { + val := val_.(*Validator) + val.Accum += int64(val.VotingPower) + proposer = proposer.CompareAccum(val) + return false + }) + vset.proposer = proposer +} + +func (vset *ValidatorSet) Copy() *ValidatorSet { + return &ValidatorSet{ + validators: vset.validators.Copy(), + proposer: vset.proposer, + totalVotingPower: vset.totalVotingPower, + } +} + +func (vset *ValidatorSet) GetById(id uint64) (index uint32, val *Validator) { + index_, val_ := vset.validators.Get(id) + index, val = uint32(index_), val_.(*Validator) + return +} + +func (vset *ValidatorSet) GetByIndex(index uint32) (id uint64, val *Validator) { + id_, val_ := vset.validators.GetByIndex(uint64(index)) + id, val = id_.(uint64), val_.(*Validator) + return +} + +func (vset *ValidatorSet) Size() uint { + return uint(vset.validators.Size()) +} + +func (vset *ValidatorSet) TotalVotingPower() uint64 { + return vset.totalVotingPower +} + +func (vset *ValidatorSet) Proposer() (proposer *Validator) { + return vset.proposer +} + +func (vset *ValidatorSet) Hash() []byte { + return vset.validators.Hash() +} + +func (vset *ValidatorSet) AddValidator(val *Validator) (added bool) { + if val.Accum != 0 { + panic("AddValidator only accepts validators with zero accumpower") + } + if vset.validators.Has(val.Id) { + return false + } + updated := vset.validators.Set(val.Id, val) + return !updated +} + +func (vset *ValidatorSet) RemoveValidator(validatorId uint64) (removed bool) { + _, removed = vset.validators.Remove(validatorId) + return removed +}