You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

523 lines
14 KiB

  1. package conn
  2. import (
  3. "bufio"
  4. "encoding/hex"
  5. "flag"
  6. "fmt"
  7. "io"
  8. "log"
  9. "os"
  10. "path/filepath"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "testing"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/stretchr/testify/require"
  17. "github.com/tendermint/tendermint/crypto"
  18. "github.com/tendermint/tendermint/crypto/ed25519"
  19. "github.com/tendermint/tendermint/crypto/secp256k1"
  20. cmn "github.com/tendermint/tendermint/libs/common"
  21. )
  22. type kvstoreConn struct {
  23. *io.PipeReader
  24. *io.PipeWriter
  25. }
  26. func (drw kvstoreConn) Close() (err error) {
  27. err2 := drw.PipeWriter.CloseWithError(io.EOF)
  28. err1 := drw.PipeReader.Close()
  29. if err2 != nil {
  30. return err
  31. }
  32. return err1
  33. }
  34. // Each returned ReadWriteCloser is akin to a net.Connection
  35. func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
  36. barReader, fooWriter := io.Pipe()
  37. fooReader, barWriter := io.Pipe()
  38. return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
  39. }
  40. func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
  41. var fooConn, barConn = makeKVStoreConnPair()
  42. var fooPrvKey = ed25519.GenPrivKey()
  43. var fooPubKey = fooPrvKey.PubKey()
  44. var barPrvKey = ed25519.GenPrivKey()
  45. var barPubKey = barPrvKey.PubKey()
  46. // Make connections from both sides in parallel.
  47. var trs, ok = cmn.Parallel(
  48. func(_ int) (val interface{}, err error, abort bool) {
  49. fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
  50. if err != nil {
  51. tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
  52. return nil, err, true
  53. }
  54. remotePubBytes := fooSecConn.RemotePubKey()
  55. if !remotePubBytes.Equals(barPubKey) {
  56. err = fmt.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
  57. barPubKey, fooSecConn.RemotePubKey())
  58. tb.Error(err)
  59. return nil, err, false
  60. }
  61. return nil, nil, false
  62. },
  63. func(_ int) (val interface{}, err error, abort bool) {
  64. barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
  65. if barSecConn == nil {
  66. tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
  67. return nil, err, true
  68. }
  69. remotePubBytes := barSecConn.RemotePubKey()
  70. if !remotePubBytes.Equals(fooPubKey) {
  71. err = fmt.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v",
  72. fooPubKey, barSecConn.RemotePubKey())
  73. tb.Error(err)
  74. return nil, nil, false
  75. }
  76. return nil, nil, false
  77. },
  78. )
  79. require.Nil(tb, trs.FirstError())
  80. require.True(tb, ok, "Unexpected task abortion")
  81. return fooSecConn, barSecConn
  82. }
  83. func TestSecretConnectionHandshake(t *testing.T) {
  84. fooSecConn, barSecConn := makeSecretConnPair(t)
  85. if err := fooSecConn.Close(); err != nil {
  86. t.Error(err)
  87. }
  88. if err := barSecConn.Close(); err != nil {
  89. t.Error(err)
  90. }
  91. }
  92. // Test that shareEphPubKey rejects lower order public keys based on an
  93. // (incomplete) blacklist.
  94. func TestShareLowOrderPubkey(t *testing.T) {
  95. var fooConn, barConn = makeKVStoreConnPair()
  96. defer fooConn.Close()
  97. defer barConn.Close()
  98. locEphPub, _ := genEphKeys()
  99. // all blacklisted low order points:
  100. for _, remLowOrderPubKey := range blacklist {
  101. remLowOrderPubKey := remLowOrderPubKey
  102. _, _ = cmn.Parallel(
  103. func(_ int) (val interface{}, err error, abort bool) {
  104. _, err = shareEphPubKey(fooConn, locEphPub)
  105. require.Error(t, err)
  106. require.Equal(t, err, ErrSmallOrderRemotePubKey)
  107. return nil, nil, false
  108. },
  109. func(_ int) (val interface{}, err error, abort bool) {
  110. readRemKey, err := shareEphPubKey(barConn, &remLowOrderPubKey)
  111. require.NoError(t, err)
  112. require.Equal(t, locEphPub, readRemKey)
  113. return nil, nil, false
  114. })
  115. }
  116. }
  117. // Test that additionally that the Diffie-Hellman shared secret is non-zero.
  118. // The shared secret would be zero for lower order pub-keys (but tested against the blacklist only).
  119. func TestComputeDHFailsOnLowOrder(t *testing.T) {
  120. _, locPrivKey := genEphKeys()
  121. for _, remLowOrderPubKey := range blacklist {
  122. remLowOrderPubKey := remLowOrderPubKey
  123. shared, err := computeDHSecret(&remLowOrderPubKey, locPrivKey)
  124. assert.Error(t, err)
  125. assert.Equal(t, err, ErrSharedSecretIsZero)
  126. assert.Empty(t, shared)
  127. }
  128. }
  129. func TestConcurrentWrite(t *testing.T) {
  130. fooSecConn, barSecConn := makeSecretConnPair(t)
  131. fooWriteText := cmn.RandStr(dataMaxSize)
  132. // write from two routines.
  133. // should be safe from race according to net.Conn:
  134. // https://golang.org/pkg/net/#Conn
  135. n := 100
  136. wg := new(sync.WaitGroup)
  137. wg.Add(3)
  138. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  139. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  140. // Consume reads from bar's reader
  141. readLots(t, wg, barSecConn, n*2)
  142. wg.Wait()
  143. if err := fooSecConn.Close(); err != nil {
  144. t.Error(err)
  145. }
  146. }
  147. func TestConcurrentRead(t *testing.T) {
  148. fooSecConn, barSecConn := makeSecretConnPair(t)
  149. fooWriteText := cmn.RandStr(dataMaxSize)
  150. n := 100
  151. // read from two routines.
  152. // should be safe from race according to net.Conn:
  153. // https://golang.org/pkg/net/#Conn
  154. wg := new(sync.WaitGroup)
  155. wg.Add(3)
  156. go readLots(t, wg, fooSecConn, n/2)
  157. go readLots(t, wg, fooSecConn, n/2)
  158. // write to bar
  159. writeLots(t, wg, barSecConn, fooWriteText, n)
  160. wg.Wait()
  161. if err := fooSecConn.Close(); err != nil {
  162. t.Error(err)
  163. }
  164. }
  165. func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
  166. defer wg.Done()
  167. for i := 0; i < n; i++ {
  168. _, err := conn.Write([]byte(txt))
  169. if err != nil {
  170. t.Errorf("Failed to write to fooSecConn: %v", err)
  171. return
  172. }
  173. }
  174. }
  175. func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
  176. readBuffer := make([]byte, dataMaxSize)
  177. for i := 0; i < n; i++ {
  178. _, err := conn.Read(readBuffer)
  179. assert.NoError(t, err)
  180. }
  181. wg.Done()
  182. }
  183. func TestSecretConnectionReadWrite(t *testing.T) {
  184. fooConn, barConn := makeKVStoreConnPair()
  185. fooWrites, barWrites := []string{}, []string{}
  186. fooReads, barReads := []string{}, []string{}
  187. // Pre-generate the things to write (for foo & bar)
  188. for i := 0; i < 100; i++ {
  189. fooWrites = append(fooWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
  190. barWrites = append(barWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
  191. }
  192. // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
  193. genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) cmn.Task {
  194. return func(_ int) (interface{}, error, bool) {
  195. // Initiate cryptographic private key and secret connection trhough nodeConn.
  196. nodePrvKey := ed25519.GenPrivKey()
  197. nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
  198. if err != nil {
  199. t.Errorf("Failed to establish SecretConnection for node: %v", err)
  200. return nil, err, true
  201. }
  202. // In parallel, handle some reads and writes.
  203. var trs, ok = cmn.Parallel(
  204. func(_ int) (interface{}, error, bool) {
  205. // Node writes:
  206. for _, nodeWrite := range nodeWrites {
  207. n, err := nodeSecretConn.Write([]byte(nodeWrite))
  208. if err != nil {
  209. t.Errorf("Failed to write to nodeSecretConn: %v", err)
  210. return nil, err, true
  211. }
  212. if n != len(nodeWrite) {
  213. err = fmt.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
  214. t.Error(err)
  215. return nil, err, true
  216. }
  217. }
  218. if err := nodeConn.PipeWriter.Close(); err != nil {
  219. t.Error(err)
  220. return nil, err, true
  221. }
  222. return nil, nil, false
  223. },
  224. func(_ int) (interface{}, error, bool) {
  225. // Node reads:
  226. readBuffer := make([]byte, dataMaxSize)
  227. for {
  228. n, err := nodeSecretConn.Read(readBuffer)
  229. if err == io.EOF {
  230. if err := nodeConn.PipeReader.Close(); err != nil {
  231. t.Error(err)
  232. return nil, err, true
  233. }
  234. return nil, nil, false
  235. } else if err != nil {
  236. t.Errorf("Failed to read from nodeSecretConn: %v", err)
  237. return nil, err, true
  238. }
  239. *nodeReads = append(*nodeReads, string(readBuffer[:n]))
  240. }
  241. },
  242. )
  243. assert.True(t, ok, "Unexpected task abortion")
  244. // If error:
  245. if trs.FirstError() != nil {
  246. return nil, trs.FirstError(), true
  247. }
  248. // Otherwise:
  249. return nil, nil, false
  250. }
  251. }
  252. // Run foo & bar in parallel
  253. var trs, ok = cmn.Parallel(
  254. genNodeRunner("foo", fooConn, fooWrites, &fooReads),
  255. genNodeRunner("bar", barConn, barWrites, &barReads),
  256. )
  257. require.Nil(t, trs.FirstError())
  258. require.True(t, ok, "unexpected task abortion")
  259. // A helper to ensure that the writes and reads match.
  260. // Additionally, small writes (<= dataMaxSize) must be atomically read.
  261. compareWritesReads := func(writes []string, reads []string) {
  262. for {
  263. // Pop next write & corresponding reads
  264. var read, write string = "", writes[0]
  265. var readCount = 0
  266. for _, readChunk := range reads {
  267. read += readChunk
  268. readCount++
  269. if len(write) <= len(read) {
  270. break
  271. }
  272. if len(write) <= dataMaxSize {
  273. break // atomicity of small writes
  274. }
  275. }
  276. // Compare
  277. if write != read {
  278. t.Errorf("Expected to read %X, got %X", write, read)
  279. }
  280. // Iterate
  281. writes = writes[1:]
  282. reads = reads[readCount:]
  283. if len(writes) == 0 {
  284. break
  285. }
  286. }
  287. }
  288. compareWritesReads(fooWrites, barReads)
  289. compareWritesReads(barWrites, fooReads)
  290. }
  291. // Run go test -update from within this module
  292. // to update the golden test vector file
  293. var update = flag.Bool("update", false, "update .golden files")
  294. func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
  295. goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
  296. if *update {
  297. t.Logf("Updating golden test vector file %s", goldenFilepath)
  298. data := createGoldenTestVectors(t)
  299. cmn.WriteFile(goldenFilepath, []byte(data), 0644)
  300. }
  301. f, err := os.Open(goldenFilepath)
  302. if err != nil {
  303. log.Fatal(err)
  304. }
  305. defer f.Close()
  306. scanner := bufio.NewScanner(f)
  307. for scanner.Scan() {
  308. line := scanner.Text()
  309. params := strings.Split(line, ",")
  310. randSecretVector, err := hex.DecodeString(params[0])
  311. require.Nil(t, err)
  312. randSecret := new([32]byte)
  313. copy((*randSecret)[:], randSecretVector)
  314. locIsLeast, err := strconv.ParseBool(params[1])
  315. require.Nil(t, err)
  316. expectedRecvSecret, err := hex.DecodeString(params[2])
  317. require.Nil(t, err)
  318. expectedSendSecret, err := hex.DecodeString(params[3])
  319. require.Nil(t, err)
  320. expectedChallenge, err := hex.DecodeString(params[4])
  321. require.Nil(t, err)
  322. recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
  323. require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
  324. require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
  325. require.Equal(t, expectedChallenge, (*challenge)[:], "challenges aren't equal")
  326. }
  327. }
  328. type privKeyWithNilPubKey struct {
  329. orig crypto.PrivKey
  330. }
  331. func (pk privKeyWithNilPubKey) Bytes() []byte { return pk.orig.Bytes() }
  332. func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
  333. func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey { return nil }
  334. func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool { return pk.orig.Equals(pk2) }
  335. func TestNilPubkey(t *testing.T) {
  336. var fooConn, barConn = makeKVStoreConnPair()
  337. var fooPrvKey = ed25519.GenPrivKey()
  338. var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
  339. go func() {
  340. _, err := MakeSecretConnection(barConn, barPrvKey)
  341. assert.NoError(t, err)
  342. }()
  343. assert.NotPanics(t, func() {
  344. _, err := MakeSecretConnection(fooConn, fooPrvKey)
  345. if assert.Error(t, err) {
  346. assert.Equal(t, "expected ed25519 pubkey, got <nil>", err.Error())
  347. }
  348. })
  349. }
  350. func TestNonEd25519Pubkey(t *testing.T) {
  351. var fooConn, barConn = makeKVStoreConnPair()
  352. var fooPrvKey = ed25519.GenPrivKey()
  353. var barPrvKey = secp256k1.GenPrivKey()
  354. go func() {
  355. _, err := MakeSecretConnection(barConn, barPrvKey)
  356. assert.NoError(t, err)
  357. }()
  358. assert.NotPanics(t, func() {
  359. _, err := MakeSecretConnection(fooConn, fooPrvKey)
  360. if assert.Error(t, err) {
  361. assert.Equal(t, "expected ed25519 pubkey, got secp256k1.PubKeySecp256k1", err.Error())
  362. }
  363. })
  364. }
  365. // Creates the data for a test vector file.
  366. // The file format is:
  367. // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
  368. func createGoldenTestVectors(t *testing.T) string {
  369. data := ""
  370. for i := 0; i < 32; i++ {
  371. randSecretVector := cmn.RandBytes(32)
  372. randSecret := new([32]byte)
  373. copy((*randSecret)[:], randSecretVector)
  374. data += hex.EncodeToString((*randSecret)[:]) + ","
  375. locIsLeast := cmn.RandBool()
  376. data += strconv.FormatBool(locIsLeast) + ","
  377. recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
  378. data += hex.EncodeToString((*recvSecret)[:]) + ","
  379. data += hex.EncodeToString((*sendSecret)[:]) + ","
  380. data += hex.EncodeToString((*challenge)[:]) + "\n"
  381. }
  382. return data
  383. }
  384. func BenchmarkWriteSecretConnection(b *testing.B) {
  385. b.StopTimer()
  386. b.ReportAllocs()
  387. fooSecConn, barSecConn := makeSecretConnPair(b)
  388. randomMsgSizes := []int{
  389. dataMaxSize / 10,
  390. dataMaxSize / 3,
  391. dataMaxSize / 2,
  392. dataMaxSize,
  393. dataMaxSize * 3 / 2,
  394. dataMaxSize * 2,
  395. dataMaxSize * 7 / 2,
  396. }
  397. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  398. for _, size := range randomMsgSizes {
  399. fooWriteBytes = append(fooWriteBytes, cmn.RandBytes(size))
  400. }
  401. // Consume reads from bar's reader
  402. go func() {
  403. readBuffer := make([]byte, dataMaxSize)
  404. for {
  405. _, err := barSecConn.Read(readBuffer)
  406. if err == io.EOF {
  407. return
  408. } else if err != nil {
  409. b.Errorf("Failed to read from barSecConn: %v", err)
  410. return
  411. }
  412. }
  413. }()
  414. b.StartTimer()
  415. for i := 0; i < b.N; i++ {
  416. idx := cmn.RandIntn(len(fooWriteBytes))
  417. _, err := fooSecConn.Write(fooWriteBytes[idx])
  418. if err != nil {
  419. b.Errorf("Failed to write to fooSecConn: %v", err)
  420. return
  421. }
  422. }
  423. b.StopTimer()
  424. if err := fooSecConn.Close(); err != nil {
  425. b.Error(err)
  426. }
  427. //barSecConn.Close() race condition
  428. }
  429. func BenchmarkReadSecretConnection(b *testing.B) {
  430. b.StopTimer()
  431. b.ReportAllocs()
  432. fooSecConn, barSecConn := makeSecretConnPair(b)
  433. randomMsgSizes := []int{
  434. dataMaxSize / 10,
  435. dataMaxSize / 3,
  436. dataMaxSize / 2,
  437. dataMaxSize,
  438. dataMaxSize * 3 / 2,
  439. dataMaxSize * 2,
  440. dataMaxSize * 7 / 2,
  441. }
  442. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  443. for _, size := range randomMsgSizes {
  444. fooWriteBytes = append(fooWriteBytes, cmn.RandBytes(size))
  445. }
  446. go func() {
  447. for i := 0; i < b.N; i++ {
  448. idx := cmn.RandIntn(len(fooWriteBytes))
  449. _, err := fooSecConn.Write(fooWriteBytes[idx])
  450. if err != nil {
  451. b.Errorf("Failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
  452. return
  453. }
  454. }
  455. }()
  456. b.StartTimer()
  457. for i := 0; i < b.N; i++ {
  458. readBuffer := make([]byte, dataMaxSize)
  459. _, err := barSecConn.Read(readBuffer)
  460. if err == io.EOF {
  461. return
  462. } else if err != nil {
  463. b.Fatalf("Failed to read from barSecConn: %v", err)
  464. }
  465. }
  466. b.StopTimer()
  467. }