You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

296 lines
6.2 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. package types
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "sync"
  7. "github.com/pkg/errors"
  8. "github.com/tendermint/tendermint/crypto/merkle"
  9. cmn "github.com/tendermint/tendermint/libs/common"
  10. )
  11. var (
  12. ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index")
  13. ErrPartSetInvalidProof = errors.New("Error part set invalid proof")
  14. )
  15. type Part struct {
  16. Index int `json:"index"`
  17. Bytes cmn.HexBytes `json:"bytes"`
  18. Proof merkle.SimpleProof `json:"proof"`
  19. }
  20. // ValidateBasic performs basic validation.
  21. func (part *Part) ValidateBasic() error {
  22. if part.Index < 0 {
  23. return errors.New("negative Index")
  24. }
  25. if len(part.Bytes) > BlockPartSizeBytes {
  26. return errors.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
  27. }
  28. if err := part.Proof.ValidateBasic(); err != nil {
  29. return errors.Wrap(err, "wrong Proof")
  30. }
  31. return nil
  32. }
  33. func (part *Part) String() string {
  34. return part.StringIndented("")
  35. }
  36. func (part *Part) StringIndented(indent string) string {
  37. return fmt.Sprintf(`Part{#%v
  38. %s Bytes: %X...
  39. %s Proof: %v
  40. %s}`,
  41. part.Index,
  42. indent, cmn.Fingerprint(part.Bytes),
  43. indent, part.Proof.StringIndented(indent+" "),
  44. indent)
  45. }
  46. //-------------------------------------
  47. type PartSetHeader struct {
  48. Total int `json:"total"`
  49. Hash cmn.HexBytes `json:"hash"`
  50. }
  51. func (psh PartSetHeader) String() string {
  52. return fmt.Sprintf("%v:%X", psh.Total, cmn.Fingerprint(psh.Hash))
  53. }
  54. func (psh PartSetHeader) IsZero() bool {
  55. return psh.Total == 0 && len(psh.Hash) == 0
  56. }
  57. func (psh PartSetHeader) Equals(other PartSetHeader) bool {
  58. return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
  59. }
  60. // ValidateBasic performs basic validation.
  61. func (psh PartSetHeader) ValidateBasic() error {
  62. if psh.Total < 0 {
  63. return errors.New("Negative Total")
  64. }
  65. // Hash can be empty in case of POLBlockID.PartsHeader in Proposal.
  66. if err := ValidateHash(psh.Hash); err != nil {
  67. return errors.Wrap(err, "Wrong Hash")
  68. }
  69. return nil
  70. }
  71. //-------------------------------------
  72. type PartSet struct {
  73. total int
  74. hash []byte
  75. mtx sync.Mutex
  76. parts []*Part
  77. partsBitArray *cmn.BitArray
  78. count int
  79. }
  80. // Returns an immutable, full PartSet from the data bytes.
  81. // The data bytes are split into "partSize" chunks, and merkle tree computed.
  82. func NewPartSetFromData(data []byte, partSize int) *PartSet {
  83. // divide data into 4kb parts.
  84. total := (len(data) + partSize - 1) / partSize
  85. parts := make([]*Part, total)
  86. partsBytes := make([][]byte, total)
  87. partsBitArray := cmn.NewBitArray(total)
  88. for i := 0; i < total; i++ {
  89. part := &Part{
  90. Index: i,
  91. Bytes: data[i*partSize : cmn.MinInt(len(data), (i+1)*partSize)],
  92. }
  93. parts[i] = part
  94. partsBytes[i] = part.Bytes
  95. partsBitArray.SetIndex(i, true)
  96. }
  97. // Compute merkle proofs
  98. root, proofs := merkle.SimpleProofsFromByteSlices(partsBytes)
  99. for i := 0; i < total; i++ {
  100. parts[i].Proof = *proofs[i]
  101. }
  102. return &PartSet{
  103. total: total,
  104. hash: root,
  105. parts: parts,
  106. partsBitArray: partsBitArray,
  107. count: total,
  108. }
  109. }
  110. // Returns an empty PartSet ready to be populated.
  111. func NewPartSetFromHeader(header PartSetHeader) *PartSet {
  112. return &PartSet{
  113. total: header.Total,
  114. hash: header.Hash,
  115. parts: make([]*Part, header.Total),
  116. partsBitArray: cmn.NewBitArray(header.Total),
  117. count: 0,
  118. }
  119. }
  120. func (ps *PartSet) Header() PartSetHeader {
  121. if ps == nil {
  122. return PartSetHeader{}
  123. }
  124. return PartSetHeader{
  125. Total: ps.total,
  126. Hash: ps.hash,
  127. }
  128. }
  129. func (ps *PartSet) HasHeader(header PartSetHeader) bool {
  130. if ps == nil {
  131. return false
  132. }
  133. return ps.Header().Equals(header)
  134. }
  135. func (ps *PartSet) BitArray() *cmn.BitArray {
  136. ps.mtx.Lock()
  137. defer ps.mtx.Unlock()
  138. return ps.partsBitArray.Copy()
  139. }
  140. func (ps *PartSet) Hash() []byte {
  141. if ps == nil {
  142. return nil
  143. }
  144. return ps.hash
  145. }
  146. func (ps *PartSet) HashesTo(hash []byte) bool {
  147. if ps == nil {
  148. return false
  149. }
  150. return bytes.Equal(ps.hash, hash)
  151. }
  152. func (ps *PartSet) Count() int {
  153. if ps == nil {
  154. return 0
  155. }
  156. return ps.count
  157. }
  158. func (ps *PartSet) Total() int {
  159. if ps == nil {
  160. return 0
  161. }
  162. return ps.total
  163. }
  164. func (ps *PartSet) AddPart(part *Part) (bool, error) {
  165. if ps == nil {
  166. return false, nil
  167. }
  168. ps.mtx.Lock()
  169. defer ps.mtx.Unlock()
  170. // Invalid part index
  171. if part.Index >= ps.total {
  172. return false, ErrPartSetUnexpectedIndex
  173. }
  174. // If part already exists, return false.
  175. if ps.parts[part.Index] != nil {
  176. return false, nil
  177. }
  178. // Check hash proof
  179. if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
  180. return false, ErrPartSetInvalidProof
  181. }
  182. // Add part
  183. ps.parts[part.Index] = part
  184. ps.partsBitArray.SetIndex(part.Index, true)
  185. ps.count++
  186. return true, nil
  187. }
  188. func (ps *PartSet) GetPart(index int) *Part {
  189. ps.mtx.Lock()
  190. defer ps.mtx.Unlock()
  191. return ps.parts[index]
  192. }
  193. func (ps *PartSet) IsComplete() bool {
  194. return ps.count == ps.total
  195. }
  196. func (ps *PartSet) GetReader() io.Reader {
  197. if !ps.IsComplete() {
  198. panic("Cannot GetReader() on incomplete PartSet")
  199. }
  200. return NewPartSetReader(ps.parts)
  201. }
  202. type PartSetReader struct {
  203. i int
  204. parts []*Part
  205. reader *bytes.Reader
  206. }
  207. func NewPartSetReader(parts []*Part) *PartSetReader {
  208. return &PartSetReader{
  209. i: 0,
  210. parts: parts,
  211. reader: bytes.NewReader(parts[0].Bytes),
  212. }
  213. }
  214. func (psr *PartSetReader) Read(p []byte) (n int, err error) {
  215. readerLen := psr.reader.Len()
  216. if readerLen >= len(p) {
  217. return psr.reader.Read(p)
  218. } else if readerLen > 0 {
  219. n1, err := psr.Read(p[:readerLen])
  220. if err != nil {
  221. return n1, err
  222. }
  223. n2, err := psr.Read(p[readerLen:])
  224. return n1 + n2, err
  225. }
  226. psr.i++
  227. if psr.i >= len(psr.parts) {
  228. return 0, io.EOF
  229. }
  230. psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
  231. return psr.Read(p)
  232. }
  233. func (ps *PartSet) StringShort() string {
  234. if ps == nil {
  235. return "nil-PartSet"
  236. }
  237. ps.mtx.Lock()
  238. defer ps.mtx.Unlock()
  239. return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
  240. }
  241. func (ps *PartSet) MarshalJSON() ([]byte, error) {
  242. if ps == nil {
  243. return []byte("{}"), nil
  244. }
  245. ps.mtx.Lock()
  246. defer ps.mtx.Unlock()
  247. return cdc.MarshalJSON(struct {
  248. CountTotal string `json:"count/total"`
  249. PartsBitArray *cmn.BitArray `json:"parts_bit_array"`
  250. }{
  251. fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
  252. ps.partsBitArray,
  253. })
  254. }