package types import ( "bytes" "errors" "fmt" "io" "sync" "github.com/tendermint/tendermint/crypto/merkle" "github.com/tendermint/tendermint/libs/bits" tmbytes "github.com/tendermint/tendermint/libs/bytes" tmjson "github.com/tendermint/tendermint/libs/json" tmmath "github.com/tendermint/tendermint/libs/math" tmproto "github.com/tendermint/tendermint/proto/types" ) var ( ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index") ErrPartSetInvalidProof = errors.New("error part set invalid proof") ) type Part struct { Index uint32 `json:"index"` Bytes tmbytes.HexBytes `json:"bytes"` Proof merkle.Proof `json:"proof"` } // ValidateBasic performs basic validation. func (part *Part) ValidateBasic() error { if len(part.Bytes) > int(BlockPartSizeBytes) { return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes) } if err := part.Proof.ValidateBasic(); err != nil { return fmt.Errorf("wrong Proof: %w", err) } return nil } func (part *Part) String() string { return part.StringIndented("") } func (part *Part) StringIndented(indent string) string { return fmt.Sprintf(`Part{#%v %s Bytes: %X... %s Proof: %v %s}`, part.Index, indent, tmbytes.Fingerprint(part.Bytes), indent, part.Proof.StringIndented(indent+" "), indent) } func (part *Part) ToProto() (*tmproto.Part, error) { if part == nil { return nil, errors.New("nil part") } pb := new(tmproto.Part) proof := part.Proof.ToProto() pb.Index = part.Index pb.Bytes = part.Bytes pb.Proof = *proof return pb, nil } func PartFromProto(pb *tmproto.Part) (*Part, error) { if pb == nil { return nil, errors.New("nil part") } part := new(Part) proof, err := merkle.ProofFromProto(&pb.Proof) if err != nil { return nil, err } part.Index = pb.Index part.Bytes = pb.Bytes part.Proof = *proof return part, part.ValidateBasic() } //------------------------------------- type PartSetHeader struct { Total uint32 `json:"total"` Hash tmbytes.HexBytes `json:"hash"` } func (psh PartSetHeader) String() string { return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash)) } func (psh PartSetHeader) IsZero() bool { return psh.Total == 0 && len(psh.Hash) == 0 } func (psh PartSetHeader) Equals(other PartSetHeader) bool { return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash) } // ValidateBasic performs basic validation. func (psh PartSetHeader) ValidateBasic() error { // Hash can be empty in case of POLBlockID.PartsHeader in Proposal. if err := ValidateHash(psh.Hash); err != nil { return fmt.Errorf("wrong Hash: %w", err) } return nil } // ToProto converts BloPartSetHeaderckID to protobuf func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader { if psh == nil { return tmproto.PartSetHeader{} } return tmproto.PartSetHeader{ Total: psh.Total, Hash: psh.Hash, } } // FromProto sets a protobuf PartSetHeader to the given pointer func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) { if ppsh == nil { return nil, errors.New("nil PartSetHeader") } psh := new(PartSetHeader) psh.Total = ppsh.Total psh.Hash = ppsh.Hash return psh, psh.ValidateBasic() } //------------------------------------- type PartSet struct { total uint32 hash []byte mtx sync.Mutex parts []*Part partsBitArray *bits.BitArray count uint32 } // Returns an immutable, full PartSet from the data bytes. // The data bytes are split into "partSize" chunks, and merkle tree computed. // CONTRACT: partSize is greater than zero. func NewPartSetFromData(data []byte, partSize uint32) *PartSet { // divide data into 4kb parts. total := (uint32(len(data)) + partSize - 1) / partSize parts := make([]*Part, total) partsBytes := make([][]byte, total) partsBitArray := bits.NewBitArray(int(total)) for i := uint32(0); i < total; i++ { part := &Part{ Index: i, Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))], } parts[i] = part partsBytes[i] = part.Bytes partsBitArray.SetIndex(int(i), true) } // Compute merkle proofs root, proofs := merkle.ProofsFromByteSlices(partsBytes) for i := uint32(0); i < total; i++ { parts[i].Proof = *proofs[i] } return &PartSet{ total: total, hash: root, parts: parts, partsBitArray: partsBitArray, count: total, } } // Returns an empty PartSet ready to be populated. func NewPartSetFromHeader(header PartSetHeader) *PartSet { return &PartSet{ total: header.Total, hash: header.Hash, parts: make([]*Part, header.Total), partsBitArray: bits.NewBitArray(int(header.Total)), count: 0, } } func (ps *PartSet) Header() PartSetHeader { if ps == nil { return PartSetHeader{} } return PartSetHeader{ Total: ps.total, Hash: ps.hash, } } func (ps *PartSet) HasHeader(header PartSetHeader) bool { if ps == nil { return false } return ps.Header().Equals(header) } func (ps *PartSet) BitArray() *bits.BitArray { ps.mtx.Lock() defer ps.mtx.Unlock() return ps.partsBitArray.Copy() } func (ps *PartSet) Hash() []byte { if ps == nil { return nil } return ps.hash } func (ps *PartSet) HashesTo(hash []byte) bool { if ps == nil { return false } return bytes.Equal(ps.hash, hash) } func (ps *PartSet) Count() uint32 { if ps == nil { return 0 } return ps.count } func (ps *PartSet) Total() uint32 { if ps == nil { return 0 } return ps.total } func (ps *PartSet) AddPart(part *Part) (bool, error) { if ps == nil { return false, nil } ps.mtx.Lock() defer ps.mtx.Unlock() // Invalid part index if part.Index >= ps.total { return false, ErrPartSetUnexpectedIndex } // If part already exists, return false. if ps.parts[part.Index] != nil { return false, nil } // Check hash proof if part.Proof.Verify(ps.Hash(), part.Bytes) != nil { return false, ErrPartSetInvalidProof } // Add part ps.parts[part.Index] = part ps.partsBitArray.SetIndex(int(part.Index), true) ps.count++ return true, nil } func (ps *PartSet) GetPart(index int) *Part { ps.mtx.Lock() defer ps.mtx.Unlock() return ps.parts[index] } func (ps *PartSet) IsComplete() bool { return ps.count == ps.total } func (ps *PartSet) GetReader() io.Reader { if !ps.IsComplete() { panic("Cannot GetReader() on incomplete PartSet") } return NewPartSetReader(ps.parts) } type PartSetReader struct { i int parts []*Part reader *bytes.Reader } func NewPartSetReader(parts []*Part) *PartSetReader { return &PartSetReader{ i: 0, parts: parts, reader: bytes.NewReader(parts[0].Bytes), } } func (psr *PartSetReader) Read(p []byte) (n int, err error) { readerLen := psr.reader.Len() if readerLen >= len(p) { return psr.reader.Read(p) } else if readerLen > 0 { n1, err := psr.Read(p[:readerLen]) if err != nil { return n1, err } n2, err := psr.Read(p[readerLen:]) return n1 + n2, err } psr.i++ if psr.i >= len(psr.parts) { return 0, io.EOF } psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes) return psr.Read(p) } func (ps *PartSet) StringShort() string { if ps == nil { return "nil-PartSet" } ps.mtx.Lock() defer ps.mtx.Unlock() return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total()) } func (ps *PartSet) MarshalJSON() ([]byte, error) { if ps == nil { return []byte("{}"), nil } ps.mtx.Lock() defer ps.mtx.Unlock() return tmjson.Marshal(struct { CountTotal string `json:"count/total"` PartsBitArray *bits.BitArray `json:"parts_bit_array"` }{ fmt.Sprintf("%d/%d", ps.Count(), ps.Total()), ps.partsBitArray, }) }