|
|
- package consensus
-
- import (
- "bytes"
- "crypto/sha256"
- "errors"
- "fmt"
- "io"
- "strings"
- "sync"
-
- . "github.com/tendermint/tendermint/binary"
- . "github.com/tendermint/tendermint/common"
- "github.com/tendermint/tendermint/merkle"
- )
-
- const (
- partSize = 4096 // 4KB
- )
-
- var (
- ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index")
- ErrPartSetInvalidTrail = errors.New("Error part set invalid trail")
- )
-
- type Part struct {
- Index uint16
- Trail [][]byte
- Bytes []byte
-
- // Cache
- hash []byte
- }
-
- func ReadPart(r io.Reader, n *int64, err *error) *Part {
- return &Part{
- Index: ReadUInt16(r, n, err),
- Trail: ReadByteSlices(r, n, err),
- Bytes: ReadByteSlice(r, n, err),
- }
- }
-
- func (b *Part) WriteTo(w io.Writer) (n int64, err error) {
- WriteUInt16(w, b.Index, &n, &err)
- WriteByteSlices(w, b.Trail, &n, &err)
- WriteByteSlice(w, b.Bytes, &n, &err)
- return
- }
-
- func (pt *Part) Hash() []byte {
- if pt.hash != nil {
- return pt.hash
- } else {
- hasher := sha256.New()
- _, err := hasher.Write(pt.Bytes)
- if err != nil {
- panic(err)
- }
- pt.hash = hasher.Sum(nil)
- return pt.hash
- }
- }
-
- func (pt *Part) String() string {
- return pt.StringWithIndent("")
- }
-
- func (pt *Part) StringWithIndent(indent string) string {
- trailStrings := make([]string, len(pt.Trail))
- for i, hash := range pt.Trail {
- trailStrings[i] = fmt.Sprintf("%X", hash)
- }
- return fmt.Sprintf(`Part{
- %s Index: %v
- %s Trail:
- %s %v
- %s}`,
- indent, pt.Index,
- indent,
- indent, strings.Join(trailStrings, "\n"+indent+" "),
- indent)
- }
-
- //-------------------------------------
-
- type PartSet struct {
- rootHash []byte
- total uint16
-
- mtx sync.Mutex
- parts []*Part
- partsBitArray BitArray
- count uint16
- }
-
- // Returns an immutable, full PartSet.
- func NewPartSetFromData(data []byte) *PartSet {
- // divide data into 4kb parts.
- total := (len(data) + partSize - 1) / partSize
- parts := make([]*Part, total)
- parts_ := make([]merkle.Hashable, total)
- partsBitArray := NewBitArray(uint(total))
- for i := 0; i < total; i++ {
- part := &Part{
- Index: uint16(i),
- Bytes: data[i*partSize : MinInt(len(data), (i+1)*partSize)],
- }
- parts[i] = part
- parts_[i] = part
- partsBitArray.SetIndex(uint(i), true)
- }
- // Compute merkle trails
- hashTree := merkle.HashTreeFromHashables(parts_)
- for i := 0; i < total; i++ {
- parts[i].Trail = merkle.HashTrailForIndex(hashTree, i)
- }
- return &PartSet{
- parts: parts,
- partsBitArray: partsBitArray,
- rootHash: hashTree[len(hashTree)/2],
- total: uint16(total),
- count: uint16(total),
- }
- }
-
- // Returns an empty PartSet ready to be populated.
- func NewPartSetFromMetadata(total uint16, rootHash []byte) *PartSet {
- return &PartSet{
- parts: make([]*Part, total),
- partsBitArray: NewBitArray(uint(total)),
- rootHash: rootHash,
- total: total,
- count: 0,
- }
- }
-
- func (ps *PartSet) BitArray() BitArray {
- ps.mtx.Lock()
- defer ps.mtx.Unlock()
- return ps.partsBitArray.Copy()
- }
-
- func (ps *PartSet) RootHash() []byte {
- return ps.rootHash
- }
-
- func (ps *PartSet) Count() uint16 {
- if ps == nil {
- return 0
- }
- return ps.count
- }
-
- func (ps *PartSet) Total() uint16 {
- if ps == nil {
- return 0
- }
- return ps.total
- }
-
- func (ps *PartSet) AddPart(part *Part) (bool, error) {
- ps.mtx.Lock()
- defer ps.mtx.Unlock()
-
- // Invalid part index
- if part.Index >= ps.total {
- return false, ErrPartSetUnexpectedIndex
- }
-
- // If part already exists, return false.
- if ps.parts[part.Index] != nil {
- return false, nil
- }
-
- // Check hash trail
- if !merkle.VerifyHashTrailForIndex(int(part.Index), part.Hash(), part.Trail, ps.rootHash) {
- return false, ErrPartSetInvalidTrail
- }
-
- // Add part
- ps.parts[part.Index] = part
- ps.partsBitArray.SetIndex(uint(part.Index), true)
- ps.count++
- return true, nil
- }
-
- func (ps *PartSet) GetPart(index uint16) *Part {
- ps.mtx.Lock()
- defer ps.mtx.Unlock()
- return ps.parts[index]
- }
-
- func (ps *PartSet) IsComplete() bool {
- return ps.count == ps.total
- }
-
- func (ps *PartSet) GetReader() io.Reader {
- if !ps.IsComplete() {
- panic("Cannot GetReader() on incomplete PartSet")
- }
- buf := []byte{}
- for _, part := range ps.parts {
- buf = append(buf, part.Bytes...)
- }
- return bytes.NewReader(buf)
- }
-
- func (ps *PartSet) Description() string {
- if ps == nil {
- return "nil-PartSet"
- } else {
- return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
- }
- }
|