package types
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
|
|
"github.com/tendermint/tendermint/crypto/merkle"
|
|
"github.com/tendermint/tendermint/libs/bits"
|
|
tmbytes "github.com/tendermint/tendermint/libs/bytes"
|
|
tmjson "github.com/tendermint/tendermint/libs/json"
|
|
tmmath "github.com/tendermint/tendermint/libs/math"
|
|
tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
|
|
)
|
|
|
|
var (
|
|
ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
|
|
ErrPartSetInvalidProof = errors.New("error part set invalid proof")
|
|
)
|
|
|
|
type Part struct {
|
|
Index uint32 `json:"index"`
|
|
Bytes tmbytes.HexBytes `json:"bytes"`
|
|
Proof merkle.Proof `json:"proof"`
|
|
}
|
|
|
|
// ValidateBasic performs basic validation.
|
|
func (part *Part) ValidateBasic() error {
|
|
if len(part.Bytes) > int(BlockPartSizeBytes) {
|
|
return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
|
|
}
|
|
if err := part.Proof.ValidateBasic(); err != nil {
|
|
return fmt.Errorf("wrong Proof: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// String returns a string representation of Part.
|
|
//
|
|
// See StringIndented.
|
|
func (part *Part) String() string {
|
|
return part.StringIndented("")
|
|
}
|
|
|
|
// StringIndented returns an indented Part.
|
|
//
|
|
// See merkle.Proof#StringIndented
|
|
func (part *Part) StringIndented(indent string) string {
|
|
return fmt.Sprintf(`Part{#%v
|
|
%s Bytes: %X...
|
|
%s Proof: %v
|
|
%s}`,
|
|
part.Index,
|
|
indent, tmbytes.Fingerprint(part.Bytes),
|
|
indent, part.Proof.StringIndented(indent+" "),
|
|
indent)
|
|
}
|
|
|
|
func (part *Part) ToProto() (*tmproto.Part, error) {
|
|
if part == nil {
|
|
return nil, errors.New("nil part")
|
|
}
|
|
pb := new(tmproto.Part)
|
|
proof := part.Proof.ToProto()
|
|
|
|
pb.Index = part.Index
|
|
pb.Bytes = part.Bytes
|
|
pb.Proof = *proof
|
|
|
|
return pb, nil
|
|
}
|
|
|
|
func PartFromProto(pb *tmproto.Part) (*Part, error) {
|
|
if pb == nil {
|
|
return nil, errors.New("nil part")
|
|
}
|
|
|
|
part := new(Part)
|
|
proof, err := merkle.ProofFromProto(&pb.Proof)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
part.Index = pb.Index
|
|
part.Bytes = pb.Bytes
|
|
part.Proof = *proof
|
|
|
|
return part, part.ValidateBasic()
|
|
}
|
|
|
|
//-------------------------------------
|
|
|
|
type PartSetHeader struct {
|
|
Total uint32 `json:"total"`
|
|
Hash tmbytes.HexBytes `json:"hash"`
|
|
}
|
|
|
|
// String returns a string representation of PartSetHeader.
|
|
//
|
|
// 1. total number of parts
|
|
// 2. first 6 bytes of the hash
|
|
func (psh PartSetHeader) String() string {
|
|
return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
|
|
}
|
|
|
|
func (psh PartSetHeader) IsZero() bool {
|
|
return psh.Total == 0 && len(psh.Hash) == 0
|
|
}
|
|
|
|
func (psh PartSetHeader) Equals(other PartSetHeader) bool {
|
|
return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
|
|
}
|
|
|
|
// ValidateBasic performs basic validation.
|
|
func (psh PartSetHeader) ValidateBasic() error {
|
|
// Hash can be empty in case of POLBlockID.PartSetHeader in Proposal.
|
|
if err := ValidateHash(psh.Hash); err != nil {
|
|
return fmt.Errorf("wrong Hash: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ToProto converts PartSetHeader to protobuf
|
|
func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
|
|
if psh == nil {
|
|
return tmproto.PartSetHeader{}
|
|
}
|
|
|
|
return tmproto.PartSetHeader{
|
|
Total: psh.Total,
|
|
Hash: psh.Hash,
|
|
}
|
|
}
|
|
|
|
// FromProto sets a protobuf PartSetHeader to the given pointer
|
|
func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
|
|
if ppsh == nil {
|
|
return nil, errors.New("nil PartSetHeader")
|
|
}
|
|
psh := new(PartSetHeader)
|
|
psh.Total = ppsh.Total
|
|
psh.Hash = ppsh.Hash
|
|
|
|
return psh, psh.ValidateBasic()
|
|
}
|
|
|
|
//-------------------------------------
|
|
|
|
type PartSet struct {
|
|
total uint32
|
|
hash []byte
|
|
|
|
mtx sync.Mutex
|
|
parts []*Part
|
|
partsBitArray *bits.BitArray
|
|
count uint32
|
|
// a count of the total size (in bytes). Used to ensure that the
|
|
// part set doesn't exceed the maximum block bytes
|
|
byteSize int64
|
|
}
|
|
|
|
// Returns an immutable, full PartSet from the data bytes.
|
|
// The data bytes are split into "partSize" chunks, and merkle tree computed.
|
|
// CONTRACT: partSize is greater than zero.
|
|
func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
|
|
// divide data into 4kb parts.
|
|
total := (uint32(len(data)) + partSize - 1) / partSize
|
|
parts := make([]*Part, total)
|
|
partsBytes := make([][]byte, total)
|
|
partsBitArray := bits.NewBitArray(int(total))
|
|
for i := uint32(0); i < total; i++ {
|
|
part := &Part{
|
|
Index: i,
|
|
Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))],
|
|
}
|
|
parts[i] = part
|
|
partsBytes[i] = part.Bytes
|
|
partsBitArray.SetIndex(int(i), true)
|
|
}
|
|
// Compute merkle proofs
|
|
root, proofs := merkle.ProofsFromByteSlices(partsBytes)
|
|
for i := uint32(0); i < total; i++ {
|
|
parts[i].Proof = *proofs[i]
|
|
}
|
|
return &PartSet{
|
|
total: total,
|
|
hash: root,
|
|
parts: parts,
|
|
partsBitArray: partsBitArray,
|
|
count: total,
|
|
byteSize: int64(len(data)),
|
|
}
|
|
}
|
|
|
|
// Returns an empty PartSet ready to be populated.
|
|
func NewPartSetFromHeader(header PartSetHeader) *PartSet {
|
|
return &PartSet{
|
|
total: header.Total,
|
|
hash: header.Hash,
|
|
parts: make([]*Part, header.Total),
|
|
partsBitArray: bits.NewBitArray(int(header.Total)),
|
|
count: 0,
|
|
byteSize: 0,
|
|
}
|
|
}
|
|
|
|
func (ps *PartSet) Header() PartSetHeader {
|
|
if ps == nil {
|
|
return PartSetHeader{}
|
|
}
|
|
return PartSetHeader{
|
|
Total: ps.total,
|
|
Hash: ps.hash,
|
|
}
|
|
}
|
|
|
|
func (ps *PartSet) HasHeader(header PartSetHeader) bool {
|
|
if ps == nil {
|
|
return false
|
|
}
|
|
return ps.Header().Equals(header)
|
|
}
|
|
|
|
func (ps *PartSet) BitArray() *bits.BitArray {
|
|
ps.mtx.Lock()
|
|
defer ps.mtx.Unlock()
|
|
return ps.partsBitArray.Copy()
|
|
}
|
|
|
|
func (ps *PartSet) Hash() []byte {
|
|
if ps == nil {
|
|
return merkle.HashFromByteSlices(nil)
|
|
}
|
|
return ps.hash
|
|
}
|
|
|
|
func (ps *PartSet) HashesTo(hash []byte) bool {
|
|
if ps == nil {
|
|
return false
|
|
}
|
|
return bytes.Equal(ps.hash, hash)
|
|
}
|
|
|
|
func (ps *PartSet) Count() uint32 {
|
|
if ps == nil {
|
|
return 0
|
|
}
|
|
return ps.count
|
|
}
|
|
|
|
func (ps *PartSet) ByteSize() int64 {
|
|
if ps == nil {
|
|
return 0
|
|
}
|
|
return ps.byteSize
|
|
}
|
|
|
|
func (ps *PartSet) Total() uint32 {
|
|
if ps == nil {
|
|
return 0
|
|
}
|
|
return ps.total
|
|
}
|
|
|
|
func (ps *PartSet) AddPart(part *Part) (bool, error) {
|
|
if ps == nil {
|
|
return false, nil
|
|
}
|
|
ps.mtx.Lock()
|
|
defer ps.mtx.Unlock()
|
|
|
|
// Invalid part index
|
|
if part.Index >= ps.total {
|
|
return false, ErrPartSetUnexpectedIndex
|
|
}
|
|
|
|
// If part already exists, return false.
|
|
if ps.parts[part.Index] != nil {
|
|
return false, nil
|
|
}
|
|
|
|
// Check hash proof
|
|
if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
|
|
return false, ErrPartSetInvalidProof
|
|
}
|
|
|
|
// Add part
|
|
ps.parts[part.Index] = part
|
|
ps.partsBitArray.SetIndex(int(part.Index), true)
|
|
ps.count++
|
|
ps.byteSize += int64(len(part.Bytes))
|
|
return true, nil
|
|
}
|
|
|
|
func (ps *PartSet) GetPart(index int) *Part {
|
|
ps.mtx.Lock()
|
|
defer ps.mtx.Unlock()
|
|
return ps.parts[index]
|
|
}
|
|
|
|
func (ps *PartSet) IsComplete() bool {
|
|
return ps.count == ps.total
|
|
}
|
|
|
|
func (ps *PartSet) GetReader() io.Reader {
|
|
if !ps.IsComplete() {
|
|
panic("Cannot GetReader() on incomplete PartSet")
|
|
}
|
|
return NewPartSetReader(ps.parts)
|
|
}
|
|
|
|
type PartSetReader struct {
|
|
i int
|
|
parts []*Part
|
|
reader *bytes.Reader
|
|
}
|
|
|
|
func NewPartSetReader(parts []*Part) *PartSetReader {
|
|
return &PartSetReader{
|
|
i: 0,
|
|
parts: parts,
|
|
reader: bytes.NewReader(parts[0].Bytes),
|
|
}
|
|
}
|
|
|
|
func (psr *PartSetReader) Read(p []byte) (n int, err error) {
|
|
readerLen := psr.reader.Len()
|
|
if readerLen >= len(p) {
|
|
return psr.reader.Read(p)
|
|
} else if readerLen > 0 {
|
|
n1, err := psr.Read(p[:readerLen])
|
|
if err != nil {
|
|
return n1, err
|
|
}
|
|
n2, err := psr.Read(p[readerLen:])
|
|
return n1 + n2, err
|
|
}
|
|
|
|
psr.i++
|
|
if psr.i >= len(psr.parts) {
|
|
return 0, io.EOF
|
|
}
|
|
psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
|
|
return psr.Read(p)
|
|
}
|
|
|
|
// StringShort returns a short version of String.
|
|
//
|
|
// (Count of Total)
|
|
func (ps *PartSet) StringShort() string {
|
|
if ps == nil {
|
|
return "nil-PartSet"
|
|
}
|
|
ps.mtx.Lock()
|
|
defer ps.mtx.Unlock()
|
|
return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
|
|
}
|
|
|
|
func (ps *PartSet) MarshalJSON() ([]byte, error) {
|
|
if ps == nil {
|
|
return []byte("{}"), nil
|
|
}
|
|
|
|
ps.mtx.Lock()
|
|
defer ps.mtx.Unlock()
|
|
|
|
return tmjson.Marshal(struct {
|
|
CountTotal string `json:"count/total"`
|
|
PartsBitArray *bits.BitArray `json:"parts_bit_array"`
|
|
}{
|
|
fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
|
|
ps.partsBitArray,
|
|
})
|
|
}
|