You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

307 lines
6.3 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. package types
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "sync"
  7. "github.com/pkg/errors"
  8. "github.com/tendermint/tendermint/crypto/merkle"
  9. "github.com/tendermint/tendermint/crypto/tmhash"
  10. cmn "github.com/tendermint/tendermint/libs/common"
  11. )
  12. var (
  13. ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index")
  14. ErrPartSetInvalidProof = errors.New("Error part set invalid proof")
  15. )
  16. type Part struct {
  17. Index int `json:"index"`
  18. Bytes cmn.HexBytes `json:"bytes"`
  19. Proof merkle.SimpleProof `json:"proof"`
  20. // Cache
  21. hash []byte
  22. }
  23. func (part *Part) Hash() []byte {
  24. if part.hash != nil {
  25. return part.hash
  26. }
  27. hasher := tmhash.New()
  28. hasher.Write(part.Bytes) // nolint: errcheck, gas
  29. part.hash = hasher.Sum(nil)
  30. return part.hash
  31. }
  32. // ValidateBasic performs basic validation.
  33. func (part *Part) ValidateBasic() error {
  34. if part.Index < 0 {
  35. return errors.New("Negative Index")
  36. }
  37. if len(part.Bytes) > BlockPartSizeBytes {
  38. return fmt.Errorf("Too big (max: %d)", BlockPartSizeBytes)
  39. }
  40. return nil
  41. }
  42. func (part *Part) String() string {
  43. return part.StringIndented("")
  44. }
  45. func (part *Part) StringIndented(indent string) string {
  46. return fmt.Sprintf(`Part{#%v
  47. %s Bytes: %X...
  48. %s Proof: %v
  49. %s}`,
  50. part.Index,
  51. indent, cmn.Fingerprint(part.Bytes),
  52. indent, part.Proof.StringIndented(indent+" "),
  53. indent)
  54. }
  55. //-------------------------------------
  56. type PartSetHeader struct {
  57. Total int `json:"total"`
  58. Hash cmn.HexBytes `json:"hash"`
  59. }
  60. func (psh PartSetHeader) String() string {
  61. return fmt.Sprintf("%v:%X", psh.Total, cmn.Fingerprint(psh.Hash))
  62. }
  63. func (psh PartSetHeader) IsZero() bool {
  64. return psh.Total == 0 && len(psh.Hash) == 0
  65. }
  66. func (psh PartSetHeader) Equals(other PartSetHeader) bool {
  67. return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
  68. }
  69. // ValidateBasic performs basic validation.
  70. func (psh PartSetHeader) ValidateBasic() error {
  71. if psh.Total < 0 {
  72. return errors.New("Negative Total")
  73. }
  74. // Hash can be empty in case of POLBlockID.PartsHeader in Proposal.
  75. if err := ValidateHash(psh.Hash); err != nil {
  76. return errors.Wrap(err, "Wrong Hash")
  77. }
  78. return nil
  79. }
  80. //-------------------------------------
  81. type PartSet struct {
  82. total int
  83. hash []byte
  84. mtx sync.Mutex
  85. parts []*Part
  86. partsBitArray *cmn.BitArray
  87. count int
  88. }
  89. // Returns an immutable, full PartSet from the data bytes.
  90. // The data bytes are split into "partSize" chunks, and merkle tree computed.
  91. func NewPartSetFromData(data []byte, partSize int) *PartSet {
  92. // divide data into 4kb parts.
  93. total := (len(data) + partSize - 1) / partSize
  94. parts := make([]*Part, total)
  95. partsBytes := make([][]byte, total)
  96. partsBitArray := cmn.NewBitArray(total)
  97. for i := 0; i < total; i++ {
  98. part := &Part{
  99. Index: i,
  100. Bytes: data[i*partSize : cmn.MinInt(len(data), (i+1)*partSize)],
  101. }
  102. parts[i] = part
  103. partsBytes[i] = part.Bytes
  104. partsBitArray.SetIndex(i, true)
  105. }
  106. // Compute merkle proofs
  107. root, proofs := merkle.SimpleProofsFromByteSlices(partsBytes)
  108. for i := 0; i < total; i++ {
  109. parts[i].Proof = *proofs[i]
  110. }
  111. return &PartSet{
  112. total: total,
  113. hash: root,
  114. parts: parts,
  115. partsBitArray: partsBitArray,
  116. count: total,
  117. }
  118. }
  119. // Returns an empty PartSet ready to be populated.
  120. func NewPartSetFromHeader(header PartSetHeader) *PartSet {
  121. return &PartSet{
  122. total: header.Total,
  123. hash: header.Hash,
  124. parts: make([]*Part, header.Total),
  125. partsBitArray: cmn.NewBitArray(header.Total),
  126. count: 0,
  127. }
  128. }
  129. func (ps *PartSet) Header() PartSetHeader {
  130. if ps == nil {
  131. return PartSetHeader{}
  132. }
  133. return PartSetHeader{
  134. Total: ps.total,
  135. Hash: ps.hash,
  136. }
  137. }
  138. func (ps *PartSet) HasHeader(header PartSetHeader) bool {
  139. if ps == nil {
  140. return false
  141. }
  142. return ps.Header().Equals(header)
  143. }
  144. func (ps *PartSet) BitArray() *cmn.BitArray {
  145. ps.mtx.Lock()
  146. defer ps.mtx.Unlock()
  147. return ps.partsBitArray.Copy()
  148. }
  149. func (ps *PartSet) Hash() []byte {
  150. if ps == nil {
  151. return nil
  152. }
  153. return ps.hash
  154. }
  155. func (ps *PartSet) HashesTo(hash []byte) bool {
  156. if ps == nil {
  157. return false
  158. }
  159. return bytes.Equal(ps.hash, hash)
  160. }
  161. func (ps *PartSet) Count() int {
  162. if ps == nil {
  163. return 0
  164. }
  165. return ps.count
  166. }
  167. func (ps *PartSet) Total() int {
  168. if ps == nil {
  169. return 0
  170. }
  171. return ps.total
  172. }
  173. func (ps *PartSet) AddPart(part *Part) (bool, error) {
  174. if ps == nil {
  175. return false, nil
  176. }
  177. ps.mtx.Lock()
  178. defer ps.mtx.Unlock()
  179. // Invalid part index
  180. if part.Index >= ps.total {
  181. return false, ErrPartSetUnexpectedIndex
  182. }
  183. // If part already exists, return false.
  184. if ps.parts[part.Index] != nil {
  185. return false, nil
  186. }
  187. // Check hash proof
  188. if part.Proof.Verify(ps.Hash(), part.Hash()) != nil {
  189. return false, ErrPartSetInvalidProof
  190. }
  191. // Add part
  192. ps.parts[part.Index] = part
  193. ps.partsBitArray.SetIndex(part.Index, true)
  194. ps.count++
  195. return true, nil
  196. }
  197. func (ps *PartSet) GetPart(index int) *Part {
  198. ps.mtx.Lock()
  199. defer ps.mtx.Unlock()
  200. return ps.parts[index]
  201. }
  202. func (ps *PartSet) IsComplete() bool {
  203. return ps.count == ps.total
  204. }
  205. func (ps *PartSet) GetReader() io.Reader {
  206. if !ps.IsComplete() {
  207. cmn.PanicSanity("Cannot GetReader() on incomplete PartSet")
  208. }
  209. return NewPartSetReader(ps.parts)
  210. }
  211. type PartSetReader struct {
  212. i int
  213. parts []*Part
  214. reader *bytes.Reader
  215. }
  216. func NewPartSetReader(parts []*Part) *PartSetReader {
  217. return &PartSetReader{
  218. i: 0,
  219. parts: parts,
  220. reader: bytes.NewReader(parts[0].Bytes),
  221. }
  222. }
  223. func (psr *PartSetReader) Read(p []byte) (n int, err error) {
  224. readerLen := psr.reader.Len()
  225. if readerLen >= len(p) {
  226. return psr.reader.Read(p)
  227. } else if readerLen > 0 {
  228. n1, err := psr.Read(p[:readerLen])
  229. if err != nil {
  230. return n1, err
  231. }
  232. n2, err := psr.Read(p[readerLen:])
  233. return n1 + n2, err
  234. }
  235. psr.i++
  236. if psr.i >= len(psr.parts) {
  237. return 0, io.EOF
  238. }
  239. psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
  240. return psr.Read(p)
  241. }
  242. func (ps *PartSet) StringShort() string {
  243. if ps == nil {
  244. return "nil-PartSet"
  245. }
  246. ps.mtx.Lock()
  247. defer ps.mtx.Unlock()
  248. return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
  249. }
  250. func (ps *PartSet) MarshalJSON() ([]byte, error) {
  251. if ps == nil {
  252. return []byte("{}"), nil
  253. }
  254. ps.mtx.Lock()
  255. defer ps.mtx.Unlock()
  256. return cdc.MarshalJSON(struct {
  257. CountTotal string `json:"count/total"`
  258. PartsBitArray *cmn.BitArray `json:"parts_bit_array"`
  259. }{
  260. fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
  261. ps.partsBitArray,
  262. })
  263. }