You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

349 lines
7.5 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. package types
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "sync"
  8. "github.com/tendermint/tendermint/crypto/merkle"
  9. "github.com/tendermint/tendermint/libs/bits"
  10. tmbytes "github.com/tendermint/tendermint/libs/bytes"
  11. tmjson "github.com/tendermint/tendermint/libs/json"
  12. tmmath "github.com/tendermint/tendermint/libs/math"
  13. tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
  14. )
  15. var (
  16. ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
  17. ErrPartSetInvalidProof = errors.New("error part set invalid proof")
  18. )
  19. type Part struct {
  20. Index uint32 `json:"index"`
  21. Bytes tmbytes.HexBytes `json:"bytes"`
  22. Proof merkle.Proof `json:"proof"`
  23. }
  24. // ValidateBasic performs basic validation.
  25. func (part *Part) ValidateBasic() error {
  26. if len(part.Bytes) > int(BlockPartSizeBytes) {
  27. return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
  28. }
  29. if err := part.Proof.ValidateBasic(); err != nil {
  30. return fmt.Errorf("wrong Proof: %w", err)
  31. }
  32. return nil
  33. }
  34. func (part *Part) String() string {
  35. return part.StringIndented("")
  36. }
  37. func (part *Part) StringIndented(indent string) string {
  38. return fmt.Sprintf(`Part{#%v
  39. %s Bytes: %X...
  40. %s Proof: %v
  41. %s}`,
  42. part.Index,
  43. indent, tmbytes.Fingerprint(part.Bytes),
  44. indent, part.Proof.StringIndented(indent+" "),
  45. indent)
  46. }
  47. func (part *Part) ToProto() (*tmproto.Part, error) {
  48. if part == nil {
  49. return nil, errors.New("nil part")
  50. }
  51. pb := new(tmproto.Part)
  52. proof := part.Proof.ToProto()
  53. pb.Index = part.Index
  54. pb.Bytes = part.Bytes
  55. pb.Proof = *proof
  56. return pb, nil
  57. }
  58. func PartFromProto(pb *tmproto.Part) (*Part, error) {
  59. if pb == nil {
  60. return nil, errors.New("nil part")
  61. }
  62. part := new(Part)
  63. proof, err := merkle.ProofFromProto(&pb.Proof)
  64. if err != nil {
  65. return nil, err
  66. }
  67. part.Index = pb.Index
  68. part.Bytes = pb.Bytes
  69. part.Proof = *proof
  70. return part, part.ValidateBasic()
  71. }
  72. //-------------------------------------
  73. type PartSetHeader struct {
  74. Total uint32 `json:"total"`
  75. Hash tmbytes.HexBytes `json:"hash"`
  76. }
  77. func (psh PartSetHeader) String() string {
  78. return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
  79. }
  80. func (psh PartSetHeader) IsZero() bool {
  81. return psh.Total == 0 && len(psh.Hash) == 0
  82. }
  83. func (psh PartSetHeader) Equals(other PartSetHeader) bool {
  84. return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
  85. }
  86. // ValidateBasic performs basic validation.
  87. func (psh PartSetHeader) ValidateBasic() error {
  88. // Hash can be empty in case of POLBlockID.PartSetHeader in Proposal.
  89. if err := ValidateHash(psh.Hash); err != nil {
  90. return fmt.Errorf("wrong Hash: %w", err)
  91. }
  92. return nil
  93. }
  94. // ToProto converts BloPartSetHeaderckID to protobuf
  95. func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
  96. if psh == nil {
  97. return tmproto.PartSetHeader{}
  98. }
  99. return tmproto.PartSetHeader{
  100. Total: psh.Total,
  101. Hash: psh.Hash,
  102. }
  103. }
  104. // FromProto sets a protobuf PartSetHeader to the given pointer
  105. func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
  106. if ppsh == nil {
  107. return nil, errors.New("nil PartSetHeader")
  108. }
  109. psh := new(PartSetHeader)
  110. psh.Total = ppsh.Total
  111. psh.Hash = ppsh.Hash
  112. return psh, psh.ValidateBasic()
  113. }
  114. //-------------------------------------
  115. type PartSet struct {
  116. total uint32
  117. hash []byte
  118. mtx sync.Mutex
  119. parts []*Part
  120. partsBitArray *bits.BitArray
  121. count uint32
  122. }
  123. // Returns an immutable, full PartSet from the data bytes.
  124. // The data bytes are split into "partSize" chunks, and merkle tree computed.
  125. // CONTRACT: partSize is greater than zero.
  126. func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
  127. // divide data into 4kb parts.
  128. total := (uint32(len(data)) + partSize - 1) / partSize
  129. parts := make([]*Part, total)
  130. partsBytes := make([][]byte, total)
  131. partsBitArray := bits.NewBitArray(int(total))
  132. for i := uint32(0); i < total; i++ {
  133. part := &Part{
  134. Index: i,
  135. Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))],
  136. }
  137. parts[i] = part
  138. partsBytes[i] = part.Bytes
  139. partsBitArray.SetIndex(int(i), true)
  140. }
  141. // Compute merkle proofs
  142. root, proofs := merkle.ProofsFromByteSlices(partsBytes)
  143. for i := uint32(0); i < total; i++ {
  144. parts[i].Proof = *proofs[i]
  145. }
  146. return &PartSet{
  147. total: total,
  148. hash: root,
  149. parts: parts,
  150. partsBitArray: partsBitArray,
  151. count: total,
  152. }
  153. }
  154. // Returns an empty PartSet ready to be populated.
  155. func NewPartSetFromHeader(header PartSetHeader) *PartSet {
  156. return &PartSet{
  157. total: header.Total,
  158. hash: header.Hash,
  159. parts: make([]*Part, header.Total),
  160. partsBitArray: bits.NewBitArray(int(header.Total)),
  161. count: 0,
  162. }
  163. }
  164. func (ps *PartSet) Header() PartSetHeader {
  165. if ps == nil {
  166. return PartSetHeader{}
  167. }
  168. return PartSetHeader{
  169. Total: ps.total,
  170. Hash: ps.hash,
  171. }
  172. }
  173. func (ps *PartSet) HasHeader(header PartSetHeader) bool {
  174. if ps == nil {
  175. return false
  176. }
  177. return ps.Header().Equals(header)
  178. }
  179. func (ps *PartSet) BitArray() *bits.BitArray {
  180. ps.mtx.Lock()
  181. defer ps.mtx.Unlock()
  182. return ps.partsBitArray.Copy()
  183. }
  184. func (ps *PartSet) Hash() []byte {
  185. if ps == nil {
  186. return nil
  187. }
  188. return ps.hash
  189. }
  190. func (ps *PartSet) HashesTo(hash []byte) bool {
  191. if ps == nil {
  192. return false
  193. }
  194. return bytes.Equal(ps.hash, hash)
  195. }
  196. func (ps *PartSet) Count() uint32 {
  197. if ps == nil {
  198. return 0
  199. }
  200. return ps.count
  201. }
  202. func (ps *PartSet) Total() uint32 {
  203. if ps == nil {
  204. return 0
  205. }
  206. return ps.total
  207. }
  208. func (ps *PartSet) AddPart(part *Part) (bool, error) {
  209. if ps == nil {
  210. return false, nil
  211. }
  212. ps.mtx.Lock()
  213. defer ps.mtx.Unlock()
  214. // Invalid part index
  215. if part.Index >= ps.total {
  216. return false, ErrPartSetUnexpectedIndex
  217. }
  218. // If part already exists, return false.
  219. if ps.parts[part.Index] != nil {
  220. return false, nil
  221. }
  222. // Check hash proof
  223. if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
  224. return false, ErrPartSetInvalidProof
  225. }
  226. // Add part
  227. ps.parts[part.Index] = part
  228. ps.partsBitArray.SetIndex(int(part.Index), true)
  229. ps.count++
  230. return true, nil
  231. }
  232. func (ps *PartSet) GetPart(index int) *Part {
  233. ps.mtx.Lock()
  234. defer ps.mtx.Unlock()
  235. return ps.parts[index]
  236. }
  237. func (ps *PartSet) IsComplete() bool {
  238. return ps.count == ps.total
  239. }
  240. func (ps *PartSet) GetReader() io.Reader {
  241. if !ps.IsComplete() {
  242. panic("Cannot GetReader() on incomplete PartSet")
  243. }
  244. return NewPartSetReader(ps.parts)
  245. }
  246. type PartSetReader struct {
  247. i int
  248. parts []*Part
  249. reader *bytes.Reader
  250. }
  251. func NewPartSetReader(parts []*Part) *PartSetReader {
  252. return &PartSetReader{
  253. i: 0,
  254. parts: parts,
  255. reader: bytes.NewReader(parts[0].Bytes),
  256. }
  257. }
  258. func (psr *PartSetReader) Read(p []byte) (n int, err error) {
  259. readerLen := psr.reader.Len()
  260. if readerLen >= len(p) {
  261. return psr.reader.Read(p)
  262. } else if readerLen > 0 {
  263. n1, err := psr.Read(p[:readerLen])
  264. if err != nil {
  265. return n1, err
  266. }
  267. n2, err := psr.Read(p[readerLen:])
  268. return n1 + n2, err
  269. }
  270. psr.i++
  271. if psr.i >= len(psr.parts) {
  272. return 0, io.EOF
  273. }
  274. psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
  275. return psr.Read(p)
  276. }
  277. func (ps *PartSet) StringShort() string {
  278. if ps == nil {
  279. return "nil-PartSet"
  280. }
  281. ps.mtx.Lock()
  282. defer ps.mtx.Unlock()
  283. return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
  284. }
  285. func (ps *PartSet) MarshalJSON() ([]byte, error) {
  286. if ps == nil {
  287. return []byte("{}"), nil
  288. }
  289. ps.mtx.Lock()
  290. defer ps.mtx.Unlock()
  291. return tmjson.Marshal(struct {
  292. CountTotal string `json:"count/total"`
  293. PartsBitArray *bits.BitArray `json:"parts_bit_array"`
  294. }{
  295. fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
  296. ps.partsBitArray,
  297. })
  298. }