|
|
- package blocks
-
- import (
- "bytes"
- "errors"
- "sync"
- )
-
- // Helper for keeping track of block parts.
- type BlockPartSet struct {
- mtx sync.Mutex
- signer *Account
- height uint32
- round uint16 // Not used
- total uint16
- numParts uint16
- parts []*BlockPart
-
- _block *Block // cache
- }
-
- var (
- ErrInvalidBlockPartSignature = errors.New("Invalid block part signature") // Peer gave us a fake part
- ErrInvalidBlockPartConflict = errors.New("Invalid block part conflict") // Signer signed conflicting parts
- )
-
- // Signer may be nil if signer is unknown beforehand.
- func NewBlockPartSet(height uint32, round uint16, signer *Account) *BlockPartSet {
- return &BlockPartSet{
- signer: signer,
- height: height,
- round: round,
- }
- }
-
- // In the case where the signer wasn't known prior to NewBlockPartSet(),
- // user should call SetSigner() prior to AddBlockPart().
- func (bps *BlockPartSet) SetSigner(signer *Account) {
- bps.mtx.Lock()
- defer bps.mtx.Unlock()
- if bps.signer != nil {
- panic("BlockPartSet signer already set.")
- }
- bps.signer = signer
- }
-
- func (bps *BlockPartSet) BlockParts() []*BlockPart {
- bps.mtx.Lock()
- defer bps.mtx.Unlock()
- return bps.parts
- }
-
- func (bps *BlockPartSet) BitArray() []byte {
- bps.mtx.Lock()
- defer bps.mtx.Unlock()
- if bps.parts == nil {
- return nil
- }
- bitArray := make([]byte, (len(bps.parts)+7)/8)
- for i, part := range bps.parts {
- if part != nil {
- bitArray[i/8] |= 1 << uint(i%8)
- }
- }
- return bitArray
- }
-
- // If the part isn't valid, returns an error.
- // err can be ErrInvalidBlockPart[Conflict|Signature]
- func (bps *BlockPartSet) AddBlockPart(part *BlockPart) (added bool, err error) {
- bps.mtx.Lock()
- defer bps.mtx.Unlock()
-
- // If part is invalid, return an error.
- err = part.ValidateWithSigner(bps.signer)
- if err != nil {
- return false, err
- }
-
- if bps.parts == nil {
- // First received part for this round.
- bps.parts = make([]*BlockPart, part.Total)
- bps.total = uint16(part.Total)
- bps.parts[int(part.Index)] = part
- bps.numParts++
- return true, nil
- } else {
- // Check part.Index and part.Total
- if uint16(part.Index) >= bps.total {
- return false, ErrInvalidBlockPartConflict
- }
- if uint16(part.Total) != bps.total {
- return false, ErrInvalidBlockPartConflict
- }
- // Check for existing parts.
- existing := bps.parts[part.Index]
- if existing != nil {
- if existing.Bytes.Equals(part.Bytes) {
- // Ignore duplicate
- return false, nil
- } else {
- return false, ErrInvalidBlockPartConflict
- }
- } else {
- bps.parts[int(part.Index)] = part
- bps.numParts++
- return true, nil
- }
- }
-
- }
-
- func (bps *BlockPartSet) IsComplete() bool {
- bps.mtx.Lock()
- defer bps.mtx.Unlock()
- return bps.total > 0 && bps.total == bps.numParts
- }
-
- func (bps *BlockPartSet) Block() *Block {
- if !bps.IsComplete() {
- return nil
- }
- bps.mtx.Lock()
- defer bps.mtx.Unlock()
- if bps._block == nil {
- blockBytes := []byte{}
- for _, part := range bps.parts {
- blockBytes = append(blockBytes, part.Bytes...)
- }
- block := ReadBlock(bytes.NewReader(blockBytes))
- bps._block = block
- }
- return bps._block
- }
|