- package types
-
- import (
- "bytes"
- "errors"
- "fmt"
- "io"
-
- "github.com/tendermint/tendermint/crypto/merkle"
- "github.com/tendermint/tendermint/libs/bits"
- tmbytes "github.com/tendermint/tendermint/libs/bytes"
- tmjson "github.com/tendermint/tendermint/libs/json"
- tmmath "github.com/tendermint/tendermint/libs/math"
- tmsync "github.com/tendermint/tendermint/libs/sync"
- tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
- )
-
- var (
- ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
- ErrPartSetInvalidProof = errors.New("error part set invalid proof")
- )
-
- type Part struct {
- Index uint32 `json:"index"`
- Bytes tmbytes.HexBytes `json:"bytes"`
- Proof merkle.Proof `json:"proof"`
- }
-
- // ValidateBasic performs basic validation.
- func (part *Part) ValidateBasic() error {
- if len(part.Bytes) > int(BlockPartSizeBytes) {
- return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
- }
- if err := part.Proof.ValidateBasic(); err != nil {
- return fmt.Errorf("wrong Proof: %w", err)
- }
- return nil
- }
-
- // String returns a string representation of Part.
- //
- // See StringIndented.
- func (part *Part) String() string {
- return part.StringIndented("")
- }
-
- // StringIndented returns an indented Part.
- //
- // See merkle.Proof#StringIndented
- func (part *Part) StringIndented(indent string) string {
- return fmt.Sprintf(`Part{#%v
- %s Bytes: %X...
- %s Proof: %v
- %s}`,
- part.Index,
- indent, tmbytes.Fingerprint(part.Bytes),
- indent, part.Proof.StringIndented(indent+" "),
- indent)
- }
-
- func (part *Part) ToProto() (*tmproto.Part, error) {
- if part == nil {
- return nil, errors.New("nil part")
- }
- pb := new(tmproto.Part)
- proof := part.Proof.ToProto()
-
- pb.Index = part.Index
- pb.Bytes = part.Bytes
- pb.Proof = *proof
-
- return pb, nil
- }
-
- func PartFromProto(pb *tmproto.Part) (*Part, error) {
- if pb == nil {
- return nil, errors.New("nil part")
- }
-
- part := new(Part)
- proof, err := merkle.ProofFromProto(&pb.Proof)
- if err != nil {
- return nil, err
- }
- part.Index = pb.Index
- part.Bytes = pb.Bytes
- part.Proof = *proof
-
- return part, part.ValidateBasic()
- }
-
- //-------------------------------------
-
- type PartSetHeader struct {
- Total uint32 `json:"total"`
- Hash tmbytes.HexBytes `json:"hash"`
- }
-
- // String returns a string representation of PartSetHeader.
- //
- // 1. total number of parts
- // 2. first 6 bytes of the hash
- func (psh PartSetHeader) String() string {
- return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
- }
-
- func (psh PartSetHeader) IsZero() bool {
- return psh.Total == 0 && len(psh.Hash) == 0
- }
-
- func (psh PartSetHeader) Equals(other PartSetHeader) bool {
- return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
- }
-
- // ValidateBasic performs basic validation.
- func (psh PartSetHeader) ValidateBasic() error {
- // Hash can be empty in case of POLBlockID.PartSetHeader in Proposal.
- if err := ValidateHash(psh.Hash); err != nil {
- return fmt.Errorf("wrong Hash: %w", err)
- }
- return nil
- }
-
- // ToProto converts BloPartSetHeaderckID to protobuf
- func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
- if psh == nil {
- return tmproto.PartSetHeader{}
- }
-
- return tmproto.PartSetHeader{
- Total: psh.Total,
- Hash: psh.Hash,
- }
- }
-
- // FromProto sets a protobuf PartSetHeader to the given pointer
- func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
- if ppsh == nil {
- return nil, errors.New("nil PartSetHeader")
- }
- psh := new(PartSetHeader)
- psh.Total = ppsh.Total
- psh.Hash = ppsh.Hash
-
- return psh, psh.ValidateBasic()
- }
-
- //-------------------------------------
-
- type PartSet struct {
- total uint32
- hash []byte
-
- mtx tmsync.Mutex
- parts []*Part
- partsBitArray *bits.BitArray
- count uint32
- // a count of the total size (in bytes). Used to ensure that the
- // part set doesn't exceed the maximum block bytes
- byteSize int64
- }
-
- // Returns an immutable, full PartSet from the data bytes.
- // The data bytes are split into "partSize" chunks, and merkle tree computed.
- // CONTRACT: partSize is greater than zero.
- func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
- // divide data into 4kb parts.
- total := (uint32(len(data)) + partSize - 1) / partSize
- parts := make([]*Part, total)
- partsBytes := make([][]byte, total)
- partsBitArray := bits.NewBitArray(int(total))
- for i := uint32(0); i < total; i++ {
- part := &Part{
- Index: i,
- Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))],
- }
- parts[i] = part
- partsBytes[i] = part.Bytes
- partsBitArray.SetIndex(int(i), true)
- }
- // Compute merkle proofs
- root, proofs := merkle.ProofsFromByteSlices(partsBytes)
- for i := uint32(0); i < total; i++ {
- parts[i].Proof = *proofs[i]
- }
- return &PartSet{
- total: total,
- hash: root,
- parts: parts,
- partsBitArray: partsBitArray,
- count: total,
- byteSize: int64(len(data)),
- }
- }
-
- // Returns an empty PartSet ready to be populated.
- func NewPartSetFromHeader(header PartSetHeader) *PartSet {
- return &PartSet{
- total: header.Total,
- hash: header.Hash,
- parts: make([]*Part, header.Total),
- partsBitArray: bits.NewBitArray(int(header.Total)),
- count: 0,
- byteSize: 0,
- }
- }
-
- func (ps *PartSet) Header() PartSetHeader {
- if ps == nil {
- return PartSetHeader{}
- }
- return PartSetHeader{
- Total: ps.total,
- Hash: ps.hash,
- }
- }
-
- func (ps *PartSet) HasHeader(header PartSetHeader) bool {
- if ps == nil {
- return false
- }
- return ps.Header().Equals(header)
- }
-
- func (ps *PartSet) BitArray() *bits.BitArray {
- ps.mtx.Lock()
- defer ps.mtx.Unlock()
- return ps.partsBitArray.Copy()
- }
-
- func (ps *PartSet) Hash() []byte {
- if ps == nil {
- return merkle.HashFromByteSlices(nil)
- }
- return ps.hash
- }
-
- func (ps *PartSet) HashesTo(hash []byte) bool {
- if ps == nil {
- return false
- }
- return bytes.Equal(ps.hash, hash)
- }
-
- func (ps *PartSet) Count() uint32 {
- if ps == nil {
- return 0
- }
- return ps.count
- }
-
- func (ps *PartSet) ByteSize() int64 {
- if ps == nil {
- return 0
- }
- return ps.byteSize
- }
-
- func (ps *PartSet) Total() uint32 {
- if ps == nil {
- return 0
- }
- return ps.total
- }
-
- func (ps *PartSet) AddPart(part *Part) (bool, error) {
- if ps == nil {
- return false, nil
- }
- ps.mtx.Lock()
- defer ps.mtx.Unlock()
-
- // Invalid part index
- if part.Index >= ps.total {
- return false, ErrPartSetUnexpectedIndex
- }
-
- // If part already exists, return false.
- if ps.parts[part.Index] != nil {
- return false, nil
- }
-
- // Check hash proof
- if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
- return false, ErrPartSetInvalidProof
- }
-
- // Add part
- ps.parts[part.Index] = part
- ps.partsBitArray.SetIndex(int(part.Index), true)
- ps.count++
- ps.byteSize += int64(len(part.Bytes))
- return true, nil
- }
-
- func (ps *PartSet) GetPart(index int) *Part {
- ps.mtx.Lock()
- defer ps.mtx.Unlock()
- return ps.parts[index]
- }
-
- func (ps *PartSet) IsComplete() bool {
- return ps.count == ps.total
- }
-
- func (ps *PartSet) GetReader() io.Reader {
- if !ps.IsComplete() {
- panic("Cannot GetReader() on incomplete PartSet")
- }
- return NewPartSetReader(ps.parts)
- }
-
- type PartSetReader struct {
- i int
- parts []*Part
- reader *bytes.Reader
- }
-
- func NewPartSetReader(parts []*Part) *PartSetReader {
- return &PartSetReader{
- i: 0,
- parts: parts,
- reader: bytes.NewReader(parts[0].Bytes),
- }
- }
-
- func (psr *PartSetReader) Read(p []byte) (n int, err error) {
- readerLen := psr.reader.Len()
- if readerLen >= len(p) {
- return psr.reader.Read(p)
- } else if readerLen > 0 {
- n1, err := psr.Read(p[:readerLen])
- if err != nil {
- return n1, err
- }
- n2, err := psr.Read(p[readerLen:])
- return n1 + n2, err
- }
-
- psr.i++
- if psr.i >= len(psr.parts) {
- return 0, io.EOF
- }
- psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
- return psr.Read(p)
- }
-
- // StringShort returns a short version of String.
- //
- // (Count of Total)
- func (ps *PartSet) StringShort() string {
- if ps == nil {
- return "nil-PartSet"
- }
- ps.mtx.Lock()
- defer ps.mtx.Unlock()
- return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
- }
-
- func (ps *PartSet) MarshalJSON() ([]byte, error) {
- if ps == nil {
- return []byte("{}"), nil
- }
-
- ps.mtx.Lock()
- defer ps.mtx.Unlock()
-
- return tmjson.Marshal(struct {
- CountTotal string `json:"count/total"`
- PartsBitArray *bits.BitArray `json:"parts_bit_array"`
- }{
- fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
- ps.partsBitArray,
- })
- }
|