You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

322 lines
6.9 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. package types
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "sync"
  8. "github.com/tendermint/tendermint/crypto/merkle"
  9. "github.com/tendermint/tendermint/libs/bits"
  10. tmbytes "github.com/tendermint/tendermint/libs/bytes"
  11. tmmath "github.com/tendermint/tendermint/libs/math"
  12. tmproto "github.com/tendermint/tendermint/proto/types"
  13. )
  14. var (
  15. ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
  16. ErrPartSetInvalidProof = errors.New("error part set invalid proof")
  17. )
  18. type Part struct {
  19. Index int `json:"index"`
  20. Bytes tmbytes.HexBytes `json:"bytes"`
  21. Proof merkle.SimpleProof `json:"proof"`
  22. }
  23. // ValidateBasic performs basic validation.
  24. func (part *Part) ValidateBasic() error {
  25. if part.Index < 0 {
  26. return errors.New("negative Index")
  27. }
  28. if len(part.Bytes) > BlockPartSizeBytes {
  29. return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
  30. }
  31. if err := part.Proof.ValidateBasic(); err != nil {
  32. return fmt.Errorf("wrong Proof: %w", err)
  33. }
  34. return nil
  35. }
  36. func (part *Part) String() string {
  37. return part.StringIndented("")
  38. }
  39. func (part *Part) StringIndented(indent string) string {
  40. return fmt.Sprintf(`Part{#%v
  41. %s Bytes: %X...
  42. %s Proof: %v
  43. %s}`,
  44. part.Index,
  45. indent, tmbytes.Fingerprint(part.Bytes),
  46. indent, part.Proof.StringIndented(indent+" "),
  47. indent)
  48. }
  49. //-------------------------------------
  50. type PartSetHeader struct {
  51. Total int `json:"total"`
  52. Hash tmbytes.HexBytes `json:"hash"`
  53. }
  54. func (psh PartSetHeader) String() string {
  55. return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
  56. }
  57. func (psh PartSetHeader) IsZero() bool {
  58. return psh.Total == 0 && len(psh.Hash) == 0
  59. }
  60. func (psh PartSetHeader) Equals(other PartSetHeader) bool {
  61. return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
  62. }
  63. // ValidateBasic performs basic validation.
  64. func (psh PartSetHeader) ValidateBasic() error {
  65. if psh.Total < 0 {
  66. return errors.New("negative Total")
  67. }
  68. // Hash can be empty in case of POLBlockID.PartsHeader in Proposal.
  69. if err := ValidateHash(psh.Hash); err != nil {
  70. return fmt.Errorf("wrong Hash: %w", err)
  71. }
  72. return nil
  73. }
  74. // ToProto converts BloPartSetHeaderckID to protobuf
  75. func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
  76. if psh == nil {
  77. return tmproto.PartSetHeader{}
  78. }
  79. return tmproto.PartSetHeader{
  80. Total: int64(psh.Total),
  81. Hash: psh.Hash,
  82. }
  83. }
  84. // FromProto sets a protobuf PartSetHeader to the given pointer
  85. func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
  86. if ppsh == nil {
  87. return nil, errors.New("nil PartSetHeader")
  88. }
  89. psh := new(PartSetHeader)
  90. psh.Total = int(ppsh.Total)
  91. psh.Hash = ppsh.Hash
  92. return psh, psh.ValidateBasic()
  93. }
  94. //-------------------------------------
  95. type PartSet struct {
  96. total int
  97. hash []byte
  98. mtx sync.Mutex
  99. parts []*Part
  100. partsBitArray *bits.BitArray
  101. count int
  102. }
  103. // Returns an immutable, full PartSet from the data bytes.
  104. // The data bytes are split into "partSize" chunks, and merkle tree computed.
  105. func NewPartSetFromData(data []byte, partSize int) *PartSet {
  106. // divide data into 4kb parts.
  107. total := (len(data) + partSize - 1) / partSize
  108. parts := make([]*Part, total)
  109. partsBytes := make([][]byte, total)
  110. partsBitArray := bits.NewBitArray(total)
  111. for i := 0; i < total; i++ {
  112. part := &Part{
  113. Index: i,
  114. Bytes: data[i*partSize : tmmath.MinInt(len(data), (i+1)*partSize)],
  115. }
  116. parts[i] = part
  117. partsBytes[i] = part.Bytes
  118. partsBitArray.SetIndex(i, true)
  119. }
  120. // Compute merkle proofs
  121. root, proofs := merkle.SimpleProofsFromByteSlices(partsBytes)
  122. for i := 0; i < total; i++ {
  123. parts[i].Proof = *proofs[i]
  124. }
  125. return &PartSet{
  126. total: total,
  127. hash: root,
  128. parts: parts,
  129. partsBitArray: partsBitArray,
  130. count: total,
  131. }
  132. }
  133. // Returns an empty PartSet ready to be populated.
  134. func NewPartSetFromHeader(header PartSetHeader) *PartSet {
  135. return &PartSet{
  136. total: header.Total,
  137. hash: header.Hash,
  138. parts: make([]*Part, header.Total),
  139. partsBitArray: bits.NewBitArray(header.Total),
  140. count: 0,
  141. }
  142. }
  143. func (ps *PartSet) Header() PartSetHeader {
  144. if ps == nil {
  145. return PartSetHeader{}
  146. }
  147. return PartSetHeader{
  148. Total: ps.total,
  149. Hash: ps.hash,
  150. }
  151. }
  152. func (ps *PartSet) HasHeader(header PartSetHeader) bool {
  153. if ps == nil {
  154. return false
  155. }
  156. return ps.Header().Equals(header)
  157. }
  158. func (ps *PartSet) BitArray() *bits.BitArray {
  159. ps.mtx.Lock()
  160. defer ps.mtx.Unlock()
  161. return ps.partsBitArray.Copy()
  162. }
  163. func (ps *PartSet) Hash() []byte {
  164. if ps == nil {
  165. return nil
  166. }
  167. return ps.hash
  168. }
  169. func (ps *PartSet) HashesTo(hash []byte) bool {
  170. if ps == nil {
  171. return false
  172. }
  173. return bytes.Equal(ps.hash, hash)
  174. }
  175. func (ps *PartSet) Count() int {
  176. if ps == nil {
  177. return 0
  178. }
  179. return ps.count
  180. }
  181. func (ps *PartSet) Total() int {
  182. if ps == nil {
  183. return 0
  184. }
  185. return ps.total
  186. }
  187. func (ps *PartSet) AddPart(part *Part) (bool, error) {
  188. if ps == nil {
  189. return false, nil
  190. }
  191. ps.mtx.Lock()
  192. defer ps.mtx.Unlock()
  193. // Invalid part index
  194. if part.Index >= ps.total {
  195. return false, ErrPartSetUnexpectedIndex
  196. }
  197. // If part already exists, return false.
  198. if ps.parts[part.Index] != nil {
  199. return false, nil
  200. }
  201. // Check hash proof
  202. if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
  203. return false, ErrPartSetInvalidProof
  204. }
  205. // Add part
  206. ps.parts[part.Index] = part
  207. ps.partsBitArray.SetIndex(part.Index, true)
  208. ps.count++
  209. return true, nil
  210. }
  211. func (ps *PartSet) GetPart(index int) *Part {
  212. ps.mtx.Lock()
  213. defer ps.mtx.Unlock()
  214. return ps.parts[index]
  215. }
  216. func (ps *PartSet) IsComplete() bool {
  217. return ps.count == ps.total
  218. }
  219. func (ps *PartSet) GetReader() io.Reader {
  220. if !ps.IsComplete() {
  221. panic("Cannot GetReader() on incomplete PartSet")
  222. }
  223. return NewPartSetReader(ps.parts)
  224. }
  225. type PartSetReader struct {
  226. i int
  227. parts []*Part
  228. reader *bytes.Reader
  229. }
  230. func NewPartSetReader(parts []*Part) *PartSetReader {
  231. return &PartSetReader{
  232. i: 0,
  233. parts: parts,
  234. reader: bytes.NewReader(parts[0].Bytes),
  235. }
  236. }
  237. func (psr *PartSetReader) Read(p []byte) (n int, err error) {
  238. readerLen := psr.reader.Len()
  239. if readerLen >= len(p) {
  240. return psr.reader.Read(p)
  241. } else if readerLen > 0 {
  242. n1, err := psr.Read(p[:readerLen])
  243. if err != nil {
  244. return n1, err
  245. }
  246. n2, err := psr.Read(p[readerLen:])
  247. return n1 + n2, err
  248. }
  249. psr.i++
  250. if psr.i >= len(psr.parts) {
  251. return 0, io.EOF
  252. }
  253. psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
  254. return psr.Read(p)
  255. }
  256. func (ps *PartSet) StringShort() string {
  257. if ps == nil {
  258. return "nil-PartSet"
  259. }
  260. ps.mtx.Lock()
  261. defer ps.mtx.Unlock()
  262. return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
  263. }
  264. func (ps *PartSet) MarshalJSON() ([]byte, error) {
  265. if ps == nil {
  266. return []byte("{}"), nil
  267. }
  268. ps.mtx.Lock()
  269. defer ps.mtx.Unlock()
  270. return cdc.MarshalJSON(struct {
  271. CountTotal string `json:"count/total"`
  272. PartsBitArray *bits.BitArray `json:"parts_bit_array"`
  273. }{
  274. fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
  275. ps.partsBitArray,
  276. })
  277. }