You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

318 lines
9.0 KiB

  1. package conn
  2. import (
  3. "bufio"
  4. "encoding/hex"
  5. "flag"
  6. "fmt"
  7. "io"
  8. "log"
  9. "os"
  10. "path/filepath"
  11. "strconv"
  12. "strings"
  13. "testing"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. "github.com/tendermint/tendermint/crypto/ed25519"
  17. cmn "github.com/tendermint/tendermint/libs/common"
  18. )
  19. type kvstoreConn struct {
  20. *io.PipeReader
  21. *io.PipeWriter
  22. }
  23. func (drw kvstoreConn) Close() (err error) {
  24. err2 := drw.PipeWriter.CloseWithError(io.EOF)
  25. err1 := drw.PipeReader.Close()
  26. if err2 != nil {
  27. return err
  28. }
  29. return err1
  30. }
  31. // Each returned ReadWriteCloser is akin to a net.Connection
  32. func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
  33. barReader, fooWriter := io.Pipe()
  34. fooReader, barWriter := io.Pipe()
  35. return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
  36. }
  37. func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
  38. var fooConn, barConn = makeKVStoreConnPair()
  39. var fooPrvKey = ed25519.GenPrivKey()
  40. var fooPubKey = fooPrvKey.PubKey()
  41. var barPrvKey = ed25519.GenPrivKey()
  42. var barPubKey = barPrvKey.PubKey()
  43. // Make connections from both sides in parallel.
  44. var trs, ok = cmn.Parallel(
  45. func(_ int) (val interface{}, err error, abort bool) {
  46. fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
  47. if err != nil {
  48. tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
  49. return nil, err, true
  50. }
  51. remotePubBytes := fooSecConn.RemotePubKey()
  52. if !remotePubBytes.Equals(barPubKey) {
  53. err = fmt.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
  54. barPubKey, fooSecConn.RemotePubKey())
  55. tb.Error(err)
  56. return nil, err, false
  57. }
  58. return nil, nil, false
  59. },
  60. func(_ int) (val interface{}, err error, abort bool) {
  61. barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
  62. if barSecConn == nil {
  63. tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
  64. return nil, err, true
  65. }
  66. remotePubBytes := barSecConn.RemotePubKey()
  67. if !remotePubBytes.Equals(fooPubKey) {
  68. err = fmt.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v",
  69. fooPubKey, barSecConn.RemotePubKey())
  70. tb.Error(err)
  71. return nil, nil, false
  72. }
  73. return nil, nil, false
  74. },
  75. )
  76. require.Nil(tb, trs.FirstError())
  77. require.True(tb, ok, "Unexpected task abortion")
  78. return
  79. }
  80. func TestSecretConnectionHandshake(t *testing.T) {
  81. fooSecConn, barSecConn := makeSecretConnPair(t)
  82. if err := fooSecConn.Close(); err != nil {
  83. t.Error(err)
  84. }
  85. if err := barSecConn.Close(); err != nil {
  86. t.Error(err)
  87. }
  88. }
  89. func TestSecretConnectionReadWrite(t *testing.T) {
  90. fooConn, barConn := makeKVStoreConnPair()
  91. fooWrites, barWrites := []string{}, []string{}
  92. fooReads, barReads := []string{}, []string{}
  93. // Pre-generate the things to write (for foo & bar)
  94. for i := 0; i < 100; i++ {
  95. fooWrites = append(fooWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
  96. barWrites = append(barWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
  97. }
  98. // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
  99. genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) cmn.Task {
  100. return func(_ int) (interface{}, error, bool) {
  101. // Initiate cryptographic private key and secret connection trhough nodeConn.
  102. nodePrvKey := ed25519.GenPrivKey()
  103. nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
  104. if err != nil {
  105. t.Errorf("Failed to establish SecretConnection for node: %v", err)
  106. return nil, err, true
  107. }
  108. // In parallel, handle some reads and writes.
  109. var trs, ok = cmn.Parallel(
  110. func(_ int) (interface{}, error, bool) {
  111. // Node writes:
  112. for _, nodeWrite := range nodeWrites {
  113. n, err := nodeSecretConn.Write([]byte(nodeWrite))
  114. if err != nil {
  115. t.Errorf("Failed to write to nodeSecretConn: %v", err)
  116. return nil, err, true
  117. }
  118. if n != len(nodeWrite) {
  119. err = fmt.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
  120. t.Error(err)
  121. return nil, err, true
  122. }
  123. }
  124. if err := nodeConn.PipeWriter.Close(); err != nil {
  125. t.Error(err)
  126. return nil, err, true
  127. }
  128. return nil, nil, false
  129. },
  130. func(_ int) (interface{}, error, bool) {
  131. // Node reads:
  132. readBuffer := make([]byte, dataMaxSize)
  133. for {
  134. n, err := nodeSecretConn.Read(readBuffer)
  135. if err == io.EOF {
  136. return nil, nil, false
  137. } else if err != nil {
  138. t.Errorf("Failed to read from nodeSecretConn: %v", err)
  139. return nil, err, true
  140. }
  141. *nodeReads = append(*nodeReads, string(readBuffer[:n]))
  142. }
  143. if err := nodeConn.PipeReader.Close(); err != nil {
  144. t.Error(err)
  145. return nil, err, true
  146. }
  147. return nil, nil, false
  148. },
  149. )
  150. assert.True(t, ok, "Unexpected task abortion")
  151. // If error:
  152. if trs.FirstError() != nil {
  153. return nil, trs.FirstError(), true
  154. }
  155. // Otherwise:
  156. return nil, nil, false
  157. }
  158. }
  159. // Run foo & bar in parallel
  160. var trs, ok = cmn.Parallel(
  161. genNodeRunner("foo", fooConn, fooWrites, &fooReads),
  162. genNodeRunner("bar", barConn, barWrites, &barReads),
  163. )
  164. require.Nil(t, trs.FirstError())
  165. require.True(t, ok, "unexpected task abortion")
  166. // A helper to ensure that the writes and reads match.
  167. // Additionally, small writes (<= dataMaxSize) must be atomically read.
  168. compareWritesReads := func(writes []string, reads []string) {
  169. for {
  170. // Pop next write & corresponding reads
  171. var read, write string = "", writes[0]
  172. var readCount = 0
  173. for _, readChunk := range reads {
  174. read += readChunk
  175. readCount++
  176. if len(write) <= len(read) {
  177. break
  178. }
  179. if len(write) <= dataMaxSize {
  180. break // atomicity of small writes
  181. }
  182. }
  183. // Compare
  184. if write != read {
  185. t.Errorf("Expected to read %X, got %X", write, read)
  186. }
  187. // Iterate
  188. writes = writes[1:]
  189. reads = reads[readCount:]
  190. if len(writes) == 0 {
  191. break
  192. }
  193. }
  194. }
  195. compareWritesReads(fooWrites, barReads)
  196. compareWritesReads(barWrites, fooReads)
  197. }
  198. // Run go test -update from within this module
  199. // to update the golden test vector file
  200. var update = flag.Bool("update", false, "update .golden files")
  201. func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
  202. goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
  203. if *update {
  204. t.Logf("Updating golden test vector file %s", goldenFilepath)
  205. data := createGoldenTestVectors(t)
  206. cmn.WriteFile(goldenFilepath, []byte(data), 0644)
  207. }
  208. f, err := os.Open(goldenFilepath)
  209. if err != nil {
  210. log.Fatal(err)
  211. }
  212. defer f.Close()
  213. scanner := bufio.NewScanner(f)
  214. for scanner.Scan() {
  215. line := scanner.Text()
  216. params := strings.Split(line, ",")
  217. randSecretVector, err := hex.DecodeString(params[0])
  218. require.Nil(t, err)
  219. randSecret := new([32]byte)
  220. copy((*randSecret)[:], randSecretVector)
  221. locIsLeast, err := strconv.ParseBool(params[1])
  222. require.Nil(t, err)
  223. expectedRecvSecret, err := hex.DecodeString(params[2])
  224. require.Nil(t, err)
  225. expectedSendSecret, err := hex.DecodeString(params[3])
  226. require.Nil(t, err)
  227. expectedChallenge, err := hex.DecodeString(params[4])
  228. require.Nil(t, err)
  229. recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
  230. require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
  231. require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
  232. require.Equal(t, expectedChallenge, (*challenge)[:], "challenges aren't equal")
  233. }
  234. }
  235. // Creates the data for a test vector file.
  236. // The file format is:
  237. // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
  238. func createGoldenTestVectors(t *testing.T) string {
  239. data := ""
  240. for i := 0; i < 32; i++ {
  241. randSecretVector := cmn.RandBytes(32)
  242. randSecret := new([32]byte)
  243. copy((*randSecret)[:], randSecretVector)
  244. data += hex.EncodeToString((*randSecret)[:]) + ","
  245. locIsLeast := cmn.RandBool()
  246. data += strconv.FormatBool(locIsLeast) + ","
  247. recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
  248. data += hex.EncodeToString((*recvSecret)[:]) + ","
  249. data += hex.EncodeToString((*sendSecret)[:]) + ","
  250. data += hex.EncodeToString((*challenge)[:]) + "\n"
  251. }
  252. return data
  253. }
  254. func BenchmarkSecretConnection(b *testing.B) {
  255. b.StopTimer()
  256. fooSecConn, barSecConn := makeSecretConnPair(b)
  257. fooWriteText := cmn.RandStr(dataMaxSize)
  258. // Consume reads from bar's reader
  259. go func() {
  260. readBuffer := make([]byte, dataMaxSize)
  261. for {
  262. _, err := barSecConn.Read(readBuffer)
  263. if err == io.EOF {
  264. return
  265. } else if err != nil {
  266. b.Fatalf("Failed to read from barSecConn: %v", err)
  267. }
  268. }
  269. }()
  270. b.StartTimer()
  271. for i := 0; i < b.N; i++ {
  272. _, err := fooSecConn.Write([]byte(fooWriteText))
  273. if err != nil {
  274. b.Fatalf("Failed to write to fooSecConn: %v", err)
  275. }
  276. }
  277. b.StopTimer()
  278. if err := fooSecConn.Close(); err != nil {
  279. b.Error(err)
  280. }
  281. //barSecConn.Close() race condition
  282. }
  283. func fingerprint(bz []byte) []byte {
  284. const fbsize = 40
  285. if len(bz) < fbsize {
  286. return bz
  287. } else {
  288. return bz[:fbsize]
  289. }
  290. }