You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

252 lines
5.1 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. package types
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "strings"
  9. "sync"
  10. "github.com/tendermint/tendermint/binary"
  11. . "github.com/tendermint/tendermint/common"
  12. "github.com/tendermint/tendermint/merkle"
  13. )
  14. const (
  15. partSize = 4096 // 4KB
  16. )
  17. var (
  18. ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index")
  19. ErrPartSetInvalidTrail = errors.New("Error part set invalid trail")
  20. )
  21. type Part struct {
  22. Index uint `json:"index"`
  23. Trail [][]byte `json:"trail"`
  24. Bytes []byte `json:"bytes"`
  25. // Cache
  26. hash []byte
  27. }
  28. func (part *Part) Hash() []byte {
  29. if part.hash != nil {
  30. return part.hash
  31. } else {
  32. hasher := sha256.New()
  33. _, err := hasher.Write(part.Bytes)
  34. if err != nil {
  35. panic(err)
  36. }
  37. part.hash = hasher.Sum(nil)
  38. return part.hash
  39. }
  40. }
  41. func (part *Part) String() string {
  42. return part.StringIndented("")
  43. }
  44. func (part *Part) StringIndented(indent string) string {
  45. trailStrings := make([]string, len(part.Trail))
  46. for i, hash := range part.Trail {
  47. trailStrings[i] = fmt.Sprintf("%X", hash)
  48. }
  49. return fmt.Sprintf(`Part{
  50. %s Index: %v
  51. %s Trail:
  52. %s %v
  53. %s}`,
  54. indent, part.Index,
  55. indent,
  56. indent, strings.Join(trailStrings, "\n"+indent+" "),
  57. indent)
  58. }
  59. //-------------------------------------
  60. type PartSetHeader struct {
  61. Total uint `json:"total"`
  62. Hash []byte `json:"hash"`
  63. }
  64. func (psh PartSetHeader) String() string {
  65. return fmt.Sprintf("PartSet{T:%v %X}", psh.Total, Fingerprint(psh.Hash))
  66. }
  67. func (psh PartSetHeader) IsZero() bool {
  68. return psh.Total == 0
  69. }
  70. func (psh PartSetHeader) Equals(other PartSetHeader) bool {
  71. return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
  72. }
  73. func (psh PartSetHeader) WriteSignBytes(w io.Writer, n *int64, err *error) {
  74. binary.WriteTo([]byte(Fmt(`{"hash":"%X","total":%v}`, psh.Hash, psh.Total)), w, n, err)
  75. }
  76. //-------------------------------------
  77. type PartSet struct {
  78. total uint
  79. hash []byte
  80. mtx sync.Mutex
  81. parts []*Part
  82. partsBitArray *BitArray
  83. count uint
  84. }
  85. // Returns an immutable, full PartSet from the data bytes.
  86. // The data bytes are split into "partSize" chunks, and merkle tree computed.
  87. func NewPartSetFromData(data []byte) *PartSet {
  88. // divide data into 4kb parts.
  89. total := (len(data) + partSize - 1) / partSize
  90. parts := make([]*Part, total)
  91. parts_ := make([]merkle.Hashable, total)
  92. partsBitArray := NewBitArray(uint(total))
  93. for i := 0; i < total; i++ {
  94. part := &Part{
  95. Index: uint(i),
  96. Bytes: data[i*partSize : MinInt(len(data), (i+1)*partSize)],
  97. }
  98. parts[i] = part
  99. parts_[i] = part
  100. partsBitArray.SetIndex(uint(i), true)
  101. }
  102. // Compute merkle trails
  103. trails, rootTrail := merkle.HashTrailsFromHashables(parts_)
  104. for i := 0; i < total; i++ {
  105. parts[i].Trail = trails[i].Flatten()
  106. }
  107. return &PartSet{
  108. total: uint(total),
  109. hash: rootTrail.Hash,
  110. parts: parts,
  111. partsBitArray: partsBitArray,
  112. count: uint(total),
  113. }
  114. }
  115. // Returns an empty PartSet ready to be populated.
  116. func NewPartSetFromHeader(header PartSetHeader) *PartSet {
  117. return &PartSet{
  118. total: header.Total,
  119. hash: header.Hash,
  120. parts: make([]*Part, header.Total),
  121. partsBitArray: NewBitArray(uint(header.Total)),
  122. count: 0,
  123. }
  124. }
  125. func (ps *PartSet) Header() PartSetHeader {
  126. if ps == nil {
  127. return PartSetHeader{}
  128. } else {
  129. return PartSetHeader{
  130. Total: ps.total,
  131. Hash: ps.hash,
  132. }
  133. }
  134. }
  135. func (ps *PartSet) HasHeader(header PartSetHeader) bool {
  136. if ps == nil {
  137. return false
  138. } else {
  139. return ps.Header().Equals(header)
  140. }
  141. }
  142. func (ps *PartSet) BitArray() *BitArray {
  143. ps.mtx.Lock()
  144. defer ps.mtx.Unlock()
  145. return ps.partsBitArray.Copy()
  146. }
  147. func (ps *PartSet) Hash() []byte {
  148. if ps == nil {
  149. return nil
  150. }
  151. return ps.hash
  152. }
  153. func (ps *PartSet) HashesTo(hash []byte) bool {
  154. if ps == nil {
  155. return false
  156. }
  157. return bytes.Equal(ps.hash, hash)
  158. }
  159. func (ps *PartSet) Count() uint {
  160. if ps == nil {
  161. return 0
  162. }
  163. return ps.count
  164. }
  165. func (ps *PartSet) Total() uint {
  166. if ps == nil {
  167. return 0
  168. }
  169. return ps.total
  170. }
  171. func (ps *PartSet) AddPart(part *Part) (bool, error) {
  172. ps.mtx.Lock()
  173. defer ps.mtx.Unlock()
  174. // Invalid part index
  175. if part.Index >= ps.total {
  176. return false, ErrPartSetUnexpectedIndex
  177. }
  178. // If part already exists, return false.
  179. if ps.parts[part.Index] != nil {
  180. return false, nil
  181. }
  182. // Check hash trail
  183. if !merkle.VerifyHashTrail(uint(part.Index), uint(ps.total), part.Hash(), part.Trail, ps.hash) {
  184. return false, ErrPartSetInvalidTrail
  185. }
  186. // Add part
  187. ps.parts[part.Index] = part
  188. ps.partsBitArray.SetIndex(uint(part.Index), true)
  189. ps.count++
  190. return true, nil
  191. }
  192. func (ps *PartSet) GetPart(index uint) *Part {
  193. ps.mtx.Lock()
  194. defer ps.mtx.Unlock()
  195. return ps.parts[index]
  196. }
  197. func (ps *PartSet) IsComplete() bool {
  198. return ps.count == ps.total
  199. }
  200. func (ps *PartSet) GetReader() io.Reader {
  201. if !ps.IsComplete() {
  202. panic("Cannot GetReader() on incomplete PartSet")
  203. }
  204. buf := []byte{}
  205. for _, part := range ps.parts {
  206. buf = append(buf, part.Bytes...)
  207. }
  208. return bytes.NewReader(buf)
  209. }
  210. func (ps *PartSet) StringShort() string {
  211. if ps == nil {
  212. return "nil-PartSet"
  213. } else {
  214. return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
  215. }
  216. }