You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

214 lines
4.3 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. package consensus
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "strings"
  9. "sync"
  10. . "github.com/tendermint/tendermint/binary"
  11. . "github.com/tendermint/tendermint/common"
  12. "github.com/tendermint/tendermint/merkle"
  13. )
  14. const (
  15. partSize = 4096 // 4KB
  16. )
  17. var (
  18. ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index")
  19. ErrPartSetInvalidTrail = errors.New("Error part set invalid trail")
  20. )
  21. type Part struct {
  22. Index uint16
  23. Trail [][]byte
  24. Bytes []byte
  25. // Cache
  26. hash []byte
  27. }
  28. func ReadPart(r io.Reader, n *int64, err *error) *Part {
  29. return &Part{
  30. Index: ReadUInt16(r, n, err),
  31. Trail: ReadByteSlices(r, n, err),
  32. Bytes: ReadByteSlice(r, n, err),
  33. }
  34. }
  35. func (b *Part) WriteTo(w io.Writer) (n int64, err error) {
  36. WriteUInt16(w, b.Index, &n, &err)
  37. WriteByteSlices(w, b.Trail, &n, &err)
  38. WriteByteSlice(w, b.Bytes, &n, &err)
  39. return
  40. }
  41. func (pt *Part) Hash() []byte {
  42. if pt.hash != nil {
  43. return pt.hash
  44. } else {
  45. hasher := sha256.New()
  46. _, err := hasher.Write(pt.Bytes)
  47. if err != nil {
  48. panic(err)
  49. }
  50. pt.hash = hasher.Sum(nil)
  51. return pt.hash
  52. }
  53. }
  54. func (pt *Part) String() string {
  55. return pt.StringWithIndent("")
  56. }
  57. func (pt *Part) StringWithIndent(indent string) string {
  58. trailStrings := make([]string, len(pt.Trail))
  59. for i, hash := range pt.Trail {
  60. trailStrings[i] = fmt.Sprintf("%X", hash)
  61. }
  62. return fmt.Sprintf(`Part{
  63. %s Index: %v
  64. %s Trail:
  65. %s %v
  66. %s}`,
  67. indent, pt.Index,
  68. indent,
  69. indent, strings.Join(trailStrings, "\n"+indent+" "),
  70. indent)
  71. }
  72. //-------------------------------------
  73. type PartSet struct {
  74. rootHash []byte
  75. total uint16
  76. mtx sync.Mutex
  77. parts []*Part
  78. partsBitArray BitArray
  79. count uint16
  80. }
  81. // Returns an immutable, full PartSet.
  82. func NewPartSetFromData(data []byte) *PartSet {
  83. // divide data into 4kb parts.
  84. total := (len(data) + partSize - 1) / partSize
  85. parts := make([]*Part, total)
  86. parts_ := make([]merkle.Hashable, total)
  87. partsBitArray := NewBitArray(uint(total))
  88. for i := 0; i < total; i++ {
  89. part := &Part{
  90. Index: uint16(i),
  91. Bytes: data[i*partSize : MinInt(len(data), (i+1)*partSize)],
  92. }
  93. parts[i] = part
  94. parts_[i] = part
  95. partsBitArray.SetIndex(uint(i), true)
  96. }
  97. // Compute merkle trails
  98. hashTree := merkle.HashTreeFromHashables(parts_)
  99. for i := 0; i < total; i++ {
  100. parts[i].Trail = merkle.HashTrailForIndex(hashTree, i)
  101. }
  102. return &PartSet{
  103. parts: parts,
  104. partsBitArray: partsBitArray,
  105. rootHash: hashTree[len(hashTree)/2],
  106. total: uint16(total),
  107. count: uint16(total),
  108. }
  109. }
  110. // Returns an empty PartSet ready to be populated.
  111. func NewPartSetFromMetadata(total uint16, rootHash []byte) *PartSet {
  112. return &PartSet{
  113. parts: make([]*Part, total),
  114. partsBitArray: NewBitArray(uint(total)),
  115. rootHash: rootHash,
  116. total: total,
  117. count: 0,
  118. }
  119. }
  120. func (ps *PartSet) BitArray() BitArray {
  121. ps.mtx.Lock()
  122. defer ps.mtx.Unlock()
  123. return ps.partsBitArray.Copy()
  124. }
  125. func (ps *PartSet) RootHash() []byte {
  126. return ps.rootHash
  127. }
  128. func (ps *PartSet) Count() uint16 {
  129. if ps == nil {
  130. return 0
  131. }
  132. return ps.count
  133. }
  134. func (ps *PartSet) Total() uint16 {
  135. if ps == nil {
  136. return 0
  137. }
  138. return ps.total
  139. }
  140. func (ps *PartSet) AddPart(part *Part) (bool, error) {
  141. ps.mtx.Lock()
  142. defer ps.mtx.Unlock()
  143. // Invalid part index
  144. if part.Index >= ps.total {
  145. return false, ErrPartSetUnexpectedIndex
  146. }
  147. // If part already exists, return false.
  148. if ps.parts[part.Index] != nil {
  149. return false, nil
  150. }
  151. // Check hash trail
  152. if !merkle.VerifyHashTrailForIndex(int(part.Index), part.Hash(), part.Trail, ps.rootHash) {
  153. return false, ErrPartSetInvalidTrail
  154. }
  155. // Add part
  156. ps.parts[part.Index] = part
  157. ps.partsBitArray.SetIndex(uint(part.Index), true)
  158. ps.count++
  159. return true, nil
  160. }
  161. func (ps *PartSet) GetPart(index uint16) *Part {
  162. ps.mtx.Lock()
  163. defer ps.mtx.Unlock()
  164. return ps.parts[index]
  165. }
  166. func (ps *PartSet) IsComplete() bool {
  167. return ps.count == ps.total
  168. }
  169. func (ps *PartSet) GetReader() io.Reader {
  170. if !ps.IsComplete() {
  171. panic("Cannot GetReader() on incomplete PartSet")
  172. }
  173. buf := []byte{}
  174. for _, part := range ps.parts {
  175. buf = append(buf, part.Bytes...)
  176. }
  177. return bytes.NewReader(buf)
  178. }
  179. func (ps *PartSet) Description() string {
  180. if ps == nil {
  181. return "nil-PartSet"
  182. } else {
  183. return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
  184. }
  185. }