You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

250 lines
6.7 KiB

9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
  1. package conn
  2. import (
  3. "fmt"
  4. "io"
  5. "testing"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/stretchr/testify/require"
  8. crypto "github.com/tendermint/go-crypto"
  9. cmn "github.com/tendermint/tmlibs/common"
  10. )
  11. type kvstoreConn struct {
  12. *io.PipeReader
  13. *io.PipeWriter
  14. }
  15. func (drw kvstoreConn) Close() (err error) {
  16. err2 := drw.PipeWriter.CloseWithError(io.EOF)
  17. err1 := drw.PipeReader.Close()
  18. if err2 != nil {
  19. return err
  20. }
  21. return err1
  22. }
  23. // Each returned ReadWriteCloser is akin to a net.Connection
  24. func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
  25. barReader, fooWriter := io.Pipe()
  26. fooReader, barWriter := io.Pipe()
  27. return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
  28. }
  29. func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
  30. var fooConn, barConn = makeKVStoreConnPair()
  31. var fooPrvKey = crypto.GenPrivKeyEd25519()
  32. var fooPubKey = fooPrvKey.PubKey()
  33. var barPrvKey = crypto.GenPrivKeyEd25519()
  34. var barPubKey = barPrvKey.PubKey()
  35. // Make connections from both sides in parallel.
  36. var trs, ok = cmn.Parallel(
  37. func(_ int) (val interface{}, err error, abort bool) {
  38. fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
  39. if err != nil {
  40. tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
  41. return nil, err, true
  42. }
  43. remotePubBytes := fooSecConn.RemotePubKey()
  44. if !remotePubBytes.Equals(barPubKey) {
  45. err = fmt.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
  46. barPubKey, fooSecConn.RemotePubKey())
  47. tb.Error(err)
  48. return nil, err, false
  49. }
  50. return nil, nil, false
  51. },
  52. func(_ int) (val interface{}, err error, abort bool) {
  53. barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
  54. if barSecConn == nil {
  55. tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
  56. return nil, err, true
  57. }
  58. remotePubBytes := barSecConn.RemotePubKey()
  59. if !remotePubBytes.Equals(fooPubKey) {
  60. err = fmt.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v",
  61. fooPubKey, barSecConn.RemotePubKey())
  62. tb.Error(err)
  63. return nil, nil, false
  64. }
  65. return nil, nil, false
  66. },
  67. )
  68. require.Nil(tb, trs.FirstError())
  69. require.True(tb, ok, "Unexpected task abortion")
  70. return
  71. }
  72. func TestSecretConnectionHandshake(t *testing.T) {
  73. fooSecConn, barSecConn := makeSecretConnPair(t)
  74. if err := fooSecConn.Close(); err != nil {
  75. t.Error(err)
  76. }
  77. if err := barSecConn.Close(); err != nil {
  78. t.Error(err)
  79. }
  80. }
  81. func TestSecretConnectionReadWrite(t *testing.T) {
  82. fooConn, barConn := makeKVStoreConnPair()
  83. fooWrites, barWrites := []string{}, []string{}
  84. fooReads, barReads := []string{}, []string{}
  85. // Pre-generate the things to write (for foo & bar)
  86. for i := 0; i < 100; i++ {
  87. fooWrites = append(fooWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
  88. barWrites = append(barWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
  89. }
  90. // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
  91. genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) cmn.Task {
  92. return func(_ int) (interface{}, error, bool) {
  93. // Initiate cryptographic private key and secret connection trhough nodeConn.
  94. nodePrvKey := crypto.GenPrivKeyEd25519()
  95. nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
  96. if err != nil {
  97. t.Errorf("Failed to establish SecretConnection for node: %v", err)
  98. return nil, err, true
  99. }
  100. // In parallel, handle some reads and writes.
  101. var trs, ok = cmn.Parallel(
  102. func(_ int) (interface{}, error, bool) {
  103. // Node writes:
  104. for _, nodeWrite := range nodeWrites {
  105. n, err := nodeSecretConn.Write([]byte(nodeWrite))
  106. if err != nil {
  107. t.Errorf("Failed to write to nodeSecretConn: %v", err)
  108. return nil, err, true
  109. }
  110. if n != len(nodeWrite) {
  111. err = fmt.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
  112. t.Error(err)
  113. return nil, err, true
  114. }
  115. }
  116. if err := nodeConn.PipeWriter.Close(); err != nil {
  117. t.Error(err)
  118. return nil, err, true
  119. }
  120. return nil, nil, false
  121. },
  122. func(_ int) (interface{}, error, bool) {
  123. // Node reads:
  124. readBuffer := make([]byte, dataMaxSize)
  125. for {
  126. n, err := nodeSecretConn.Read(readBuffer)
  127. if err == io.EOF {
  128. return nil, nil, false
  129. } else if err != nil {
  130. t.Errorf("Failed to read from nodeSecretConn: %v", err)
  131. return nil, err, true
  132. }
  133. *nodeReads = append(*nodeReads, string(readBuffer[:n]))
  134. }
  135. if err := nodeConn.PipeReader.Close(); err != nil {
  136. t.Error(err)
  137. return nil, err, true
  138. }
  139. return nil, nil, false
  140. },
  141. )
  142. assert.True(t, ok, "Unexpected task abortion")
  143. // If error:
  144. if trs.FirstError() != nil {
  145. return nil, trs.FirstError(), true
  146. }
  147. // Otherwise:
  148. return nil, nil, false
  149. }
  150. }
  151. // Run foo & bar in parallel
  152. var trs, ok = cmn.Parallel(
  153. genNodeRunner("foo", fooConn, fooWrites, &fooReads),
  154. genNodeRunner("bar", barConn, barWrites, &barReads),
  155. )
  156. require.Nil(t, trs.FirstError())
  157. require.True(t, ok, "unexpected task abortion")
  158. // A helper to ensure that the writes and reads match.
  159. // Additionally, small writes (<= dataMaxSize) must be atomically read.
  160. compareWritesReads := func(writes []string, reads []string) {
  161. for {
  162. // Pop next write & corresponding reads
  163. var read, write string = "", writes[0]
  164. var readCount = 0
  165. for _, readChunk := range reads {
  166. read += readChunk
  167. readCount++
  168. if len(write) <= len(read) {
  169. break
  170. }
  171. if len(write) <= dataMaxSize {
  172. break // atomicity of small writes
  173. }
  174. }
  175. // Compare
  176. if write != read {
  177. t.Errorf("Expected to read %X, got %X", write, read)
  178. }
  179. // Iterate
  180. writes = writes[1:]
  181. reads = reads[readCount:]
  182. if len(writes) == 0 {
  183. break
  184. }
  185. }
  186. }
  187. compareWritesReads(fooWrites, barReads)
  188. compareWritesReads(barWrites, fooReads)
  189. }
  190. func BenchmarkSecretConnection(b *testing.B) {
  191. b.StopTimer()
  192. fooSecConn, barSecConn := makeSecretConnPair(b)
  193. fooWriteText := cmn.RandStr(dataMaxSize)
  194. // Consume reads from bar's reader
  195. go func() {
  196. readBuffer := make([]byte, dataMaxSize)
  197. for {
  198. _, err := barSecConn.Read(readBuffer)
  199. if err == io.EOF {
  200. return
  201. } else if err != nil {
  202. b.Fatalf("Failed to read from barSecConn: %v", err)
  203. }
  204. }
  205. }()
  206. b.StartTimer()
  207. for i := 0; i < b.N; i++ {
  208. _, err := fooSecConn.Write([]byte(fooWriteText))
  209. if err != nil {
  210. b.Fatalf("Failed to write to fooSecConn: %v", err)
  211. }
  212. }
  213. b.StopTimer()
  214. if err := fooSecConn.Close(); err != nil {
  215. b.Error(err)
  216. }
  217. //barSecConn.Close() race condition
  218. }
  219. func fingerprint(bz []byte) []byte {
  220. const fbsize = 40
  221. if len(bz) < fbsize {
  222. return bz
  223. } else {
  224. return bz[:fbsize]
  225. }
  226. }