You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

362 lines
7.9 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. package types
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "github.com/tendermint/tendermint/crypto/merkle"
  8. "github.com/tendermint/tendermint/libs/bits"
  9. tmbytes "github.com/tendermint/tendermint/libs/bytes"
  10. tmjson "github.com/tendermint/tendermint/libs/json"
  11. tmmath "github.com/tendermint/tendermint/libs/math"
  12. tmsync "github.com/tendermint/tendermint/libs/sync"
  13. tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
  14. )
  15. var (
  16. ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
  17. ErrPartSetInvalidProof = errors.New("error part set invalid proof")
  18. )
  19. type Part struct {
  20. Index uint32 `json:"index"`
  21. Bytes tmbytes.HexBytes `json:"bytes"`
  22. Proof merkle.Proof `json:"proof"`
  23. }
  24. // ValidateBasic performs basic validation.
  25. func (part *Part) ValidateBasic() error {
  26. if len(part.Bytes) > int(BlockPartSizeBytes) {
  27. return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
  28. }
  29. if err := part.Proof.ValidateBasic(); err != nil {
  30. return fmt.Errorf("wrong Proof: %w", err)
  31. }
  32. return nil
  33. }
  34. // String returns a string representation of Part.
  35. //
  36. // See StringIndented.
  37. func (part *Part) String() string {
  38. return part.StringIndented("")
  39. }
  40. // StringIndented returns an indented Part.
  41. //
  42. // See merkle.Proof#StringIndented
  43. func (part *Part) StringIndented(indent string) string {
  44. return fmt.Sprintf(`Part{#%v
  45. %s Bytes: %X...
  46. %s Proof: %v
  47. %s}`,
  48. part.Index,
  49. indent, tmbytes.Fingerprint(part.Bytes),
  50. indent, part.Proof.StringIndented(indent+" "),
  51. indent)
  52. }
  53. func (part *Part) ToProto() (*tmproto.Part, error) {
  54. if part == nil {
  55. return nil, errors.New("nil part")
  56. }
  57. pb := new(tmproto.Part)
  58. proof := part.Proof.ToProto()
  59. pb.Index = part.Index
  60. pb.Bytes = part.Bytes
  61. pb.Proof = *proof
  62. return pb, nil
  63. }
  64. func PartFromProto(pb *tmproto.Part) (*Part, error) {
  65. if pb == nil {
  66. return nil, errors.New("nil part")
  67. }
  68. part := new(Part)
  69. proof, err := merkle.ProofFromProto(&pb.Proof)
  70. if err != nil {
  71. return nil, err
  72. }
  73. part.Index = pb.Index
  74. part.Bytes = pb.Bytes
  75. part.Proof = *proof
  76. return part, part.ValidateBasic()
  77. }
  78. //-------------------------------------
  79. type PartSetHeader struct {
  80. Total uint32 `json:"total"`
  81. Hash tmbytes.HexBytes `json:"hash"`
  82. }
  83. // String returns a string representation of PartSetHeader.
  84. //
  85. // 1. total number of parts
  86. // 2. first 6 bytes of the hash
  87. func (psh PartSetHeader) String() string {
  88. return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
  89. }
  90. func (psh PartSetHeader) IsZero() bool {
  91. return psh.Total == 0 && len(psh.Hash) == 0
  92. }
  93. func (psh PartSetHeader) Equals(other PartSetHeader) bool {
  94. return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
  95. }
  96. // ValidateBasic performs basic validation.
  97. func (psh PartSetHeader) ValidateBasic() error {
  98. // Hash can be empty in case of POLBlockID.PartSetHeader in Proposal.
  99. if err := ValidateHash(psh.Hash); err != nil {
  100. return fmt.Errorf("wrong Hash: %w", err)
  101. }
  102. return nil
  103. }
  104. // ToProto converts BloPartSetHeaderckID to protobuf
  105. func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
  106. if psh == nil {
  107. return tmproto.PartSetHeader{}
  108. }
  109. return tmproto.PartSetHeader{
  110. Total: psh.Total,
  111. Hash: psh.Hash,
  112. }
  113. }
  114. // FromProto sets a protobuf PartSetHeader to the given pointer
  115. func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
  116. if ppsh == nil {
  117. return nil, errors.New("nil PartSetHeader")
  118. }
  119. psh := new(PartSetHeader)
  120. psh.Total = ppsh.Total
  121. psh.Hash = ppsh.Hash
  122. return psh, psh.ValidateBasic()
  123. }
  124. //-------------------------------------
  125. type PartSet struct {
  126. total uint32
  127. hash []byte
  128. mtx tmsync.Mutex
  129. parts []*Part
  130. partsBitArray *bits.BitArray
  131. count uint32
  132. }
  133. // Returns an immutable, full PartSet from the data bytes.
  134. // The data bytes are split into "partSize" chunks, and merkle tree computed.
  135. // CONTRACT: partSize is greater than zero.
  136. func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
  137. // divide data into 4kb parts.
  138. total := (uint32(len(data)) + partSize - 1) / partSize
  139. parts := make([]*Part, total)
  140. partsBytes := make([][]byte, total)
  141. partsBitArray := bits.NewBitArray(int(total))
  142. for i := uint32(0); i < total; i++ {
  143. part := &Part{
  144. Index: i,
  145. Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))],
  146. }
  147. parts[i] = part
  148. partsBytes[i] = part.Bytes
  149. partsBitArray.SetIndex(int(i), true)
  150. }
  151. // Compute merkle proofs
  152. root, proofs := merkle.ProofsFromByteSlices(partsBytes)
  153. for i := uint32(0); i < total; i++ {
  154. parts[i].Proof = *proofs[i]
  155. }
  156. return &PartSet{
  157. total: total,
  158. hash: root,
  159. parts: parts,
  160. partsBitArray: partsBitArray,
  161. count: total,
  162. }
  163. }
  164. // Returns an empty PartSet ready to be populated.
  165. func NewPartSetFromHeader(header PartSetHeader) *PartSet {
  166. return &PartSet{
  167. total: header.Total,
  168. hash: header.Hash,
  169. parts: make([]*Part, header.Total),
  170. partsBitArray: bits.NewBitArray(int(header.Total)),
  171. count: 0,
  172. }
  173. }
  174. func (ps *PartSet) Header() PartSetHeader {
  175. if ps == nil {
  176. return PartSetHeader{}
  177. }
  178. return PartSetHeader{
  179. Total: ps.total,
  180. Hash: ps.hash,
  181. }
  182. }
  183. func (ps *PartSet) HasHeader(header PartSetHeader) bool {
  184. if ps == nil {
  185. return false
  186. }
  187. return ps.Header().Equals(header)
  188. }
  189. func (ps *PartSet) BitArray() *bits.BitArray {
  190. ps.mtx.Lock()
  191. defer ps.mtx.Unlock()
  192. return ps.partsBitArray.Copy()
  193. }
  194. func (ps *PartSet) Hash() []byte {
  195. if ps == nil {
  196. return merkle.HashFromByteSlices(nil)
  197. }
  198. return ps.hash
  199. }
  200. func (ps *PartSet) HashesTo(hash []byte) bool {
  201. if ps == nil {
  202. return false
  203. }
  204. return bytes.Equal(ps.hash, hash)
  205. }
  206. func (ps *PartSet) Count() uint32 {
  207. if ps == nil {
  208. return 0
  209. }
  210. return ps.count
  211. }
  212. func (ps *PartSet) Total() uint32 {
  213. if ps == nil {
  214. return 0
  215. }
  216. return ps.total
  217. }
  218. func (ps *PartSet) AddPart(part *Part) (bool, error) {
  219. if ps == nil {
  220. return false, nil
  221. }
  222. ps.mtx.Lock()
  223. defer ps.mtx.Unlock()
  224. // Invalid part index
  225. if part.Index >= ps.total {
  226. return false, ErrPartSetUnexpectedIndex
  227. }
  228. // If part already exists, return false.
  229. if ps.parts[part.Index] != nil {
  230. return false, nil
  231. }
  232. // Check hash proof
  233. if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
  234. return false, ErrPartSetInvalidProof
  235. }
  236. // Add part
  237. ps.parts[part.Index] = part
  238. ps.partsBitArray.SetIndex(int(part.Index), true)
  239. ps.count++
  240. return true, nil
  241. }
  242. func (ps *PartSet) GetPart(index int) *Part {
  243. ps.mtx.Lock()
  244. defer ps.mtx.Unlock()
  245. return ps.parts[index]
  246. }
  247. func (ps *PartSet) IsComplete() bool {
  248. return ps.count == ps.total
  249. }
  250. func (ps *PartSet) GetReader() io.Reader {
  251. if !ps.IsComplete() {
  252. panic("Cannot GetReader() on incomplete PartSet")
  253. }
  254. return NewPartSetReader(ps.parts)
  255. }
  256. type PartSetReader struct {
  257. i int
  258. parts []*Part
  259. reader *bytes.Reader
  260. }
  261. func NewPartSetReader(parts []*Part) *PartSetReader {
  262. return &PartSetReader{
  263. i: 0,
  264. parts: parts,
  265. reader: bytes.NewReader(parts[0].Bytes),
  266. }
  267. }
  268. func (psr *PartSetReader) Read(p []byte) (n int, err error) {
  269. readerLen := psr.reader.Len()
  270. if readerLen >= len(p) {
  271. return psr.reader.Read(p)
  272. } else if readerLen > 0 {
  273. n1, err := psr.Read(p[:readerLen])
  274. if err != nil {
  275. return n1, err
  276. }
  277. n2, err := psr.Read(p[readerLen:])
  278. return n1 + n2, err
  279. }
  280. psr.i++
  281. if psr.i >= len(psr.parts) {
  282. return 0, io.EOF
  283. }
  284. psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
  285. return psr.Read(p)
  286. }
  287. // StringShort returns a short version of String.
  288. //
  289. // (Count of Total)
  290. func (ps *PartSet) StringShort() string {
  291. if ps == nil {
  292. return "nil-PartSet"
  293. }
  294. ps.mtx.Lock()
  295. defer ps.mtx.Unlock()
  296. return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
  297. }
  298. func (ps *PartSet) MarshalJSON() ([]byte, error) {
  299. if ps == nil {
  300. return []byte("{}"), nil
  301. }
  302. ps.mtx.Lock()
  303. defer ps.mtx.Unlock()
  304. return tmjson.Marshal(struct {
  305. CountTotal string `json:"count/total"`
  306. PartsBitArray *bits.BitArray `json:"parts_bit_array"`
  307. }{
  308. fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
  309. ps.partsBitArray,
  310. })
  311. }