You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

247 lines
4.8 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. package types
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "strings"
  9. "sync"
  10. . "github.com/tendermint/tendermint/common"
  11. "github.com/tendermint/tendermint/merkle"
  12. )
  13. const (
  14. partSize = 4096 // 4KB
  15. )
  16. var (
  17. ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index")
  18. ErrPartSetInvalidTrail = errors.New("Error part set invalid trail")
  19. )
  20. type Part struct {
  21. Index uint
  22. Trail [][]byte
  23. Bytes []byte
  24. // Cache
  25. hash []byte
  26. }
  27. func (part *Part) Hash() []byte {
  28. if part.hash != nil {
  29. return part.hash
  30. } else {
  31. hasher := sha256.New()
  32. _, err := hasher.Write(part.Bytes)
  33. if err != nil {
  34. panic(err)
  35. }
  36. part.hash = hasher.Sum(nil)
  37. return part.hash
  38. }
  39. }
  40. func (part *Part) String() string {
  41. return part.StringIndented("")
  42. }
  43. func (part *Part) StringIndented(indent string) string {
  44. trailStrings := make([]string, len(part.Trail))
  45. for i, hash := range part.Trail {
  46. trailStrings[i] = fmt.Sprintf("%X", hash)
  47. }
  48. return fmt.Sprintf(`Part{
  49. %s Index: %v
  50. %s Trail:
  51. %s %v
  52. %s}`,
  53. indent, part.Index,
  54. indent,
  55. indent, strings.Join(trailStrings, "\n"+indent+" "),
  56. indent)
  57. }
  58. //-------------------------------------
  59. type PartSetHeader struct {
  60. Total uint
  61. Hash []byte
  62. }
  63. func (psh PartSetHeader) String() string {
  64. return fmt.Sprintf("PartSet{T:%v %X}", psh.Total, Fingerprint(psh.Hash))
  65. }
  66. func (psh PartSetHeader) IsZero() bool {
  67. return psh.Total == 0
  68. }
  69. func (psh PartSetHeader) Equals(other PartSetHeader) bool {
  70. return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
  71. }
  72. //-------------------------------------
  73. type PartSet struct {
  74. total uint
  75. hash []byte
  76. mtx sync.Mutex
  77. parts []*Part
  78. partsBitArray BitArray
  79. count uint
  80. }
  81. // Returns an immutable, full PartSet from the data bytes.
  82. // The data bytes are split into "partSize" chunks, and merkle tree computed.
  83. func NewPartSetFromData(data []byte) *PartSet {
  84. // divide data into 4kb parts.
  85. total := (len(data) + partSize - 1) / partSize
  86. parts := make([]*Part, total)
  87. parts_ := make([]merkle.Hashable, total)
  88. partsBitArray := NewBitArray(uint(total))
  89. for i := 0; i < total; i++ {
  90. part := &Part{
  91. Index: uint(i),
  92. Bytes: data[i*partSize : MinInt(len(data), (i+1)*partSize)],
  93. }
  94. parts[i] = part
  95. parts_[i] = part
  96. partsBitArray.SetIndex(uint(i), true)
  97. }
  98. // Compute merkle trails
  99. trails, rootTrail := merkle.HashTrailsFromHashables(parts_)
  100. for i := 0; i < total; i++ {
  101. parts[i].Trail = trails[i].Flatten()
  102. }
  103. return &PartSet{
  104. total: uint(total),
  105. hash: rootTrail.Hash,
  106. parts: parts,
  107. partsBitArray: partsBitArray,
  108. count: uint(total),
  109. }
  110. }
  111. // Returns an empty PartSet ready to be populated.
  112. func NewPartSetFromHeader(header PartSetHeader) *PartSet {
  113. return &PartSet{
  114. total: header.Total,
  115. hash: header.Hash,
  116. parts: make([]*Part, header.Total),
  117. partsBitArray: NewBitArray(uint(header.Total)),
  118. count: 0,
  119. }
  120. }
  121. func (ps *PartSet) Header() PartSetHeader {
  122. if ps == nil {
  123. return PartSetHeader{}
  124. } else {
  125. return PartSetHeader{
  126. Total: ps.total,
  127. Hash: ps.hash,
  128. }
  129. }
  130. }
  131. func (ps *PartSet) HasHeader(header PartSetHeader) bool {
  132. if ps == nil {
  133. return false
  134. } else {
  135. return ps.Header().Equals(header)
  136. }
  137. }
  138. func (ps *PartSet) BitArray() BitArray {
  139. ps.mtx.Lock()
  140. defer ps.mtx.Unlock()
  141. return ps.partsBitArray.Copy()
  142. }
  143. func (ps *PartSet) Hash() []byte {
  144. if ps == nil {
  145. return nil
  146. }
  147. return ps.hash
  148. }
  149. func (ps *PartSet) HashesTo(hash []byte) bool {
  150. if ps == nil {
  151. return false
  152. }
  153. return bytes.Equal(ps.hash, hash)
  154. }
  155. func (ps *PartSet) Count() uint {
  156. if ps == nil {
  157. return 0
  158. }
  159. return ps.count
  160. }
  161. func (ps *PartSet) Total() uint {
  162. if ps == nil {
  163. return 0
  164. }
  165. return ps.total
  166. }
  167. func (ps *PartSet) AddPart(part *Part) (bool, error) {
  168. ps.mtx.Lock()
  169. defer ps.mtx.Unlock()
  170. // Invalid part index
  171. if part.Index >= ps.total {
  172. return false, ErrPartSetUnexpectedIndex
  173. }
  174. // If part already exists, return false.
  175. if ps.parts[part.Index] != nil {
  176. return false, nil
  177. }
  178. // Check hash trail
  179. if !merkle.VerifyHashTrail(uint(part.Index), uint(ps.total), part.Hash(), part.Trail, ps.hash) {
  180. return false, ErrPartSetInvalidTrail
  181. }
  182. // Add part
  183. ps.parts[part.Index] = part
  184. ps.partsBitArray.SetIndex(uint(part.Index), true)
  185. ps.count++
  186. return true, nil
  187. }
  188. func (ps *PartSet) GetPart(index uint) *Part {
  189. ps.mtx.Lock()
  190. defer ps.mtx.Unlock()
  191. return ps.parts[index]
  192. }
  193. func (ps *PartSet) IsComplete() bool {
  194. return ps.count == ps.total
  195. }
  196. func (ps *PartSet) GetReader() io.Reader {
  197. if !ps.IsComplete() {
  198. panic("Cannot GetReader() on incomplete PartSet")
  199. }
  200. buf := []byte{}
  201. for _, part := range ps.parts {
  202. buf = append(buf, part.Bytes...)
  203. }
  204. return bytes.NewReader(buf)
  205. }
  206. func (ps *PartSet) StringShort() string {
  207. if ps == nil {
  208. return "nil-PartSet"
  209. } else {
  210. return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
  211. }
  212. }