You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

375 lines
8.2 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. package types
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "github.com/tendermint/tendermint/crypto/merkle"
  8. tmsync "github.com/tendermint/tendermint/internal/libs/sync"
  9. "github.com/tendermint/tendermint/libs/bits"
  10. tmbytes "github.com/tendermint/tendermint/libs/bytes"
  11. tmjson "github.com/tendermint/tendermint/libs/json"
  12. tmmath "github.com/tendermint/tendermint/libs/math"
  13. tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
  14. )
  15. var (
  16. ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
  17. ErrPartSetInvalidProof = errors.New("error part set invalid proof")
  18. )
  19. type Part struct {
  20. Index uint32 `json:"index"`
  21. Bytes tmbytes.HexBytes `json:"bytes"`
  22. Proof merkle.Proof `json:"proof"`
  23. }
  24. // ValidateBasic performs basic validation.
  25. func (part *Part) ValidateBasic() error {
  26. if len(part.Bytes) > int(BlockPartSizeBytes) {
  27. return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
  28. }
  29. if err := part.Proof.ValidateBasic(); err != nil {
  30. return fmt.Errorf("wrong Proof: %w", err)
  31. }
  32. return nil
  33. }
  34. // String returns a string representation of Part.
  35. //
  36. // See StringIndented.
  37. func (part *Part) String() string {
  38. return part.StringIndented("")
  39. }
  40. // StringIndented returns an indented Part.
  41. //
  42. // See merkle.Proof#StringIndented
  43. func (part *Part) StringIndented(indent string) string {
  44. return fmt.Sprintf(`Part{#%v
  45. %s Bytes: %X...
  46. %s Proof: %v
  47. %s}`,
  48. part.Index,
  49. indent, tmbytes.Fingerprint(part.Bytes),
  50. indent, part.Proof.StringIndented(indent+" "),
  51. indent)
  52. }
  53. func (part *Part) ToProto() (*tmproto.Part, error) {
  54. if part == nil {
  55. return nil, errors.New("nil part")
  56. }
  57. pb := new(tmproto.Part)
  58. proof := part.Proof.ToProto()
  59. pb.Index = part.Index
  60. pb.Bytes = part.Bytes
  61. pb.Proof = *proof
  62. return pb, nil
  63. }
  64. func PartFromProto(pb *tmproto.Part) (*Part, error) {
  65. if pb == nil {
  66. return nil, errors.New("nil part")
  67. }
  68. part := new(Part)
  69. proof, err := merkle.ProofFromProto(&pb.Proof)
  70. if err != nil {
  71. return nil, err
  72. }
  73. part.Index = pb.Index
  74. part.Bytes = pb.Bytes
  75. part.Proof = *proof
  76. return part, part.ValidateBasic()
  77. }
  78. //-------------------------------------
  79. type PartSetHeader struct {
  80. Total uint32 `json:"total"`
  81. Hash tmbytes.HexBytes `json:"hash"`
  82. }
  83. // String returns a string representation of PartSetHeader.
  84. //
  85. // 1. total number of parts
  86. // 2. first 6 bytes of the hash
  87. func (psh PartSetHeader) String() string {
  88. return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
  89. }
  90. func (psh PartSetHeader) IsZero() bool {
  91. return psh.Total == 0 && len(psh.Hash) == 0
  92. }
  93. func (psh PartSetHeader) Equals(other PartSetHeader) bool {
  94. return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
  95. }
  96. // ValidateBasic performs basic validation.
  97. func (psh PartSetHeader) ValidateBasic() error {
  98. // Hash can be empty in case of POLBlockID.PartSetHeader in Proposal.
  99. if err := ValidateHash(psh.Hash); err != nil {
  100. return fmt.Errorf("wrong Hash: %w", err)
  101. }
  102. return nil
  103. }
  104. // ToProto converts PartSetHeader to protobuf
  105. func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
  106. if psh == nil {
  107. return tmproto.PartSetHeader{}
  108. }
  109. return tmproto.PartSetHeader{
  110. Total: psh.Total,
  111. Hash: psh.Hash,
  112. }
  113. }
  114. // FromProto sets a protobuf PartSetHeader to the given pointer
  115. func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
  116. if ppsh == nil {
  117. return nil, errors.New("nil PartSetHeader")
  118. }
  119. psh := new(PartSetHeader)
  120. psh.Total = ppsh.Total
  121. psh.Hash = ppsh.Hash
  122. return psh, psh.ValidateBasic()
  123. }
  124. //-------------------------------------
  125. type PartSet struct {
  126. total uint32
  127. hash []byte
  128. mtx tmsync.Mutex
  129. parts []*Part
  130. partsBitArray *bits.BitArray
  131. count uint32
  132. // a count of the total size (in bytes). Used to ensure that the
  133. // part set doesn't exceed the maximum block bytes
  134. byteSize int64
  135. }
  136. // Returns an immutable, full PartSet from the data bytes.
  137. // The data bytes are split into "partSize" chunks, and merkle tree computed.
  138. // CONTRACT: partSize is greater than zero.
  139. func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
  140. // divide data into 4kb parts.
  141. total := (uint32(len(data)) + partSize - 1) / partSize
  142. parts := make([]*Part, total)
  143. partsBytes := make([][]byte, total)
  144. partsBitArray := bits.NewBitArray(int(total))
  145. for i := uint32(0); i < total; i++ {
  146. part := &Part{
  147. Index: i,
  148. Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))],
  149. }
  150. parts[i] = part
  151. partsBytes[i] = part.Bytes
  152. partsBitArray.SetIndex(int(i), true)
  153. }
  154. // Compute merkle proofs
  155. root, proofs := merkle.ProofsFromByteSlices(partsBytes)
  156. for i := uint32(0); i < total; i++ {
  157. parts[i].Proof = *proofs[i]
  158. }
  159. return &PartSet{
  160. total: total,
  161. hash: root,
  162. parts: parts,
  163. partsBitArray: partsBitArray,
  164. count: total,
  165. byteSize: int64(len(data)),
  166. }
  167. }
  168. // Returns an empty PartSet ready to be populated.
  169. func NewPartSetFromHeader(header PartSetHeader) *PartSet {
  170. return &PartSet{
  171. total: header.Total,
  172. hash: header.Hash,
  173. parts: make([]*Part, header.Total),
  174. partsBitArray: bits.NewBitArray(int(header.Total)),
  175. count: 0,
  176. byteSize: 0,
  177. }
  178. }
  179. func (ps *PartSet) Header() PartSetHeader {
  180. if ps == nil {
  181. return PartSetHeader{}
  182. }
  183. return PartSetHeader{
  184. Total: ps.total,
  185. Hash: ps.hash,
  186. }
  187. }
  188. func (ps *PartSet) HasHeader(header PartSetHeader) bool {
  189. if ps == nil {
  190. return false
  191. }
  192. return ps.Header().Equals(header)
  193. }
  194. func (ps *PartSet) BitArray() *bits.BitArray {
  195. ps.mtx.Lock()
  196. defer ps.mtx.Unlock()
  197. return ps.partsBitArray.Copy()
  198. }
  199. func (ps *PartSet) Hash() []byte {
  200. if ps == nil {
  201. return merkle.HashFromByteSlices(nil)
  202. }
  203. return ps.hash
  204. }
  205. func (ps *PartSet) HashesTo(hash []byte) bool {
  206. if ps == nil {
  207. return false
  208. }
  209. return bytes.Equal(ps.hash, hash)
  210. }
  211. func (ps *PartSet) Count() uint32 {
  212. if ps == nil {
  213. return 0
  214. }
  215. return ps.count
  216. }
  217. func (ps *PartSet) ByteSize() int64 {
  218. if ps == nil {
  219. return 0
  220. }
  221. return ps.byteSize
  222. }
  223. func (ps *PartSet) Total() uint32 {
  224. if ps == nil {
  225. return 0
  226. }
  227. return ps.total
  228. }
  229. func (ps *PartSet) AddPart(part *Part) (bool, error) {
  230. if ps == nil {
  231. return false, nil
  232. }
  233. ps.mtx.Lock()
  234. defer ps.mtx.Unlock()
  235. // Invalid part index
  236. if part.Index >= ps.total {
  237. return false, ErrPartSetUnexpectedIndex
  238. }
  239. // If part already exists, return false.
  240. if ps.parts[part.Index] != nil {
  241. return false, nil
  242. }
  243. // Check hash proof
  244. if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
  245. return false, ErrPartSetInvalidProof
  246. }
  247. // Add part
  248. ps.parts[part.Index] = part
  249. ps.partsBitArray.SetIndex(int(part.Index), true)
  250. ps.count++
  251. ps.byteSize += int64(len(part.Bytes))
  252. return true, nil
  253. }
  254. func (ps *PartSet) GetPart(index int) *Part {
  255. ps.mtx.Lock()
  256. defer ps.mtx.Unlock()
  257. return ps.parts[index]
  258. }
  259. func (ps *PartSet) IsComplete() bool {
  260. return ps.count == ps.total
  261. }
  262. func (ps *PartSet) GetReader() io.Reader {
  263. if !ps.IsComplete() {
  264. panic("Cannot GetReader() on incomplete PartSet")
  265. }
  266. return NewPartSetReader(ps.parts)
  267. }
  268. type PartSetReader struct {
  269. i int
  270. parts []*Part
  271. reader *bytes.Reader
  272. }
  273. func NewPartSetReader(parts []*Part) *PartSetReader {
  274. return &PartSetReader{
  275. i: 0,
  276. parts: parts,
  277. reader: bytes.NewReader(parts[0].Bytes),
  278. }
  279. }
  280. func (psr *PartSetReader) Read(p []byte) (n int, err error) {
  281. readerLen := psr.reader.Len()
  282. if readerLen >= len(p) {
  283. return psr.reader.Read(p)
  284. } else if readerLen > 0 {
  285. n1, err := psr.Read(p[:readerLen])
  286. if err != nil {
  287. return n1, err
  288. }
  289. n2, err := psr.Read(p[readerLen:])
  290. return n1 + n2, err
  291. }
  292. psr.i++
  293. if psr.i >= len(psr.parts) {
  294. return 0, io.EOF
  295. }
  296. psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
  297. return psr.Read(p)
  298. }
  299. // StringShort returns a short version of String.
  300. //
  301. // (Count of Total)
  302. func (ps *PartSet) StringShort() string {
  303. if ps == nil {
  304. return "nil-PartSet"
  305. }
  306. ps.mtx.Lock()
  307. defer ps.mtx.Unlock()
  308. return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
  309. }
  310. func (ps *PartSet) MarshalJSON() ([]byte, error) {
  311. if ps == nil {
  312. return []byte("{}"), nil
  313. }
  314. ps.mtx.Lock()
  315. defer ps.mtx.Unlock()
  316. return tmjson.Marshal(struct {
  317. CountTotal string `json:"count/total"`
  318. PartsBitArray *bits.BitArray `json:"parts_bit_array"`
  319. }{
  320. fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
  321. ps.partsBitArray,
  322. })
  323. }