You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

476 lines
13 KiB

  1. package conn
  2. import (
  3. "bufio"
  4. "encoding/hex"
  5. "flag"
  6. "fmt"
  7. "io"
  8. "log"
  9. "os"
  10. "path/filepath"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "testing"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/stretchr/testify/require"
  17. "github.com/tendermint/tendermint/crypto/ed25519"
  18. cmn "github.com/tendermint/tendermint/libs/common"
  19. )
  20. type kvstoreConn struct {
  21. *io.PipeReader
  22. *io.PipeWriter
  23. }
  24. func (drw kvstoreConn) Close() (err error) {
  25. err2 := drw.PipeWriter.CloseWithError(io.EOF)
  26. err1 := drw.PipeReader.Close()
  27. if err2 != nil {
  28. return err
  29. }
  30. return err1
  31. }
  32. // Each returned ReadWriteCloser is akin to a net.Connection
  33. func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
  34. barReader, fooWriter := io.Pipe()
  35. fooReader, barWriter := io.Pipe()
  36. return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
  37. }
  38. func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
  39. var fooConn, barConn = makeKVStoreConnPair()
  40. var fooPrvKey = ed25519.GenPrivKey()
  41. var fooPubKey = fooPrvKey.PubKey()
  42. var barPrvKey = ed25519.GenPrivKey()
  43. var barPubKey = barPrvKey.PubKey()
  44. // Make connections from both sides in parallel.
  45. var trs, ok = cmn.Parallel(
  46. func(_ int) (val interface{}, err error, abort bool) {
  47. fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
  48. if err != nil {
  49. tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
  50. return nil, err, true
  51. }
  52. remotePubBytes := fooSecConn.RemotePubKey()
  53. if !remotePubBytes.Equals(barPubKey) {
  54. err = fmt.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
  55. barPubKey, fooSecConn.RemotePubKey())
  56. tb.Error(err)
  57. return nil, err, false
  58. }
  59. return nil, nil, false
  60. },
  61. func(_ int) (val interface{}, err error, abort bool) {
  62. barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
  63. if barSecConn == nil {
  64. tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
  65. return nil, err, true
  66. }
  67. remotePubBytes := barSecConn.RemotePubKey()
  68. if !remotePubBytes.Equals(fooPubKey) {
  69. err = fmt.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v",
  70. fooPubKey, barSecConn.RemotePubKey())
  71. tb.Error(err)
  72. return nil, nil, false
  73. }
  74. return nil, nil, false
  75. },
  76. )
  77. require.Nil(tb, trs.FirstError())
  78. require.True(tb, ok, "Unexpected task abortion")
  79. return fooSecConn, barSecConn
  80. }
  81. func TestSecretConnectionHandshake(t *testing.T) {
  82. fooSecConn, barSecConn := makeSecretConnPair(t)
  83. if err := fooSecConn.Close(); err != nil {
  84. t.Error(err)
  85. }
  86. if err := barSecConn.Close(); err != nil {
  87. t.Error(err)
  88. }
  89. }
  90. // Test that shareEphPubKey rejects lower order public keys based on an
  91. // (incomplete) blacklist.
  92. func TestShareLowOrderPubkey(t *testing.T) {
  93. var fooConn, barConn = makeKVStoreConnPair()
  94. defer fooConn.Close()
  95. defer barConn.Close()
  96. locEphPub, _ := genEphKeys()
  97. // all blacklisted low order points:
  98. for _, remLowOrderPubKey := range blacklist {
  99. remLowOrderPubKey := remLowOrderPubKey
  100. _, _ = cmn.Parallel(
  101. func(_ int) (val interface{}, err error, abort bool) {
  102. _, err = shareEphPubKey(fooConn, locEphPub)
  103. require.Error(t, err)
  104. require.Equal(t, err, ErrSmallOrderRemotePubKey)
  105. return nil, nil, false
  106. },
  107. func(_ int) (val interface{}, err error, abort bool) {
  108. readRemKey, err := shareEphPubKey(barConn, &remLowOrderPubKey)
  109. require.NoError(t, err)
  110. require.Equal(t, locEphPub, readRemKey)
  111. return nil, nil, false
  112. })
  113. }
  114. }
  115. // Test that additionally that the Diffie-Hellman shared secret is non-zero.
  116. // The shared secret would be zero for lower order pub-keys (but tested against the blacklist only).
  117. func TestComputeDHFailsOnLowOrder(t *testing.T) {
  118. _, locPrivKey := genEphKeys()
  119. for _, remLowOrderPubKey := range blacklist {
  120. remLowOrderPubKey := remLowOrderPubKey
  121. shared, err := computeDHSecret(&remLowOrderPubKey, locPrivKey)
  122. assert.Error(t, err)
  123. assert.Equal(t, err, ErrSharedSecretIsZero)
  124. assert.Empty(t, shared)
  125. }
  126. }
  127. func TestConcurrentWrite(t *testing.T) {
  128. fooSecConn, barSecConn := makeSecretConnPair(t)
  129. fooWriteText := cmn.RandStr(dataMaxSize)
  130. // write from two routines.
  131. // should be safe from race according to net.Conn:
  132. // https://golang.org/pkg/net/#Conn
  133. n := 100
  134. wg := new(sync.WaitGroup)
  135. wg.Add(3)
  136. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  137. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  138. // Consume reads from bar's reader
  139. readLots(t, wg, barSecConn, n*2)
  140. wg.Wait()
  141. if err := fooSecConn.Close(); err != nil {
  142. t.Error(err)
  143. }
  144. }
  145. func TestConcurrentRead(t *testing.T) {
  146. fooSecConn, barSecConn := makeSecretConnPair(t)
  147. fooWriteText := cmn.RandStr(dataMaxSize)
  148. n := 100
  149. // read from two routines.
  150. // should be safe from race according to net.Conn:
  151. // https://golang.org/pkg/net/#Conn
  152. wg := new(sync.WaitGroup)
  153. wg.Add(3)
  154. go readLots(t, wg, fooSecConn, n/2)
  155. go readLots(t, wg, fooSecConn, n/2)
  156. // write to bar
  157. writeLots(t, wg, barSecConn, fooWriteText, n)
  158. wg.Wait()
  159. if err := fooSecConn.Close(); err != nil {
  160. t.Error(err)
  161. }
  162. }
  163. func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
  164. defer wg.Done()
  165. for i := 0; i < n; i++ {
  166. _, err := conn.Write([]byte(txt))
  167. if err != nil {
  168. t.Errorf("Failed to write to fooSecConn: %v", err)
  169. return
  170. }
  171. }
  172. }
  173. func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
  174. readBuffer := make([]byte, dataMaxSize)
  175. for i := 0; i < n; i++ {
  176. _, err := conn.Read(readBuffer)
  177. assert.NoError(t, err)
  178. }
  179. wg.Done()
  180. }
  181. func TestSecretConnectionReadWrite(t *testing.T) {
  182. fooConn, barConn := makeKVStoreConnPair()
  183. fooWrites, barWrites := []string{}, []string{}
  184. fooReads, barReads := []string{}, []string{}
  185. // Pre-generate the things to write (for foo & bar)
  186. for i := 0; i < 100; i++ {
  187. fooWrites = append(fooWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
  188. barWrites = append(barWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
  189. }
  190. // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
  191. genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) cmn.Task {
  192. return func(_ int) (interface{}, error, bool) {
  193. // Initiate cryptographic private key and secret connection trhough nodeConn.
  194. nodePrvKey := ed25519.GenPrivKey()
  195. nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
  196. if err != nil {
  197. t.Errorf("Failed to establish SecretConnection for node: %v", err)
  198. return nil, err, true
  199. }
  200. // In parallel, handle some reads and writes.
  201. var trs, ok = cmn.Parallel(
  202. func(_ int) (interface{}, error, bool) {
  203. // Node writes:
  204. for _, nodeWrite := range nodeWrites {
  205. n, err := nodeSecretConn.Write([]byte(nodeWrite))
  206. if err != nil {
  207. t.Errorf("Failed to write to nodeSecretConn: %v", err)
  208. return nil, err, true
  209. }
  210. if n != len(nodeWrite) {
  211. err = fmt.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
  212. t.Error(err)
  213. return nil, err, true
  214. }
  215. }
  216. if err := nodeConn.PipeWriter.Close(); err != nil {
  217. t.Error(err)
  218. return nil, err, true
  219. }
  220. return nil, nil, false
  221. },
  222. func(_ int) (interface{}, error, bool) {
  223. // Node reads:
  224. readBuffer := make([]byte, dataMaxSize)
  225. for {
  226. n, err := nodeSecretConn.Read(readBuffer)
  227. if err == io.EOF {
  228. if err := nodeConn.PipeReader.Close(); err != nil {
  229. t.Error(err)
  230. return nil, err, true
  231. }
  232. return nil, nil, false
  233. } else if err != nil {
  234. t.Errorf("Failed to read from nodeSecretConn: %v", err)
  235. return nil, err, true
  236. }
  237. *nodeReads = append(*nodeReads, string(readBuffer[:n]))
  238. }
  239. },
  240. )
  241. assert.True(t, ok, "Unexpected task abortion")
  242. // If error:
  243. if trs.FirstError() != nil {
  244. return nil, trs.FirstError(), true
  245. }
  246. // Otherwise:
  247. return nil, nil, false
  248. }
  249. }
  250. // Run foo & bar in parallel
  251. var trs, ok = cmn.Parallel(
  252. genNodeRunner("foo", fooConn, fooWrites, &fooReads),
  253. genNodeRunner("bar", barConn, barWrites, &barReads),
  254. )
  255. require.Nil(t, trs.FirstError())
  256. require.True(t, ok, "unexpected task abortion")
  257. // A helper to ensure that the writes and reads match.
  258. // Additionally, small writes (<= dataMaxSize) must be atomically read.
  259. compareWritesReads := func(writes []string, reads []string) {
  260. for {
  261. // Pop next write & corresponding reads
  262. var read, write string = "", writes[0]
  263. var readCount = 0
  264. for _, readChunk := range reads {
  265. read += readChunk
  266. readCount++
  267. if len(write) <= len(read) {
  268. break
  269. }
  270. if len(write) <= dataMaxSize {
  271. break // atomicity of small writes
  272. }
  273. }
  274. // Compare
  275. if write != read {
  276. t.Errorf("Expected to read %X, got %X", write, read)
  277. }
  278. // Iterate
  279. writes = writes[1:]
  280. reads = reads[readCount:]
  281. if len(writes) == 0 {
  282. break
  283. }
  284. }
  285. }
  286. compareWritesReads(fooWrites, barReads)
  287. compareWritesReads(barWrites, fooReads)
  288. }
  289. // Run go test -update from within this module
  290. // to update the golden test vector file
  291. var update = flag.Bool("update", false, "update .golden files")
  292. func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
  293. goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
  294. if *update {
  295. t.Logf("Updating golden test vector file %s", goldenFilepath)
  296. data := createGoldenTestVectors(t)
  297. cmn.WriteFile(goldenFilepath, []byte(data), 0644)
  298. }
  299. f, err := os.Open(goldenFilepath)
  300. if err != nil {
  301. log.Fatal(err)
  302. }
  303. defer f.Close()
  304. scanner := bufio.NewScanner(f)
  305. for scanner.Scan() {
  306. line := scanner.Text()
  307. params := strings.Split(line, ",")
  308. randSecretVector, err := hex.DecodeString(params[0])
  309. require.Nil(t, err)
  310. randSecret := new([32]byte)
  311. copy((*randSecret)[:], randSecretVector)
  312. locIsLeast, err := strconv.ParseBool(params[1])
  313. require.Nil(t, err)
  314. expectedRecvSecret, err := hex.DecodeString(params[2])
  315. require.Nil(t, err)
  316. expectedSendSecret, err := hex.DecodeString(params[3])
  317. require.Nil(t, err)
  318. expectedChallenge, err := hex.DecodeString(params[4])
  319. require.Nil(t, err)
  320. recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
  321. require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
  322. require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
  323. require.Equal(t, expectedChallenge, (*challenge)[:], "challenges aren't equal")
  324. }
  325. }
  326. // Creates the data for a test vector file.
  327. // The file format is:
  328. // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
  329. func createGoldenTestVectors(t *testing.T) string {
  330. data := ""
  331. for i := 0; i < 32; i++ {
  332. randSecretVector := cmn.RandBytes(32)
  333. randSecret := new([32]byte)
  334. copy((*randSecret)[:], randSecretVector)
  335. data += hex.EncodeToString((*randSecret)[:]) + ","
  336. locIsLeast := cmn.RandBool()
  337. data += strconv.FormatBool(locIsLeast) + ","
  338. recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
  339. data += hex.EncodeToString((*recvSecret)[:]) + ","
  340. data += hex.EncodeToString((*sendSecret)[:]) + ","
  341. data += hex.EncodeToString((*challenge)[:]) + "\n"
  342. }
  343. return data
  344. }
  345. func BenchmarkWriteSecretConnection(b *testing.B) {
  346. b.StopTimer()
  347. b.ReportAllocs()
  348. fooSecConn, barSecConn := makeSecretConnPair(b)
  349. randomMsgSizes := []int{
  350. dataMaxSize / 10,
  351. dataMaxSize / 3,
  352. dataMaxSize / 2,
  353. dataMaxSize,
  354. dataMaxSize * 3 / 2,
  355. dataMaxSize * 2,
  356. dataMaxSize * 7 / 2,
  357. }
  358. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  359. for _, size := range randomMsgSizes {
  360. fooWriteBytes = append(fooWriteBytes, cmn.RandBytes(size))
  361. }
  362. // Consume reads from bar's reader
  363. go func() {
  364. readBuffer := make([]byte, dataMaxSize)
  365. for {
  366. _, err := barSecConn.Read(readBuffer)
  367. if err == io.EOF {
  368. return
  369. } else if err != nil {
  370. b.Errorf("Failed to read from barSecConn: %v", err)
  371. return
  372. }
  373. }
  374. }()
  375. b.StartTimer()
  376. for i := 0; i < b.N; i++ {
  377. idx := cmn.RandIntn(len(fooWriteBytes))
  378. _, err := fooSecConn.Write(fooWriteBytes[idx])
  379. if err != nil {
  380. b.Errorf("Failed to write to fooSecConn: %v", err)
  381. return
  382. }
  383. }
  384. b.StopTimer()
  385. if err := fooSecConn.Close(); err != nil {
  386. b.Error(err)
  387. }
  388. //barSecConn.Close() race condition
  389. }
  390. func BenchmarkReadSecretConnection(b *testing.B) {
  391. b.StopTimer()
  392. b.ReportAllocs()
  393. fooSecConn, barSecConn := makeSecretConnPair(b)
  394. randomMsgSizes := []int{
  395. dataMaxSize / 10,
  396. dataMaxSize / 3,
  397. dataMaxSize / 2,
  398. dataMaxSize,
  399. dataMaxSize * 3 / 2,
  400. dataMaxSize * 2,
  401. dataMaxSize * 7 / 2,
  402. }
  403. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  404. for _, size := range randomMsgSizes {
  405. fooWriteBytes = append(fooWriteBytes, cmn.RandBytes(size))
  406. }
  407. go func() {
  408. for i := 0; i < b.N; i++ {
  409. idx := cmn.RandIntn(len(fooWriteBytes))
  410. _, err := fooSecConn.Write(fooWriteBytes[idx])
  411. if err != nil {
  412. b.Errorf("Failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
  413. return
  414. }
  415. }
  416. }()
  417. b.StartTimer()
  418. for i := 0; i < b.N; i++ {
  419. readBuffer := make([]byte, dataMaxSize)
  420. _, err := barSecConn.Read(readBuffer)
  421. if err == io.EOF {
  422. return
  423. } else if err != nil {
  424. b.Fatalf("Failed to read from barSecConn: %v", err)
  425. }
  426. }
  427. b.StopTimer()
  428. }