Browse Source

remove mutex from state.

pull/9/head
Jae Kwon 10 years ago
parent
commit
31c1a1bbf7
3 changed files with 55 additions and 100 deletions
  1. +3
    -3
      consensus/state.go
  2. +40
    -83
      state/state.go
  3. +12
    -14
      state/state_test.go

+ 3
- 3
consensus/state.go View File

@ -51,12 +51,11 @@ type RoundState struct {
// Tracks consensus state across block heights and rounds. // Tracks consensus state across block heights and rounds.
type ConsensusState struct { type ConsensusState struct {
mtx sync.Mutex
RoundState
blockStore *BlockStore blockStore *BlockStore
mempool *Mempool mempool *Mempool
mtx sync.Mutex
RoundState
state *State // State until height-1. state *State // State until height-1.
stagedBlock *Block // Cache last staged block. stagedBlock *Block // Cache last staged block.
stagedState *State // Cache result of staged block. stagedState *State // Cache result of staged block.
@ -105,6 +104,7 @@ func (cs *ConsensusState) updateToState(state *State) {
cs.Precommits = NewVoteSet(height, 0, VoteTypePrecommit, validators) cs.Precommits = NewVoteSet(height, 0, VoteTypePrecommit, validators)
cs.Commits = NewVoteSet(height, 0, VoteTypeCommit, validators) cs.Commits = NewVoteSet(height, 0, VoteTypeCommit, validators)
cs.state = state
cs.stagedBlock = nil cs.stagedBlock = nil
cs.stagedState = nil cs.stagedState = nil


+ 40
- 83
state/state.go View File

@ -3,7 +3,6 @@ package state
import ( import (
"bytes" "bytes"
"errors" "errors"
"sync"
"time" "time"
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
@ -35,15 +34,14 @@ func (abc accountBalanceCodec) Read(accBalBytes []byte) (interface{}, error) {
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
// TODO: make it unsafe, remove mtx, and export fields?
// NOTE: not goroutine-safe.
type State struct { type State struct {
mtx sync.Mutex
db DB
height uint32 // Last known block height
blockHash []byte // Last known block hash
commitTime time.Time
accountBalances *merkle.TypedTree
validators *ValidatorSet
DB DB
Height uint32 // Last known block height
BlockHash []byte // Last known block hash
CommitTime time.Time
AccountBalances *merkle.TypedTree
Validators *ValidatorSet
} }
func GenesisState(db DB, genesisTime time.Time, accBals []*AccountBalance) *State { func GenesisState(db DB, genesisTime time.Time, accBals []*AccountBalance) *State {
@ -64,17 +62,17 @@ func GenesisState(db DB, genesisTime time.Time, accBals []*AccountBalance) *Stat
validatorSet := NewValidatorSet(validators) validatorSet := NewValidatorSet(validators)
return &State{ return &State{
db: db,
height: 0,
blockHash: nil,
commitTime: genesisTime,
accountBalances: accountBalances,
validators: validatorSet,
DB: db,
Height: 0,
BlockHash: nil,
CommitTime: genesisTime,
AccountBalances: accountBalances,
Validators: validatorSet,
} }
} }
func LoadState(db DB) *State { func LoadState(db DB) *State {
s := &State{db: db}
s := &State{DB: db}
buf := db.Get(stateKey) buf := db.Get(stateKey)
if len(buf) == 0 { if len(buf) == 0 {
return nil return nil
@ -82,17 +80,17 @@ func LoadState(db DB) *State {
reader := bytes.NewReader(buf) reader := bytes.NewReader(buf)
var n int64 var n int64
var err error var err error
s.height = ReadUInt32(reader, &n, &err)
s.commitTime = ReadTime(reader, &n, &err)
s.blockHash = ReadByteSlice(reader, &n, &err)
s.Height = ReadUInt32(reader, &n, &err)
s.CommitTime = ReadTime(reader, &n, &err)
s.BlockHash = ReadByteSlice(reader, &n, &err)
accountBalancesHash := ReadByteSlice(reader, &n, &err) accountBalancesHash := ReadByteSlice(reader, &n, &err)
s.accountBalances = merkle.NewTypedTree(merkle.LoadIAVLTreeFromHash(db, accountBalancesHash), BasicCodec, accountBalanceCodec{})
s.AccountBalances = merkle.NewTypedTree(merkle.LoadIAVLTreeFromHash(db, accountBalancesHash), BasicCodec, accountBalanceCodec{})
var validators = map[uint64]*Validator{} var validators = map[uint64]*Validator{}
for reader.Len() > 0 { for reader.Len() > 0 {
validator := ReadValidator(reader, &n, &err) validator := ReadValidator(reader, &n, &err)
validators[validator.Id] = validator validators[validator.Id] = validator
} }
s.validators = NewValidatorSet(validators)
s.Validators = NewValidatorSet(validators)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -104,48 +102,38 @@ func LoadState(db DB) *State {
// For convenience, the commitTime (required by ConsensusAgent) // For convenience, the commitTime (required by ConsensusAgent)
// is saved here. // is saved here.
func (s *State) Save(commitTime time.Time) { func (s *State) Save(commitTime time.Time) {
s.mtx.Lock()
defer s.mtx.Unlock()
s.commitTime = commitTime
s.accountBalances.Tree.Save()
s.CommitTime = commitTime
s.AccountBalances.Tree.Save()
var buf bytes.Buffer var buf bytes.Buffer
var n int64 var n int64
var err error var err error
WriteUInt32(&buf, s.height, &n, &err)
WriteUInt32(&buf, s.Height, &n, &err)
WriteTime(&buf, commitTime, &n, &err) WriteTime(&buf, commitTime, &n, &err)
WriteByteSlice(&buf, s.blockHash, &n, &err)
WriteByteSlice(&buf, s.accountBalances.Tree.Hash(), &n, &err)
for _, validator := range s.validators.Map() {
WriteByteSlice(&buf, s.BlockHash, &n, &err)
WriteByteSlice(&buf, s.AccountBalances.Tree.Hash(), &n, &err)
for _, validator := range s.Validators.Map() {
WriteBinary(&buf, validator, &n, &err) WriteBinary(&buf, validator, &n, &err)
} }
if err != nil { if err != nil {
panic(err) panic(err)
} }
s.db.Set(stateKey, buf.Bytes())
s.DB.Set(stateKey, buf.Bytes())
} }
func (s *State) Copy() *State { func (s *State) Copy() *State {
s.mtx.Lock()
defer s.mtx.Unlock()
return &State{ return &State{
db: s.db,
height: s.height,
commitTime: s.commitTime,
blockHash: s.blockHash,
accountBalances: s.accountBalances.Copy(),
validators: s.validators.Copy(),
DB: s.DB,
Height: s.Height,
CommitTime: s.CommitTime,
BlockHash: s.BlockHash,
AccountBalances: s.AccountBalances.Copy(),
Validators: s.Validators.Copy(),
} }
} }
// If the tx is invalid, an error will be returned. // If the tx is invalid, an error will be returned.
// Unlike AppendBlock(), state will not be altered. // Unlike AppendBlock(), state will not be altered.
func (s *State) ExecTx(tx Tx) error { func (s *State) ExecTx(tx Tx) error {
s.mtx.Lock()
defer s.mtx.Unlock()
return s.execTx(tx)
}
func (s *State) execTx(tx Tx) error {
/* /*
// Get the signer's incr // Get the signer's incr
signerId := tx.Signature().SignerId signerId := tx.Signature().SignerId
@ -161,69 +149,38 @@ func (s *State) execTx(tx Tx) error {
// NOTE: If an error occurs during block execution, state will be left // NOTE: If an error occurs during block execution, state will be left
// at an invalid state. Copy the state before calling Commit! // at an invalid state. Copy the state before calling Commit!
func (s *State) AppendBlock(b *Block) error { func (s *State) AppendBlock(b *Block) error {
s.mtx.Lock()
defer s.mtx.Unlock()
// Basic block validation. // Basic block validation.
err := b.ValidateBasic(s.height, s.blockHash)
err := b.ValidateBasic(s.Height, s.BlockHash)
if err != nil { if err != nil {
return err return err
} }
// Commit each tx // Commit each tx
for _, tx := range b.Data.Txs { for _, tx := range b.Data.Txs {
err := s.execTx(tx)
err := s.ExecTx(tx)
if err != nil { if err != nil {
return err return err
} }
} }
// Increment validator AccumPowers // Increment validator AccumPowers
s.validators.IncrementAccum()
s.Validators.IncrementAccum()
// State hashes should match // State hashes should match
if !bytes.Equal(s.validators.Hash(), b.ValidationStateHash) {
if !bytes.Equal(s.Validators.Hash(), b.ValidationStateHash) {
return ErrStateInvalidValidationStateHash return ErrStateInvalidValidationStateHash
} }
if !bytes.Equal(s.accountBalances.Tree.Hash(), b.AccountStateHash) {
if !bytes.Equal(s.AccountBalances.Tree.Hash(), b.AccountStateHash) {
return ErrStateInvalidAccountStateHash return ErrStateInvalidAccountStateHash
} }
s.height = b.Height
s.blockHash = b.Hash()
s.Height = b.Height
s.BlockHash = b.Hash()
return nil return nil
} }
func (s *State) Height() uint32 {
s.mtx.Lock()
defer s.mtx.Unlock()
return s.height
}
func (s *State) CommitTime() time.Time {
s.mtx.Lock()
defer s.mtx.Unlock()
return s.commitTime
}
// The returned ValidatorSet gets mutated upon s.ExecTx() and s.AppendBlock().
// Caller should copy the returned set before mutating.
func (s *State) Validators() *ValidatorSet {
s.mtx.Lock()
defer s.mtx.Unlock()
return s.validators
}
func (s *State) BlockHash() []byte {
s.mtx.Lock()
defer s.mtx.Unlock()
return s.blockHash
}
func (s *State) AccountBalance(accountId uint64) *AccountBalance { func (s *State) AccountBalance(accountId uint64) *AccountBalance {
s.mtx.Lock()
defer s.mtx.Unlock()
accBal := s.accountBalances.Get(accountId)
accBal := s.AccountBalances.Get(accountId)
if accBal == nil { if accBal == nil {
return nil return nil
} }


+ 12
- 14
state/state_test.go View File

@ -42,10 +42,10 @@ func TestGenesisSaveLoad(t *testing.T) {
// Generate a state, save & load it. // Generate a state, save & load it.
s0 := randGenesisState(10, 5) s0 := randGenesisState(10, 5)
// Figure out what the next state hashes should be. // Figure out what the next state hashes should be.
s0ValsCopy := s0.Validators().Copy()
s0ValsCopy := s0.Validators.Copy()
s0ValsCopy.IncrementAccum() s0ValsCopy.IncrementAccum()
nextValidationStateHash := s0ValsCopy.Hash() nextValidationStateHash := s0ValsCopy.Hash()
nextAccountStateHash := s0.accountBalances.Tree.Hash()
nextAccountStateHash := s0.AccountBalances.Tree.Hash()
// Mutate the state to append one empty block. // Mutate the state to append one empty block.
block := &Block{ block := &Block{
Header: Header{ Header: Header{
@ -66,33 +66,31 @@ func TestGenesisSaveLoad(t *testing.T) {
// Save s0, load s1. // Save s0, load s1.
commitTime := time.Now() commitTime := time.Now()
s0.Save(commitTime) s0.Save(commitTime)
// s0.db.(*MemDB).Print()
s1 := LoadState(s0.db)
//s0.DB.(*MemDB).Print()
s1 := LoadState(s0.DB)
// Compare CommitTime // Compare CommitTime
if commitTime.Unix() != s1.CommitTime().Unix() {
if commitTime.Unix() != s1.CommitTime.Unix() {
t.Error("CommitTime was not the same") t.Error("CommitTime was not the same")
} }
// Compare height & blockHash // Compare height & blockHash
if s0.Height() != 1 {
t.Error("s0 Height should be 1, got", s0.Height())
if s0.Height != 1 {
t.Error("s0 Height should be 1, got", s0.Height)
} }
if s0.Height() != s1.Height() {
if s0.Height != s1.Height {
t.Error("Height mismatch") t.Error("Height mismatch")
} }
if !bytes.Equal(s0.BlockHash(), s1.BlockHash()) {
if !bytes.Equal(s0.BlockHash, s1.BlockHash) {
t.Error("BlockHash mismatch") t.Error("BlockHash mismatch")
} }
// Compare Validators // Compare Validators
s0Vals := s0.Validators()
s1Vals := s1.Validators()
if s0Vals.Size() != s1Vals.Size() {
if s0.Validators.Size() != s1.Validators.Size() {
t.Error("Validators Size changed") t.Error("Validators Size changed")
} }
if s0Vals.TotalVotingPower() == 0 {
if s0.Validators.TotalVotingPower() == 0 {
t.Error("s0 Validators TotalVotingPower should not be 0") t.Error("s0 Validators TotalVotingPower should not be 0")
} }
if s0Vals.TotalVotingPower() != s1Vals.TotalVotingPower() {
if s0.Validators.TotalVotingPower() != s1.Validators.TotalVotingPower() {
t.Error("Validators TotalVotingPower changed") t.Error("Validators TotalVotingPower changed")
} }
// TODO Compare accountBalances, height, blockHash // TODO Compare accountBalances, height, blockHash


Loading…
Cancel
Save