Browse Source

fix tests

pull/9/head
Jae Kwon 10 years ago
parent
commit
7652c5d0de
15 changed files with 236 additions and 197 deletions
  1. +0
    -2
      blocks/block_test.go
  2. +3
    -3
      blocks/tx.go
  3. +4
    -0
      common/random.go
  4. +8
    -8
      consensus/consensus.go
  5. +10
    -10
      consensus/pol.go
  6. +2
    -2
      consensus/state.go
  7. +3
    -4
      consensus/vote_set.go
  8. +1
    -1
      merkle/iavl_tree.go
  9. +2
    -2
      merkle/types.go
  10. +21
    -3
      state/account.go
  11. +2
    -2
      state/account_test.go
  12. +26
    -37
      state/state.go
  13. +13
    -11
      state/state_test.go
  14. +24
    -112
      state/validator.go
  15. +117
    -0
      state/validator_set.go

+ 0
- 2
blocks/block_test.go View File

@ -39,13 +39,11 @@ func TestBlock(t *testing.T) {
BaseTx: randBaseTx(), BaseTx: randBaseTx(),
Fee: RandUInt64Exp(), Fee: RandUInt64Exp(),
UnbondTo: RandUInt64Exp(), UnbondTo: RandUInt64Exp(),
Amount: RandUInt64Exp(),
} }
unbondTx := &UnbondTx{ unbondTx := &UnbondTx{
BaseTx: randBaseTx(), BaseTx: randBaseTx(),
Fee: RandUInt64Exp(), Fee: RandUInt64Exp(),
Amount: RandUInt64Exp(),
} }
timeoutTx := &TimeoutTx{ timeoutTx := &TimeoutTx{


+ 3
- 3
blocks/tx.go View File

@ -20,7 +20,7 @@ Validation Txs:
type Tx interface { type Tx interface {
Signable Signable
GetSequence() uint64
GetSequence() uint
} }
const ( const (
@ -83,7 +83,7 @@ func ReadTx(r io.Reader, n *int64, err *error) Tx {
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
type BaseTx struct { type BaseTx struct {
Sequence uint64
Sequence uint
Signature Signature
} }
@ -100,7 +100,7 @@ func (tx BaseTx) WriteTo(w io.Writer) (n int64, err error) {
return return
} }
func (tx *BaseTx) GetSequence() uint64 {
func (tx *BaseTx) GetSequence() uint {
return tx.Sequence return tx.Sequence
} }


+ 4
- 0
common/random.go View File

@ -58,6 +58,10 @@ func RandUInt64() uint64 {
return uint64(rand.Uint32())<<32 + uint64(rand.Uint32()) return uint64(rand.Uint32())<<32 + uint64(rand.Uint32())
} }
func RandUInt() uint {
return uint(rand.Int())
}
// Distributed pseudo-exponentially to test for various cases // Distributed pseudo-exponentially to test for various cases
func RandUInt16Exp() uint16 { func RandUInt16Exp() uint16 {
bits := rand.Uint32() % 16 bits := rand.Uint32() % 16


+ 8
- 8
consensus/consensus.go View File

@ -229,8 +229,8 @@ func (conR *ConsensusReactor) Receive(chId byte, peer *p2p.Peer, msgBytes []byte
if vote.Height != rs.Height || vote.Height != ps.Height { if vote.Height != rs.Height || vote.Height != ps.Height {
return return
} }
index, ok := rs.Validators.GetIndexById(vote.SignerId)
if !ok {
index, val := rs.Validators.GetById(vote.SignerId)
if val == nil {
log.Warning("Peer gave us an invalid vote.") log.Warning("Peer gave us an invalid vote.")
return return
} }
@ -348,8 +348,8 @@ OUTER_LOOP:
if prs.Step <= RoundStepVote { if prs.Step <= RoundStepVote {
index, ok := rs.Votes.BitArray().Sub(prs.Votes).PickRandom() index, ok := rs.Votes.BitArray().Sub(prs.Votes).PickRandom()
if ok { if ok {
valId, ok := rs.Validators.GetIdByIndex(uint32(index))
if ok {
valId, val := rs.Validators.GetByIndex(uint32(index))
if val != nil {
vote := rs.Votes.GetVote(valId) vote := rs.Votes.GetVote(valId)
msg := p2p.TypedMessage{msgTypeVote, vote} msg := p2p.TypedMessage{msgTypeVote, vote}
peer.Send(VoteCh, msg) peer.Send(VoteCh, msg)
@ -365,8 +365,8 @@ OUTER_LOOP:
if prs.Step <= RoundStepPrecommit { if prs.Step <= RoundStepPrecommit {
index, ok := rs.Precommits.BitArray().Sub(prs.Precommits).PickRandom() index, ok := rs.Precommits.BitArray().Sub(prs.Precommits).PickRandom()
if ok { if ok {
valId, ok := rs.Validators.GetIdByIndex(uint32(index))
if ok {
valId, val := rs.Validators.GetByIndex(uint32(index))
if val != nil {
vote := rs.Precommits.GetVote(valId) vote := rs.Precommits.GetVote(valId)
msg := p2p.TypedMessage{msgTypeVote, vote} msg := p2p.TypedMessage{msgTypeVote, vote}
peer.Send(VoteCh, msg) peer.Send(VoteCh, msg)
@ -381,8 +381,8 @@ OUTER_LOOP:
// If there are any commits to send... // If there are any commits to send...
index, ok := rs.Commits.BitArray().Sub(prs.Commits).PickRandom() index, ok := rs.Commits.BitArray().Sub(prs.Commits).PickRandom()
if ok { if ok {
valId, ok := rs.Validators.GetIdByIndex(uint32(index))
if ok {
valId, val := rs.Validators.GetByIndex(uint32(index))
if val != nil {
vote := rs.Commits.GetVote(valId) vote := rs.Commits.GetVote(valId)
msg := p2p.TypedMessage{msgTypeVote, vote} msg := p2p.TypedMessage{msgTypeVote, vote}
peer.Send(VoteCh, msg) peer.Send(VoteCh, msg)


+ 10
- 10
consensus/pol.go View File

@ -55,17 +55,17 @@ func (pol *POL) Verify(vset *ValidatorSet) error {
if _, seen := seenValidators[sig.SignerId]; seen { if _, seen := seenValidators[sig.SignerId]; seen {
return Errorf("Duplicate validator for vote %v for POL %v", sig, pol) return Errorf("Duplicate validator for vote %v for POL %v", sig, pol)
} }
validator := vset.GetById(sig.SignerId)
if validator == nil {
_, val := vset.GetById(sig.SignerId)
if val == nil {
return Errorf("Invalid validator for vote %v for POL %v", sig, pol) return Errorf("Invalid validator for vote %v for POL %v", sig, pol)
} }
if !validator.VerifyBytes(voteDoc, sig) {
if !val.VerifyBytes(voteDoc, sig) {
return Errorf("Invalid signature for vote %v for POL %v", sig, pol) return Errorf("Invalid signature for vote %v for POL %v", sig, pol)
} }
// Tally // Tally
seenValidators[validator.Id] = struct{}{}
talliedVotingPower += validator.VotingPower
seenValidators[val.Id] = struct{}{}
talliedVotingPower += val.VotingPower
} }
for i, sig := range pol.Commits { for i, sig := range pol.Commits {
@ -75,20 +75,20 @@ func (pol *POL) Verify(vset *ValidatorSet) error {
if _, seen := seenValidators[sig.SignerId]; seen { if _, seen := seenValidators[sig.SignerId]; seen {
return Errorf("Duplicate validator for commit %v for POL %v", sig, pol) return Errorf("Duplicate validator for commit %v for POL %v", sig, pol)
} }
validator := vset.GetById(sig.SignerId)
if validator == nil {
_, val := vset.GetById(sig.SignerId)
if val == nil {
return Errorf("Invalid validator for commit %v for POL %v", sig, pol) return Errorf("Invalid validator for commit %v for POL %v", sig, pol)
} }
commitDoc := BinaryBytes(&Vote{Height: pol.Height, Round: round, commitDoc := BinaryBytes(&Vote{Height: pol.Height, Round: round,
Type: VoteTypeCommit, BlockHash: pol.BlockHash}) // TODO cache Type: VoteTypeCommit, BlockHash: pol.BlockHash}) // TODO cache
if !validator.VerifyBytes(commitDoc, sig) {
if !val.VerifyBytes(commitDoc, sig) {
return Errorf("Invalid signature for commit %v for POL %v", sig, pol) return Errorf("Invalid signature for commit %v for POL %v", sig, pol)
} }
// Tally // Tally
seenValidators[validator.Id] = struct{}{}
talliedVotingPower += validator.VotingPower
seenValidators[val.Id] = struct{}{}
talliedVotingPower += val.VotingPower
} }
if talliedVotingPower > vset.TotalVotingPower()*2/3 { if talliedVotingPower > vset.TotalVotingPower()*2/3 {


+ 2
- 2
consensus/state.go View File

@ -92,7 +92,7 @@ func (cs *ConsensusState) updateToState(state *State) {
cs.Step = RoundStepStart cs.Step = RoundStepStart
cs.StartTime = state.CommitTime.Add(newBlockWaitDuration) cs.StartTime = state.CommitTime.Add(newBlockWaitDuration)
cs.Validators = validators cs.Validators = validators
cs.Proposer = validators.GetProposer()
cs.Proposer = validators.Proposer()
cs.Proposal = nil cs.Proposal = nil
cs.ProposalBlock = nil cs.ProposalBlock = nil
cs.ProposalBlockPartSet = nil cs.ProposalBlockPartSet = nil
@ -135,7 +135,7 @@ func (cs *ConsensusState) setupRound(round uint16) {
cs.Round = round cs.Round = round
cs.Step = RoundStepStart cs.Step = RoundStepStart
cs.Validators = validators cs.Validators = validators
cs.Proposer = validators.GetProposer()
cs.Proposer = validators.Proposer()
cs.Proposal = nil cs.Proposal = nil
cs.ProposalBlock = nil cs.ProposalBlock = nil
cs.ProposalBlockPartSet = nil cs.ProposalBlockPartSet = nil


+ 3
- 4
consensus/vote_set.go View File

@ -63,7 +63,7 @@ func (vs *VoteSet) AddVote(vote *Vote) (bool, error) {
} }
// Ensure that signer is a validator. // Ensure that signer is a validator.
val := vs.vset.GetById(vote.SignerId)
_, val := vs.vset.GetById(vote.SignerId)
if val == nil { if val == nil {
return false, ErrVoteInvalidAccount return false, ErrVoteInvalidAccount
} }
@ -89,12 +89,11 @@ func (vs *VoteSet) addVote(vote *Vote) (bool, error) {
// Add vote. // Add vote.
vs.votes[vote.SignerId] = vote vs.votes[vote.SignerId] = vote
voterIndex, ok := vs.vset.GetIndexById(vote.SignerId)
if !ok {
voterIndex, val := vs.vset.GetById(vote.SignerId)
if val == nil {
return false, ErrVoteInvalidAccount return false, ErrVoteInvalidAccount
} }
vs.votesBitArray.SetIndex(uint(voterIndex), true) vs.votesBitArray.SetIndex(uint(voterIndex), true)
val := vs.vset.GetById(vote.SignerId)
totalBlockHashVotes := vs.votesByBlockHash[string(vote.BlockHash)] + val.VotingPower totalBlockHashVotes := vs.votesByBlockHash[string(vote.BlockHash)] + val.VotingPower
vs.votesByBlockHash[string(vote.BlockHash)] = totalBlockHashVotes vs.votesByBlockHash[string(vote.BlockHash)] = totalBlockHashVotes
vs.totalVotes += val.VotingPower vs.totalVotes += val.VotingPower


+ 1
- 1
merkle/iavl_tree.go View File

@ -39,7 +39,7 @@ func NewIAVLTree(keyCodec, valueCodec Codec, cacheSize int, db DB) *IAVLTree {
// The returned tree and the original tree are goroutine independent. // The returned tree and the original tree are goroutine independent.
// That is, they can each run in their own goroutine. // That is, they can each run in their own goroutine.
func (t *IAVLTree) Copy() *IAVLTree {
func (t *IAVLTree) Copy() Tree {
if t.ndb != nil && !t.root.persisted { if t.ndb != nil && !t.root.persisted {
panic("It is unsafe to Copy() an unpersisted tree.") panic("It is unsafe to Copy() an unpersisted tree.")
// Saving a tree finalizes all the nodes. // Saving a tree finalizes all the nodes.


+ 2
- 2
merkle/types.go View File

@ -11,8 +11,8 @@ type Tree interface {
HashWithCount() (hash []byte, count uint64) HashWithCount() (hash []byte, count uint64)
Hash() (hash []byte) Hash() (hash []byte)
Save() (hash []byte) Save() (hash []byte)
Checkpoint() (checkpoint interface{})
Restore(checkpoint interface{})
Load(hash []byte)
Copy() Tree
Iterate(func(key interface{}, value interface{}) (stop bool)) (stopped bool) Iterate(func(key interface{}, value interface{}) (stop bool)) (stopped bool)
} }


+ 21
- 3
state/account.go View File

@ -57,7 +57,7 @@ func (account Account) Verify(o Signable) bool {
type AccountDetail struct { type AccountDetail struct {
Account Account
Sequence uint64
Sequence uint
Balance uint64 Balance uint64
Status byte Status byte
} }
@ -65,7 +65,7 @@ type AccountDetail struct {
func ReadAccountDetail(r io.Reader, n *int64, err *error) *AccountDetail { func ReadAccountDetail(r io.Reader, n *int64, err *error) *AccountDetail {
return &AccountDetail{ return &AccountDetail{
Account: ReadAccount(r, n, err), Account: ReadAccount(r, n, err),
Sequence: ReadUInt64(r, n, err),
Sequence: ReadUVarInt(r, n, err),
Balance: ReadUInt64(r, n, err), Balance: ReadUInt64(r, n, err),
Status: ReadByte(r, n, err), Status: ReadByte(r, n, err),
} }
@ -73,12 +73,30 @@ func ReadAccountDetail(r io.Reader, n *int64, err *error) *AccountDetail {
func (accDet AccountDetail) WriteTo(w io.Writer) (n int64, err error) { func (accDet AccountDetail) WriteTo(w io.Writer) (n int64, err error) {
WriteBinary(w, accDet.Account, &n, &err) WriteBinary(w, accDet.Account, &n, &err)
WriteUInt64(w, accDet.Sequence, &n, &err)
WriteUVarInt(w, accDet.Sequence, &n, &err)
WriteUInt64(w, accDet.Balance, &n, &err) WriteUInt64(w, accDet.Balance, &n, &err)
WriteByte(w, accDet.Status, &n, &err) WriteByte(w, accDet.Status, &n, &err)
return return
} }
//-------------------------------------
var AccountDetailCodec = accountDetailCodec{}
type accountDetailCodec struct{}
func (abc accountDetailCodec) Encode(accDet interface{}, w io.Writer, n *int64, err *error) {
WriteBinary(w, accDet.(*AccountDetail), n, err)
}
func (abc accountDetailCodec) Decode(r io.Reader, n *int64, err *error) interface{} {
return ReadAccountDetail(r, n, err)
}
func (abc accountDetailCodec) Compare(o1 interface{}, o2 interface{}) int {
panic("AccountDetailCodec.Compare not implemented")
}
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
type PrivAccount struct { type PrivAccount struct {


+ 2
- 2
state/account_test.go View File

@ -15,14 +15,14 @@ func TestSignAndValidate(t *testing.T) {
t.Logf("msg: %X, sig: %X", msg, sig) t.Logf("msg: %X, sig: %X", msg, sig)
// Test the signature // Test the signature
if !account.Verify(msg, sig) {
if !account.VerifyBytes(msg, sig) {
t.Errorf("Account message signature verification failed") t.Errorf("Account message signature verification failed")
} }
// Mutate the signature, just one bit. // Mutate the signature, just one bit.
sig.Bytes[0] ^= byte(0x01) sig.Bytes[0] ^= byte(0x01)
if account.Verify(msg, sig) {
if account.VerifyBytes(msg, sig) {
t.Errorf("Account message signature verification should have failed but passed instead") t.Errorf("Account message signature verification should have failed but passed instead")
} }
} }

+ 26
- 37
state/state.go View File

@ -20,23 +20,11 @@ var (
ErrStateInvalidAccountStateHash = errors.New("Error State invalid AccountStateHash") ErrStateInvalidAccountStateHash = errors.New("Error State invalid AccountStateHash")
ErrStateInsufficientFunds = errors.New("Error State insufficient funds") ErrStateInsufficientFunds = errors.New("Error State insufficient funds")
stateKey = []byte("stateKey")
minBondAmount = uint64(1) // TODO adjust
stateKey = []byte("stateKey")
minBondAmount = uint64(1) // TODO adjust
defaultAccountDetailsCacheCapacity = 1000 // TODO adjust
) )
type accountDetailCodec struct{}
func (abc accountDetailCodec) Write(accDet interface{}) (accDetBytes []byte, err error) {
w := new(bytes.Buffer)
_, err = accDet.(*AccountDetail).WriteTo(w)
return w.Bytes(), err
}
func (abc accountDetailCodec) Read(accDetBytes []byte) (interface{}, error) {
n, err, r := new(int64), new(error), bytes.NewBuffer(accDetBytes)
return ReadAccountDetail(r, n, err), *err
}
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
// NOTE: not goroutine-safe. // NOTE: not goroutine-safe.
@ -45,25 +33,31 @@ type State struct {
Height uint32 // Last known block height Height uint32 // Last known block height
BlockHash []byte // Last known block hash BlockHash []byte // Last known block hash
CommitTime time.Time CommitTime time.Time
AccountDetails *merkle.TypedTree
AccountDetails merkle.Tree
Validators *ValidatorSet Validators *ValidatorSet
} }
func GenesisState(db DB, genesisTime time.Time, accDets []*AccountDetail) *State { func GenesisState(db DB, genesisTime time.Time, accDets []*AccountDetail) *State {
// TODO: Use "uint64Codec" instead of BasicCodec // TODO: Use "uint64Codec" instead of BasicCodec
accountDetails := merkle.NewTypedTree(merkle.NewIAVLTree(db), BasicCodec, accountDetailCodec{})
validators := map[uint64]*Validator{}
accountDetails := merkle.NewIAVLTree(BasicCodec, AccountDetailCodec, defaultAccountDetailsCacheCapacity, db)
validators := []*Validator{}
for _, accDet := range accDets { for _, accDet := range accDets {
accountDetails.Set(accDet.Id, accDet) accountDetails.Set(accDet.Id, accDet)
validators[accDet.Id] = &Validator{
Account: accDet.Account,
BondHeight: 0,
VotingPower: accDet.Balance,
Accum: 0,
if accDet.Status == AccountDetailStatusBonded {
validators = append(validators, &Validator{
Account: accDet.Account,
BondHeight: 0,
VotingPower: accDet.Balance,
Accum: 0,
})
} }
} }
if len(validators) == 0 {
panic("Must have some validators")
}
validatorSet := NewValidatorSet(validators) validatorSet := NewValidatorSet(validators)
return &State{ return &State{
@ -89,16 +83,13 @@ func LoadState(db DB) *State {
s.CommitTime = ReadTime(reader, &n, &err) s.CommitTime = ReadTime(reader, &n, &err)
s.BlockHash = ReadByteSlice(reader, &n, &err) s.BlockHash = ReadByteSlice(reader, &n, &err)
accountDetailsHash := ReadByteSlice(reader, &n, &err) accountDetailsHash := ReadByteSlice(reader, &n, &err)
s.AccountDetails = merkle.NewTypedTree(merkle.LoadIAVLTreeFromHash(db, accountDetailsHash), BasicCodec, accountDetailCodec{})
var validators = map[uint64]*Validator{}
for reader.Len() > 0 {
validator := ReadValidator(reader, &n, &err)
validators[validator.Id] = validator
}
s.Validators = NewValidatorSet(validators)
s.AccountDetails = merkle.NewIAVLTree(BasicCodec, AccountDetailCodec, defaultAccountDetailsCacheCapacity, db)
s.AccountDetails.Load(accountDetailsHash)
s.Validators = ReadValidatorSet(reader, &n, &err)
if err != nil { if err != nil {
panic(err) panic(err)
} }
// TODO: ensure that buf is completely read.
} }
return s return s
} }
@ -108,17 +99,15 @@ func LoadState(db DB) *State {
// is saved here. // is saved here.
func (s *State) Save(commitTime time.Time) { func (s *State) Save(commitTime time.Time) {
s.CommitTime = commitTime s.CommitTime = commitTime
s.AccountDetails.Tree.Save()
s.AccountDetails.Save()
var buf bytes.Buffer var buf bytes.Buffer
var n int64 var n int64
var err error var err error
WriteUInt32(&buf, s.Height, &n, &err) WriteUInt32(&buf, s.Height, &n, &err)
WriteTime(&buf, commitTime, &n, &err) WriteTime(&buf, commitTime, &n, &err)
WriteByteSlice(&buf, s.BlockHash, &n, &err) WriteByteSlice(&buf, s.BlockHash, &n, &err)
WriteByteSlice(&buf, s.AccountDetails.Tree.Hash(), &n, &err)
for _, validator := range s.Validators.Map() {
WriteBinary(&buf, validator, &n, &err)
}
WriteByteSlice(&buf, s.AccountDetails.Hash(), &n, &err)
WriteBinary(&buf, s.Validators, &n, &err)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -225,7 +214,7 @@ func (s *State) AppendBlock(b *Block) error {
if !bytes.Equal(s.Validators.Hash(), b.ValidationStateHash) { if !bytes.Equal(s.Validators.Hash(), b.ValidationStateHash) {
return ErrStateInvalidValidationStateHash return ErrStateInvalidValidationStateHash
} }
if !bytes.Equal(s.AccountDetails.Tree.Hash(), b.AccountStateHash) {
if !bytes.Equal(s.AccountDetails.Hash(), b.AccountStateHash) {
return ErrStateInvalidAccountStateHash return ErrStateInvalidAccountStateHash
} }
@ -235,7 +224,7 @@ func (s *State) AppendBlock(b *Block) error {
} }
func (s *State) GetAccountDetail(accountId uint64) *AccountDetail { func (s *State) GetAccountDetail(accountId uint64) *AccountDetail {
accDet := s.AccountDetails.Get(accountId)
_, accDet := s.AccountDetails.Get(accountId)
if accDet == nil { if accDet == nil {
return nil return nil
} }


+ 13
- 11
state/state_test.go View File

@ -11,29 +11,30 @@ import (
"time" "time"
) )
func randAccountBalance(id uint64, status byte) *AccountBalance {
return &AccountBalance{
func randAccountDetail(id uint64, status byte) *AccountDetail {
return &AccountDetail{
Account: Account{ Account: Account{
Id: id, Id: id,
PubKey: CRandBytes(32), PubKey: CRandBytes(32),
}, },
Balance: RandUInt64(),
Status: status,
Sequence: RandUInt(),
Balance: RandUInt64(),
Status: status,
} }
} }
// The first numValidators accounts are validators. // The first numValidators accounts are validators.
func randGenesisState(numAccounts int, numValidators int) *State { func randGenesisState(numAccounts int, numValidators int) *State {
db := NewMemDB() db := NewMemDB()
accountBalances := make([]*AccountBalance, numAccounts)
accountDetails := make([]*AccountDetail, numAccounts)
for i := 0; i < numAccounts; i++ { for i := 0; i < numAccounts; i++ {
if i < numValidators { if i < numValidators {
accountBalances[i] = randAccountBalance(uint64(i), AccountBalanceStatusNominal)
accountDetails[i] = randAccountDetail(uint64(i), AccountDetailStatusNominal)
} else { } else {
accountBalances[i] = randAccountBalance(uint64(i), AccountBalanceStatusBonded)
accountDetails[i] = randAccountDetail(uint64(i), AccountDetailStatusBonded)
} }
} }
s0 := GenesisState(db, time.Now(), accountBalances)
s0 := GenesisState(db, time.Now(), accountDetails)
return s0 return s0
} }
@ -42,10 +43,11 @@ func TestGenesisSaveLoad(t *testing.T) {
// Generate a state, save & load it. // Generate a state, save & load it.
s0 := randGenesisState(10, 5) s0 := randGenesisState(10, 5)
// Figure out what the next state hashes should be. // Figure out what the next state hashes should be.
s0.Validators.Hash()
s0ValsCopy := s0.Validators.Copy() s0ValsCopy := s0.Validators.Copy()
s0ValsCopy.IncrementAccum() s0ValsCopy.IncrementAccum()
nextValidationStateHash := s0ValsCopy.Hash() nextValidationStateHash := s0ValsCopy.Hash()
nextAccountStateHash := s0.AccountBalances.Tree.Hash()
nextAccountStateHash := s0.AccountDetails.Hash()
// Mutate the state to append one empty block. // Mutate the state to append one empty block.
block := &Block{ block := &Block{
Header: Header{ Header: Header{
@ -97,7 +99,7 @@ func TestGenesisSaveLoad(t *testing.T) {
if s0.Validators.TotalVotingPower() != s1.Validators.TotalVotingPower() { if s0.Validators.TotalVotingPower() != s1.Validators.TotalVotingPower() {
t.Error("Validators TotalVotingPower mismatch") t.Error("Validators TotalVotingPower mismatch")
} }
if !bytes.Equal(s0.AccountBalances.Tree.Hash(), s1.AccountBalances.Tree.Hash()) {
t.Error("AccountBalance mismatch")
if !bytes.Equal(s0.AccountDetails.Hash(), s1.AccountDetails.Hash()) {
t.Error("AccountDetail mismatch")
} }
} }

+ 24
- 112
state/validator.go View File

@ -4,8 +4,6 @@ import (
"io" "io"
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/common"
"github.com/tendermint/tendermint/merkle"
) )
// Holds state for a Validator at a given height+round. // Holds state for a Validator at a given height+round.
@ -47,126 +45,40 @@ func (v *Validator) WriteTo(w io.Writer) (n int64, err error) {
return return
} }
//-----------------------------------------------------------------------------
// Not goroutine-safe.
type ValidatorSet struct {
validators map[uint64]*Validator
indexToId map[uint32]uint64 // bitarray index to validator id
idToIndex map[uint64]uint32 // validator id to bitarray index
totalVotingPower uint64
}
func NewValidatorSet(validators map[uint64]*Validator) *ValidatorSet {
if validators == nil {
validators = make(map[uint64]*Validator)
}
ids := []uint64{}
indexToId := map[uint32]uint64{}
idToIndex := map[uint64]uint32{}
totalVotingPower := uint64(0)
for id, val := range validators {
ids = append(ids, id)
totalVotingPower += val.VotingPower
// Returns the one with higher Accum.
func (v *Validator) CompareAccum(other *Validator) *Validator {
if v == nil {
return other
} }
UInt64Slice(ids).Sort()
for i, id := range ids {
indexToId[uint32(i)] = id
idToIndex[id] = uint32(i)
}
return &ValidatorSet{
validators: validators,
indexToId: indexToId,
idToIndex: idToIndex,
totalVotingPower: totalVotingPower,
}
}
func (vset *ValidatorSet) IncrementAccum() {
totalDelta := int64(0)
for _, validator := range vset.validators {
validator.Accum += int64(validator.VotingPower)
totalDelta += int64(validator.VotingPower)
}
proposer := vset.GetProposer()
proposer.Accum -= totalDelta
// NOTE: sum(v) here should be zero.
if true {
totalAccum := int64(0)
for _, validator := range vset.validators {
totalAccum += validator.Accum
}
if totalAccum != 0 {
Panicf("Total Accum of validators did not equal 0. Got: ", totalAccum)
if v.Accum > other.Accum {
return v
} else if v.Accum < other.Accum {
return other
} else {
if v.Id < other.Id {
return v
} else if v.Id > other.Id {
return other
} else {
panic("Cannot compare identical validators")
} }
} }
} }
func (vset *ValidatorSet) Copy() *ValidatorSet {
validators := map[uint64]*Validator{}
for id, val := range vset.validators {
validators[id] = val.Copy()
}
return &ValidatorSet{
validators: validators,
indexToId: vset.indexToId,
idToIndex: vset.idToIndex,
totalVotingPower: vset.totalVotingPower,
}
}
func (vset *ValidatorSet) GetById(id uint64) *Validator {
return vset.validators[id]
}
//-------------------------------------
func (vset *ValidatorSet) GetIndexById(id uint64) (uint32, bool) {
index, ok := vset.idToIndex[id]
return index, ok
}
var ValidatorCodec = validatorCodec{}
func (vset *ValidatorSet) GetIdByIndex(index uint32) (uint64, bool) {
id, ok := vset.indexToId[index]
return id, ok
}
type validatorCodec struct{}
func (vset *ValidatorSet) Map() map[uint64]*Validator {
return vset.validators
func (vc validatorCodec) Encode(o interface{}, w io.Writer, n *int64, err *error) {
WriteBinary(w, o.(*Validator), n, err)
} }
func (vset *ValidatorSet) Size() uint {
return uint(len(vset.validators))
func (vc validatorCodec) Decode(r io.Reader, n *int64, err *error) interface{} {
return ReadValidator(r, n, err)
} }
func (vset *ValidatorSet) TotalVotingPower() uint64 {
return vset.totalVotingPower
}
// TODO: cache proposer. invalidate upon increment.
func (vset *ValidatorSet) GetProposer() (proposer *Validator) {
highestAccum := int64(0)
for _, validator := range vset.validators {
if validator.Accum > highestAccum {
highestAccum = validator.Accum
proposer = validator
} else if validator.Accum == highestAccum {
if validator.Id < proposer.Id { // Seniority
proposer = validator
}
}
}
return
}
// Should uniquely determine the state of the ValidatorSet.
func (vset *ValidatorSet) Hash() []byte {
ids := []uint64{}
for id, _ := range vset.validators {
ids = append(ids, id)
}
UInt64Slice(ids).Sort()
sortedValidators := make([]Binary, len(ids))
for i, id := range ids {
sortedValidators[i] = vset.validators[id]
}
return merkle.HashFromBinaries(sortedValidators)
func (vc validatorCodec) Compare(o1 interface{}, o2 interface{}) int {
panic("ValidatorCodec.Compare not implemented")
} }

+ 117
- 0
state/validator_set.go View File

@ -0,0 +1,117 @@
package state
import (
"io"
. "github.com/tendermint/tendermint/binary"
"github.com/tendermint/tendermint/merkle"
)
// Not goroutine-safe.
type ValidatorSet struct {
validators merkle.Tree
proposer *Validator // Whoever has the highest Accum.
totalVotingPower uint64
}
func NewValidatorSet(vals []*Validator) *ValidatorSet {
validators := merkle.NewIAVLTree(BasicCodec, ValidatorCodec, 0, nil) // In memory
var proposer *Validator
totalVotingPower := uint64(0)
for _, val := range vals {
validators.Set(val.Id, val)
proposer = proposer.CompareAccum(val)
totalVotingPower += val.VotingPower
}
return &ValidatorSet{
validators: validators,
proposer: proposer,
totalVotingPower: totalVotingPower,
}
}
func ReadValidatorSet(r io.Reader, n *int64, err *error) *ValidatorSet {
size := ReadUInt64(r, n, err)
validators := []*Validator{}
for i := uint64(0); i < size; i++ {
validator := ReadValidator(r, n, err)
validators = append(validators, validator)
}
return NewValidatorSet(validators)
}
func (vset *ValidatorSet) WriteTo(w io.Writer) (n int64, err error) {
WriteUInt64(w, uint64(vset.validators.Size()), &n, &err)
vset.validators.Iterate(func(key_ interface{}, val_ interface{}) bool {
val := val_.(*Validator)
WriteBinary(w, val, &n, &err)
return false
})
return
}
func (vset *ValidatorSet) IncrementAccum() {
// Decrement from previous proposer
vset.proposer.Accum -= int64(vset.totalVotingPower)
var proposer *Validator
// Increment accum and find proposer
vset.validators.Iterate(func(key_ interface{}, val_ interface{}) bool {
val := val_.(*Validator)
val.Accum += int64(val.VotingPower)
proposer = proposer.CompareAccum(val)
return false
})
vset.proposer = proposer
}
func (vset *ValidatorSet) Copy() *ValidatorSet {
return &ValidatorSet{
validators: vset.validators.Copy(),
proposer: vset.proposer,
totalVotingPower: vset.totalVotingPower,
}
}
func (vset *ValidatorSet) GetById(id uint64) (index uint32, val *Validator) {
index_, val_ := vset.validators.Get(id)
index, val = uint32(index_), val_.(*Validator)
return
}
func (vset *ValidatorSet) GetByIndex(index uint32) (id uint64, val *Validator) {
id_, val_ := vset.validators.GetByIndex(uint64(index))
id, val = id_.(uint64), val_.(*Validator)
return
}
func (vset *ValidatorSet) Size() uint {
return uint(vset.validators.Size())
}
func (vset *ValidatorSet) TotalVotingPower() uint64 {
return vset.totalVotingPower
}
func (vset *ValidatorSet) Proposer() (proposer *Validator) {
return vset.proposer
}
func (vset *ValidatorSet) Hash() []byte {
return vset.validators.Hash()
}
func (vset *ValidatorSet) AddValidator(val *Validator) (added bool) {
if val.Accum != 0 {
panic("AddValidator only accepts validators with zero accumpower")
}
if vset.validators.Has(val.Id) {
return false
}
updated := vset.validators.Set(val.Id, val)
return !updated
}
func (vset *ValidatorSet) RemoveValidator(validatorId uint64) (removed bool) {
_, removed = vset.validators.Remove(validatorId)
return removed
}

Loading…
Cancel
Save