You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

348 lines
7.4 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. package types
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "sync"
  8. "github.com/tendermint/tendermint/crypto/merkle"
  9. "github.com/tendermint/tendermint/libs/bits"
  10. tmbytes "github.com/tendermint/tendermint/libs/bytes"
  11. tmmath "github.com/tendermint/tendermint/libs/math"
  12. tmproto "github.com/tendermint/tendermint/proto/types"
  13. )
  14. var (
  15. ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
  16. ErrPartSetInvalidProof = errors.New("error part set invalid proof")
  17. )
  18. type Part struct {
  19. Index uint32 `json:"index"`
  20. Bytes tmbytes.HexBytes `json:"bytes"`
  21. Proof merkle.SimpleProof `json:"proof"`
  22. }
  23. // ValidateBasic performs basic validation.
  24. func (part *Part) ValidateBasic() error {
  25. if len(part.Bytes) > int(BlockPartSizeBytes) {
  26. return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
  27. }
  28. if err := part.Proof.ValidateBasic(); err != nil {
  29. return fmt.Errorf("wrong Proof: %w", err)
  30. }
  31. return nil
  32. }
  33. func (part *Part) String() string {
  34. return part.StringIndented("")
  35. }
  36. func (part *Part) StringIndented(indent string) string {
  37. return fmt.Sprintf(`Part{#%v
  38. %s Bytes: %X...
  39. %s Proof: %v
  40. %s}`,
  41. part.Index,
  42. indent, tmbytes.Fingerprint(part.Bytes),
  43. indent, part.Proof.StringIndented(indent+" "),
  44. indent)
  45. }
  46. func (part *Part) ToProto() (*tmproto.Part, error) {
  47. if part == nil {
  48. return nil, errors.New("nil part")
  49. }
  50. pb := new(tmproto.Part)
  51. proof := part.Proof.ToProto()
  52. pb.Index = part.Index
  53. pb.Bytes = part.Bytes
  54. pb.Proof = *proof
  55. return pb, nil
  56. }
  57. func PartFromProto(pb *tmproto.Part) (*Part, error) {
  58. if pb == nil {
  59. return nil, errors.New("nil part")
  60. }
  61. part := new(Part)
  62. proof, err := merkle.SimpleProofFromProto(&pb.Proof)
  63. if err != nil {
  64. return nil, err
  65. }
  66. part.Index = pb.Index
  67. part.Bytes = pb.Bytes
  68. part.Proof = *proof
  69. return part, part.ValidateBasic()
  70. }
  71. //-------------------------------------
  72. type PartSetHeader struct {
  73. Total uint32 `json:"total"`
  74. Hash tmbytes.HexBytes `json:"hash"`
  75. }
  76. func (psh PartSetHeader) String() string {
  77. return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
  78. }
  79. func (psh PartSetHeader) IsZero() bool {
  80. return psh.Total == 0 && len(psh.Hash) == 0
  81. }
  82. func (psh PartSetHeader) Equals(other PartSetHeader) bool {
  83. return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
  84. }
  85. // ValidateBasic performs basic validation.
  86. func (psh PartSetHeader) ValidateBasic() error {
  87. // Hash can be empty in case of POLBlockID.PartsHeader in Proposal.
  88. if err := ValidateHash(psh.Hash); err != nil {
  89. return fmt.Errorf("wrong Hash: %w", err)
  90. }
  91. return nil
  92. }
  93. // ToProto converts BloPartSetHeaderckID to protobuf
  94. func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
  95. if psh == nil {
  96. return tmproto.PartSetHeader{}
  97. }
  98. return tmproto.PartSetHeader{
  99. Total: psh.Total,
  100. Hash: psh.Hash,
  101. }
  102. }
  103. // FromProto sets a protobuf PartSetHeader to the given pointer
  104. func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
  105. if ppsh == nil {
  106. return nil, errors.New("nil PartSetHeader")
  107. }
  108. psh := new(PartSetHeader)
  109. psh.Total = ppsh.Total
  110. psh.Hash = ppsh.Hash
  111. return psh, psh.ValidateBasic()
  112. }
  113. //-------------------------------------
  114. type PartSet struct {
  115. total uint32
  116. hash []byte
  117. mtx sync.Mutex
  118. parts []*Part
  119. partsBitArray *bits.BitArray
  120. count uint32
  121. }
  122. // Returns an immutable, full PartSet from the data bytes.
  123. // The data bytes are split into "partSize" chunks, and merkle tree computed.
  124. // CONTRACT: partSize is greater than zero.
  125. func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
  126. // divide data into 4kb parts.
  127. total := (uint32(len(data)) + partSize - 1) / partSize
  128. parts := make([]*Part, total)
  129. partsBytes := make([][]byte, total)
  130. partsBitArray := bits.NewBitArray(int(total))
  131. for i := uint32(0); i < total; i++ {
  132. part := &Part{
  133. Index: i,
  134. Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))],
  135. }
  136. parts[i] = part
  137. partsBytes[i] = part.Bytes
  138. partsBitArray.SetIndex(int(i), true)
  139. }
  140. // Compute merkle proofs
  141. root, proofs := merkle.SimpleProofsFromByteSlices(partsBytes)
  142. for i := uint32(0); i < total; i++ {
  143. parts[i].Proof = *proofs[i]
  144. }
  145. return &PartSet{
  146. total: total,
  147. hash: root,
  148. parts: parts,
  149. partsBitArray: partsBitArray,
  150. count: total,
  151. }
  152. }
  153. // Returns an empty PartSet ready to be populated.
  154. func NewPartSetFromHeader(header PartSetHeader) *PartSet {
  155. return &PartSet{
  156. total: header.Total,
  157. hash: header.Hash,
  158. parts: make([]*Part, header.Total),
  159. partsBitArray: bits.NewBitArray(int(header.Total)),
  160. count: 0,
  161. }
  162. }
  163. func (ps *PartSet) Header() PartSetHeader {
  164. if ps == nil {
  165. return PartSetHeader{}
  166. }
  167. return PartSetHeader{
  168. Total: ps.total,
  169. Hash: ps.hash,
  170. }
  171. }
  172. func (ps *PartSet) HasHeader(header PartSetHeader) bool {
  173. if ps == nil {
  174. return false
  175. }
  176. return ps.Header().Equals(header)
  177. }
  178. func (ps *PartSet) BitArray() *bits.BitArray {
  179. ps.mtx.Lock()
  180. defer ps.mtx.Unlock()
  181. return ps.partsBitArray.Copy()
  182. }
  183. func (ps *PartSet) Hash() []byte {
  184. if ps == nil {
  185. return nil
  186. }
  187. return ps.hash
  188. }
  189. func (ps *PartSet) HashesTo(hash []byte) bool {
  190. if ps == nil {
  191. return false
  192. }
  193. return bytes.Equal(ps.hash, hash)
  194. }
  195. func (ps *PartSet) Count() uint32 {
  196. if ps == nil {
  197. return 0
  198. }
  199. return ps.count
  200. }
  201. func (ps *PartSet) Total() uint32 {
  202. if ps == nil {
  203. return 0
  204. }
  205. return ps.total
  206. }
  207. func (ps *PartSet) AddPart(part *Part) (bool, error) {
  208. if ps == nil {
  209. return false, nil
  210. }
  211. ps.mtx.Lock()
  212. defer ps.mtx.Unlock()
  213. // Invalid part index
  214. if part.Index >= ps.total {
  215. return false, ErrPartSetUnexpectedIndex
  216. }
  217. // If part already exists, return false.
  218. if ps.parts[part.Index] != nil {
  219. return false, nil
  220. }
  221. // Check hash proof
  222. if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
  223. return false, ErrPartSetInvalidProof
  224. }
  225. // Add part
  226. ps.parts[part.Index] = part
  227. ps.partsBitArray.SetIndex(int(part.Index), true)
  228. ps.count++
  229. return true, nil
  230. }
  231. func (ps *PartSet) GetPart(index int) *Part {
  232. ps.mtx.Lock()
  233. defer ps.mtx.Unlock()
  234. return ps.parts[index]
  235. }
  236. func (ps *PartSet) IsComplete() bool {
  237. return ps.count == ps.total
  238. }
  239. func (ps *PartSet) GetReader() io.Reader {
  240. if !ps.IsComplete() {
  241. panic("Cannot GetReader() on incomplete PartSet")
  242. }
  243. return NewPartSetReader(ps.parts)
  244. }
  245. type PartSetReader struct {
  246. i int
  247. parts []*Part
  248. reader *bytes.Reader
  249. }
  250. func NewPartSetReader(parts []*Part) *PartSetReader {
  251. return &PartSetReader{
  252. i: 0,
  253. parts: parts,
  254. reader: bytes.NewReader(parts[0].Bytes),
  255. }
  256. }
  257. func (psr *PartSetReader) Read(p []byte) (n int, err error) {
  258. readerLen := psr.reader.Len()
  259. if readerLen >= len(p) {
  260. return psr.reader.Read(p)
  261. } else if readerLen > 0 {
  262. n1, err := psr.Read(p[:readerLen])
  263. if err != nil {
  264. return n1, err
  265. }
  266. n2, err := psr.Read(p[readerLen:])
  267. return n1 + n2, err
  268. }
  269. psr.i++
  270. if psr.i >= len(psr.parts) {
  271. return 0, io.EOF
  272. }
  273. psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
  274. return psr.Read(p)
  275. }
  276. func (ps *PartSet) StringShort() string {
  277. if ps == nil {
  278. return "nil-PartSet"
  279. }
  280. ps.mtx.Lock()
  281. defer ps.mtx.Unlock()
  282. return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
  283. }
  284. func (ps *PartSet) MarshalJSON() ([]byte, error) {
  285. if ps == nil {
  286. return []byte("{}"), nil
  287. }
  288. ps.mtx.Lock()
  289. defer ps.mtx.Unlock()
  290. return cdc.MarshalJSON(struct {
  291. CountTotal string `json:"count/total"`
  292. PartsBitArray *bits.BitArray `json:"parts_bit_array"`
  293. }{
  294. fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
  295. ps.partsBitArray,
  296. })
  297. }