From 8e452aa0d28d1aa230b4d383e28922baf3221496 Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Sun, 14 Sep 2014 15:37:32 -0700 Subject: [PATCH] fixed algorithm --- binary/byteslice.go | 30 +- binary/int.go | 28 + blocks/block.go | 90 +-- blocks/block_part_set.go | 158 ----- blocks/document.go | 46 ++ blocks/signature.go | 30 +- blocks/store.go | 52 +- blocks/tx.go | 52 +- blocks/vote.go | 54 ++ common/bit_array.go | 155 +++++ common/int.go | 20 + consensus/README.md | 3 +- consensus/consensus.go | 1290 ++++++++++++++--------------------- consensus/document.go | 41 -- consensus/part_set.go | 177 +++++ consensus/pol.go | 98 +++ consensus/priv_validator.go | 25 +- consensus/proposal.go | 64 ++ consensus/state.go | 483 +++++++++---- consensus/vote.go | 184 ----- consensus/vote_set.go | 180 +++++ main.go | 60 +- mempool/agent.go | 169 ----- mempool/mempool.go | 2 +- mempool/reactor.go | 141 ++++ merkle/types.go | 4 + merkle/util.go | 236 ++++++- p2p/README.md | 114 +--- p2p/connection.go | 44 +- p2p/peer.go | 25 +- p2p/pex_agent.go | 278 -------- p2p/pex_reactor.go | 267 ++++++++ p2p/switch.go | 117 ++-- state/state.go | 5 +- state/validator.go | 74 +- 35 files changed, 2569 insertions(+), 2227 deletions(-) delete mode 100644 blocks/block_part_set.go create mode 100644 blocks/document.go create mode 100644 blocks/vote.go create mode 100644 common/bit_array.go create mode 100644 common/int.go delete mode 100644 consensus/document.go create mode 100644 consensus/part_set.go create mode 100644 consensus/pol.go create mode 100644 consensus/proposal.go delete mode 100644 consensus/vote.go create mode 100644 consensus/vote_set.go delete mode 100644 mempool/agent.go create mode 100644 mempool/reactor.go delete mode 100644 p2p/pex_agent.go create mode 100644 p2p/pex_reactor.go diff --git a/binary/byteslice.go b/binary/byteslice.go index e3ca14458..e31a5b1e6 100644 --- a/binary/byteslice.go +++ b/binary/byteslice.go @@ -4,8 +4,6 @@ import ( "io" ) -// ByteSlice - func WriteByteSlice(w io.Writer, bz []byte, n *int64, err *error) { WriteUInt32(w, uint32(len(bz)), n, err) WriteTo(w, bz, n, err) @@ -20,3 +18,31 @@ func ReadByteSlice(r io.Reader, n *int64, err *error) []byte { ReadFull(r, buf, n, err) return buf } + +//----------------------------------------------------------------------------- + +func WriteByteSlices(w io.Writer, bzz [][]byte, n *int64, err *error) { + WriteUInt32(w, uint32(len(bzz)), n, err) + for _, bz := range bzz { + WriteByteSlice(w, bz, n, err) + if *err != nil { + return + } + } +} + +func ReadByteSlices(r io.Reader, n *int64, err *error) [][]byte { + length := ReadUInt32(r, n, err) + if *err != nil { + return nil + } + bzz := make([][]byte, length) + for i := uint32(0); i < length; i++ { + bz := ReadByteSlice(r, n, err) + if *err != nil { + return nil + } + bzz[i] = bz + } + return bzz +} diff --git a/binary/int.go b/binary/int.go index 85bcbde5e..4ff9d7cb4 100644 --- a/binary/int.go +++ b/binary/int.go @@ -68,6 +68,34 @@ func ReadUInt16(r io.Reader, n *int64, err *error) uint16 { return uint16(binary.LittleEndian.Uint16(buf)) } +// []UInt16 + +func WriteUInt16s(w io.Writer, iz []uint16, n *int64, err *error) { + WriteUInt32(w, uint32(len(iz)), n, err) + for _, i := range iz { + WriteUInt16(w, i, n, err) + if *err != nil { + return + } + } +} + +func ReadUInt16s(r io.Reader, n *int64, err *error) []uint16 { + length := ReadUInt32(r, n, err) + if *err != nil { + return nil + } + iz := make([]uint16, length) + for j := uint32(0); j < length; j++ { + ii := ReadUInt16(r, n, err) + if *err != nil { + return nil + } + iz[j] = ii + } + return iz +} + // Int32 func WriteInt32(w io.Writer, i int32, n *int64, err *error) { diff --git a/blocks/block.go b/blocks/block.go index d9f6d4f43..bcb5d6d6a 100644 --- a/blocks/block.go +++ b/blocks/block.go @@ -8,15 +8,10 @@ import ( "time" . "github.com/tendermint/tendermint/binary" - . "github.com/tendermint/tendermint/common" . "github.com/tendermint/tendermint/config" "github.com/tendermint/tendermint/merkle" ) -const ( - defaultBlockPartSizeBytes = 4096 -) - var ( ErrBlockInvalidNetwork = errors.New("Error block invalid network") ErrBlockInvalidBlockHeight = errors.New("Error block invalid height") @@ -72,30 +67,21 @@ func (b *Block) Hash() []byte { b.Data.Hash(), } // Merkle hash from sub-hashes. - return merkle.HashFromByteSlices(hashes) + return merkle.HashFromHashes(hashes) } } -// The returns parts must be signed afterwards. -func (b *Block) ToBlockPartSet() *BlockPartSet { - var parts []*BlockPart - blockBytes := BinaryBytes(b) - total := (len(blockBytes) + defaultBlockPartSizeBytes - 1) / defaultBlockPartSizeBytes - for i := 0; i < total; i++ { - start := defaultBlockPartSizeBytes * i - end := MinInt(start+defaultBlockPartSizeBytes, len(blockBytes)) - partBytes := make([]byte, end-start) - copy(partBytes, blockBytes[start:end]) // Do not ref the original byteslice. - part := &BlockPart{ - Height: b.Height, - Index: uint16(i), - Total: uint16(total), - Bytes: partBytes, - Signature: Signature{}, // No signature. - } - parts = append(parts, part) +// Convenience. +// A nil block never hashes to anything. +// Nothing hashes to a nil hash. +func (b *Block) HashesTo(hash []byte) bool { + if len(hash) == 0 { + return false } - return NewBlockPartSet(b.Height, parts) + if b == nil { + return false + } + return bytes.Equal(b.Hash(), hash) } // Makes an empty next block. @@ -115,58 +101,6 @@ func (b *Block) MakeNextBlock() *Block { //----------------------------------------------------------------------------- -/* -BlockPart represents a chunk of the bytes of a block. -Each block is divided into fixed length chunks (e.g. 4Kb) -for faster propagation across the gossip network. -*/ -type BlockPart struct { - Height uint32 - Round uint16 // Add Round? Well I need to know... - Index uint16 - Total uint16 - Bytes []byte - Signature - - // Volatile - hash []byte -} - -func ReadBlockPart(r io.Reader, n *int64, err *error) *BlockPart { - return &BlockPart{ - Height: ReadUInt32(r, n, err), - Round: ReadUInt16(r, n, err), - Index: ReadUInt16(r, n, err), - Total: ReadUInt16(r, n, err), - Bytes: ReadByteSlice(r, n, err), - Signature: ReadSignature(r, n, err), - } -} - -func (bp *BlockPart) WriteTo(w io.Writer) (n int64, err error) { - WriteUInt32(w, bp.Height, &n, &err) - WriteUInt16(w, bp.Round, &n, &err) - WriteUInt16(w, bp.Index, &n, &err) - WriteUInt16(w, bp.Total, &n, &err) - WriteByteSlice(w, bp.Bytes, &n, &err) - WriteBinary(w, bp.Signature, &n, &err) - return -} - -// Hash returns the hash of the block part data bytes. -func (bp *BlockPart) Hash() []byte { - if bp.hash != nil { - return bp.hash - } else { - hasher := sha256.New() - hasher.Write(bp.Bytes) - bp.hash = hasher.Sum(nil) - return bp.hash - } -} - -//----------------------------------------------------------------------------- - type Header struct { Network string Height uint32 @@ -296,7 +230,7 @@ func (data *Data) Hash() []byte { for i, tx := range data.Txs { bs[i] = Binary(tx) } - data.hash = merkle.HashFromBinarySlice(bs) + data.hash = merkle.HashFromBinaries(bs) return data.hash } } diff --git a/blocks/block_part_set.go b/blocks/block_part_set.go deleted file mode 100644 index d97f4e8d3..000000000 --- a/blocks/block_part_set.go +++ /dev/null @@ -1,158 +0,0 @@ -package blocks - -import ( - "bytes" - "errors" - "sync" - - "github.com/tendermint/tendermint/merkle" -) - -// A collection of block parts. -// Doesn't do any validation. -type BlockPartSet struct { - mtx sync.Mutex - height uint32 - total uint16 // total number of parts - numParts uint16 // number of parts in this set - parts []*BlockPart - - _block *Block // cache -} - -var ( - ErrInvalidBlockPartConflict = errors.New("Invalid block part conflict") // Signer signed conflicting parts -) - -// parts may be nil if the parts aren't in hand. -func NewBlockPartSet(height uint32, parts []*BlockPart) *BlockPartSet { - bps := &BlockPartSet{ - height: height, - parts: parts, - numParts: uint16(len(parts)), - } - if len(parts) > 0 { - bps.total = parts[0].Total - } - return bps -} - -func (bps *BlockPartSet) Height() uint32 { - return bps.height -} - -func (bps *BlockPartSet) BlockParts() []*BlockPart { - bps.mtx.Lock() - defer bps.mtx.Unlock() - return bps.parts -} - -func (bps *BlockPartSet) BitArray() []byte { - bps.mtx.Lock() - defer bps.mtx.Unlock() - if bps.parts == nil { - return nil - } - bitArray := make([]byte, (len(bps.parts)+7)/8) - for i, part := range bps.parts { - if part != nil { - bitArray[i/8] |= 1 << uint(i%8) - } - } - return bitArray -} - -// If the part isn't valid, returns an error. -// err can be ErrInvalidBlockPartConflict -// NOTE: Caller must check the signature before adding. -func (bps *BlockPartSet) AddBlockPart(part *BlockPart) (added bool, err error) { - bps.mtx.Lock() - defer bps.mtx.Unlock() - - if bps.parts == nil { - // First received part for this round. - bps.parts = make([]*BlockPart, part.Total) - bps.total = uint16(part.Total) - bps.parts[int(part.Index)] = part - bps.numParts++ - return true, nil - } else { - // Check part.Index and part.Total - if uint16(part.Index) >= bps.total { - return false, ErrInvalidBlockPartConflict - } - if uint16(part.Total) != bps.total { - return false, ErrInvalidBlockPartConflict - } - // Check for existing parts. - existing := bps.parts[part.Index] - if existing != nil { - if bytes.Equal(existing.Bytes, part.Bytes) { - // Ignore duplicate - return false, nil - } else { - return false, ErrInvalidBlockPartConflict - } - } else { - bps.parts[int(part.Index)] = part - bps.numParts++ - return true, nil - } - } - -} - -func (bps *BlockPartSet) IsComplete() bool { - bps.mtx.Lock() - defer bps.mtx.Unlock() - return bps.total > 0 && bps.total == bps.numParts -} - -func (bps *BlockPartSet) Block() *Block { - if !bps.IsComplete() { - return nil - } - bps.mtx.Lock() - defer bps.mtx.Unlock() - if bps._block == nil { - block, err := BlockPartsToBlock(bps.parts) - if err != nil { - panic(err) - } - bps._block = block - } - return bps._block -} - -func (bps *BlockPartSet) Hash() []byte { - if !bps.IsComplete() { - panic("Cannot get hash of an incomplete BlockPartSet") - } - hashes := [][]byte{} - for _, part := range bps.parts { - partHash := part.Hash() - hashes = append(hashes, partHash) - } - return merkle.HashFromByteSlices(hashes) -} - -// The proposal hash includes both the block hash -// as well as the BlockPartSet merkle hash. -func (bps *BlockPartSet) ProposalHash() []byte { - bpsHash := bps.Hash() - blockHash := bps.Block().Hash() - return merkle.HashFromByteSlices([][]byte{bpsHash, blockHash}) -} - -//----------------------------------------------------------------------------- - -func BlockPartsToBlock(parts []*BlockPart) (*Block, error) { - blockBytes := []byte{} - for _, part := range parts { - blockBytes = append(blockBytes, part.Bytes...) - } - var n int64 - var err error - block := ReadBlock(bytes.NewReader(blockBytes), &n, &err) - return block, err -} diff --git a/blocks/document.go b/blocks/document.go new file mode 100644 index 000000000..658751bc3 --- /dev/null +++ b/blocks/document.go @@ -0,0 +1,46 @@ +package blocks + +import ( + "fmt" + . "github.com/tendermint/tendermint/config" +) + +func GenVoteDocument(voteType byte, height uint32, round uint16, blockHash []byte) []byte { + stepName := "" + switch voteType { + case VoteTypeBare: + stepName = "bare" + case VoteTypePrecommit: + stepName = "precommit" + case VoteTypeCommit: + stepName = "commit" + default: + panic("Unknown vote type") + } + return []byte(fmt.Sprintf( + `!!!!!BEGIN TENDERMINT VOTE!!!!! +Network: %v +Height: %v +Round: %v +Step: %v +BlockHash: %v +!!!!!END TENDERMINT VOTE!!!!!`, + Config.Network, height, round, stepName, blockHash, + )) +} + +func GenProposalDocument(height uint32, round uint16, blockPartsTotal uint16, blockPartsHash []byte, + polPartsTotal uint16, polPartsHash []byte) []byte { + return []byte(fmt.Sprintf( + `!!!!!BEGIN TENDERMINT PROPOSAL!!!!! +Network: %v +Height: %v +Round: %v +BlockPartsTotal: %v +BlockPartsHash: %X +POLPartsTotal: %v +POLPartsHash: %X +!!!!!END TENDERMINT PROPOSAL!!!!!`, + Config.Network, height, round, blockPartsTotal, blockPartsHash, polPartsTotal, polPartsHash, + )) +} diff --git a/blocks/signature.go b/blocks/signature.go index 86f911efa..6bfa28b82 100644 --- a/blocks/signature.go +++ b/blocks/signature.go @@ -5,18 +5,6 @@ import ( "io" ) -/* -Signature message wire format: - - |a...|sss...| - - a Account number, varint encoded (1+ bytes) - s Signature of all prior bytes (32 bytes) - -It usually follows the message to be signed. - -*/ - type Signature struct { SignerId uint64 Bytes []byte @@ -38,3 +26,21 @@ func (sig Signature) WriteTo(w io.Writer) (n int64, err error) { WriteByteSlice(w, sig.Bytes, &n, &err) return } + +func ReadSignatures(r io.Reader, n *int64, err *error) (sigs []Signature) { + length := ReadUInt32(r, n, err) + for i := uint32(0); i < length; i++ { + sigs = append(sigs, ReadSignature(r, n, err)) + } + return +} + +func WriteSignatures(w io.Writer, sigs []Signature, n *int64, err *error) { + WriteUInt32(w, uint32(len(sigs)), n, err) + for _, sig := range sigs { + WriteBinary(w, sig, n, err) + if *err != nil { + return + } + } +} diff --git a/blocks/store.go b/blocks/store.go index 54a0f3a2f..c56267f6c 100644 --- a/blocks/store.go +++ b/blocks/store.go @@ -70,54 +70,27 @@ func (bs *BlockStore) Height() uint32 { return bs.height } -// LoadBlockPart loads a part of a block. -func (bs *BlockStore) LoadBlockPart(height uint32, index uint16) *BlockPart { - partBytes, err := bs.db.Get(calcBlockPartKey(height, index), nil) +func (bs *BlockStore) LoadBlock(height uint32) *Block { + blockBytes, err := bs.db.Get(calcBlockKey(height), nil) if err != nil { - Panicf("Could not load block part: %v", err) + Panicf("Could not load block: %v", err) } - if partBytes == nil { + if blockBytes == nil { return nil } var n int64 - return ReadBlockPart(bytes.NewReader(partBytes), &n, &err) + return ReadBlock(bytes.NewReader(blockBytes), &n, &err) } -// Convenience method for loading block parts and merging to a block. -func (bs *BlockStore) LoadBlock(height uint32) *Block { - // Get the first part. - part0 := bs.LoadBlockPart(height, 0) - if part0 == nil { - return nil - } - parts := []*BlockPart{part0} - for i := uint16(1); i < part0.Total; i++ { - part := bs.LoadBlockPart(height, i) - if part == nil { - Panicf("Failed to retrieve block part %v at height %v", i, height) - } - parts = append(parts, part) - } - block, err := BlockPartsToBlock(parts) - if err != nil { - panic(err) - } - return block -} - -// NOTE: Assumes that parts as well as the block are valid. See StageBlockParts(). // Writes are synchronous and atomic. -func (bs *BlockStore) SaveBlockParts(height uint32, parts []*BlockPart) error { +func (bs *BlockStore) SaveBlock(block *Block) error { + height := block.Height if height != bs.height+1 { return Errorf("BlockStore can only save contiguous blocks. Wanted %v, got %v", bs.height+1, height) } - // Save parts - batch := new(leveldb.Batch) - for _, part := range parts { - partBytes := BinaryBytes(part) - batch.Put(calcBlockPartKey(uint32(part.Height), uint16(part.Index)), partBytes) - } - err := bs.db.Write(batch, &opt.WriteOptions{Sync: true}) + // Save block + blockBytes := BinaryBytes(block) + err := bs.db.Put(calcBlockKey(height), blockBytes, &opt.WriteOptions{Sync: true}) // Save new BlockStoreJSON descriptor BlockStoreJSON{Height: height}.Save(bs.db) return err @@ -125,9 +98,8 @@ func (bs *BlockStore) SaveBlockParts(height uint32, parts []*BlockPart) error { //----------------------------------------------------------------------------- -func calcBlockPartKey(height uint32, index uint16) []byte { - buf := [11]byte{'B'} +func calcBlockKey(height uint32) []byte { + buf := [9]byte{'B'} binary.BigEndian.PutUint32(buf[1:9], height) - binary.BigEndian.PutUint16(buf[9:11], index) return buf[:] } diff --git a/blocks/tx.go b/blocks/tx.go index bd7d3adc0..2df41b428 100644 --- a/blocks/tx.go +++ b/blocks/tx.go @@ -7,17 +7,6 @@ import ( ) /* - -Tx wire format: - - |T|L...|MMM...|A...|SSS...| - - T type of the tx (1 byte) - L length of M, varint encoded (1+ bytes) - M Tx bytes (L bytes) - A account number, varint encoded (1+ bytes) - S signature of all prior bytes (32 bytes) - Account Txs: 1. Send Send coins to account 2. Name Associate account with a name @@ -27,8 +16,6 @@ Validation Txs: 4. Unbond Validator leaves 5. Timeout Validator times out 6. Dupeout Validator dupes out (signs twice) - - */ type Tx interface { @@ -89,8 +76,8 @@ func ReadTx(r io.Reader, n *int64, err *error) Tx { case TX_TYPE_DUPEOUT: return &DupeoutTx{ BaseTx: ReadBaseTx(r, n, err), - VoteA: *ReadBlockVote(r, n, err), - VoteB: *ReadBlockVote(r, n, err), + VoteA: *ReadVote(r, n, err), + VoteB: *ReadVote(r, n, err), } default: Panicf("Unknown Tx type %x", t) @@ -234,37 +221,10 @@ func (tx *TimeoutTx) WriteTo(w io.Writer) (n int64, err error) { //----------------------------------------------------------------------------- -/* -The full vote structure is only needed when presented as evidence. -Typically only the signature is passed around, as the hash & height are implied. -*/ -type BlockVote struct { - Height uint64 - BlockHash []byte - Signature -} - -func ReadBlockVote(r io.Reader, n *int64, err *error) *BlockVote { - return &BlockVote{ - Height: ReadUInt64(r, n, err), - BlockHash: ReadByteSlice(r, n, err), - Signature: ReadSignature(r, n, err), - } -} - -func (tx BlockVote) WriteTo(w io.Writer) (n int64, err error) { - WriteUInt64(w, tx.Height, &n, &err) - WriteByteSlice(w, tx.BlockHash, &n, &err) - WriteBinary(w, tx.Signature, &n, &err) - return -} - -//----------------------------------------------------------------------------- - type DupeoutTx struct { BaseTx - VoteA BlockVote - VoteB BlockVote + VoteA Vote + VoteB Vote } func (tx *DupeoutTx) Type() byte { @@ -274,7 +234,7 @@ func (tx *DupeoutTx) Type() byte { func (tx *DupeoutTx) WriteTo(w io.Writer) (n int64, err error) { WriteByte(w, tx.Type(), &n, &err) WriteBinary(w, &tx.BaseTx, &n, &err) - WriteBinary(w, tx.VoteA, &n, &err) - WriteBinary(w, tx.VoteB, &n, &err) + WriteBinary(w, &tx.VoteA, &n, &err) + WriteBinary(w, &tx.VoteB, &n, &err) return } diff --git a/blocks/vote.go b/blocks/vote.go new file mode 100644 index 000000000..649476331 --- /dev/null +++ b/blocks/vote.go @@ -0,0 +1,54 @@ +package blocks + +import ( + "errors" + "io" + + . "github.com/tendermint/tendermint/binary" +) + +const ( + VoteTypeBare = byte(0x00) + VoteTypePrecommit = byte(0x01) + VoteTypeCommit = byte(0x02) +) + +var ( + ErrVoteUnexpectedPhase = errors.New("Unexpected phase") + ErrVoteInvalidAccount = errors.New("Invalid round vote account") + ErrVoteInvalidSignature = errors.New("Invalid round vote signature") + ErrVoteInvalidBlockHash = errors.New("Invalid block hash") + ErrVoteConflictingSignature = errors.New("Conflicting round vote signature") +) + +// Represents a bare, precommit, or commit vote for proposals. +type Vote struct { + Height uint32 + Round uint16 + Type byte + BlockHash []byte // empty if vote is nil. + Signature +} + +func ReadVote(r io.Reader, n *int64, err *error) *Vote { + return &Vote{ + Height: ReadUInt32(r, n, err), + Round: ReadUInt16(r, n, err), + Type: ReadByte(r, n, err), + BlockHash: ReadByteSlice(r, n, err), + Signature: ReadSignature(r, n, err), + } +} + +func (v *Vote) WriteTo(w io.Writer) (n int64, err error) { + WriteUInt32(w, v.Height, &n, &err) + WriteUInt16(w, v.Round, &n, &err) + WriteByte(w, v.Type, &n, &err) + WriteByteSlice(w, v.BlockHash, &n, &err) + WriteBinary(w, v.Signature, &n, &err) + return +} + +func (v *Vote) GenDocument() []byte { + return GenVoteDocument(v.Type, v.Height, v.Round, v.BlockHash) +} diff --git a/common/bit_array.go b/common/bit_array.go new file mode 100644 index 000000000..d771c4c39 --- /dev/null +++ b/common/bit_array.go @@ -0,0 +1,155 @@ +package common + +import ( + "io" + "math" + "math/rand" + + . "github.com/tendermint/tendermint/binary" +) + +// Not goroutine safe +type BitArray []uint64 + +func NewBitArray(length uint) BitArray { + return BitArray(make([]uint64, (length+63)/64)) +} + +func ReadBitArray(r io.Reader, n *int64, err *error) BitArray { + lengthTotal := ReadUInt32(r, n, err) + lengthWritten := ReadUInt32(r, n, err) + if *err != nil { + return nil + } + buf := make([]uint64, int(lengthTotal)) + for i := uint32(0); i < lengthWritten; i++ { + buf[i] = ReadUInt64(r, n, err) + if err != nil { + return nil + } + } + return BitArray(buf) +} + +func (bA BitArray) WriteTo(w io.Writer) (n int64, err error) { + // Count the last element > 0. + lastNonzeroIndex := -1 + for i, elem := range bA { + if elem > 0 { + lastNonzeroIndex = i + } + } + WriteUInt32(w, uint32(len(bA)), &n, &err) + WriteUInt32(w, uint32(lastNonzeroIndex+1), &n, &err) + for i, elem := range bA { + if i > lastNonzeroIndex { + break + } + WriteUInt64(w, elem, &n, &err) + } + return +} + +func (bA BitArray) GetIndex(i uint) bool { + return bA[i/64]&uint64(1<<(i%64)) > 0 +} + +func (bA BitArray) SetIndex(i uint, v bool) { + if v { + bA[i/64] |= uint64(1 << (i % 64)) + } else { + bA[i/64] &= ^uint64(1 << (i % 64)) + } +} + +func (bA BitArray) Copy() BitArray { + c := make([]uint64, len(bA)) + copy(c, bA) + return BitArray(c) +} + +func (bA BitArray) Or(o BitArray) BitArray { + c := bA.Copy() + for i, _ := range c { + c[i] = o[i] | c[i] + } + return c +} + +func (bA BitArray) And(o BitArray) BitArray { + c := bA.Copy() + for i, _ := range c { + c[i] = o[i] & c[i] + } + return c +} + +func (bA BitArray) Not() BitArray { + c := bA.Copy() + for i, _ := range c { + c[i] = ^c[i] + } + return c +} + +func (bA BitArray) Sub(o BitArray) BitArray { + return bA.And(o.Not()) +} + +// NOTE: returns counts or a longer int slice as necessary. +func (bA BitArray) AddToCounts(counts []int) []int { + for bytei := 0; bytei < len(bA); bytei++ { + for biti := 0; biti < 64; biti++ { + if (bA[bytei] & (1 << uint(biti))) == 0 { + continue + } + index := 64*bytei + biti + if len(counts) <= index { + counts = append(counts, make([]int, (index-len(counts)+1))...) + } + counts[index]++ + } + } + return counts +} + +func (bA BitArray) PickRandom() (int, bool) { + randStart := rand.Intn(len(bA)) + for i := 0; i < len(bA); i++ { + bytei := ((i + randStart) % len(bA)) + if bA[bytei] > 0 { + randBitStart := rand.Intn(64) + for j := 0; j < 64; j++ { + biti := ((j + randBitStart) % 64) + //fmt.Printf("%X %v %v %v\n", iHas, j, biti, randBitStart) + if (bA[bytei] & (1 << uint(biti))) > 0 { + return 64*int(bytei) + int(biti), true + } + } + panic("should not happen") + } + } + return 0, false +} + +// Pick an index from this BitArray that is 1 && whose count is lowest. +func (bA BitArray) PickRarest(counts []int) (rarest int, ok bool) { + smallestCount := math.MaxInt32 + for bytei := 0; bytei < len(bA); bytei++ { + if bA[bytei] > 0 { + for biti := 0; biti < 64; biti++ { + if (bA[bytei] & (1 << uint(biti))) == 0 { + continue + } + index := 64*bytei + biti + if counts[index] < smallestCount { + smallestCount = counts[index] + rarest = index + ok = true + } + } + panic("should not happen") + } + } + return +} diff --git a/common/int.go b/common/int.go new file mode 100644 index 000000000..b8a25dbc8 --- /dev/null +++ b/common/int.go @@ -0,0 +1,20 @@ +package common + +import ( + "sort" +) + +// Sort for []uint64 + +type UInt64Slice []uint64 + +func (p UInt64Slice) Len() int { return len(p) } +func (p UInt64Slice) Less(i, j int) bool { return p[i] < p[j] } +func (p UInt64Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +func (p UInt64Slice) Sort() { sort.Sort(p) } + +func SearchUInt64s(a []uint64, x uint64) int { + return sort.Search(len(a), func(i int) bool { return a[i] >= x }) +} + +func (p UInt64Slice) Search(x uint64) int { return SearchUInt64s(p, x) } diff --git a/consensus/README.md b/consensus/README.md index 82df15f5a..de01bdaa2 100644 --- a/consensus/README.md +++ b/consensus/README.md @@ -1,4 +1,4 @@ -## Determining the order of proposers at height h: +## Determining the order of proposers at height h ``` Determining the order of proposers at height h: @@ -33,7 +33,6 @@ round R in the consensus rounds at height h (the parent block). We omit details of dealing with membership changes. ``` - ## Zombie Validators The most likely scenario may be during an upgrade. diff --git a/consensus/consensus.go b/consensus/consensus.go index 1b9293767..1205ba490 100644 --- a/consensus/consensus.go +++ b/consensus/consensus.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "math" - "math/rand" "sync" "sync/atomic" "time" @@ -20,9 +19,11 @@ import ( ) const ( - ProposalCh = byte(0x20) - KnownPartsCh = byte(0x21) - VoteCh = byte(0x22) + StateCh = byte(0x20) + DataCh = byte(0x21) + VoteCh = byte(0x22) + + peerStateKey = "ConsensusReactor.peerState" voteTypeNil = byte(0x00) voteTypeBlock = byte(0x01) @@ -32,9 +33,9 @@ const ( roundDeadlineBare = float64(1.0 / 3.0) // When the bare vote is due. roundDeadlinePrecommit = float64(2.0 / 3.0) // When the precommit vote is due. - newBlockWaitDuration = roundDuration0 / 3 // The time to wait between commitTime and startTime of next consensus rounds. - voteRankCutoff = 2 // Higher ranks --> do not send votes. - unsolicitedVoteRate = 0.01 // Probability of sending a high ranked vote. + newBlockWaitDuration = roundDuration0 / 3 // The time to wait between commitTime and startTime of next consensus rounds. + peerGossipSleepDuration = 50 * time.Millisecond // Time to sleep if there's nothing to send. + hasVotesThreshold = 50 // After this many new votes we'll send a HasVotesMessage. ) //----------------------------------------------------------------------------- @@ -91,620 +92,455 @@ func calcRoundInfo(startTime time.Time) (round uint16, roundStartTime time.Time, //----------------------------------------------------------------------------- -type ConsensusAgent struct { - sw *p2p.Switch - swEvents chan interface{} - quit chan struct{} - started uint32 - stopped uint32 +type ConsensusReactor struct { + sw *p2p.Switch + quit chan struct{} + started uint32 + stopped uint32 conS *ConsensusState - blockStore *BlockStore - mempool *Mempool doActionCh chan RoundAction - - mtx sync.Mutex - state *State - privValidator *PrivValidator - peerStates map[string]*PeerState - stagedProposal *BlockPartSet - stagedState *State } -func NewConsensusAgent(sw *p2p.Switch, blockStore *BlockStore, mempool *Mempool, state *State) *ConsensusAgent { - swEvents := make(chan interface{}) - sw.AddEventListener("ConsensusAgent.swEvents", swEvents) - conS := NewConsensusState(state) - conA := &ConsensusAgent{ - sw: sw, - swEvents: swEvents, - quit: make(chan struct{}), +func NewConsensusReactor(sw *p2p.Switch, blockStore *BlockStore, mempool *Mempool, state *State) *ConsensusReactor { + conS := NewConsensusState(state, blockStore, mempool) + conR := &ConsensusReactor{ + sw: sw, + quit: make(chan struct{}), conS: conS, - blockStore: blockStore, - mempool: mempool, doActionCh: make(chan RoundAction, 1), - - state: state, - peerStates: make(map[string]*PeerState), } - return conA + return conR } // Sets our private validator account for signing votes. -func (conA *ConsensusAgent) SetPrivValidator(priv *PrivValidator) { - conA.mtx.Lock() - defer conA.mtx.Unlock() - conA.privValidator = priv +func (conR *ConsensusReactor) SetPrivValidator(priv *PrivValidator) { + conR.conS.SetPrivValidator(priv) } -func (conA *ConsensusAgent) PrivValidator() *PrivValidator { - conA.mtx.Lock() - defer conA.mtx.Unlock() - return conA.privValidator +func (conR *ConsensusReactor) Start() { + if atomic.CompareAndSwapUint32(&conR.started, 0, 1) { + log.Info("Starting ConsensusReactor") + go conR.proposeAndVoteRoutine() + } } -func (conA *ConsensusAgent) Start() { - if atomic.CompareAndSwapUint32(&conA.started, 0, 1) { - log.Info("Starting ConsensusAgent") - go conA.switchEventsRoutine() - go conA.gossipProposalRoutine() - go conA.knownPartsRoutine() - go conA.gossipVoteRoutine() - go conA.proposeAndVoteRoutine() +func (conR *ConsensusReactor) Stop() { + if atomic.CompareAndSwapUint32(&conR.stopped, 0, 1) { + log.Info("Stopping ConsensusReactor") + close(conR.quit) } } -func (conA *ConsensusAgent) Stop() { - if atomic.CompareAndSwapUint32(&conA.stopped, 0, 1) { - log.Info("Stopping ConsensusAgent") - close(conA.quit) - close(conA.swEvents) - } +func (conR *ConsensusReactor) IsStopped() bool { + return atomic.LoadUint32(&conR.stopped) == 1 } -// Handle peer new/done events -func (conA *ConsensusAgent) switchEventsRoutine() { - for { - swEvent, ok := <-conA.swEvents - if !ok { - break - } - switch swEvent.(type) { - case p2p.SwitchEventNewPeer: - event := swEvent.(p2p.SwitchEventNewPeer) - // Create peerState for event.Peer - conA.mtx.Lock() - conA.peerStates[event.Peer.Key] = NewPeerState(event.Peer) - conA.mtx.Unlock() - // Share our state with event.Peer - // By sending KnownBlockPartsMessage, - // we send our height/round + startTime, and known block parts, - // which is sufficient for the peer to begin interacting with us. - event.Peer.TrySend(ProposalCh, conA.makeKnownBlockPartsMessage(conA.conS.RoundState())) - case p2p.SwitchEventDonePeer: - event := swEvent.(p2p.SwitchEventDonePeer) - // Delete peerState for event.Peer - conA.mtx.Lock() - peerState := conA.peerStates[event.Peer.Key] - if peerState != nil { - peerState.Disconnect() - delete(conA.peerStates, event.Peer.Key) - } - conA.mtx.Unlock() - default: - log.Warning("Unhandled switch event type") - } +// Implements Reactor +func (conR *ConsensusReactor) GetChannels() []*p2p.ChannelDescriptor { + // TODO optimize + return []*p2p.ChannelDescriptor{ + &p2p.ChannelDescriptor{ + Id: StateCh, + SendQueueCapacity: 1, + RecvQueueCapacity: 10, + RecvBufferSize: 10240, + DefaultPriority: 5, + }, + &p2p.ChannelDescriptor{ + Id: DataCh, + SendQueueCapacity: 1, + RecvQueueCapacity: 10, + RecvBufferSize: 10240, + DefaultPriority: 5, + }, + &p2p.ChannelDescriptor{ + Id: VoteCh, + SendQueueCapacity: 1, + RecvQueueCapacity: 1000, + RecvBufferSize: 10240, + DefaultPriority: 5, + }, } } -// Like, how large is it and how often can we send it? -func (conA *ConsensusAgent) makeKnownBlockPartsMessage(rs *RoundState) *KnownBlockPartsMessage { - return &KnownBlockPartsMessage{ - Height: rs.Height, - SecondsSinceStartTime: uint32(time.Now().Sub(rs.StartTime).Seconds()), - BlockPartsBitArray: rs.Proposal.BitArray(), - } +// Implements Reactor +func (conR *ConsensusReactor) AddPeer(peer *p2p.Peer) { + // Create peerState for peer + peerState := NewPeerState(peer) + peer.Data.Set(peerStateKey, peerState) + + // Begin gossip routines for this peer. + go conR.gossipDataRoutine(peer, peerState) + go conR.gossipVotesRoutine(peer, peerState) } -// NOTE: may return nil, but (nil).Wants*() returns false. -func (conA *ConsensusAgent) getPeerState(peer *p2p.Peer) *PeerState { - conA.mtx.Lock() - defer conA.mtx.Unlock() - return conA.peerStates[peer.Key] +// Implements Reactor +func (conR *ConsensusReactor) RemovePeer(peer *p2p.Peer, reason interface{}) { + //peer.Data.Get(peerStateKey).(*PeerState).Disconnect() } -func (conA *ConsensusAgent) gossipProposalRoutine() { -OUTER_LOOP: - for { - // Get round state - rs := conA.conS.RoundState() +// Implements Reactor +func (conR *ConsensusReactor) Receive(chId byte, peer *p2p.Peer, msgBytes []byte) { - // Receive incoming message on ProposalCh - inMsg, ok := conA.sw.Receive(ProposalCh) - if !ok { - break OUTER_LOOP // Client has stopped - } - _, msg_ := decodeMessage(inMsg.Bytes) - log.Info("gossipProposalRoutine received %v", msg_) + // Get round state + rs := conR.conS.GetRoundState() + ps := peer.Data.Get(peerStateKey).(*PeerState) + _, msg_ := decodeMessage(msgBytes) + voteAddCounter := 0 + var err error = nil + switch chId { + case StateCh: switch msg_.(type) { - case *BlockPartMessage: - msg := msg_.(*BlockPartMessage) + case *NewRoundStepMessage: + msg := msg_.(*NewRoundStepMessage) + err = ps.ApplyNewRoundStepMessage(msg) - // Add the block part if the height matches. - if msg.BlockPart.Height == rs.Height && - msg.BlockPart.Round == rs.Round { + case *HasVotesMessage: + msg := msg_.(*HasVotesMessage) + err = ps.ApplyHasVotesMessage(msg) - // TODO Continue if we've already voted, then no point processing the part. + default: + // Ignore unknown message + } - // Check that the signature is valid and from proposer. - if rs.Proposer.Verify(msg.BlockPart.Hash(), msg.BlockPart.Signature) { - // TODO handle bad peer. - continue OUTER_LOOP - } + case DataCh: + switch msg_.(type) { + case *Proposal: + proposal := msg_.(*Proposal) + ps.SetHasProposal(proposal.Height, proposal.Round) + err = conR.conS.SetProposal(proposal) + + case *PartMessage: + msg := msg_.(*PartMessage) + if msg.Type == partTypeProposalBlock { + ps.SetHasProposalBlockPart(msg.Height, msg.Round, msg.Part.Index) + _, err = conR.conS.AddProposalBlockPart(msg.Height, msg.Round, msg.Part) + } else if msg.Type == partTypeProposalPOL { + ps.SetHasProposalPOLPart(msg.Height, msg.Round, msg.Part.Index) + _, err = conR.conS.AddProposalPOLPart(msg.Height, msg.Round, msg.Part) + } else { + // Ignore unknown part type + } - // If we are the proposer, then don't do anything else. - // We're already sending peers our proposal on another routine. - privValidator := conA.PrivValidator() - if privValidator != nil && rs.Proposer.Account.Id == privValidator.Id { - continue OUTER_LOOP - } + default: + // Ignore unknown message + } - // Add and process the block part - added, err := rs.Proposal.AddBlockPart(msg.BlockPart) - if err == ErrInvalidBlockPartConflict { - // TODO: Bad validator - } else if err != nil { - Panicf("Unexpected blockPartsSet error %v", err) - } - if added { - // If peer wants this part, send peer the part - // and our new blockParts state. - kbpMsg := conA.makeKnownBlockPartsMessage(rs) - partMsg := &BlockPartMessage{BlockPart: msg.BlockPart} - for _, peer := range conA.sw.Peers().List() { - peerState := conA.getPeerState(peer) - if peerState.WantsBlockPart(msg.BlockPart) { - peer.TrySend(KnownPartsCh, kbpMsg) - peer.TrySend(ProposalCh, partMsg) - } + case VoteCh: + switch msg_.(type) { + case *Vote: + vote := msg_.(*Vote) + // We can't deal with votes from another height, + // as they have a different validator set. + if vote.Height != rs.Height || vote.Height != ps.Height { + return + } + index, ok := rs.Validators.GetIndexById(vote.SignerId) + if !ok { + log.Warning("Peer gave us an invalid vote.") + return + } + ps.SetHasVote(rs.Height, rs.Round, vote.Type, uint32(index)) + added, err := conR.conS.AddVote(vote) + if err != nil { + log.Warning("Error attempting to add vote: %v", err) + } + if added { + // Maybe send HasVotesMessage + voteAddCounter++ + if voteAddCounter%hasVotesThreshold == 0 { + // TODO optimize. + msg := &HasVotesMessage{ + Height: rs.Height, + Round: rs.Round, + Votes: rs.Votes.BitArray(), + Precommits: rs.Precommits.BitArray(), + Commits: rs.Commits.BitArray(), } - - } else { - // We failed to process the block part. - // Either an error, which we handled, or duplicate part. - continue OUTER_LOOP + conR.sw.Broadcast(StateCh, msg) } } default: // Ignore unknown message - // conA.sw.StopPeerForError(inMsg.MConn.Peer, errInvalidMessage) } + default: + // Ignore unknown channel } - // Cleanup + if err != nil { + log.Warning("Error in Receive(): %v", err) + } } -func (conA *ConsensusAgent) knownPartsRoutine() { +func (conR *ConsensusReactor) gossipDataRoutine(peer *p2p.Peer, ps *PeerState) { + OUTER_LOOP: for { - // Receive incoming message on ProposalCh - inMsg, ok := conA.sw.Receive(KnownPartsCh) - if !ok { - break OUTER_LOOP // Client has stopped + // Manage disconnects from self or peer. + if peer.IsStopped() || conR.IsStopped() { + log.Info("Stopping gossipDataRoutine for %v.", peer) + return } - _, msg_ := decodeMessage(inMsg.Bytes) - log.Info("knownPartsRoutine received %v", msg_) + rs := conR.conS.GetRoundState() + prs := ps.GetRoundState() - msg, ok := msg_.(*KnownBlockPartsMessage) - if !ok { - // Ignore unknown message type - // conA.sw.StopPeerForError(inMsg.MConn.Peer, errInvalidMessage) + // If height and round doesn't match, sleep. + if rs.Height != prs.Height || rs.Round != prs.Round { + time.Sleep(peerGossipSleepDuration) continue OUTER_LOOP } - peerState := conA.getPeerState(inMsg.MConn.Peer) - if !peerState.IsConnected() { - // Peer disconnected before we were able to process. - continue OUTER_LOOP - } - peerState.ApplyKnownBlockPartsMessage(msg) - } - // Cleanup -} - -// Signs a vote document and broadcasts it. -// hash can be nil to vote "nil" -func (conA *ConsensusAgent) signAndVote(vote *Vote) error { - privValidator := conA.PrivValidator() - if privValidator != nil { - err := privValidator.SignVote(vote) - if err != nil { - return err + // Send proposal? + if rs.Proposal != nil && !prs.Proposal { + msg := p2p.TypedMessage{msgTypeProposal, rs.Proposal} + peer.Send(DataCh, msg) + ps.SetHasProposal(rs.Height, rs.Round) + continue OUTER_LOOP } - msg := p2p.TypedMessage{msgTypeVote, vote} - conA.sw.Broadcast(VoteCh, msg) - } - return nil -} - -func (conA *ConsensusAgent) stageProposal(proposal *BlockPartSet) error { - // Already staged? - conA.mtx.Lock() - if conA.stagedProposal == proposal { - conA.mtx.Unlock() - return nil - } else { - conA.mtx.Unlock() - } - if !proposal.IsComplete() { - return errors.New("Incomplete proposal BlockPartSet") - } - block := proposal.Block() - - // Basic validation is done in state.CommitBlock(). - //err := block.ValidateBasic() - //if err != nil { - // return err - //} - - // Create a copy of the state for staging - conA.mtx.Lock() - stateCopy := conA.state.Copy() // Deep copy the state before staging. - conA.mtx.Unlock() - - // Commit block onto the copied state. - err := stateCopy.CommitBlock(block) - if err != nil { - return err - } - - // Looks good! - conA.mtx.Lock() - conA.stagedProposal = proposal - conA.stagedState = stateCopy - conA.mtx.Unlock() - return nil -} - -// Constructs an unsigned proposal -func (conA *ConsensusAgent) constructProposal(rs *RoundState) (*BlockPartSet, error) { - // TODO: make use of state returned from MakeProposal() - proposalBlock, _ := conA.mempool.MakeProposal() - proposal := proposalBlock.ToBlockPartSet() - return proposal, nil -} - -// Vote for (or against) the proposal for this round. -// Call during transition from RoundStepProposal to RoundStepVote. -// We may not have received a full proposal. -func (conA *ConsensusAgent) voteProposal(rs *RoundState) error { - // If we're locked, must vote that. - locked := conA.conS.LockedProposal() - if locked != nil { - block := locked.Block() - err := conA.signAndVote(&Vote{ - Height: rs.Height, - Round: rs.Round, - Type: VoteTypeBare, - Hash: block.Hash(), - }) - return err - } - // Stage proposal - err := conA.stageProposal(rs.Proposal) - if err != nil { - // Vote for nil, whatever the error. - err := conA.signAndVote(&Vote{ - Height: rs.Height, - Round: rs.Round, - Type: VoteTypeBare, - Hash: nil, - }) - return err - } - // Vote for block. - err = conA.signAndVote(&Vote{ - Height: rs.Height, - Round: rs.Round, - Type: VoteTypeBare, - Hash: rs.Proposal.Block().Hash(), - }) - return err -} - -// Precommit proposal if we see enough votes for it. -// Call during transition from RoundStepVote to RoundStepPrecommit. -func (conA *ConsensusAgent) precommitProposal(rs *RoundState) error { - // If we see a 2/3 majority for votes for a block, precommit. - - // TODO: maybe could use commitTime here and avg it with later commitTime? - if hash, _, ok := rs.RoundBareVotes.TwoThirdsMajority(); ok { - if len(hash) == 0 { - // 2/3 majority voted for nil. - return nil - } else { - // 2/3 majority voted for a block. - - // If proposal is invalid or unknown, do nothing. - // See note on ZombieValidators to see why. - if conA.stageProposal(rs.Proposal) != nil { - return nil + // Send proposal block part? + if index, ok := rs.ProposalBlockPartSet.BitArray().Sub( + prs.ProposalBlockBitArray).PickRandom(); ok { + msg := &PartMessage{ + Height: rs.Height, + Round: rs.Round, + Type: partTypeProposalBlock, + Part: rs.ProposalBlockPartSet.GetPart(uint16(index)), } + peer.Send(DataCh, msg) + ps.SetHasProposalBlockPart(rs.Height, rs.Round, uint16(index)) + continue OUTER_LOOP + } - // Lock this proposal. - // NOTE: we're unlocking any prior locks. - conA.conS.LockProposal(rs.Proposal) - - // Send precommit vote. - err := conA.signAndVote(&Vote{ + // Send proposal POL part? + if index, ok := rs.ProposalPOLPartSet.BitArray().Sub( + prs.ProposalPOLBitArray).PickRandom(); ok { + msg := &PartMessage{ Height: rs.Height, Round: rs.Round, - Type: VoteTypePrecommit, - Hash: hash, - }) - return err + Type: partTypeProposalPOL, + Part: rs.ProposalPOLPartSet.GetPart(uint16(index)), + } + peer.Send(DataCh, msg) + ps.SetHasProposalPOLPart(rs.Height, rs.Round, uint16(index)) + continue OUTER_LOOP } - } else { - // If we haven't seen enough votes, do nothing. - return nil + + // Nothing to do. Sleep. + time.Sleep(peerGossipSleepDuration) + continue OUTER_LOOP } } -// Commit or unlock. -// Call after RoundStepPrecommit, after round has completely expired. -func (conA *ConsensusAgent) commitOrUnlockProposal(rs *RoundState) (commitTime time.Time, err error) { - // If there exists a 2/3 majority of precommits. - // Validate the block and commit. - if hash, commitTime, ok := rs.RoundPrecommits.TwoThirdsMajority(); ok { - - // If the proposal is invalid or we don't have it, - // do not commit. - // TODO If we were just late to receive the block, when - // do we actually get it? Document it. - if conA.stageProposal(rs.Proposal) != nil { - return time.Time{}, nil +func (conR *ConsensusReactor) gossipVotesRoutine(peer *p2p.Peer, ps *PeerState) { +OUTER_LOOP: + for { + // Manage disconnects from self or peer. + if peer.IsStopped() || conR.IsStopped() { + log.Info("Stopping gossipVotesRoutine for %v.", peer) + return } - // TODO: Remove? - conA.conS.LockProposal(rs.Proposal) - // Vote commit. - err := conA.signAndVote(&Vote{ - Height: rs.Height, - Round: rs.Round, - Type: VoteTypePrecommit, - Hash: hash, - }) - if err != nil { - return time.Time{}, err + rs := conR.conS.GetRoundState() + prs := ps.GetRoundState() + + // If height doens't match, sleep. + if rs.Height != prs.Height { + time.Sleep(peerGossipSleepDuration) + continue OUTER_LOOP } - // Commit block. - conA.commitProposal(rs.Proposal, commitTime) - return commitTime, nil - } else { - // Otherwise, if a 1/3 majority if a block that isn't our locked one exists, unlock. - locked := conA.conS.LockedProposal() - if locked != nil { - for _, hashOrNil := range rs.RoundPrecommits.OneThirdMajority() { - if hashOrNil == nil { - continue + + // If there are bare votes to send... + if prs.Step <= RoundStepVote { + index, ok := rs.Votes.BitArray().Sub(prs.Votes).PickRandom() + if ok { + valId, ok := rs.Validators.GetIdByIndex(uint32(index)) + if ok { + vote := rs.Votes.GetVote(valId) + msg := p2p.TypedMessage{msgTypeVote, vote} + peer.Send(VoteCh, msg) + ps.SetHasVote(rs.Height, rs.Round, VoteTypeBare, uint32(index)) + continue OUTER_LOOP + } else { + log.Error("index is not a valid validator index") } - if !bytes.Equal(hashOrNil, locked.Block().Hash()) { - // Unlock our lock. - conA.conS.LockProposal(nil) + } + } + + // If there are precommits to send... + if prs.Step <= RoundStepPrecommit { + index, ok := rs.Precommits.BitArray().Sub(prs.Precommits).PickRandom() + if ok { + valId, ok := rs.Validators.GetIdByIndex(uint32(index)) + if ok { + vote := rs.Precommits.GetVote(valId) + msg := p2p.TypedMessage{msgTypeVote, vote} + peer.Send(VoteCh, msg) + ps.SetHasVote(rs.Height, rs.Round, VoteTypePrecommit, uint32(index)) + continue OUTER_LOOP + } else { + log.Error("index is not a valid validator index") } } } - return time.Time{}, nil - } -} -func (conA *ConsensusAgent) commitProposal(proposal *BlockPartSet, commitTime time.Time) error { - conA.mtx.Lock() - defer conA.mtx.Unlock() + // If there are any commits to send... + index, ok := rs.Commits.BitArray().Sub(prs.Commits).PickRandom() + if ok { + valId, ok := rs.Validators.GetIdByIndex(uint32(index)) + if ok { + vote := rs.Commits.GetVote(valId) + msg := p2p.TypedMessage{msgTypeVote, vote} + peer.Send(VoteCh, msg) + ps.SetHasVote(rs.Height, rs.Round, VoteTypeCommit, uint32(index)) + continue OUTER_LOOP + } else { + log.Error("index is not a valid validator index") + } + } - if conA.stagedProposal != proposal { - panic("Unexpected stagedProposal.") // Shouldn't happen. + // We sent nothing. Sleep... + time.Sleep(peerGossipSleepDuration) + continue OUTER_LOOP } +} - // Save to blockStore - block, blockParts := proposal.Block(), proposal.BlockParts() - err := conA.blockStore.SaveBlockParts(block.Height, blockParts) - if err != nil { - return err +// Signs a vote document and broadcasts it. +func (conR *ConsensusReactor) signAndBroadcastVote(rs *RoundState, vote *Vote) { + if rs.PrivValidator != nil { + rs.PrivValidator.SignVote(vote) + conR.conS.AddVote(vote) + msg := p2p.TypedMessage{msgTypeVote, vote} + conR.sw.Broadcast(VoteCh, msg) } +} - // What was staged becomes committed. - conA.state = conA.stagedState - conA.state.Save(commitTime) - conA.conS.Update(conA.state) - conA.stagedProposal = nil - conA.stagedState = nil - conA.mempool.ResetForBlockAndState(block, conA.state) +//------------------------------------- - return nil +func (conR *ConsensusReactor) runStepPropose(rs *RoundState) { + conR.conS.MakeProposal() } -// Given a RoundState where we are the proposer, -// broadcast rs.proposal to all the peers. -func (conA *ConsensusAgent) shareProposal(rs *RoundState) { - privValidator := conA.PrivValidator() - proposal := rs.Proposal - if privValidator == nil || proposal == nil { - return - } - privValidator.SignProposal(rs.Round, proposal) - blockParts := proposal.BlockParts() - peers := conA.sw.Peers().List() - if len(peers) == 0 { - log.Warning("Could not propose: no peers") - return - } - numBlockParts := uint16(len(blockParts)) - kbpMsg := conA.makeKnownBlockPartsMessage(rs) - for i, peer := range peers { - peerState := conA.getPeerState(peer) - if !peerState.IsConnected() { - continue // Peer was disconnected. - } - startIndex := uint16((i * len(blockParts)) / len(peers)) - // Create a function that when called, - // starts sending block parts to peer. - cb := func(peer *p2p.Peer, startIndex uint16) func() { - return func() { - // TODO: if the clocks are off a bit, - // peer may receive this before the round flips. - peer.Send(KnownPartsCh, kbpMsg) - for i := uint16(0); i < numBlockParts; i++ { - part := blockParts[(startIndex+i)%numBlockParts] - // Ensure round hasn't expired on our end. - currentRS := conA.conS.RoundState() - if currentRS != rs { - return - } - // If peer wants the block: - if peerState.WantsBlockPart(part) { - partMsg := &BlockPartMessage{BlockPart: part} - peer.Send(ProposalCh, partMsg) - } - } - } - }(peer, startIndex) - // Call immediately or schedule cb for when peer is ready. - peerState.SetRoundCallback(rs.Height, rs.Round, cb) - } -} +func (conR *ConsensusReactor) runStepVote(rs *RoundState) { -func (conA *ConsensusAgent) gossipVoteRoutine() { -OUTER_LOOP: - for { - // Get round state - rs := conA.conS.RoundState() + // If we have a locked block, we must vote for that. + // NOTE: a locked block is already valid. + if rs.LockedBlock != nil { + conR.signAndBroadcastVote(rs, &Vote{ + Height: rs.Height, + Round: rs.Round, + Type: VoteTypeBare, + BlockHash: rs.LockedBlock.Hash(), + }) + } - // Receive incoming message on VoteCh - inMsg, ok := conA.sw.Receive(VoteCh) - if !ok { - break // Client has stopped - } - type_, msg_ := decodeMessage(inMsg.Bytes) - log.Info("gossipVoteRoutine received %v", msg_) + // Try staging proposed block. + // If Block is nil, an error is returned. + err := conR.conS.stageBlock(rs.ProposalBlock) + if err != nil { - switch msg_.(type) { - case *Vote: - vote := msg_.(*Vote) + // Vote nil + conR.signAndBroadcastVote(rs, &Vote{ + Height: rs.Height, + Round: rs.Round, + Type: VoteTypeBare, + BlockHash: nil, + }) - if vote.Height != rs.Height || vote.Round != rs.Round { - continue OUTER_LOOP - } + } else { - added, rank, err := rs.AddVote(vote, inMsg.MConn.Peer.Key) - // Send peer VoteRankMessage if needed - if type_ == msgTypeVoteAskRank { - msg := &VoteRankMessage{ - ValidatorId: vote.SignerId, - Rank: rank, - } - inMsg.MConn.Peer.TrySend(VoteCh, msg) - } - // Process vote - if !added { - log.Info("Error adding vote %v", err) - } - switch err { - case ErrVoteInvalidAccount, ErrVoteInvalidSignature: - // TODO: Handle bad peer. - case ErrVoteConflictingSignature, ErrVoteInvalidHash: - // TODO: Handle bad validator. - case nil: - break - //case ErrVoteUnexpectedPhase: Shouldn't happen. - default: - Panicf("Unexpected error from .AddVote(): %v", err) - } - if !added { - continue - } + // Vote for block + conR.signAndBroadcastVote(rs, &Vote{ + Height: rs.Height, + Round: rs.Round, + Type: VoteTypeBare, + BlockHash: rs.ProposalBlock.Hash(), + }) + } +} - // Gossip vote. - for _, peer := range conA.sw.Peers().List() { - peerState := conA.getPeerState(peer) - wantsVote, unsolicited := peerState.WantsVote(vote) - if wantsVote { - if unsolicited { - // If we're sending an unsolicited vote, - // ask for the rank so we know whether it's good. - msg := p2p.TypedMessage{msgTypeVoteAskRank, vote} - peer.TrySend(VoteCh, msg) - } else { - msg := p2p.TypedMessage{msgTypeVote, vote} - peer.TrySend(VoteCh, msg) - } - } - } +func (conR *ConsensusReactor) runStepPrecommit(rs *RoundState) { - case *VoteRankMessage: - msg := msg_.(*VoteRankMessage) + // If we see a 2/3 majority of votes for a block, lock. + hash := conR.conS.LockOrUnlock(rs.Height, rs.Round) + if len(hash) > 0 { - peerState := conA.getPeerState(inMsg.MConn.Peer) - if !peerState.IsConnected() { - // Peer disconnected before we were able to process. - continue OUTER_LOOP - } - peerState.ApplyVoteRankMessage(msg) + // Precommit + conR.signAndBroadcastVote(rs, &Vote{ + Height: rs.Height, + Round: rs.Round, + Type: VoteTypePrecommit, + BlockHash: hash, + }) - default: - // Ignore unknown message - // conA.sw.StopPeerForError(inMsg.MConn.Peer, errInvalidMessage) - } } +} - // Cleanup +func (conR *ConsensusReactor) runStepCommit(rs *RoundState) bool { + + // If we see a 2/3 majority of precommits for a block, commit. + block := conR.conS.Commit(rs.Height, rs.Round) + if block == nil { + return false + } else { + conR.signAndBroadcastVote(rs, &Vote{ + Height: rs.Height, + Round: rs.Round, + Type: VoteTypePrecommit, + BlockHash: block.Hash(), + }) + return true + } } +//------------------------------------- + type RoundAction struct { - Height uint32 // The block height for which consensus is reaching for. - Round uint16 // The round number at given height. - XnToStep uint8 // Transition to this step. Action depends on this value. + Height uint32 // The block height for which consensus is reaching for. + Round uint16 // The round number at given height. + XnToStep uint8 // Transition to this step. Action depends on this value. + RoundElapsed time.Duration // Duration since round start. } // Source of all round state transitions and votes. -// It can be preemptively woken up via amessage to +// It can be preemptively woken up via a message to // doActionCh. -func (conA *ConsensusAgent) proposeAndVoteRoutine() { +func (conR *ConsensusReactor) proposeAndVoteRoutine() { // Figure out when to wake up next (in the absence of other events) setAlarm := func() { - if len(conA.doActionCh) > 0 { + if len(conR.doActionCh) > 0 { return // Already going to wake up later. } // Figure out which height/round/step we're at, // then schedule an action for when it is due. - rs := conA.conS.RoundState() - _, _, roundDuration, _, elapsedRatio := calcRoundInfo(rs.StartTime) - switch rs.Step() { + rs := conR.conS.GetRoundState() + _, _, roundDuration, roundElapsed, elapsedRatio := calcRoundInfo(rs.StartTime) + switch rs.Step { case RoundStepStart: // It's a new RoundState. if elapsedRatio < 0 { // startTime is in the future. time.Sleep(time.Duration(-1.0*elapsedRatio) * roundDuration) } - conA.doActionCh <- RoundAction{rs.Height, rs.Round, RoundStepProposal} - case RoundStepProposal: + conR.doActionCh <- RoundAction{rs.Height, rs.Round, RoundStepPropose, roundElapsed} + case RoundStepPropose: // Wake up when it's time to vote. time.Sleep(time.Duration(roundDeadlineBare-elapsedRatio) * roundDuration) - conA.doActionCh <- RoundAction{rs.Height, rs.Round, RoundStepBareVotes} - case RoundStepBareVotes: + conR.doActionCh <- RoundAction{rs.Height, rs.Round, RoundStepVote, roundElapsed} + case RoundStepVote: // Wake up when it's time to precommit. time.Sleep(time.Duration(roundDeadlinePrecommit-elapsedRatio) * roundDuration) - conA.doActionCh <- RoundAction{rs.Height, rs.Round, RoundStepPrecommits} - case RoundStepPrecommits: + conR.doActionCh <- RoundAction{rs.Height, rs.Round, RoundStepPrecommit, roundElapsed} + case RoundStepPrecommit: // Wake up when the round is over. time.Sleep(time.Duration(1.0-elapsedRatio) * roundDuration) - conA.doActionCh <- RoundAction{rs.Height, rs.Round, RoundStepCommitOrUnlock} - case RoundStepCommitOrUnlock: + conR.doActionCh <- RoundAction{rs.Height, rs.Round, RoundStepCommit, roundElapsed} + case RoundStepCommit: // This shouldn't happen. // Before setAlarm() got called, // logic should have created a new RoundState for the next round. @@ -714,66 +550,35 @@ func (conA *ConsensusAgent) proposeAndVoteRoutine() { for { func() { - roundAction := <-conA.doActionCh + roundAction := <-conR.doActionCh // Always set the alarm after any processing below. defer setAlarm() - // We only consider acting on given height and round. height := roundAction.Height round := roundAction.Round - // We only consider transitioning to given step. step := roundAction.XnToStep - // This is the current state. - rs := conA.conS.RoundState() + roundElapsed := roundAction.RoundElapsed + rs := conR.conS.GetRoundState() + if height != rs.Height || round != rs.Round { - return // Not relevant. + return // Action is not relevant } - if step == RoundStepProposal && rs.Step() == RoundStepStart { - // Propose a block if I am the proposer. - privValidator := conA.PrivValidator() - if privValidator != nil && rs.Proposer.Account.Id == privValidator.Id { - // If we're already locked on a proposal, use that. - proposal := conA.conS.LockedProposal() - if proposal != nil { - // Otherwise, construct a new proposal. - var err error - proposal, err = conA.constructProposal(rs) - if err != nil { - log.Error("Error attempting to construct a proposal: %v", err) - return // Pretend like we weren't the proposer. Shrug. - } - } - // Set proposal for roundState, so we vote correctly subsequently. - rs.Proposal = proposal - // Share the parts. - // We send all parts to all of our peers, but everyone receives parts - // starting at a different index, wrapping around back to 0. - conA.shareProposal(rs) - } - } else if step == RoundStepBareVotes && rs.Step() <= RoundStepProposal { - err := conA.voteProposal(rs) - if err != nil { - log.Info("Error attempting to vote for proposal: %v", err) - } - } else if step == RoundStepPrecommits && rs.Step() <= RoundStepBareVotes { - err := conA.precommitProposal(rs) - if err != nil { - log.Info("Error attempting to precommit for proposal: %v", err) - } - } else if step == RoundStepCommitOrUnlock && rs.Step() <= RoundStepPrecommits { - commitTime, err := conA.commitOrUnlockProposal(rs) - if err != nil { - log.Info("Error attempting to commit or update for proposal: %v", err) - } - - if !commitTime.IsZero() { + // Run step + if step == RoundStepPropose && rs.Step == RoundStepStart { + conR.runStepPropose(rs) + } else if step == RoundStepVote && rs.Step <= RoundStepPropose { + conR.runStepVote(rs) + } else if step == RoundStepPrecommit && rs.Step <= RoundStepVote { + conR.runStepPrecommit(rs) + } else if step == RoundStepCommit && rs.Step <= RoundStepPrecommit { + didCommit := conR.runStepCommit(rs) + if didCommit { // We already set up ConsensusState for the next height - // (it happens in the call to conA.commitProposal). + // (it happens in the call to conR.runStepCommit). } else { - // Round is over. This is a special case. // Prepare a new RoundState for the next state. - conA.conS.SetupRound(rs.Round + 1) + conR.conS.SetupRound(rs.Round + 1) return // setAlarm() takes care of the rest. } } else { @@ -781,224 +586,159 @@ func (conA *ConsensusAgent) proposeAndVoteRoutine() { } // Transition to new step. - rs.SetStep(step) + conR.conS.SetStep(step) + + // Broadcast NewRoundStepMessage. + msg := &NewRoundStepMessage{ + Height: height, + Round: round, + Step: step, + SecondsSinceStartTime: uint32(roundElapsed.Seconds()), + } + conR.sw.Broadcast(StateCh, msg) }() } } //----------------------------------------------------------------------------- +// Read only when returned by PeerState.GetRoundState(). +type PeerRoundState struct { + Height uint32 // Height peer is at + Round uint16 // Round peer is at + Step uint8 // Step peer is at + StartTime time.Time // Estimated start of round 0 at this height + Proposal bool // True if peer has proposal for this round + ProposalBlockHash []byte // Block parts merkle root + ProposalBlockBitArray BitArray // Block parts bitarray + ProposalPOLHash []byte // POL parts merkle root + ProposalPOLBitArray BitArray // POL parts bitarray + Votes BitArray // All votes peer has for this round + Precommits BitArray // All precommits peer has for this round + Commits BitArray // All commits peer has for this height +} + +//----------------------------------------------------------------------------- + var ( ErrPeerStateHeightRegression = errors.New("Error peer state height regression") ErrPeerStateInvalidStartTime = errors.New("Error peer state invalid startTime") ) -// TODO: voteRanks should purge bygone validators. type PeerState struct { - mtx sync.Mutex - connected bool - peer *p2p.Peer - height uint32 - startTime time.Time // Derived from offset seconds. - blockPartsBitArray []byte - voteRanks map[uint64]uint8 - cbHeight uint32 - cbRound uint16 - cbFunc func() + mtx sync.Mutex + PeerRoundState } func NewPeerState(peer *p2p.Peer) *PeerState { - return &PeerState{ - connected: true, - peer: peer, - height: 0, - voteRanks: make(map[uint64]uint8), - } + return &PeerState{} } -func (ps *PeerState) IsConnected() bool { - if ps == nil { - return false - } +// Returns an atomic snapshot of the PeerRoundState. +// There's no point in mutating it since it won't change PeerState. +func (ps *PeerState) GetRoundState() *PeerRoundState { ps.mtx.Lock() defer ps.mtx.Unlock() - return ps.connected + prs := ps.PeerRoundState // copy + return &prs } -func (ps *PeerState) Disconnect() { +func (ps *PeerState) SetHasProposal(height uint32, round uint16) { ps.mtx.Lock() defer ps.mtx.Unlock() - ps.connected = false -} -func (ps *PeerState) WantsBlockPart(part *BlockPart) bool { - if ps == nil { - return false + if ps.Height == height && ps.Round == round { + ps.Proposal = true } +} + +func (ps *PeerState) SetHasProposalBlockPart(height uint32, round uint16, index uint16) { ps.mtx.Lock() defer ps.mtx.Unlock() - if !ps.connected { - return false - } - // Only wants the part if peer's current height and round matches. - if ps.height == part.Height { - round := calcRound(ps.startTime) - // NOTE: validators want to receive remaining block parts - // even after it had voted bare or precommit. - // Ergo, we do not check for which step the peer is in. - if round == part.Round { - // Only wants the part if it doesn't already have it. - if ps.blockPartsBitArray[part.Index/8]&byte(1<<(part.Index%8)) == 0 { - return true - } - } + + if ps.Height == height && ps.Round == round { + ps.ProposalBlockBitArray.SetIndex(uint(index), true) } - return false } -func (ps *PeerState) WantsVote(vote *Vote) (wants bool, unsolicited bool) { - if ps == nil { - return false, false - } +func (ps *PeerState) SetHasProposalPOLPart(height uint32, round uint16, index uint16) { ps.mtx.Lock() defer ps.mtx.Unlock() - if !ps.connected { - return false, false - } - // Only wants the vote if peer's current height and round matches. - if ps.height == vote.Height { - round, _, _, _, elapsedRatio := calcRoundInfo(ps.startTime) - if round == vote.Round { - if vote.Type == VoteTypeBare && elapsedRatio > roundDeadlineBare { - return false, false - } - if vote.Type == VoteTypePrecommit && elapsedRatio > roundDeadlinePrecommit { - return false, false - } else { - // continue on ... - } - } else { - return false, false - } - } else { - return false, false - } - // Only wants the vote if voteRank is low. - if ps.voteRanks[vote.SignerId] > voteRankCutoff { - // Sometimes, send unsolicited votes to see if peer wants it. - if rand.Float32() < unsolicitedVoteRate { - return true, true - } else { - // Rank too high. Do not send vote. - return false, false - } + if ps.Height == height && ps.Round == round { + ps.ProposalPOLBitArray.SetIndex(uint(index), true) } - return true, false } -func (ps *PeerState) ApplyKnownBlockPartsMessage(msg *KnownBlockPartsMessage) error { +func (ps *PeerState) SetHasVote(height uint32, round uint16, type_ uint8, index uint32) { ps.mtx.Lock() defer ps.mtx.Unlock() - // TODO: Sanity check len(BlockParts) - if msg.Height < ps.height { - return ErrPeerStateHeightRegression - } - if msg.Height == ps.height { - if len(ps.blockPartsBitArray) == 0 { - ps.blockPartsBitArray = msg.BlockPartsBitArray - } else if len(msg.BlockPartsBitArray) > 0 { - if len(ps.blockPartsBitArray) != len(msg.BlockPartsBitArray) { - // TODO: If the peer received a part from - // a proposer who signed a bad (or conflicting) part, - // just about anything can happen with the new blockPartsBitArray. - // In those cases it's alright to ignore the peer for the round, - // and try to induce nil votes for that round. - return nil - } else { - // TODO: Same as above. If previously known parts disappear, - // something is fishy. - // For now, just copy over known parts. - for i, byt := range msg.BlockPartsBitArray { - ps.blockPartsBitArray[i] |= byt - } - } - } - } else { - // TODO: handle peer connection latency estimation. - newStartTime := time.Now().Add(-1 * time.Duration(msg.SecondsSinceStartTime) * time.Second) - // Ensure that the new height's start time is sufficiently after the last startTime. - // TODO: there should be some time between rounds. - if !newStartTime.After(ps.startTime) { - return ErrPeerStateInvalidStartTime - } - ps.startTime = newStartTime - ps.height = msg.Height - ps.blockPartsBitArray = msg.BlockPartsBitArray - // Call callback if height+round matches. - peerRound := calcRound(ps.startTime) - if ps.cbFunc != nil && ps.cbHeight == ps.height && ps.cbRound == peerRound { - go ps.cbFunc() - ps.cbFunc = nil + if ps.Height == height && (ps.Round == round || type_ == VoteTypeCommit) { + switch type_ { + case VoteTypeBare: + ps.Votes.SetIndex(uint(index), true) + case VoteTypePrecommit: + ps.Precommits.SetIndex(uint(index), true) + case VoteTypeCommit: + ps.Commits.SetIndex(uint(index), true) + default: + panic("Invalid vote type") } } - return nil } -func (ps *PeerState) ApplyVoteRankMessage(msg *VoteRankMessage) error { +func (ps *PeerState) ApplyNewRoundStepMessage(msg *NewRoundStepMessage) error { ps.mtx.Lock() defer ps.mtx.Unlock() - ps.voteRanks[msg.ValidatorId] = msg.Rank + + // Set step state + startTime := time.Now().Add(-1 * time.Duration(msg.SecondsSinceStartTime) * time.Second) + ps.Height = msg.Height + ps.Round = msg.Round + ps.Step = msg.Step + ps.StartTime = startTime + + // Reset the rest + ps.Proposal = false + ps.ProposalBlockHash = nil + ps.ProposalBlockBitArray = nil + ps.ProposalPOLHash = nil + ps.ProposalPOLBitArray = nil + ps.Votes = nil + ps.Precommits = nil + if ps.Height != msg.Height { + ps.Commits = nil + } return nil } -// Sets a single round callback, to be called when the height+round comes around. -// If the height+round is current, calls "go f()" immediately. -// Otherwise, does nothing. -func (ps *PeerState) SetRoundCallback(height uint32, round uint16, f func()) { +func (ps *PeerState) ApplyHasVotesMessage(msg *HasVotesMessage) error { ps.mtx.Lock() defer ps.mtx.Unlock() - if ps.height < height { - ps.cbHeight = height - ps.cbRound = round - ps.cbFunc = f - // Wait until the height of the peerState changes. - // We'll call cbFunc then. - return - } else if ps.height == height { - peerRound := calcRound(ps.startTime) - if peerRound < round { - // Set a timer to call the cbFunc when the time comes. - go func() { - roundStart := calcRoundStartTime(round, ps.startTime) - time.Sleep(roundStart.Sub(time.Now())) - // If peer height is still good - ps.mtx.Lock() - peerHeight := ps.height - ps.mtx.Unlock() - if peerHeight == height { - f() - } - }() - } else if peerRound == round { - go f() + if ps.Height == msg.Height { + ps.Commits = ps.Commits.Or(msg.Commits) + if ps.Round == msg.Round { + ps.Votes = ps.Votes.Or(msg.Votes) + ps.Precommits = ps.Precommits.Or(msg.Precommits) } else { - return + ps.Votes = msg.Votes + ps.Precommits = msg.Precommits } - } else { - return } + return nil } //----------------------------------------------------------------------------- // Messages const ( - msgTypeUnknown = byte(0x00) - msgTypeBlockPart = byte(0x10) - msgTypeKnownBlockParts = byte(0x11) - msgTypeVote = byte(0x20) - msgTypeVoteAskRank = byte(0x21) - msgTypeVoteRank = byte(0x22) + msgTypeUnknown = byte(0x00) + // Messages for communicating state changes + msgTypeNewRoundStep = byte(0x01) + msgTypeHasVotes = byte(0x02) + // Messages of data + msgTypeProposal = byte(0x11) + msgTypePart = byte(0x12) // both block & POL + msgTypeVote = byte(0x13) ) // TODO: check for unnecessary extra bytes at the end. @@ -1006,17 +746,20 @@ func decodeMessage(bz []byte) (msgType byte, msg interface{}) { n, err := new(int64), new(error) // log.Debug("decoding msg bytes: %X", bz) msgType = bz[0] + r := bytes.NewReader(bz[1:]) switch msgType { - case msgTypeBlockPart: - msg = readBlockPartMessage(bytes.NewReader(bz[1:]), n, err) - case msgTypeKnownBlockParts: - msg = readKnownBlockPartsMessage(bytes.NewReader(bz[1:]), n, err) + // Messages for communicating state changes + case msgTypeNewRoundStep: + msg = readNewRoundStepMessage(r, n, err) + case msgTypeHasVotes: + msg = readHasVotesMessage(r, n, err) + // Messages of data + case msgTypeProposal: + msg = ReadProposal(r, n, err) + case msgTypePart: + msg = readPartMessage(r, n, err) case msgTypeVote: - msg = ReadVote(bytes.NewReader(bz[1:]), n, err) - case msgTypeVoteAskRank: - msg = ReadVote(bytes.NewReader(bz[1:]), n, err) - case msgTypeVoteRank: - msg = readVoteRankMessage(bytes.NewReader(bz[1:]), n, err) + msg = ReadVote(r, n, err) default: msg = nil } @@ -1025,76 +768,101 @@ func decodeMessage(bz []byte) (msgType byte, msg interface{}) { //------------------------------------- -type BlockPartMessage struct { - BlockPart *BlockPart +type NewRoundStepMessage struct { + Height uint32 + Round uint16 + Step uint8 + SecondsSinceStartTime uint32 } -func readBlockPartMessage(r io.Reader, n *int64, err *error) *BlockPartMessage { - return &BlockPartMessage{ - BlockPart: ReadBlockPart(r, n, err), +func readNewRoundStepMessage(r io.Reader, n *int64, err *error) *NewRoundStepMessage { + return &NewRoundStepMessage{ + Height: ReadUInt32(r, n, err), + Round: ReadUInt16(r, n, err), + Step: ReadUInt8(r, n, err), + SecondsSinceStartTime: ReadUInt32(r, n, err), } } -func (m *BlockPartMessage) WriteTo(w io.Writer) (n int64, err error) { - WriteByte(w, msgTypeBlockPart, &n, &err) - WriteBinary(w, m.BlockPart, &n, &err) +func (m *NewRoundStepMessage) WriteTo(w io.Writer) (n int64, err error) { + WriteByte(w, msgTypeNewRoundStep, &n, &err) + WriteUInt32(w, m.Height, &n, &err) + WriteUInt16(w, m.Round, &n, &err) + WriteUInt8(w, m.Step, &n, &err) + WriteUInt32(w, m.SecondsSinceStartTime, &n, &err) return } -func (m *BlockPartMessage) String() string { - return fmt.Sprintf("[BlockPartMessage %v]", m.BlockPart) +func (m *NewRoundStepMessage) String() string { + return fmt.Sprintf("[NewRoundStepMessage H:%v R:%v]", m.Height, m.Round) } //------------------------------------- -type KnownBlockPartsMessage struct { - Height uint32 - SecondsSinceStartTime uint32 - BlockPartsBitArray []byte +type HasVotesMessage struct { + Height uint32 + Round uint16 + Votes BitArray + Precommits BitArray + Commits BitArray } -func readKnownBlockPartsMessage(r io.Reader, n *int64, err *error) *KnownBlockPartsMessage { - return &KnownBlockPartsMessage{ - Height: ReadUInt32(r, n, err), - SecondsSinceStartTime: ReadUInt32(r, n, err), - BlockPartsBitArray: ReadByteSlice(r, n, err), +func readHasVotesMessage(r io.Reader, n *int64, err *error) *HasVotesMessage { + return &HasVotesMessage{ + Height: ReadUInt32(r, n, err), + Round: ReadUInt16(r, n, err), + Votes: ReadBitArray(r, n, err), + Precommits: ReadBitArray(r, n, err), + Commits: ReadBitArray(r, n, err), } } -func (m *KnownBlockPartsMessage) WriteTo(w io.Writer) (n int64, err error) { - WriteByte(w, msgTypeKnownBlockParts, &n, &err) +func (m *HasVotesMessage) WriteTo(w io.Writer) (n int64, err error) { + WriteByte(w, msgTypeHasVotes, &n, &err) WriteUInt32(w, m.Height, &n, &err) - WriteUInt32(w, m.SecondsSinceStartTime, &n, &err) - WriteByteSlice(w, m.BlockPartsBitArray, &n, &err) + WriteUInt16(w, m.Round, &n, &err) + WriteBinary(w, m.Votes, &n, &err) + WriteBinary(w, m.Precommits, &n, &err) + WriteBinary(w, m.Commits, &n, &err) return } -func (m *KnownBlockPartsMessage) String() string { - return fmt.Sprintf("[KnownBlockPartsMessage H:%v SSST:%v, BPBA:%X]", - m.Height, m.SecondsSinceStartTime, m.BlockPartsBitArray) +func (m *HasVotesMessage) String() string { + return fmt.Sprintf("[HasVotesMessage H:%v R:%v]", m.Height, m.Round) } //------------------------------------- -type VoteRankMessage struct { - ValidatorId uint64 - Rank uint8 +const ( + partTypeProposalBlock = byte(0x01) + partTypeProposalPOL = byte(0x02) +) + +type PartMessage struct { + Height uint32 + Round uint16 + Type byte + Part *Part } -func readVoteRankMessage(r io.Reader, n *int64, err *error) *VoteRankMessage { - return &VoteRankMessage{ - ValidatorId: ReadUInt64(r, n, err), - Rank: ReadUInt8(r, n, err), +func readPartMessage(r io.Reader, n *int64, err *error) *PartMessage { + return &PartMessage{ + Height: ReadUInt32(r, n, err), + Round: ReadUInt16(r, n, err), + Type: ReadByte(r, n, err), + Part: ReadPart(r, n, err), } } -func (m *VoteRankMessage) WriteTo(w io.Writer) (n int64, err error) { - WriteByte(w, msgTypeVoteRank, &n, &err) - WriteUInt64(w, m.ValidatorId, &n, &err) - WriteUInt8(w, m.Rank, &n, &err) +func (m *PartMessage) WriteTo(w io.Writer) (n int64, err error) { + WriteByte(w, msgTypePart, &n, &err) + WriteUInt32(w, m.Height, &n, &err) + WriteUInt16(w, m.Round, &n, &err) + WriteByte(w, m.Type, &n, &err) + WriteBinary(w, m.Part, &n, &err) return } -func (m *VoteRankMessage) String() string { - return fmt.Sprintf("[VoteRankMessage V:%v, R:%v]", m.ValidatorId, m.Rank) +func (m *PartMessage) String() string { + return fmt.Sprintf("[PartMessage H:%v R:%v T:%X]", m.Height, m.Round, m.Type) } diff --git a/consensus/document.go b/consensus/document.go deleted file mode 100644 index bf9a362db..000000000 --- a/consensus/document.go +++ /dev/null @@ -1,41 +0,0 @@ -package consensus - -import ( - "fmt" - . "github.com/tendermint/tendermint/config" -) - -func GenVoteDocument(voteType byte, height uint32, round uint16, proposalHash []byte) string { - stepName := "" - switch voteType { - case VoteTypeBare: - stepName = "bare" - case VoteTypePrecommit: - stepName = "precommit" - case VoteTypeCommit: - stepName = "commit" - default: - panic("Unknown vote type") - } - return fmt.Sprintf( - `-----BEGIN TENDERMINT DOCUMENT----- -URI: %v://consensus/%v/%v/%v -ProposalHash: %X ------END TENDERMINT DOCUMENHT-----`, - Config.Network, height, round, stepName, - proposalHash, - ) -} - -func GenBlockPartDocument(height uint32, round uint16, index uint16, total uint16, blockPartHash []byte) string { - return fmt.Sprintf( - `-----BEGIN TENDERMINT DOCUMENT----- -URI: %v://blockpart/%v/%v/%v -Total: %v -BlockPartHash: %X ------END TENDERMINT DOCUMENHT-----`, - Config.Network, height, round, index, - total, - blockPartHash, - ) -} diff --git a/consensus/part_set.go b/consensus/part_set.go new file mode 100644 index 000000000..a7724795d --- /dev/null +++ b/consensus/part_set.go @@ -0,0 +1,177 @@ +package consensus + +import ( + "bytes" + "crypto/sha256" + "errors" + "io" + "sync" + + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/common" + "github.com/tendermint/tendermint/merkle" +) + +const ( + partSize = 4096 // 4KB +) + +var ( + ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index") + ErrPartSetInvalidTrail = errors.New("Error part set invalid trail") +) + +type Part struct { + Index uint16 + Trail [][]byte + Bytes []byte + + // Cache + hash []byte +} + +func ReadPart(r io.Reader, n *int64, err *error) *Part { + return &Part{ + Index: ReadUInt16(r, n, err), + Trail: ReadByteSlices(r, n, err), + Bytes: ReadByteSlice(r, n, err), + } +} + +func (b *Part) WriteTo(w io.Writer) (n int64, err error) { + WriteUInt16(w, b.Index, &n, &err) + WriteByteSlices(w, b.Trail, &n, &err) + WriteByteSlice(w, b.Bytes, &n, &err) + return +} + +func (pt *Part) Hash() []byte { + if pt.hash != nil { + return pt.hash + } else { + hasher := sha256.New() + _, err := hasher.Write(pt.Bytes) + if err != nil { + panic(err) + } + pt.hash = hasher.Sum(nil) + return pt.hash + } +} + +//------------------------------------- + +type PartSet struct { + rootHash []byte + total uint16 + + mtx sync.Mutex + parts []*Part + partsBitArray BitArray + count uint16 +} + +// Returns an immutable, full PartSet. +func NewPartSetFromData(data []byte) *PartSet { + // divide data into 4kb parts. + total := (len(data) + partSize - 1) / partSize + parts := make([]*Part, total) + parts_ := make([]merkle.Hashable, total) + partsBitArray := NewBitArray(uint(total)) + for i := 0; i < total; i++ { + part := &Part{ + Index: uint16(i), + Bytes: data[i*partSize : MinInt(len(data), (i+1)*partSize)], + } + parts[i] = part + parts_[i] = part + partsBitArray.SetIndex(uint(i), true) + } + // Compute merkle trails + hashTree := merkle.HashTreeFromHashables(parts_) + for i := 0; i < total; i++ { + parts[i].Trail = merkle.HashTrailForIndex(hashTree, i) + } + return &PartSet{ + parts: parts, + partsBitArray: partsBitArray, + rootHash: hashTree[len(hashTree)/2], + total: uint16(total), + count: uint16(total), + } +} + +// Returns an empty PartSet ready to be populated. +func NewPartSetFromMetadata(total uint16, rootHash []byte) *PartSet { + return &PartSet{ + parts: make([]*Part, total), + partsBitArray: NewBitArray(uint(total)), + rootHash: rootHash, + total: total, + count: 0, + } +} + +func (ps *PartSet) BitArray() BitArray { + ps.mtx.Lock() + defer ps.mtx.Unlock() + return ps.partsBitArray.Copy() +} + +func (ps *PartSet) RootHash() []byte { + return ps.rootHash +} + +func (ps *PartSet) Total() uint16 { + if ps == nil { + return 0 + } + return ps.total +} + +func (ps *PartSet) AddPart(part *Part) (bool, error) { + ps.mtx.Lock() + defer ps.mtx.Unlock() + + // Invalid part index + if part.Index >= ps.total { + return false, ErrPartSetUnexpectedIndex + } + + // If part already exists, return false. + if ps.parts[part.Index] != nil { + return false, nil + } + + // Check hash trail + if !merkle.VerifyHashTrailForIndex(int(part.Index), part.Hash(), part.Trail, ps.rootHash) { + return false, ErrPartSetInvalidTrail + } + + // Add part + ps.parts[part.Index] = part + ps.partsBitArray.SetIndex(uint(part.Index), true) + ps.count++ + return true, nil +} + +func (ps *PartSet) GetPart(index uint16) *Part { + ps.mtx.Lock() + defer ps.mtx.Unlock() + return ps.parts[index] +} + +func (ps *PartSet) IsComplete() bool { + return ps.count == ps.total +} + +func (ps *PartSet) GetReader() io.Reader { + if !ps.IsComplete() { + panic("Cannot GetReader() on incomplete PartSet") + } + buf := []byte{} + for _, part := range ps.parts { + buf = append(buf, part.Bytes...) + } + return bytes.NewReader(buf) +} diff --git a/consensus/pol.go b/consensus/pol.go new file mode 100644 index 000000000..395aca749 --- /dev/null +++ b/consensus/pol.go @@ -0,0 +1,98 @@ +package consensus + +import ( + "io" + + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/blocks" + . "github.com/tendermint/tendermint/common" + . "github.com/tendermint/tendermint/state" +) + +// Proof of lock. +// +2/3 of validators' (bare) votes for a given blockhash (or nil) +type POL struct { + Height uint32 + Round uint16 + BlockHash []byte // Could be nil, which makes this a proof of unlock. + Votes []Signature // Vote signatures for height/round/hash + Commits []Signature // Commit signatures for height/hash + CommitRounds []uint16 // Rounds of the commits, less than POL.Round. +} + +func ReadPOL(r io.Reader, n *int64, err *error) *POL { + return &POL{ + Height: ReadUInt32(r, n, err), + Round: ReadUInt16(r, n, err), + BlockHash: ReadByteSlice(r, n, err), + Votes: ReadSignatures(r, n, err), + Commits: ReadSignatures(r, n, err), + CommitRounds: ReadUInt16s(r, n, err), + } +} + +func (pol *POL) WriteTo(w io.Writer) (n int64, err error) { + WriteUInt32(w, pol.Height, &n, &err) + WriteUInt16(w, pol.Round, &n, &err) + WriteByteSlice(w, pol.BlockHash, &n, &err) + WriteSignatures(w, pol.Votes, &n, &err) + WriteSignatures(w, pol.Commits, &n, &err) + WriteUInt16s(w, pol.CommitRounds, &n, &err) + return +} + +// Returns whether +2/3 have voted/committed for BlockHash. +func (pol *POL) Verify(vset *ValidatorSet) error { + + talliedVotingPower := uint64(0) + voteDoc := GenVoteDocument(VoteTypeBare, pol.Height, pol.Round, pol.BlockHash) + seenValidators := map[uint64]struct{}{} + + for _, sig := range pol.Votes { + + // Validate + if _, seen := seenValidators[sig.SignerId]; seen { + return Errorf("Duplicate validator for vote %v for POL %v", sig, pol) + } + validator := vset.GetById(sig.SignerId) + if validator == nil { + return Errorf("Invalid validator for vote %v for POL %v", sig, pol) + } + if !validator.Verify(voteDoc, sig) { + return Errorf("Invalid signature for vote %v for POL %v", sig, pol) + } + + // Tally + seenValidators[validator.Id] = struct{}{} + talliedVotingPower += validator.VotingPower + } + + for i, sig := range pol.Commits { + round := pol.CommitRounds[i] + + // Validate + if _, seen := seenValidators[sig.SignerId]; seen { + return Errorf("Duplicate validator for commit %v for POL %v", sig, pol) + } + validator := vset.GetById(sig.SignerId) + if validator == nil { + return Errorf("Invalid validator for commit %v for POL %v", sig, pol) + } + commitDoc := GenVoteDocument(VoteTypeCommit, pol.Height, round, pol.BlockHash) // TODO cache + if !validator.Verify(commitDoc, sig) { + return Errorf("Invalid signature for commit %v for POL %v", sig, pol) + } + + // Tally + seenValidators[validator.Id] = struct{}{} + talliedVotingPower += validator.VotingPower + } + + if talliedVotingPower > vset.TotalVotingPower()*2/3 { + return nil + } else { + return Errorf("Invalid POL, insufficient voting power %v, needed %v", + talliedVotingPower, (vset.TotalVotingPower()*2/3 + 1)) + } + +} diff --git a/consensus/priv_validator.go b/consensus/priv_validator.go index 6e7c767df..72674302a 100644 --- a/consensus/priv_validator.go +++ b/consensus/priv_validator.go @@ -13,28 +13,17 @@ type PrivValidator struct { db *db_.LevelDB } -// Returns new signed blockParts. -// If signatures already exist in proposal BlockParts, -// e.g. a locked proposal from some prior round, -// those signatures are overwritten. -// Double signing (signing multiple proposals for the same height&round) results in an error. -func (pv *PrivValidator) SignProposal(round uint16, proposal *BlockPartSet) (err error) { +// Double signing results in an error. +func (pv *PrivValidator) SignProposal(proposal *Proposal) { //TODO: prevent double signing. - blockParts := proposal.BlockParts() - for i, part := range blockParts { - partHash := part.Hash() - doc := GenBlockPartDocument( - proposal.Height(), round, uint16(i), uint16(len(blockParts)), partHash) - part.Signature = pv.Sign([]byte(doc)) - } - return nil + doc := GenProposalDocument(proposal.Height, proposal.Round, proposal.BlockPartsTotal, + proposal.BlockPartsHash, proposal.POLPartsTotal, proposal.POLPartsHash) + proposal.Signature = pv.Sign([]byte(doc)) } -// Modifies the vote object in memory. // Double signing results in an error. -func (pv *PrivValidator) SignVote(vote *Vote) error { +func (pv *PrivValidator) SignVote(vote *Vote) { //TODO: prevent double signing. - doc := GenVoteDocument(vote.Type, vote.Height, vote.Round, vote.Hash) + doc := GenVoteDocument(vote.Type, vote.Height, vote.Round, vote.BlockHash) vote.Signature = pv.Sign([]byte(doc)) - return nil } diff --git a/consensus/proposal.go b/consensus/proposal.go new file mode 100644 index 000000000..8f0c06c30 --- /dev/null +++ b/consensus/proposal.go @@ -0,0 +1,64 @@ +package consensus + +import ( + "errors" + "io" + + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/blocks" +) + +var ( + ErrInvalidBlockPartSignature = errors.New("Error invalid block part signature") + ErrInvalidBlockPartHash = errors.New("Error invalid block part hash") +) + +type Proposal struct { + Height uint32 + Round uint16 + BlockPartsTotal uint16 + BlockPartsHash []byte + POLPartsTotal uint16 + POLPartsHash []byte + Signature +} + +func NewProposal(height uint32, round uint16, blockPartsTotal uint16, blockPartsHash []byte, + polPartsTotal uint16, polPartsHash []byte) *Proposal { + return &Proposal{ + Height: height, + Round: round, + BlockPartsTotal: blockPartsTotal, + BlockPartsHash: blockPartsHash, + POLPartsTotal: polPartsTotal, + POLPartsHash: polPartsHash, + } +} + +func ReadProposal(r io.Reader, n *int64, err *error) *Proposal { + return &Proposal{ + Height: ReadUInt32(r, n, err), + Round: ReadUInt16(r, n, err), + BlockPartsTotal: ReadUInt16(r, n, err), + BlockPartsHash: ReadByteSlice(r, n, err), + POLPartsTotal: ReadUInt16(r, n, err), + POLPartsHash: ReadByteSlice(r, n, err), + Signature: ReadSignature(r, n, err), + } +} + +func (p *Proposal) WriteTo(w io.Writer) (n int64, err error) { + WriteUInt32(w, p.Height, &n, &err) + WriteUInt16(w, p.Round, &n, &err) + WriteUInt16(w, p.BlockPartsTotal, &n, &err) + WriteByteSlice(w, p.BlockPartsHash, &n, &err) + WriteUInt16(w, p.POLPartsTotal, &n, &err) + WriteByteSlice(w, p.POLPartsHash, &n, &err) + WriteBinary(w, p.Signature, &n, &err) + return +} + +func (p *Proposal) GenDocument() []byte { + return GenProposalDocument(p.Height, p.Round, p.BlockPartsTotal, p.BlockPartsHash, + p.POLPartsTotal, p.POLPartsHash) +} diff --git a/consensus/state.go b/consensus/state.go index 3d41f863e..73d8a699e 100644 --- a/consensus/state.go +++ b/consensus/state.go @@ -1,204 +1,429 @@ package consensus import ( + "errors" "sync" "time" + . "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/blocks" . "github.com/tendermint/tendermint/common" + . "github.com/tendermint/tendermint/mempool" . "github.com/tendermint/tendermint/state" ) +const ( + RoundStepStart = uint8(0x00) // Round started. + RoundStepPropose = uint8(0x01) // Did propose, broadcasting proposal. + RoundStepVote = uint8(0x02) // Did vote, broadcasting votes. + RoundStepPrecommit = uint8(0x03) // Did precommit, broadcasting precommits. + RoundStepCommit = uint8(0x04) // We committed at this round -- do not progress to the next round. +) + var ( + ErrInvalidProposalSignature = errors.New("Error invalid proposal signature") + consensusStateKey = []byte("consensusState") ) +// Immutable when returned from ConsensusState.GetRoundState() +type RoundState struct { + Height uint32 // Height we are working on + Round uint16 + Step uint8 + StartTime time.Time + Validators *ValidatorSet + Proposer *Validator + Proposal *Proposal + ProposalBlock *Block + ProposalBlockPartSet *PartSet + ProposalPOL *POL + ProposalPOLPartSet *PartSet + LockedBlock *Block + LockedPOL *POL + Votes *VoteSet + Precommits *VoteSet + Commits *VoteSet + PrivValidator *PrivValidator +} + +//------------------------------------- + // Tracks consensus state across block heights and rounds. type ConsensusState struct { - mtx sync.Mutex - height uint32 // Height we are working on. - validatorsR0 *ValidatorSet // A copy of the validators at round 0 - lockedProposal *BlockPartSet // A BlockPartSet of the locked proposal. - startTime time.Time // Start of round 0 for this height. - commits *VoteSet // Commits for this height. - roundState *RoundState // The RoundState object for the current round. - commitTime time.Time // Time at which a block was found to be committed by +2/3. -} - -func NewConsensusState(state *State) *ConsensusState { - cs := &ConsensusState{} - cs.Update(state) + mtx sync.Mutex + RoundState + + blockStore *BlockStore + mempool *Mempool + + state *State // State until height-1. + stagedBlock *Block // Cache last staged block. + stagedState *State // Cache result of staged block. +} + +func NewConsensusState(state *State, blockStore *BlockStore, mempool *Mempool) *ConsensusState { + cs := &ConsensusState{ + blockStore: blockStore, + mempool: mempool, + } + cs.updateToState(state) return cs } -func (cs *ConsensusState) LockProposal(blockPartSet *BlockPartSet) { +func (cs *ConsensusState) GetRoundState() *RoundState { cs.mtx.Lock() defer cs.mtx.Unlock() - cs.lockedProposal = blockPartSet + rs := cs.RoundState // copy + return &rs } -func (cs *ConsensusState) UnlockProposal() { +func (cs *ConsensusState) updateToState(state *State) { + // Sanity check state. + stateHeight := state.Height() + if stateHeight > 0 && stateHeight != cs.Height+1 { + Panicf("updateToState() expected state height of %v but found %v", cs.Height+1, stateHeight) + } + + // Reset fields based on state. + height := state.Height() + validators := state.Validators() + cs.Height = height + cs.Round = 0 + cs.Step = RoundStepStart + cs.StartTime = state.CommitTime().Add(newBlockWaitDuration) + cs.Validators = validators + cs.Proposer = validators.GetProposer() + cs.Proposal = nil + cs.ProposalBlock = nil + cs.ProposalBlockPartSet = nil + cs.ProposalPOL = nil + cs.ProposalPOLPartSet = nil + cs.LockedBlock = nil + cs.LockedPOL = nil + cs.Votes = NewVoteSet(height, 0, VoteTypeBare, validators) + cs.Precommits = NewVoteSet(height, 0, VoteTypePrecommit, validators) + cs.Commits = NewVoteSet(height, 0, VoteTypeCommit, validators) + + cs.stagedBlock = nil + cs.stagedState = nil + + // Update the round if we need to. + round := calcRound(cs.StartTime) + if round > 0 { + cs.setupRound(round) + } +} + +func (cs *ConsensusState) SetupRound(round uint16) { cs.mtx.Lock() defer cs.mtx.Unlock() - cs.lockedProposal = nil + if cs.Round >= round { + Panicf("ConsensusState round %v not lower than desired round %v", cs.Round, round) + } + cs.setupRound(round) } -func (cs *ConsensusState) LockedProposal() *BlockPartSet { +func (cs *ConsensusState) setupRound(round uint16) { + + // Increment all the way to round. + validators := cs.Validators.Copy() + for r := cs.Round; r < round; r++ { + validators.IncrementAccum() + } + + cs.Round = round + cs.Step = RoundStepStart + cs.Validators = validators + cs.Proposer = validators.GetProposer() + cs.Proposal = nil + cs.ProposalBlock = nil + cs.ProposalBlockPartSet = nil + cs.ProposalPOL = nil + cs.ProposalPOLPartSet = nil + cs.Votes = NewVoteSet(cs.Height, round, VoteTypeBare, validators) + cs.Votes.AddVotesFromCommits(cs.Commits) + cs.Precommits = NewVoteSet(cs.Height, round, VoteTypePrecommit, validators) + cs.Precommits.AddVotesFromCommits(cs.Commits) +} + +func (cs *ConsensusState) SetStep(step byte) { cs.mtx.Lock() defer cs.mtx.Unlock() - return cs.lockedProposal + if cs.Step < step { + cs.Step = step + } else { + panic("step regression") + } } -func (cs *ConsensusState) RoundState() *RoundState { +func (cs *ConsensusState) SetPrivValidator(priv *PrivValidator) { cs.mtx.Lock() defer cs.mtx.Unlock() - return cs.roundState + cs.PrivValidator = priv } -// Primarily gets called upon block commit by ConsensusAgent. -func (cs *ConsensusState) Update(state *State) { +func (cs *ConsensusState) SetProposal(proposal *Proposal) error { cs.mtx.Lock() defer cs.mtx.Unlock() - // Sanity check state. - stateHeight := state.Height() - if stateHeight > 0 && stateHeight != cs.height+1 { - Panicf("Update() expected state height of %v but found %v", cs.height+1, stateHeight) + // Already have one + if cs.Proposal != nil { + return nil } - // Reset fields based on state. - cs.height = stateHeight - cs.validatorsR0 = state.Validators().Copy() // NOTE: immutable. - cs.lockedProposal = nil - cs.startTime = state.CommitTime().Add(newBlockWaitDuration) // NOTE: likely future time. - cs.commits = NewVoteSet(stateHeight, 0, VoteTypeCommit, cs.validatorsR0) + // Invalid. + if proposal.Height != cs.Height || proposal.Round != cs.Round { + return nil + } - // Setup the roundState - cs.roundState = nil - cs.setupRound(0) + // Verify signature + if !cs.Proposer.Verify(proposal.GenDocument(), proposal.Signature) { + return ErrInvalidProposalSignature + } + cs.Proposal = proposal + cs.ProposalBlockPartSet = NewPartSetFromMetadata(proposal.BlockPartsTotal, proposal.BlockPartsHash) + cs.ProposalPOLPartSet = NewPartSetFromMetadata(proposal.POLPartsTotal, proposal.POLPartsHash) + return nil } -// If cs.roundState isn't at round, set up new roundState at round. -func (cs *ConsensusState) SetupRound(round uint16) { +func (cs *ConsensusState) MakeProposal() { cs.mtx.Lock() defer cs.mtx.Unlock() - if cs.roundState != nil && cs.roundState.Round >= round { + + if cs.PrivValidator == nil || cs.Proposer.Id != cs.PrivValidator.Id { return } - cs.setupRound(round) -} -// Initialize roundState for given round. -// Involves incrementing validators for each past rand. -func (cs *ConsensusState) setupRound(round uint16) { - // Increment validator accums as necessary. - // We need to start with cs.validatorsR0 or cs.roundState.Validators - var validators *ValidatorSet - var validatorsRound uint16 - if cs.roundState == nil { - // We have no roundState so we start from validatorsR0 at round 0. - validators = cs.validatorsR0.Copy() - validatorsRound = 0 + var block *Block + var blockPartSet *PartSet + var pol *POL + var polPartSet *PartSet + + // Decide on block and POL + if cs.LockedBlock != nil { + // If we're locked onto a block, just choose that. + block = cs.LockedBlock + pol = cs.LockedPOL } else { - // We have a previous roundState so we start from that. - validators = cs.roundState.Validators.Copy() - validatorsRound = cs.roundState.Round + // TODO: make use of state returned from MakeProposalBlock() + block, _ = cs.mempool.MakeProposalBlock() + pol = cs.LockedPOL // If exists, is a PoUnlock. } - // Increment all the way to round. - for r := validatorsRound; r < round; r++ { - validators.IncrementAccum() + + blockPartSet = NewPartSetFromData(BinaryBytes(block)) + if pol != nil { + polPartSet = NewPartSetFromData(BinaryBytes(pol)) } - roundState := NewRoundState(cs.height, round, cs.startTime, validators, cs.commits) - cs.roundState = roundState + // Make proposal + proposal := NewProposal(cs.Height, cs.Round, blockPartSet.Total(), blockPartSet.RootHash(), + polPartSet.Total(), polPartSet.RootHash()) + cs.PrivValidator.SignProposal(proposal) + + // Set fields + cs.Proposal = proposal + cs.ProposalBlock = block + cs.ProposalBlockPartSet = blockPartSet + cs.ProposalPOL = pol + cs.ProposalPOLPartSet = polPartSet } -//----------------------------------------------------------------------------- +// NOTE: block is not necessarily valid. +func (cs *ConsensusState) AddProposalBlockPart(height uint32, round uint16, part *Part) (added bool, err error) { + cs.mtx.Lock() + defer cs.mtx.Unlock() -const ( - RoundStepStart = uint8(0x00) // Round started. - RoundStepProposal = uint8(0x01) // Did propose, broadcasting proposal. - RoundStepBareVotes = uint8(0x02) // Did vote bare, broadcasting bare votes. - RoundStepPrecommits = uint8(0x03) // Did precommit, broadcasting precommits. - RoundStepCommitOrUnlock = uint8(0x04) // We committed at this round -- do not progress to the next round. -) + // Blocks might be reused, so round mismatch is OK + if cs.Height != height { + return false, nil + } -//----------------------------------------------------------------------------- + // We're not expecting a block part. + if cs.ProposalBlockPartSet != nil { + return false, nil // TODO: bad peer? Return error? + } -// RoundState encapsulates all the state needed to engage in the consensus protocol. -type RoundState struct { - Height uint32 // Immutable - Round uint16 // Immutable - StartTime time.Time // Time in which consensus started for this height. - Expires time.Time // Time after which this round is expired. - Proposer *Validator // The proposer to propose a block for this round. - Validators *ValidatorSet // All validators with modified accumPower for this round. - Proposal *BlockPartSet // All block parts received for this round. - RoundBareVotes *VoteSet // All votes received for this round. - RoundPrecommits *VoteSet // All precommits received for this round. - Commits *VoteSet // A shared object for all commit votes of this height. - - mtx sync.Mutex - step uint8 // mutable -} - -func NewRoundState(height uint32, round uint16, startTime time.Time, - validators *ValidatorSet, commits *VoteSet) *RoundState { - - proposer := validators.GetProposer() - blockPartSet := NewBlockPartSet(height, nil) - roundBareVotes := NewVoteSet(height, round, VoteTypeBare, validators) - roundPrecommits := NewVoteSet(height, round, VoteTypePrecommit, validators) - - rs := &RoundState{ - Height: height, - Round: round, - StartTime: startTime, - Expires: calcRoundStartTime(round+1, startTime), - Proposer: proposer, - Validators: validators, - Proposal: blockPartSet, - RoundBareVotes: roundBareVotes, - RoundPrecommits: roundPrecommits, - Commits: commits, - - step: RoundStepStart, - } - return rs -} - -// "source" is typically the Peer.Key of the peer that gave us this vote. -func (rs *RoundState) AddVote(vote *Vote, source string) (added bool, rank uint8, err error) { + added, err = cs.ProposalBlockPartSet.AddPart(part) + if err != nil { + return added, err + } + if added && cs.ProposalBlockPartSet.IsComplete() { + var n int64 + var err error + cs.ProposalBlock = ReadBlock(cs.ProposalBlockPartSet.GetReader(), &n, &err) + return true, err + } + return true, nil +} + +// NOTE: POL is not necessarily valid. +func (cs *ConsensusState) AddProposalPOLPart(height uint32, round uint16, part *Part) (added bool, err error) { + cs.mtx.Lock() + defer cs.mtx.Unlock() + + if cs.Height != height || cs.Round != round { + return false, nil + } + + // We're not expecting a POL part. + if cs.ProposalPOLPartSet != nil { + return false, nil // TODO: bad peer? Return error? + } + + added, err = cs.ProposalPOLPartSet.AddPart(part) + if err != nil { + return added, err + } + if added && cs.ProposalPOLPartSet.IsComplete() { + var n int64 + var err error + cs.ProposalPOL = ReadPOL(cs.ProposalPOLPartSet.GetReader(), &n, &err) + return true, err + } + return true, nil +} + +func (cs *ConsensusState) AddVote(vote *Vote) (added bool, err error) { switch vote.Type { case VoteTypeBare: - return rs.RoundBareVotes.AddVote(vote, source) + // Votes checks for height+round match. + return cs.Votes.AddVote(vote) case VoteTypePrecommit: - return rs.RoundPrecommits.AddVote(vote, source) + // Precommits checks for height+round match. + return cs.Precommits.AddVote(vote) case VoteTypeCommit: - return rs.Commits.AddVote(vote, source) + // Commits checks for height match. + cs.Votes.AddVote(vote) + cs.Precommits.AddVote(vote) + return cs.Commits.AddVote(vote) default: panic("Unknown vote type") } } -func (rs *RoundState) Expired() bool { - return time.Now().After(rs.Expires) +// Lock the ProposalBlock if we have enough votes for it, +// or unlock an existing lock if +2/3 of votes were nil. +// Returns a blockhash if a block was locked. +func (cs *ConsensusState) LockOrUnlock(height uint32, round uint16) []byte { + cs.mtx.Lock() + defer cs.mtx.Unlock() + + if cs.Height != height || cs.Round != round { + return nil + } + + if hash, _, ok := cs.Votes.TwoThirdsMajority(); ok { + + // Remember this POL. (hash may be nil) + cs.LockedPOL = cs.Votes.MakePOL() + + if len(hash) == 0 { + // +2/3 voted nil. Just unlock. + cs.LockedBlock = nil + return nil + } else if cs.ProposalBlock.HashesTo(hash) { + // +2/3 voted for proposal block + // Validate the block. + // See note on ZombieValidators to see why. + if cs.stageBlock(cs.ProposalBlock) != nil { + log.Warning("+2/3 voted for an invalid block.") + return nil + } + cs.LockedBlock = cs.ProposalBlock + return hash + } else if cs.LockedBlock.HashesTo(hash) { + // +2/3 voted for already locked block + // cs.LockedBlock = cs.LockedBlock + return hash + } else { + // We don't have the block that hashes to hash. + // Unlock if we're locked. + cs.LockedBlock = nil + return nil + } + } else { + return nil + } } -func (rs *RoundState) Step() uint8 { - rs.mtx.Lock() - defer rs.mtx.Unlock() - return rs.step +func (cs *ConsensusState) Commit(height uint32, round uint16) *Block { + cs.mtx.Lock() + defer cs.mtx.Unlock() + + if cs.Height != height || cs.Round != round { + return nil + } + + if hash, commitTime, ok := cs.Precommits.TwoThirdsMajority(); ok { + + // There are some strange cases that shouldn't happen + // (unless voters are duplicitous). + // For example, the hash may not be the one that was + // proposed this round. These cases should be identified + // and warn the administrator. We should err on the side of + // caution and not, for example, sign a block. + // TODO: Identify these strange cases. + + var block *Block + if cs.LockedBlock.HashesTo(hash) { + block = cs.LockedBlock + } else if cs.ProposalBlock.HashesTo(hash) { + block = cs.ProposalBlock + } else { + return nil + } + + // The proposal must be valid. + if err := cs.stageBlock(block); err != nil { + log.Warning("Network is commiting an invalid proposal? %v", err) + return nil + } + + // Save to blockStore + err := cs.blockStore.SaveBlock(block) + if err != nil { + return nil + } + + // What was staged becomes committed. + state := cs.stagedState + state.Save(commitTime) + cs.updateToState(state) + + // Update mempool. + cs.mempool.ResetForBlockAndState(block, state) + + return block + } + + return nil } -func (rs *RoundState) SetStep(step uint8) bool { - rs.mtx.Lock() - defer rs.mtx.Unlock() - if rs.step < step { - rs.step = step - return true +func (cs *ConsensusState) stageBlock(block *Block) error { + + // Already staged? + if cs.stagedBlock == block { + return nil + } + + // Basic validation is done in state.CommitBlock(). + //err := block.ValidateBasic() + //if err != nil { + // return err + //} + + // Create a copy of the state for staging + stateCopy := cs.state.Copy() // Deep copy the state before staging. + + // Commit block onto the copied state. + err := stateCopy.CommitBlock(block) + if err != nil { + return err } else { - return false + cs.stagedBlock = block + cs.stagedState = stateCopy + return nil } } diff --git a/consensus/vote.go b/consensus/vote.go deleted file mode 100644 index 774134e1b..000000000 --- a/consensus/vote.go +++ /dev/null @@ -1,184 +0,0 @@ -package consensus - -import ( - "bytes" - "errors" - "io" - "sync" - "time" - - . "github.com/tendermint/tendermint/binary" - . "github.com/tendermint/tendermint/blocks" - . "github.com/tendermint/tendermint/state" -) - -const ( - VoteTypeBare = byte(0x00) - VoteTypePrecommit = byte(0x01) - VoteTypeCommit = byte(0x02) -) - -var ( - ErrVoteUnexpectedPhase = errors.New("Unexpected phase") - ErrVoteInvalidAccount = errors.New("Invalid round vote account") - ErrVoteInvalidSignature = errors.New("Invalid round vote signature") - ErrVoteInvalidHash = errors.New("Invalid hash") - ErrVoteConflictingSignature = errors.New("Conflicting round vote signature") -) - -// Represents a bare, precommit, or commit vote for proposals. -type Vote struct { - Height uint32 - Round uint16 // zero if commit vote. - Type byte - Hash []byte // empty if vote is nil. - Signature -} - -func ReadVote(r io.Reader, n *int64, err *error) *Vote { - return &Vote{ - Height: ReadUInt32(r, n, err), - Round: ReadUInt16(r, n, err), - Type: ReadByte(r, n, err), - Hash: ReadByteSlice(r, n, err), - Signature: ReadSignature(r, n, err), - } -} - -func (v *Vote) WriteTo(w io.Writer) (n int64, err error) { - WriteUInt32(w, v.Height, &n, &err) - WriteUInt16(w, v.Round, &n, &err) - WriteByte(w, v.Type, &n, &err) - WriteByteSlice(w, v.Hash, &n, &err) - WriteBinary(w, v.Signature, &n, &err) - return -} - -func (v *Vote) GetDocument() string { - return GenVoteDocument(v.Type, v.Height, v.Round, v.Hash) -} - -//----------------------------------------------------------------------------- - -// VoteSet helps collect signatures from validators at each height+round -// for a predefined vote type. -// TODO: test majority calculations etc. -type VoteSet struct { - mtx sync.Mutex - height uint32 - round uint16 - type_ byte - validators *ValidatorSet - votes map[uint64]*Vote - votesSources map[uint64][]string - votesByHash map[string]uint64 - totalVotes uint64 - totalVotingPower uint64 - oneThirdMajority [][]byte - twoThirdsCommitTime time.Time -} - -// Constructs a new VoteSet struct used to accumulate votes for each round. -func NewVoteSet(height uint32, round uint16, type_ byte, validators *ValidatorSet) *VoteSet { - if type_ == VoteTypeCommit && round != 0 { - panic("Expected round 0 for commit vote set") - } - totalVotingPower := uint64(0) - for _, val := range validators.Map() { - totalVotingPower += val.VotingPower - } - return &VoteSet{ - height: height, - round: round, - type_: type_, - validators: validators, - votes: make(map[uint64]*Vote, validators.Size()), - votesSources: make(map[uint64][]string, validators.Size()), - votesByHash: make(map[string]uint64), - totalVotes: 0, - totalVotingPower: totalVotingPower, - } -} - -// True if added, false if not. -// Returns ErrVote[UnexpectedPhase|InvalidAccount|InvalidSignature|InvalidHash|ConflictingSignature] -func (vs *VoteSet) AddVote(vote *Vote, source string) (bool, uint8, error) { - vs.mtx.Lock() - defer vs.mtx.Unlock() - - // Make sure the phase matches. - if vote.Height != vs.height || - (vote.Type != VoteTypeCommit && vote.Round != vs.round) || - vote.Type != vs.type_ { - return false, 0, ErrVoteUnexpectedPhase - } - - val := vs.validators.Get(vote.SignerId) - // Ensure that signer is a validator. - if val == nil { - return false, 0, ErrVoteInvalidAccount - } - // Check signature. - if !val.Verify([]byte(vote.GetDocument()), vote.Signature) { - // Bad signature. - return false, 0, ErrVoteInvalidSignature - } - // Get rank of vote & append provider key - var priorSources = vs.votesSources[vote.SignerId] - var rank = uint8(len(priorSources) + 1) - var alreadyProvidedByPeer = false - for i, otherPeer := range priorSources { - if otherPeer == source { - alreadyProvidedByPeer = true - rank = uint8(i + 1) - } - } - if !alreadyProvidedByPeer { - if len(priorSources) < voteRankCutoff { - vs.votesSources[vote.SignerId] = append(priorSources, source) - } - } - // If vote already exists, return false. - if existingVote, ok := vs.votes[vote.SignerId]; ok { - if bytes.Equal(existingVote.Hash, vote.Hash) { - return false, rank, nil - } else { - return false, rank, ErrVoteConflictingSignature - } - } - vs.votes[vote.SignerId] = vote - totalHashVotes := vs.votesByHash[string(vote.Hash)] + val.VotingPower - vs.votesByHash[string(vote.Hash)] = totalHashVotes - vs.totalVotes += val.VotingPower - // If we just nudged it up to one thirds majority, add it. - if totalHashVotes > vs.totalVotingPower/3 && - (totalHashVotes-val.VotingPower) <= vs.totalVotingPower/3 { - vs.oneThirdMajority = append(vs.oneThirdMajority, vote.Hash) - } else if totalHashVotes > vs.totalVotingPower*2/3 && - (totalHashVotes-val.VotingPower) <= vs.totalVotingPower*2/3 { - vs.twoThirdsCommitTime = time.Now() - } - return true, rank, nil -} - -// Returns either a blockhash (or nil) that received +2/3 majority. -// If there exists no such majority, returns (nil, false). -func (vs *VoteSet) TwoThirdsMajority() (hash []byte, commitTime time.Time, ok bool) { - vs.mtx.Lock() - defer vs.mtx.Unlock() - // There's only one or two in the array. - for _, hash := range vs.oneThirdMajority { - if vs.votesByHash[string(hash)] > vs.totalVotingPower*2/3 { - return hash, vs.twoThirdsCommitTime, true - } - } - return nil, time.Time{}, false -} - -// Returns blockhashes (or nil) that received a +1/3 majority. -// If there exists no such majority, returns nil. -func (vs *VoteSet) OneThirdMajority() (hashes [][]byte) { - vs.mtx.Lock() - defer vs.mtx.Unlock() - return vs.oneThirdMajority -} diff --git a/consensus/vote_set.go b/consensus/vote_set.go new file mode 100644 index 000000000..39f49c00c --- /dev/null +++ b/consensus/vote_set.go @@ -0,0 +1,180 @@ +package consensus + +import ( + "bytes" + + "sync" + "time" + + . "github.com/tendermint/tendermint/blocks" + . "github.com/tendermint/tendermint/common" + . "github.com/tendermint/tendermint/state" +) + +// VoteSet helps collect signatures from validators at each height+round +// for a predefined vote type. +// Note that there three kinds of votes: (bare) votes, precommits, and commits. +// A commit of prior rounds can be added added in lieu of votes/precommits. +// TODO: test majority calculations etc. +type VoteSet struct { + height uint32 + round uint16 + type_ byte + + mtx sync.Mutex + vset *ValidatorSet + votes map[uint64]*Vote + votesBitArray BitArray + votesByBlockHash map[string]uint64 + totalVotes uint64 + twoThirdsMajority []byte + twoThirdsCommitTime time.Time +} + +// Constructs a new VoteSet struct used to accumulate votes for each round. +func NewVoteSet(height uint32, round uint16, type_ byte, vset *ValidatorSet) *VoteSet { + if type_ == VoteTypeCommit && round != 0 { + panic("Expected round 0 for commit vote set") + } + return &VoteSet{ + height: height, + round: round, + type_: type_, + vset: vset, + votes: make(map[uint64]*Vote, vset.Size()), + votesBitArray: NewBitArray(vset.Size()), + votesByBlockHash: make(map[string]uint64), + totalVotes: 0, + } +} + +// True if added, false if not. +// Returns ErrVote[UnexpectedPhase|InvalidAccount|InvalidSignature|InvalidBlockHash|ConflictingSignature] +func (vs *VoteSet) AddVote(vote *Vote) (bool, error) { + vs.mtx.Lock() + defer vs.mtx.Unlock() + + // Make sure the phase matches. (or that vote is commit && round < vs.round) + if vote.Height != vs.height || + (vote.Type != VoteTypeCommit && vote.Round != vs.round) || + (vote.Type != VoteTypeCommit && vote.Type != vs.type_) || + (vote.Type == VoteTypeCommit && vs.type_ != VoteTypeCommit && vote.Round >= vs.round) { + return false, ErrVoteUnexpectedPhase + } + + // Ensure that signer is a validator. + val := vs.vset.GetById(vote.SignerId) + if val == nil { + return false, ErrVoteInvalidAccount + } + + // Check signature. + if !val.Verify(vote.GenDocument(), vote.Signature) { + // Bad signature. + return false, ErrVoteInvalidSignature + } + + return vs.addVote(vote) +} + +func (vs *VoteSet) addVote(vote *Vote) (bool, error) { + // If vote already exists, return false. + if existingVote, ok := vs.votes[vote.SignerId]; ok { + if bytes.Equal(existingVote.BlockHash, vote.BlockHash) { + return false, nil + } else { + return false, ErrVoteConflictingSignature + } + } + + // Add vote. + vs.votes[vote.SignerId] = vote + voterIndex, ok := vs.vset.GetIndexById(vote.SignerId) + if !ok { + return false, ErrVoteInvalidAccount + } + vs.votesBitArray.SetIndex(uint(voterIndex), true) + val := vs.vset.GetById(vote.SignerId) + totalBlockHashVotes := vs.votesByBlockHash[string(vote.BlockHash)] + val.VotingPower + vs.votesByBlockHash[string(vote.BlockHash)] = totalBlockHashVotes + vs.totalVotes += val.VotingPower + + // If we just nudged it up to two thirds majority, add it. + if totalBlockHashVotes > vs.vset.TotalVotingPower()*2/3 && + (totalBlockHashVotes-val.VotingPower) <= vs.vset.TotalVotingPower()*2/3 { + vs.twoThirdsMajority = vote.BlockHash + vs.twoThirdsCommitTime = time.Now() + } + + return true, nil +} + +// Assumes that commits VoteSet is valid. +func (vs *VoteSet) AddVotesFromCommits(commits *VoteSet) { + commitVotes := commits.AllVotes() + for _, commit := range commitVotes { + if commit.Round < vs.round { + vs.addVote(commit) + } + } +} + +func (vs *VoteSet) BitArray() BitArray { + vs.mtx.Lock() + defer vs.mtx.Unlock() + return vs.votesBitArray.Copy() +} + +func (vs *VoteSet) GetVote(id uint64) *Vote { + vs.mtx.Lock() + defer vs.mtx.Unlock() + return vs.votes[id] +} + +func (vs *VoteSet) AllVotes() []*Vote { + vs.mtx.Lock() + defer vs.mtx.Unlock() + votes := []*Vote{} + for _, vote := range vs.votes { + votes = append(votes, vote) + } + return votes +} + +// Returns either a blockhash (or nil) that received +2/3 majority. +// If there exists no such majority, returns (nil, false). +func (vs *VoteSet) TwoThirdsMajority() (hash []byte, commitTime time.Time, ok bool) { + vs.mtx.Lock() + defer vs.mtx.Unlock() + if vs.twoThirdsCommitTime.IsZero() { + return nil, time.Time{}, false + } + return vs.twoThirdsMajority, vs.twoThirdsCommitTime, true +} + +func (vs *VoteSet) MakePOL() *POL { + vs.mtx.Lock() + defer vs.mtx.Unlock() + if vs.twoThirdsCommitTime.IsZero() { + return nil + } + majHash := vs.twoThirdsMajority // hash may be nil. + pol := &POL{ + Height: vs.height, + Round: vs.round, + BlockHash: majHash, + } + for _, vote := range vs.votes { + if bytes.Equal(vote.BlockHash, majHash) { + if vote.Type == VoteTypeBare { + pol.Votes = append(pol.Votes, vote.Signature) + } else if vote.Type == VoteTypeCommit { + pol.Commits = append(pol.Votes, vote.Signature) + pol.CommitRounds = append(pol.CommitRounds, vote.Round) + } else { + Panicf("Unexpected vote type %X", vote.Type) + } + } + } + return pol +} diff --git a/main.go b/main.go index d807c017f..d3a5d775a 100644 --- a/main.go +++ b/main.go @@ -5,60 +5,26 @@ import ( "os/signal" "github.com/tendermint/tendermint/config" - "github.com/tendermint/tendermint/consensus" + //"github.com/tendermint/tendermint/consensus" "github.com/tendermint/tendermint/p2p" ) type Node struct { - lz []p2p.Listener - sw *p2p.Switch - book *p2p.AddrBook - pexAgent *p2p.PEXAgent + lz []p2p.Listener + sw *p2p.Switch + book *p2p.AddrBook + pexReactor *p2p.PEXReactor } func NewNode() *Node { - // Define channels for our app - chDescs := []*p2p.ChannelDescriptor{ - // PEX - &p2p.ChannelDescriptor{ - Id: p2p.PexCh, - SendQueueCapacity: 2, - RecvQueueCapacity: 2, - RecvBufferSize: 1024, - DefaultPriority: 1, - }, - // CONSENSUS - &p2p.ChannelDescriptor{ - Id: consensus.ProposalCh, - SendQueueCapacity: 2, - RecvQueueCapacity: 10, - RecvBufferSize: 10240, - DefaultPriority: 5, - }, - &p2p.ChannelDescriptor{ - Id: consensus.KnownPartsCh, - SendQueueCapacity: 2, - RecvQueueCapacity: 10, - RecvBufferSize: 1024, - DefaultPriority: 5, - }, - &p2p.ChannelDescriptor{ - Id: consensus.VoteCh, - SendQueueCapacity: 100, - RecvQueueCapacity: 1000, - RecvBufferSize: 10240, - DefaultPriority: 5, - }, - // TODO: MEMPOOL - } - sw := p2p.NewSwitch(chDescs) + sw := p2p.NewSwitch(nil) // XXX create and pass reactors book := p2p.NewAddrBook(config.RootDir + "/addrbook.json") - pexAgent := p2p.NewPEXAgent(sw, book) + pexReactor := p2p.NewPEXReactor(sw, book) return &Node{ - sw: sw, - book: book, - pexAgent: pexAgent, + sw: sw, + book: book, + pexReactor: pexReactor, } } @@ -69,7 +35,7 @@ func (n *Node) Start() { } n.sw.Start() n.book.Start() - n.pexAgent.Start() + n.pexReactor.Start() } func (n *Node) Stop() { @@ -77,7 +43,7 @@ func (n *Node) Stop() { // TODO: gracefully disconnect from peers. n.sw.Stop() n.book.Stop() - n.pexAgent.Stop() + n.pexReactor.Stop() } // Add a Listener to accept inbound peer connections. @@ -102,7 +68,7 @@ func (n *Node) inboundConnectionRoutine(l p2p.Listener) { } // NOTE: We don't yet have the external address of the // remote (if they have a listener at all). - // PEXAgent's pexRoutine will handle that. + // PEXReactor's pexRoutine will handle that. } // cleanup diff --git a/mempool/agent.go b/mempool/agent.go deleted file mode 100644 index bedf82dbb..000000000 --- a/mempool/agent.go +++ /dev/null @@ -1,169 +0,0 @@ -package mempool - -import ( - "bytes" - "fmt" - "io" - "sync/atomic" - - . "github.com/tendermint/tendermint/binary" - . "github.com/tendermint/tendermint/blocks" - "github.com/tendermint/tendermint/p2p" -) - -var ( - MempoolCh = byte(0x30) -) - -// MempoolAgent handles mempool tx broadcasting amongst peers. -type MempoolAgent struct { - sw *p2p.Switch - swEvents chan interface{} - quit chan struct{} - started uint32 - stopped uint32 - - mempool *Mempool -} - -func NewMempoolAgent(sw *p2p.Switch, mempool *Mempool) *MempoolAgent { - swEvents := make(chan interface{}) - sw.AddEventListener("MempoolAgent.swEvents", swEvents) - memA := &MempoolAgent{ - sw: sw, - swEvents: swEvents, - quit: make(chan struct{}), - mempool: mempool, - } - return memA -} - -func (memA *MempoolAgent) Start() { - if atomic.CompareAndSwapUint32(&memA.started, 0, 1) { - log.Info("Starting MempoolAgent") - go memA.switchEventsRoutine() - go memA.gossipTxRoutine() - } -} - -func (memA *MempoolAgent) Stop() { - if atomic.CompareAndSwapUint32(&memA.stopped, 0, 1) { - log.Info("Stopping MempoolAgent") - close(memA.quit) - close(memA.swEvents) - } -} - -func (memA *MempoolAgent) BroadcastTx(tx Tx) error { - err := memA.mempool.AddTx(tx) - if err != nil { - return err - } - msg := &TxMessage{Tx: tx} - memA.sw.Broadcast(MempoolCh, msg) - return nil -} - -// Handle peer new/done events -func (memA *MempoolAgent) switchEventsRoutine() { - for { - swEvent, ok := <-memA.swEvents - if !ok { - break - } - switch swEvent.(type) { - case p2p.SwitchEventNewPeer: - // event := swEvent.(p2p.SwitchEventNewPeer) - case p2p.SwitchEventDonePeer: - // event := swEvent.(p2p.SwitchEventDonePeer) - default: - log.Warning("Unhandled switch event type") - } - } -} - -func (memA *MempoolAgent) gossipTxRoutine() { -OUTER_LOOP: - for { - // Receive incoming message on MempoolCh - inMsg, ok := memA.sw.Receive(MempoolCh) - if !ok { - break OUTER_LOOP // Client has stopped - } - _, msg_ := decodeMessage(inMsg.Bytes) - log.Info("gossipTxRoutine received %v", msg_) - - switch msg_.(type) { - case *TxMessage: - msg := msg_.(*TxMessage) - err := memA.mempool.AddTx(msg.Tx) - if err != nil { - // Bad, seen, or conflicting tx. - log.Debug("Could not add tx %v", msg.Tx) - continue OUTER_LOOP - } else { - log.Debug("Added valid tx %V", msg.Tx) - } - // Share tx. - // We use a simple shotgun approach for now. - // TODO: improve efficiency - for _, peer := range memA.sw.Peers().List() { - if peer.Key == inMsg.MConn.Peer.Key { - continue - } - peer.TrySend(MempoolCh, msg) - } - - default: - // Ignore unknown message - // memA.sw.StopPeerForError(inMsg.MConn.Peer, errInvalidMessage) - } - } - - // Cleanup -} - -//----------------------------------------------------------------------------- -// Messages - -const ( - msgTypeUnknown = byte(0x00) - msgTypeTx = byte(0x10) -) - -// TODO: check for unnecessary extra bytes at the end. -func decodeMessage(bz []byte) (msgType byte, msg interface{}) { - n, err := new(int64), new(error) - // log.Debug("decoding msg bytes: %X", bz) - msgType = bz[0] - switch msgType { - case msgTypeTx: - msg = readTxMessage(bytes.NewReader(bz[1:]), n, err) - // case ...: - default: - msg = nil - } - return -} - -//------------------------------------- - -type TxMessage struct { - Tx Tx -} - -func readTxMessage(r io.Reader, n *int64, err *error) *TxMessage { - return &TxMessage{ - Tx: ReadTx(r, n, err), - } -} - -func (m *TxMessage) WriteTo(w io.Writer) (n int64, err error) { - WriteByte(w, msgTypeTx, &n, &err) - WriteBinary(w, m.Tx, &n, &err) - return -} - -func (m *TxMessage) String() string { - return fmt.Sprintf("[TxMessage %v]", m.Tx) -} diff --git a/mempool/mempool.go b/mempool/mempool.go index c44c77940..6d74a222a 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -45,7 +45,7 @@ func (mem *Mempool) AddTx(tx Tx) (err error) { // Returns a new block from the current state and associated transactions. // The block's Validation is empty, and some parts of the header too. -func (mem *Mempool) MakeProposal() (*Block, *State) { +func (mem *Mempool) MakeProposalBlock() (*Block, *State) { mem.mtx.Lock() defer mem.mtx.Unlock() nextBlock := mem.lastBlock.MakeNextBlock() diff --git a/mempool/reactor.go b/mempool/reactor.go new file mode 100644 index 000000000..f28daa159 --- /dev/null +++ b/mempool/reactor.go @@ -0,0 +1,141 @@ +package mempool + +import ( + "bytes" + "fmt" + "io" + "sync/atomic" + + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/blocks" + . "github.com/tendermint/tendermint/p2p" +) + +var ( + MempoolCh = byte(0x30) +) + +// MempoolReactor handles mempool tx broadcasting amongst peers. +type MempoolReactor struct { + sw *Switch + quit chan struct{} + started uint32 + stopped uint32 + + mempool *Mempool +} + +func NewMempoolReactor(sw *Switch, mempool *Mempool) *MempoolReactor { + memR := &MempoolReactor{ + sw: sw, + quit: make(chan struct{}), + mempool: mempool, + } + return memR +} + +func (memR *MempoolReactor) Start() { + if atomic.CompareAndSwapUint32(&memR.started, 0, 1) { + log.Info("Starting MempoolReactor") + } +} + +func (memR *MempoolReactor) Stop() { + if atomic.CompareAndSwapUint32(&memR.stopped, 0, 1) { + log.Info("Stopping MempoolReactor") + close(memR.quit) + } +} + +func (memR *MempoolReactor) BroadcastTx(tx Tx) error { + err := memR.mempool.AddTx(tx) + if err != nil { + return err + } + msg := &TxMessage{Tx: tx} + memR.sw.Broadcast(MempoolCh, msg) + return nil +} + +// Implements Reactor +func (pexR *MempoolReactor) AddPeer(peer *Peer) { +} + +// Implements Reactor +func (pexR *MempoolReactor) RemovePeer(peer *Peer, err error) { +} + +func (memR *MempoolReactor) Receive(chId byte, src *Peer, msgBytes []byte) { + _, msg_ := decodeMessage(msgBytes) + log.Info("MempoolReactor received %v", msg_) + + switch msg_.(type) { + case *TxMessage: + msg := msg_.(*TxMessage) + err := memR.mempool.AddTx(msg.Tx) + if err != nil { + // Bad, seen, or conflicting tx. + log.Debug("Could not add tx %v", msg.Tx) + return + } else { + log.Debug("Added valid tx %V", msg.Tx) + } + // Share tx. + // We use a simple shotgun approach for now. + // TODO: improve efficiency + for _, peer := range memR.sw.Peers().List() { + if peer.Key == src.Key { + continue + } + peer.TrySend(MempoolCh, msg) + } + + default: + // Ignore unknown message + } +} + +//----------------------------------------------------------------------------- +// Messages + +const ( + msgTypeUnknown = byte(0x00) + msgTypeTx = byte(0x10) +) + +// TODO: check for unnecessary extra bytes at the end. +func decodeMessage(bz []byte) (msgType byte, msg interface{}) { + n, err := new(int64), new(error) + // log.Debug("decoding msg bytes: %X", bz) + msgType = bz[0] + switch msgType { + case msgTypeTx: + msg = readTxMessage(bytes.NewReader(bz[1:]), n, err) + // case ...: + default: + msg = nil + } + return +} + +//------------------------------------- + +type TxMessage struct { + Tx Tx +} + +func readTxMessage(r io.Reader, n *int64, err *error) *TxMessage { + return &TxMessage{ + Tx: ReadTx(r, n, err), + } +} + +func (m *TxMessage) WriteTo(w io.Writer) (n int64, err error) { + WriteByte(w, msgTypeTx, &n, &err) + WriteBinary(w, m.Tx, &n, &err) + return +} + +func (m *TxMessage) String() string { + return fmt.Sprintf("[TxMessage %v]", m.Tx) +} diff --git a/merkle/types.go b/merkle/types.go index fe432be37..35d3ef424 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -26,3 +26,7 @@ type Tree interface { func NotFound(key []byte) error { return fmt.Errorf("Key was not found.") } + +type Hashable interface { + Hash() []byte +} diff --git a/merkle/util.go b/merkle/util.go index 262dd9d21..68a6515ec 100644 --- a/merkle/util.go +++ b/merkle/util.go @@ -1,68 +1,230 @@ package merkle import ( + "bytes" "crypto/sha256" "fmt" . "github.com/tendermint/tendermint/binary" ) -func HashFromByteSlices(items [][]byte) []byte { - switch len(items) { +func HashFromTwoHashes(left []byte, right []byte) []byte { + var n int64 + var err error + var hasher = sha256.New() + WriteByteSlice(hasher, left, &n, &err) + WriteByteSlice(hasher, right, &n, &err) + if err != nil { + panic(err) + } + return hasher.Sum(nil) +} + +/* +Computes a deterministic minimal height merkle tree hash. +If the number of items is not a power of two, some leaves +will be at different levels. + + * + / \ + / \ + / \ + / \ + * * + / \ / \ + / \ / \ + / \ / \ + * h2 * * + / \ / \ / \ + h0 h1 h3 h4 h5 h6 + +*/ +func HashFromHashes(hashes [][]byte) []byte { + switch len(hashes) { case 0: panic("Cannot compute hash of empty slice") case 1: - return items[0] + return hashes[0] default: - var n int64 - var err error - var hasher = sha256.New() - hash := HashFromByteSlices(items[0 : len(items)/2]) - WriteByteSlice(hasher, hash, &n, &err) + left := HashFromHashes(hashes[:len(hashes)/2]) + right := HashFromHashes(hashes[len(hashes)/2:]) + return HashFromTwoHashes(left, right) + } +} + +// Convenience for HashFromHashes. +func HashFromBinaries(items []Binary) []byte { + hashes := [][]byte{} + for _, item := range items { + hasher := sha256.New() + _, err := item.WriteTo(hasher) if err != nil { panic(err) } - hash = HashFromByteSlices(items[len(items)/2:]) - WriteByteSlice(hasher, hash, &n, &err) - if err != nil { - panic(err) + hash := hasher.Sum(nil) + hashes = append(hashes, hash) + } + return HashFromHashes(hashes) +} + +// Convenience for HashFromHashes. +func HashFromHashables(items []Hashable) []byte { + hashes := [][]byte{} + for _, item := range items { + hash := item.Hash() + hashes = append(hashes, hash) + } + return HashFromHashes(hashes) +} + +/* +Calculates an array of hashes, useful for deriving hash trails. + + 7 + / \ + / \ + / \ + / \ + 3 11 + / \ / \ + / \ / \ + / \ / \ + 1 5 9 13 + / \ / \ / \ / \ + 0 2 4 6 8 10 12 14 + h0 h1 h2 h3 h4 h5 h6 h7 + +(diagram and idea borrowed from libswift) + +The hashes provided get assigned to even indices. +The derived merkle hashes get assigned to odd indices. +If "hashes" is not of length power of 2, it is padded +with blank (zeroed) hashes. +*/ +func HashTreeFromHashes(hashes [][]byte) [][]byte { + + // Make length of "hashes" a power of 2 + hashesLen := uint32(len(hashes)) + fullLen := uint32(1) + for { + if fullLen >= hashesLen { + break + } else { + fullLen <<= 1 } - return hasher.Sum(nil) } + blank := make([]byte, len(hashes[0])) + for i := hashesLen; i < fullLen; i++ { + hashes = append(hashes, blank) + } + + // The result is twice the length minus one. + res := make([][]byte, len(hashes)*2-1) + for i, hash := range hashes { + res[i*2] = hash + } + + // Fill all the hashes recursively. + fillTreeRoot(res, 0, len(res)-1) + return res +} + +// Fill in the blanks. +func fillTreeRoot(res [][]byte, start, end int) []byte { + if start == end { + return res[start] + } else { + mid := (start + end) / 2 + left := fillTreeRoot(res, start, mid-1) + right := fillTreeRoot(res, mid+1, end) + root := HashFromTwoHashes(left, right) + res[mid] = root + return root + } +} + +// Convenience for HashTreeFromHashes. +func HashTreeFromHashables(items []Hashable) [][]byte { + hashes := [][]byte{} + for _, item := range items { + hash := item.Hash() + hashes = append(hashes, hash) + } + return HashTreeFromHashes(hashes) } /* -Compute a deterministic merkle hash from a list of Binary objects. +Given the original index of an item, +(e.g. for h5 in the diagram above, the index is 5, not 10) +returns a trail of hashes, which along with the index can be +used to calculate the merkle root. + +See VerifyHashTrailForIndex() */ -func HashFromBinarySlice(items []Binary) []byte { - switch len(items) { - case 0: - panic("Cannot compute hash of empty slice") - case 1: - hasher := sha256.New() - _, err := items[0].WriteTo(hasher) - if err != nil { - panic(err) +func HashTrailForIndex(hashTree [][]byte, index int) [][]byte { + trail := [][]byte{} + index *= 2 + + // We start from the leaf layer and work our way up. + // Notice the indices in the diagram: + // 0 2 4 ... offset 0, stride 2 + // 1 5 9 ... offset 1, stride 4 + // 3 11 19 ... offset 3, stride 8 + // 7 23 39 ... offset 7, stride 16 etc. + + offset := 0 + stride := 2 + + for { + // Calculate sibling of index. + var next int + if ((index-offset)/stride)%2 == 0 { + next = index + stride + } else { + next = index - stride } - return hasher.Sum(nil) - default: - var n int64 - var err error - var hasher = sha256.New() - hash := HashFromBinarySlice(items[0 : len(items)/2]) - WriteByteSlice(hasher, hash, &n, &err) - if err != nil { - panic(err) + if next >= len(hashTree) { + break } - hash = HashFromBinarySlice(items[len(items)/2:]) - WriteByteSlice(hasher, hash, &n, &err) - if err != nil { - panic(err) + // Insert sibling hash to trail. + trail = append(trail, hashTree[next]) + + index = (index + next) / 2 + offset += stride + stride *= 2 + } + + return trail +} + +// Ensures that leafHash is part of rootHash. +func VerifyHashTrailForIndex(index int, leafHash []byte, trail [][]byte, rootHash []byte) bool { + index *= 2 + offset := 0 + stride := 2 + + tempHash := make([]byte, len(leafHash)) + copy(tempHash, leafHash) + + for i := 0; i < len(trail); i++ { + var next int + if ((index-offset)/stride)%2 == 0 { + next = index + stride + tempHash = HashFromTwoHashes(tempHash, trail[i]) + } else { + next = index - stride + tempHash = HashFromTwoHashes(trail[i], tempHash) } - return hasher.Sum(nil) + index = (index + next) / 2 + offset += stride + stride *= 2 } + + return bytes.Equal(rootHash, tempHash) } +//----------------------------------------------------------------------------- + func PrintIAVLNode(node *IAVLNode) { fmt.Println("==== NODE") if node != nil { diff --git a/p2p/README.md b/p2p/README.md index 9b5654806..944198487 100644 --- a/p2p/README.md +++ b/p2p/README.md @@ -1,119 +1,11 @@ # P2P Module P2P provides an abstraction around peer-to-peer communication.
-Communication happens via Agents that react to messages from peers.
-Each Agent has one or more Channels of communication for each Peer.
+Communication happens via Reactors that react to messages from peers.
+Each Reactor has one or more Channels of communication for each Peer.
Channels are multiplexed automatically and can be configured.
A Switch is started upon app start, and handles Peer management.
-A PEXAgent implementation is provided to automate peer discovery.
- -## Usage - -MempoolAgent started from the following template code.
-Modify the snippet below according to your needs.
-Check out the ConsensusAgent for an example of tracking peer state.
- -```golang -package mempool - -import ( - "bytes" - "fmt" - "io" - "sync/atomic" - - . "github.com/tendermint/tendermint/binary" - . "github.com/tendermint/tendermint/blocks" - "github.com/tendermint/tendermint/p2p" -) - -var ( - MempoolCh = byte(0x30) -) - -// MempoolAgent handles mempool tx broadcasting amongst peers. -type MempoolAgent struct { - sw *p2p.Switch - swEvents chan interface{} - quit chan struct{} - started uint32 - stopped uint32 -} - -func NewMempoolAgent(sw *p2p.Switch) *MempoolAgent { - swEvents := make(chan interface{}) - sw.AddEventListener("MempoolAgent.swEvents", swEvents) - memA := &MempoolAgent{ - sw: sw, - swEvents: swEvents, - quit: make(chan struct{}), - } - return memA -} - -func (memA *MempoolAgent) Start() { - if atomic.CompareAndSwapUint32(&memA.started, 0, 1) { - log.Info("Starting MempoolAgent") - go memA.switchEventsRoutine() - go memA.gossipTxRoutine() - } -} - -func (memA *MempoolAgent) Stop() { - if atomic.CompareAndSwapUint32(&memA.stopped, 0, 1) { - log.Info("Stopping MempoolAgent") - close(memA.quit) - close(memA.swEvents) - } -} - -// Handle peer new/done events -func (memA *MempoolAgent) switchEventsRoutine() { - for { - swEvent, ok := <-memA.swEvents - if !ok { - break - } - switch swEvent.(type) { - case p2p.SwitchEventNewPeer: - // event := swEvent.(p2p.SwitchEventNewPeer) - // NOTE: set up peer state - case p2p.SwitchEventDonePeer: - // event := swEvent.(p2p.SwitchEventDonePeer) - // NOTE: tear down peer state - default: - log.Warning("Unhandled switch event type") - } - } -} - -func (memA *MempoolAgent) gossipTxRoutine() { -OUTER_LOOP: - for { - // Receive incoming message on MempoolCh - inMsg, ok := memA.sw.Receive(MempoolCh) - if !ok { - break OUTER_LOOP // Client has stopped - } - _, msg_ := decodeMessage(inMsg.Bytes) - log.Info("gossipMempoolRoutine received %v", msg_) - - switch msg_.(type) { - case *TxMessage: - // msg := msg_.(*TxMessage) - // handle msg - - default: - // Ignore unknown message - // memA.sw.StopPeerForError(inMsg.MConn.Peer, errInvalidMessage) - } - } - - // Cleanup -} - -``` - +A PEXReactor implementation is provided to automate peer discovery.
## Channels diff --git a/p2p/connection.go b/p2p/connection.go index dcf230eeb..6f35eff42 100644 --- a/p2p/connection.go +++ b/p2p/connection.go @@ -27,10 +27,13 @@ const ( defaultRecvRate = 51200 // 5Kb/s ) +type receiveCbFunc func(chId byte, msgBytes []byte) +type errorCbFunc func(interface{}) + /* A MConnection wraps a network connection and handles buffering and multiplexing. Binary messages are sent with ".Send(channelId, msg)". -Inbound byteslices are pushed to the designated chan<- InboundBytes. +Inbound message bytes are handled with an onReceive callback function. */ type MConnection struct { conn net.Conn @@ -48,17 +51,17 @@ type MConnection struct { chStatsTimer *RepeatTimer // update channel stats periodically channels []*Channel channelsIdx map[byte]*Channel - onError func(interface{}) + onReceive receiveCbFunc + onError errorCbFunc started uint32 stopped uint32 errored uint32 - Peer *Peer // hacky optimization, gets set by Peer LocalAddress *NetAddress RemoteAddress *NetAddress } -func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onError func(interface{})) *MConnection { +func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc) *MConnection { mconn := &MConnection{ conn: conn, @@ -74,6 +77,7 @@ func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onError func(in pingTimer: NewRepeatTimer(pingTimeoutMinutes * time.Minute), pong: make(chan struct{}), chStatsTimer: NewRepeatTimer(updateStatsSeconds * time.Second), + onReceive: onReceive, onError: onError, LocalAddress: NewNetAddress(conn.LocalAddr()), RemoteAddress: NewNetAddress(conn.RemoteAddr()), @@ -288,7 +292,7 @@ func (c *MConnection) sendPacket() bool { var leastChannel *Channel for _, channel := range c.channels { // If nothing to send, skip this channel - if !channel.sendPending() { + if !channel.isSendPending() { continue } // Get ratio, and keep track of lowest ratio. @@ -319,7 +323,7 @@ func (c *MConnection) sendPacket() bool { } // recvRoutine reads packets and reconstructs the message using the channels' "recving" buffer. -// After a whole message has been assembled, it's pushed to the Channel's recvQueue. +// After a whole message has been assembled, it's pushed to onReceive(). // Blocks depending on how the connection is throttled. func (c *MConnection) recvRoutine() { defer c._recover() @@ -372,7 +376,10 @@ FOR_LOOP: if channel == nil { Panicf("Unknown channel %v", pkt.ChannelId) } - channel.recvPacket(pkt) + msgBytes := channel.recvPacket(pkt) + if msgBytes != nil { + c.onReceive(pkt.ChannelId, msgBytes) + } default: Panicf("Unknown message type %v", pktType) } @@ -397,10 +404,6 @@ type ChannelDescriptor struct { RecvQueueCapacity int // Global for this channel. RecvBufferSize int DefaultPriority uint - - // TODO: kinda hacky. - // This is created by the switch, one per channel. - recvQueue chan InboundBytes } // TODO: lowercase. @@ -409,7 +412,6 @@ type Channel struct { conn *MConnection desc *ChannelDescriptor id byte - recvQueue chan InboundBytes sendQueue chan []byte sendQueueSize uint32 recving []byte @@ -426,7 +428,6 @@ func newChannel(conn *MConnection, desc *ChannelDescriptor) *Channel { conn: conn, desc: desc, id: desc.Id, - recvQueue: desc.recvQueue, sendQueue: make(chan []byte, desc.SendQueueCapacity), recving: make([]byte, 0, desc.RecvBufferSize), priority: desc.DefaultPriority, @@ -467,7 +468,7 @@ func (ch *Channel) canSend() bool { // Returns true if any packets are pending to be sent. // Call before calling nextPacket() // Goroutine-safe -func (ch *Channel) sendPending() bool { +func (ch *Channel) isSendPending() bool { if len(ch.sending) == 0 { if len(ch.sendQueue) == 0 { return false @@ -506,14 +507,16 @@ func (ch *Channel) writePacketTo(w io.Writer) (n int64, err error) { return } -// Handles incoming packets. +// Handles incoming packets. Returns a msg bytes if msg is complete. // Not goroutine-safe -func (ch *Channel) recvPacket(pkt packet) { +func (ch *Channel) recvPacket(pkt packet) []byte { ch.recving = append(ch.recving, pkt.Bytes...) if pkt.EOF == byte(0x01) { - ch.recvQueue <- InboundBytes{ch.conn, ch.recving} + msgBytes := ch.recving ch.recving = make([]byte, 0, ch.desc.RecvBufferSize) + return msgBytes } + return nil } // Call this periodically to update stats for throttling purposes. @@ -561,13 +564,6 @@ func readPacketSafe(r io.Reader) (pkt packet, n int64, err error) { //----------------------------------------------------------------------------- -type InboundBytes struct { - MConn *MConnection - Bytes []byte -} - -//----------------------------------------------------------------------------- - // Convenience struct for writing typed messages. // Reading requires a custom decoder that switches on the first type byte of a byteslice. type TypedMessage struct { diff --git a/p2p/peer.go b/p2p/peer.go index c331c0c89..1f39cd7b5 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -7,6 +7,7 @@ import ( "sync/atomic" . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/common" ) /* Peer */ @@ -17,23 +18,31 @@ type Peer struct { started uint32 stopped uint32 - Key string + Key string + Data *CMap // User data. } -func newPeer(conn net.Conn, outbound bool, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{})) *Peer { +func newPeer(conn net.Conn, outbound bool, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{})) *Peer { var p *Peer + onReceive := func(chId byte, msgBytes []byte) { + reactor := reactorsByCh[chId] + if reactor == nil { + Panicf("Unknown channel %X", chId) + } + reactor.Receive(chId, p, msgBytes) + } onError := func(r interface{}) { p.stop() onPeerError(p, r) } - mconn := NewMConnection(conn, chDescs, onError) + mconn := NewMConnection(conn, chDescs, onReceive, onError) p = &Peer{ outbound: outbound, mconn: mconn, stopped: 0, Key: mconn.RemoteAddress.String(), + Data: NewCMap(), } - mconn.Peer = p // hacky optimization return p } @@ -51,6 +60,14 @@ func (p *Peer) stop() { } } +func (p *Peer) IsStopped() bool { + return atomic.LoadUint32(&p.stopped) == 1 +} + +func (p *Peer) RemoteAddress() *NetAddress { + return p.mconn.RemoteAddress +} + func (p *Peer) IsOutbound() bool { return p.outbound } diff --git a/p2p/pex_agent.go b/p2p/pex_agent.go deleted file mode 100644 index ea5f55891..000000000 --- a/p2p/pex_agent.go +++ /dev/null @@ -1,278 +0,0 @@ -package p2p - -import ( - "bytes" - "errors" - "fmt" - "io" - "sync/atomic" - "time" - - . "github.com/tendermint/tendermint/binary" - . "github.com/tendermint/tendermint/common" -) - -var pexErrInvalidMessage = errors.New("Invalid PEX message") - -const ( - PexCh = byte(0x00) - ensurePeersPeriodSeconds = 30 - minNumOutboundPeers = 10 - maxNumPeers = 50 -) - -/* -PEXAgent handles PEX (peer exchange) and ensures that an -adequate number of peers are connected to the switch. -*/ -type PEXAgent struct { - sw *Switch - swEvents chan interface{} - quit chan struct{} - started uint32 - stopped uint32 - - book *AddrBook -} - -func NewPEXAgent(sw *Switch, book *AddrBook) *PEXAgent { - swEvents := make(chan interface{}) - sw.AddEventListener("PEXAgent.swEvents", swEvents) - pexA := &PEXAgent{ - sw: sw, - swEvents: swEvents, - quit: make(chan struct{}), - book: book, - } - return pexA -} - -func (pexA *PEXAgent) Start() { - if atomic.CompareAndSwapUint32(&pexA.started, 0, 1) { - log.Info("Starting PEXAgent") - go pexA.switchEventsRoutine() - go pexA.requestRoutine() - go pexA.ensurePeersRoutine() - } -} - -func (pexA *PEXAgent) Stop() { - if atomic.CompareAndSwapUint32(&pexA.stopped, 0, 1) { - log.Info("Stopping PEXAgent") - close(pexA.quit) - close(pexA.swEvents) - } -} - -// Asks peer for more addresses. -func (pexA *PEXAgent) RequestPEX(peer *Peer) { - peer.TrySend(PexCh, &pexRequestMessage{}) -} - -func (pexA *PEXAgent) SendAddrs(peer *Peer, addrs []*NetAddress) { - peer.Send(PexCh, &pexAddrsMessage{Addrs: addrs}) -} - -// For new outbound peers, announce our listener addresses if any, -// and if .book needs more addresses, ask for them. -func (pexA *PEXAgent) switchEventsRoutine() { - for { - swEvent, ok := <-pexA.swEvents - if !ok { - break - } - switch swEvent.(type) { - case SwitchEventNewPeer: - event := swEvent.(SwitchEventNewPeer) - if event.Peer.IsOutbound() { - pexA.SendAddrs(event.Peer, pexA.book.OurAddresses()) - if pexA.book.NeedMoreAddrs() { - pexA.RequestPEX(event.Peer) - } - } - case SwitchEventDonePeer: - // TODO - } - } -} - -// Ensures that sufficient peers are connected. (continuous) -func (pexA *PEXAgent) ensurePeersRoutine() { - // fire once immediately. - pexA.ensurePeers() - // fire periodically - timer := NewRepeatTimer(ensurePeersPeriodSeconds * time.Second) -FOR_LOOP: - for { - select { - case <-timer.Ch: - pexA.ensurePeers() - case <-pexA.quit: - break FOR_LOOP - } - } - - // Cleanup - timer.Stop() -} - -// Ensures that sufficient peers are connected. (once) -func (pexA *PEXAgent) ensurePeers() { - numOutPeers, _, numDialing := pexA.sw.NumPeers() - numToDial := minNumOutboundPeers - (numOutPeers + numDialing) - if numToDial <= 0 { - return - } - toDial := NewCMap() - - // Try to pick numToDial addresses to dial. - // TODO: improve logic. - for i := 0; i < numToDial; i++ { - newBias := MinInt(numOutPeers, 8)*10 + 10 - var picked *NetAddress - // Try to fetch a new peer 3 times. - // This caps the maximum number of tries to 3 * numToDial. - for j := 0; i < 3; j++ { - picked = pexA.book.PickAddress(newBias) - if picked == nil { - return - } - if toDial.Has(picked.String()) || - pexA.sw.IsDialing(picked) || - pexA.sw.Peers().Has(picked.String()) { - continue - } else { - break - } - } - if picked == nil { - continue - } - toDial.Set(picked.String(), picked) - } - - // Dial picked addresses - for _, item := range toDial.Values() { - picked := item.(*NetAddress) - go func() { - _, err := pexA.sw.DialPeerWithAddress(picked) - if err != nil { - pexA.book.MarkAttempt(picked) - } - }() - } -} - -// Handles incoming PEX messages. -func (pexA *PEXAgent) requestRoutine() { - - for { - inMsg, ok := pexA.sw.Receive(PexCh) // {Peer, Time, Packet} - if !ok { - // Client has stopped - break - } - - // decode message - msg := decodeMessage(inMsg.Bytes) - log.Info("requestRoutine received %v", msg) - - switch msg.(type) { - case *pexRequestMessage: - // inMsg.MConn.Peer requested some peers. - // TODO: prevent abuse. - addrs := pexA.book.GetSelection() - msg := &pexAddrsMessage{Addrs: addrs} - queued := inMsg.MConn.Peer.TrySend(PexCh, msg) - if !queued { - // ignore - } - case *pexAddrsMessage: - // We received some peer addresses from inMsg.MConn.Peer. - // TODO: prevent abuse. - // (We don't want to get spammed with bad peers) - srcAddr := inMsg.MConn.RemoteAddress - for _, addr := range msg.(*pexAddrsMessage).Addrs { - pexA.book.AddAddress(addr, srcAddr) - } - default: - // Ignore unknown message. - // pexA.sw.StopPeerForError(inMsg.MConn.Peer, pexErrInvalidMessage) - } - } - - // Cleanup - -} - -//----------------------------------------------------------------------------- - -/* Messages */ - -const ( - msgTypeUnknown = byte(0x00) - msgTypeRequest = byte(0x01) - msgTypeAddrs = byte(0x02) -) - -// TODO: check for unnecessary extra bytes at the end. -func decodeMessage(bz []byte) (msg interface{}) { - var n int64 - var err error - // log.Debug("decoding msg bytes: %X", bz) - switch bz[0] { - case msgTypeRequest: - return &pexRequestMessage{} - case msgTypeAddrs: - return readPexAddrsMessage(bytes.NewReader(bz[1:]), &n, &err) - default: - return nil - } -} - -/* -A pexRequestMessage requests additional peer addresses. -*/ -type pexRequestMessage struct { -} - -func (m *pexRequestMessage) WriteTo(w io.Writer) (n int64, err error) { - WriteByte(w, msgTypeRequest, &n, &err) - return -} - -func (m *pexRequestMessage) String() string { - return "[pexRequest]" -} - -/* -A message with announced peer addresses. -*/ -type pexAddrsMessage struct { - Addrs []*NetAddress -} - -func readPexAddrsMessage(r io.Reader, n *int64, err *error) *pexAddrsMessage { - numAddrs := int(ReadUInt32(r, n, err)) - addrs := []*NetAddress{} - for i := 0; i < numAddrs; i++ { - addr := ReadNetAddress(r, n, err) - addrs = append(addrs, addr) - } - return &pexAddrsMessage{ - Addrs: addrs, - } -} - -func (m *pexAddrsMessage) WriteTo(w io.Writer) (n int64, err error) { - WriteByte(w, msgTypeAddrs, &n, &err) - WriteUInt32(w, uint32(len(m.Addrs)), &n, &err) - for _, addr := range m.Addrs { - WriteBinary(w, addr, &n, &err) - } - return -} - -func (m *pexAddrsMessage) String() string { - return fmt.Sprintf("[pexAddrs %v]", m.Addrs) -} diff --git a/p2p/pex_reactor.go b/p2p/pex_reactor.go new file mode 100644 index 000000000..da86f20bf --- /dev/null +++ b/p2p/pex_reactor.go @@ -0,0 +1,267 @@ +package p2p + +import ( + "bytes" + "errors" + "fmt" + "io" + "sync/atomic" + "time" + + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/common" +) + +var pexErrInvalidMessage = errors.New("Invalid PEX message") + +const ( + PexCh = byte(0x00) + ensurePeersPeriodSeconds = 30 + minNumOutboundPeers = 10 + maxNumPeers = 50 +) + +/* +PEXReactor handles PEX (peer exchange) and ensures that an +adequate number of peers are connected to the switch. +*/ +type PEXReactor struct { + sw *Switch + quit chan struct{} + started uint32 + stopped uint32 + + book *AddrBook +} + +func NewPEXReactor(sw *Switch, book *AddrBook) *PEXReactor { + pexR := &PEXReactor{ + sw: sw, + quit: make(chan struct{}), + book: book, + } + return pexR +} + +func (pexR *PEXReactor) Start() { + if atomic.CompareAndSwapUint32(&pexR.started, 0, 1) { + log.Info("Starting PEXReactor") + go pexR.ensurePeersRoutine() + } +} + +func (pexR *PEXReactor) Stop() { + if atomic.CompareAndSwapUint32(&pexR.stopped, 0, 1) { + log.Info("Stopping PEXReactor") + close(pexR.quit) + } +} + +// Asks peer for more addresses. +func (pexR *PEXReactor) RequestPEX(peer *Peer) { + peer.TrySend(PexCh, &pexRequestMessage{}) +} + +func (pexR *PEXReactor) SendAddrs(peer *Peer, addrs []*NetAddress) { + peer.Send(PexCh, &pexRddrsMessage{Addrs: addrs}) +} + +// Implements Reactor +func (pexR *PEXReactor) GetChannels() []*ChannelDescriptor { + // TODO optimize + return []*ChannelDescriptor{ + &ChannelDescriptor{ + Id: PexCh, + SendQueueCapacity: 1, + RecvQueueCapacity: 2, + RecvBufferSize: 1024, + DefaultPriority: 1, + }, + } +} + +// Implements Reactor +func (pexR *PEXReactor) AddPeer(peer *Peer) { + if peer.IsOutbound() { + pexR.SendAddrs(peer, pexR.book.OurAddresses()) + if pexR.book.NeedMoreAddrs() { + pexR.RequestPEX(peer) + } + } +} + +// Implements Reactor +func (pexR *PEXReactor) RemovePeer(peer *Peer, err error) { + // TODO +} + +// Implements Reactor +// Handles incoming PEX messages. +func (pexR *PEXReactor) Receive(chId byte, src *Peer, msgBytes []byte) { + + // decode message + msg := decodeMessage(msgBytes) + log.Info("requestRoutine received %v", msg) + + switch msg.(type) { + case *pexRequestMessage: + // src requested some peers. + // TODO: prevent abuse. + addrs := pexR.book.GetSelection() + msg := &pexRddrsMessage{Addrs: addrs} + queued := src.TrySend(PexCh, msg) + if !queued { + // ignore + } + case *pexRddrsMessage: + // We received some peer addresses from src. + // TODO: prevent abuse. + // (We don't want to get spammed with bad peers) + srcAddr := src.RemoteAddress() + for _, addr := range msg.(*pexRddrsMessage).Addrs { + pexR.book.AddAddress(addr, srcAddr) + } + default: + // Ignore unknown message. + } + +} + +// Ensures that sufficient peers are connected. (continuous) +func (pexR *PEXReactor) ensurePeersRoutine() { + // fire once immediately. + pexR.ensurePeers() + // fire periodically + timer := NewRepeatTimer(ensurePeersPeriodSeconds * time.Second) +FOR_LOOP: + for { + select { + case <-timer.Ch: + pexR.ensurePeers() + case <-pexR.quit: + break FOR_LOOP + } + } + + // Cleanup + timer.Stop() +} + +// Ensures that sufficient peers are connected. (once) +func (pexR *PEXReactor) ensurePeers() { + numOutPeers, _, numDialing := pexR.sw.NumPeers() + numToDial := minNumOutboundPeers - (numOutPeers + numDialing) + if numToDial <= 0 { + return + } + toDial := NewCMap() + + // Try to pick numToDial addresses to dial. + // TODO: improve logic. + for i := 0; i < numToDial; i++ { + newBias := MinInt(numOutPeers, 8)*10 + 10 + var picked *NetAddress + // Try to fetch a new peer 3 times. + // This caps the maximum number of tries to 3 * numToDial. + for j := 0; i < 3; j++ { + picked = pexR.book.PickAddress(newBias) + if picked == nil { + return + } + if toDial.Has(picked.String()) || + pexR.sw.IsDialing(picked) || + pexR.sw.Peers().Has(picked.String()) { + continue + } else { + break + } + } + if picked == nil { + continue + } + toDial.Set(picked.String(), picked) + } + + // Dial picked addresses + for _, item := range toDial.Values() { + picked := item.(*NetAddress) + go func() { + _, err := pexR.sw.DialPeerWithAddress(picked) + if err != nil { + pexR.book.MarkAttempt(picked) + } + }() + } +} + +//----------------------------------------------------------------------------- + +/* Messages */ + +const ( + msgTypeUnknown = byte(0x00) + msgTypeRequest = byte(0x01) + msgTypeAddrs = byte(0x02) +) + +// TODO: check for unnecessary extra bytes at the end. +func decodeMessage(bz []byte) (msg interface{}) { + var n int64 + var err error + // log.Debug("decoding msg bytes: %X", bz) + switch bz[0] { + case msgTypeRequest: + return &pexRequestMessage{} + case msgTypeAddrs: + return readPexAddrsMessage(bytes.NewReader(bz[1:]), &n, &err) + default: + return nil + } +} + +/* +A pexRequestMessage requests additional peer addresses. +*/ +type pexRequestMessage struct { +} + +func (m *pexRequestMessage) WriteTo(w io.Writer) (n int64, err error) { + WriteByte(w, msgTypeRequest, &n, &err) + return +} + +func (m *pexRequestMessage) String() string { + return "[pexRequest]" +} + +/* +A message with announced peer addresses. +*/ +type pexRddrsMessage struct { + Addrs []*NetAddress +} + +func readPexAddrsMessage(r io.Reader, n *int64, err *error) *pexRddrsMessage { + numAddrs := int(ReadUInt32(r, n, err)) + addrs := []*NetAddress{} + for i := 0; i < numAddrs; i++ { + addr := ReadNetAddress(r, n, err) + addrs = append(addrs, addr) + } + return &pexRddrsMessage{ + Addrs: addrs, + } +} + +func (m *pexRddrsMessage) WriteTo(w io.Writer) (n int64, err error) { + WriteByte(w, msgTypeAddrs, &n, &err) + WriteUInt32(w, uint32(len(m.Addrs)), &n, &err) + for _, addr := range m.Addrs { + WriteBinary(w, addr, &n, &err) + } + return +} + +func (m *pexRddrsMessage) String() string { + return fmt.Sprintf("[pexRddrs %v]", m.Addrs) +} diff --git a/p2p/switch.go b/p2p/switch.go index 9a3386892..565d2c16d 100644 --- a/p2p/switch.go +++ b/p2p/switch.go @@ -10,6 +10,15 @@ import ( . "github.com/tendermint/tendermint/common" ) +type Reactor interface { + GetChannels() []*ChannelDescriptor + AddPeer(peer *Peer) + RemovePeer(peer *Peer, reason interface{}) + Receive(chId byte, peer *Peer, msgBytes []byte) +} + +//----------------------------------------------------------------------------- + /* All communication amongst peers are multiplexed by "channels". (Not the same as Go "channels") @@ -26,14 +35,15 @@ The receiver is responsible for decoding the message bytes, which may be precede by a single type byte if a TypedBytes{} was used. */ type Switch struct { - chDescs []*ChannelDescriptor - recvQueues map[byte]chan InboundBytes - peers *PeerSet - dialing *CMap - listeners *CMap // listenerName -> chan interface{} - quit chan struct{} - started uint32 - stopped uint32 + reactors []Reactor + chDescs []*ChannelDescriptor + reactorsByCh map[byte]Reactor + peers *PeerSet + dialing *CMap + listeners *CMap // listenerName -> chan interface{} + quit chan struct{} + started uint32 + stopped uint32 } var ( @@ -45,22 +55,32 @@ const ( peerDialTimeoutSeconds = 30 ) -func NewSwitch(chDescs []*ChannelDescriptor) *Switch { - s := &Switch{ - chDescs: chDescs, - recvQueues: make(map[byte]chan InboundBytes), - peers: NewPeerSet(), - dialing: NewCMap(), - listeners: NewCMap(), - quit: make(chan struct{}), - stopped: 0, +func NewSwitch(reactors []Reactor) *Switch { + + // Validate the reactors. no two reactors can share the same channel. + chDescs := []*ChannelDescriptor{} + reactorsByCh := make(map[byte]Reactor) + for _, reactor := range reactors { + reactorChannels := reactor.GetChannels() + for _, chDesc := range reactorChannels { + chId := chDesc.Id + if reactorsByCh[chId] != nil { + Panicf("Channel %X has multiple reactors %v & %v", chId, reactorsByCh[chId], reactor) + } + chDescs = append(chDescs, chDesc) + reactorsByCh[chId] = reactor + } } - // Create global recvQueues, one per channel. - for _, chDesc := range chDescs { - recvQueue := make(chan InboundBytes, chDesc.RecvQueueCapacity) - chDesc.recvQueue = recvQueue - s.recvQueues[chDesc.Id] = recvQueue + s := &Switch{ + reactors: reactors, + chDescs: chDescs, + reactorsByCh: reactorsByCh, + peers: NewPeerSet(), + dialing: NewCMap(), + listeners: NewCMap(), + quit: make(chan struct{}), + stopped: 0, } return s @@ -90,7 +110,7 @@ func (s *Switch) AddPeerWithConnection(conn net.Conn, outbound bool) (*Peer, err return nil, ErrSwitchStopped } - peer := newPeer(conn, outbound, s.chDescs, s.StopPeerForError) + peer := newPeer(conn, outbound, s.reactorsByCh, s.chDescs, s.StopPeerForError) // Add the peer to .peers if s.peers.Add(peer) { @@ -104,7 +124,7 @@ func (s *Switch) AddPeerWithConnection(conn net.Conn, outbound bool) (*Peer, err go peer.start() // Notify listeners. - s.emit(SwitchEventNewPeer{Peer: peer}) + s.doAddPeer(peer) return peer, nil } @@ -151,38 +171,6 @@ func (s *Switch) Broadcast(chId byte, msg Binary) (numSuccess, numFailure int) { } -// The events are of type SwitchEvent* defined below. -// Switch does not close these listeners. -func (s *Switch) AddEventListener(name string, listener chan<- interface{}) { - s.listeners.Set(name, listener) -} - -func (s *Switch) RemoveEventListener(name string) { - s.listeners.Delete(name) -} - -/* -Receive blocks on a channel until a message is found. -*/ -func (s *Switch) Receive(chId byte) (InboundBytes, bool) { - if atomic.LoadUint32(&s.stopped) == 1 { - return InboundBytes{}, false - } - - q := s.recvQueues[chId] - if q == nil { - Panicf("Expected recvQueues[%X], found none", chId) - } - - select { - case <-s.quit: - return InboundBytes{}, false - case inBytes := <-q: - log.Debug("RECV %v", inBytes) - return inBytes, true - } -} - // Returns the count of outbound/inbound and outbound-dialing peers. func (s *Switch) NumPeers() (outbound, inbound, dialing int) { peers := s.peers.List() @@ -209,7 +197,7 @@ func (s *Switch) StopPeerForError(peer *Peer, reason interface{}) { peer.stop() // Notify listeners - s.emit(SwitchEventDonePeer{Peer: peer, Error: reason}) + s.doRemovePeer(peer, reason) } // Disconnect from a peer gracefully. @@ -220,13 +208,18 @@ func (s *Switch) StopPeerGracefully(peer *Peer) { peer.stop() // Notify listeners - s.emit(SwitchEventDonePeer{Peer: peer}) + s.doRemovePeer(peer, nil) +} + +func (s *Switch) doAddPeer(peer *Peer) { + for _, reactor := range s.reactors { + reactor.AddPeer(peer) + } } -func (s *Switch) emit(event interface{}) { - for _, ch_i := range s.listeners.Values() { - ch := ch_i.(chan<- interface{}) - ch <- event +func (s *Switch) doRemovePeer(peer *Peer, reason interface{}) { + for _, reactor := range s.reactors { + reactor.RemovePeer(peer, reason) } } diff --git a/state/state.go b/state/state.go index 3923e7ead..10eee0dd1 100644 --- a/state/state.go +++ b/state/state.go @@ -46,11 +46,12 @@ func LoadState(db db_.Db) *State { s.blockHash = ReadByteSlice(reader, &n, &err) accountsMerkleRoot := ReadByteSlice(reader, &n, &err) s.accounts = merkle.NewIAVLTreeFromHash(db, accountsMerkleRoot) - s.validators = NewValidatorSet(nil) + var validators = map[uint64]*Validator{} for reader.Len() > 0 { validator := ReadValidator(reader, &n, &err) - s.validators.Add(validator) + validators[validator.Id] = validator } + s.validators = NewValidatorSet(validators) if err != nil { panic(err) } diff --git a/state/validator.go b/state/validator.go index eab2719e7..9f493de69 100644 --- a/state/validator.go +++ b/state/validator.go @@ -54,30 +54,49 @@ func (v *Validator) WriteTo(w io.Writer) (n int64, err error) { // Not goroutine-safe. type ValidatorSet struct { - validators map[uint64]*Validator + validators map[uint64]*Validator + indexToId map[uint32]uint64 // bitarray index to validator id + idToIndex map[uint64]uint32 // validator id to bitarray index + totalVotingPower uint64 } func NewValidatorSet(validators map[uint64]*Validator) *ValidatorSet { if validators == nil { validators = make(map[uint64]*Validator) } + ids := []uint64{} + indexToId := map[uint32]uint64{} + idToIndex := map[uint64]uint32{} + totalVotingPower := uint64(0) + for id, val := range validators { + ids = append(ids, id) + totalVotingPower += val.VotingPower + } + UInt64Slice(ids).Sort() + for i, id := range ids { + indexToId[uint32(i)] = id + idToIndex[id] = uint32(i) + } return &ValidatorSet{ - validators: validators, + validators: validators, + indexToId: indexToId, + idToIndex: idToIndex, + totalVotingPower: totalVotingPower, } } -func (v *ValidatorSet) IncrementAccum() { +func (vset *ValidatorSet) IncrementAccum() { totalDelta := int64(0) - for _, validator := range v.validators { + for _, validator := range vset.validators { validator.Accum += int64(validator.VotingPower) totalDelta += int64(validator.VotingPower) } - proposer := v.GetProposer() + proposer := vset.GetProposer() proposer.Accum -= totalDelta // NOTE: sum(v) here should be zero. if true { totalAccum := int64(0) - for _, validator := range v.validators { + for _, validator := range vset.validators { totalAccum += validator.Accum } if totalAccum != 0 { @@ -86,36 +105,49 @@ func (v *ValidatorSet) IncrementAccum() { } } -func (v *ValidatorSet) Copy() *ValidatorSet { - mapCopy := map[uint64]*Validator{} - for _, val := range v.validators { - mapCopy[val.Id] = val.Copy() +func (vset *ValidatorSet) Copy() *ValidatorSet { + validators := map[uint64]*Validator{} + for id, val := range vset.validators { + validators[id] = val.Copy() } return &ValidatorSet{ - validators: mapCopy, + validators: validators, + indexToId: vset.indexToId, + idToIndex: vset.idToIndex, + totalVotingPower: vset.totalVotingPower, } } -func (v *ValidatorSet) Add(validator *Validator) { - v.validators[validator.Id] = validator +func (vset *ValidatorSet) GetById(id uint64) *Validator { + return vset.validators[id] +} + +func (vset *ValidatorSet) GetIndexById(id uint64) (uint32, bool) { + index, ok := vset.idToIndex[id] + return index, ok +} + +func (vset *ValidatorSet) GetIdByIndex(index uint32) (uint64, bool) { + id, ok := vset.indexToId[index] + return id, ok } -func (v *ValidatorSet) Get(id uint64) *Validator { - return v.validators[id] +func (vset *ValidatorSet) Map() map[uint64]*Validator { + return vset.validators } -func (v *ValidatorSet) Map() map[uint64]*Validator { - return v.validators +func (vset *ValidatorSet) Size() uint { + return uint(len(vset.validators)) } -func (v *ValidatorSet) Size() int { - return len(v.validators) +func (vset *ValidatorSet) TotalVotingPower() uint64 { + return vset.totalVotingPower } // TODO: cache proposer. invalidate upon increment. -func (v *ValidatorSet) GetProposer() (proposer *Validator) { +func (vset *ValidatorSet) GetProposer() (proposer *Validator) { highestAccum := int64(0) - for _, validator := range v.validators { + for _, validator := range vset.validators { if validator.Accum > highestAccum { highestAccum = validator.Accum proposer = validator