You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

177 lines
3.6 KiB

10 years ago
  1. package consensus
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "errors"
  6. "io"
  7. "sync"
  8. . "github.com/tendermint/tendermint/binary"
  9. . "github.com/tendermint/tendermint/common"
  10. "github.com/tendermint/tendermint/merkle"
  11. )
  12. const (
  13. partSize = 4096 // 4KB
  14. )
  15. var (
  16. ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index")
  17. ErrPartSetInvalidTrail = errors.New("Error part set invalid trail")
  18. )
  19. type Part struct {
  20. Index uint16
  21. Trail [][]byte
  22. Bytes []byte
  23. // Cache
  24. hash []byte
  25. }
  26. func ReadPart(r io.Reader, n *int64, err *error) *Part {
  27. return &Part{
  28. Index: ReadUInt16(r, n, err),
  29. Trail: ReadByteSlices(r, n, err),
  30. Bytes: ReadByteSlice(r, n, err),
  31. }
  32. }
  33. func (b *Part) WriteTo(w io.Writer) (n int64, err error) {
  34. WriteUInt16(w, b.Index, &n, &err)
  35. WriteByteSlices(w, b.Trail, &n, &err)
  36. WriteByteSlice(w, b.Bytes, &n, &err)
  37. return
  38. }
  39. func (pt *Part) Hash() []byte {
  40. if pt.hash != nil {
  41. return pt.hash
  42. } else {
  43. hasher := sha256.New()
  44. _, err := hasher.Write(pt.Bytes)
  45. if err != nil {
  46. panic(err)
  47. }
  48. pt.hash = hasher.Sum(nil)
  49. return pt.hash
  50. }
  51. }
  52. //-------------------------------------
  53. type PartSet struct {
  54. rootHash []byte
  55. total uint16
  56. mtx sync.Mutex
  57. parts []*Part
  58. partsBitArray BitArray
  59. count uint16
  60. }
  61. // Returns an immutable, full PartSet.
  62. func NewPartSetFromData(data []byte) *PartSet {
  63. // divide data into 4kb parts.
  64. total := (len(data) + partSize - 1) / partSize
  65. parts := make([]*Part, total)
  66. parts_ := make([]merkle.Hashable, total)
  67. partsBitArray := NewBitArray(uint(total))
  68. for i := 0; i < total; i++ {
  69. part := &Part{
  70. Index: uint16(i),
  71. Bytes: data[i*partSize : MinInt(len(data), (i+1)*partSize)],
  72. }
  73. parts[i] = part
  74. parts_[i] = part
  75. partsBitArray.SetIndex(uint(i), true)
  76. }
  77. // Compute merkle trails
  78. hashTree := merkle.HashTreeFromHashables(parts_)
  79. for i := 0; i < total; i++ {
  80. parts[i].Trail = merkle.HashTrailForIndex(hashTree, i)
  81. }
  82. return &PartSet{
  83. parts: parts,
  84. partsBitArray: partsBitArray,
  85. rootHash: hashTree[len(hashTree)/2],
  86. total: uint16(total),
  87. count: uint16(total),
  88. }
  89. }
  90. // Returns an empty PartSet ready to be populated.
  91. func NewPartSetFromMetadata(total uint16, rootHash []byte) *PartSet {
  92. return &PartSet{
  93. parts: make([]*Part, total),
  94. partsBitArray: NewBitArray(uint(total)),
  95. rootHash: rootHash,
  96. total: total,
  97. count: 0,
  98. }
  99. }
  100. func (ps *PartSet) BitArray() BitArray {
  101. ps.mtx.Lock()
  102. defer ps.mtx.Unlock()
  103. return ps.partsBitArray.Copy()
  104. }
  105. func (ps *PartSet) RootHash() []byte {
  106. return ps.rootHash
  107. }
  108. func (ps *PartSet) Total() uint16 {
  109. if ps == nil {
  110. return 0
  111. }
  112. return ps.total
  113. }
  114. func (ps *PartSet) AddPart(part *Part) (bool, error) {
  115. ps.mtx.Lock()
  116. defer ps.mtx.Unlock()
  117. // Invalid part index
  118. if part.Index >= ps.total {
  119. return false, ErrPartSetUnexpectedIndex
  120. }
  121. // If part already exists, return false.
  122. if ps.parts[part.Index] != nil {
  123. return false, nil
  124. }
  125. // Check hash trail
  126. if !merkle.VerifyHashTrailForIndex(int(part.Index), part.Hash(), part.Trail, ps.rootHash) {
  127. return false, ErrPartSetInvalidTrail
  128. }
  129. // Add part
  130. ps.parts[part.Index] = part
  131. ps.partsBitArray.SetIndex(uint(part.Index), true)
  132. ps.count++
  133. return true, nil
  134. }
  135. func (ps *PartSet) GetPart(index uint16) *Part {
  136. ps.mtx.Lock()
  137. defer ps.mtx.Unlock()
  138. return ps.parts[index]
  139. }
  140. func (ps *PartSet) IsComplete() bool {
  141. return ps.count == ps.total
  142. }
  143. func (ps *PartSet) GetReader() io.Reader {
  144. if !ps.IsComplete() {
  145. panic("Cannot GetReader() on incomplete PartSet")
  146. }
  147. buf := []byte{}
  148. for _, part := range ps.parts {
  149. buf = append(buf, part.Bytes...)
  150. }
  151. return bytes.NewReader(buf)
  152. }