You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

644 lines
18 KiB

  1. package p2p_test
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "testing"
  7. "time"
  8. "github.com/fortytw2/leaktest"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "github.com/tendermint/tendermint/crypto/ed25519"
  12. "github.com/tendermint/tendermint/libs/bytes"
  13. "github.com/tendermint/tendermint/p2p"
  14. )
  15. // transportFactory is used to set up transports for tests.
  16. type transportFactory func(t *testing.T) p2p.Transport
  17. var (
  18. ctx = context.Background() // convenience context
  19. chID = p2p.ChannelID(1) // channel ID for use in tests
  20. testTransports = map[string]transportFactory{} // registry for withTransports
  21. )
  22. // withTransports is a test helper that runs a test against all transports
  23. // registered in testTransports.
  24. func withTransports(t *testing.T, tester func(*testing.T, transportFactory)) {
  25. t.Helper()
  26. for name, transportFactory := range testTransports {
  27. transportFactory := transportFactory
  28. t.Run(name, func(t *testing.T) {
  29. t.Cleanup(leaktest.Check(t))
  30. tester(t, transportFactory)
  31. })
  32. }
  33. }
  34. func TestTransport_AcceptClose(t *testing.T) {
  35. // Just test accept unblock on close, happy path is tested widely elsewhere.
  36. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  37. a := makeTransport(t)
  38. // In-progress Accept should error on concurrent close.
  39. errCh := make(chan error, 1)
  40. go func() {
  41. time.Sleep(200 * time.Millisecond)
  42. errCh <- a.Close()
  43. }()
  44. _, err := a.Accept()
  45. require.Error(t, err)
  46. require.Equal(t, io.EOF, err)
  47. require.NoError(t, <-errCh)
  48. // Closed transport should return error immediately.
  49. _, err = a.Accept()
  50. require.Error(t, err)
  51. require.Equal(t, io.EOF, err)
  52. })
  53. }
  54. func TestTransport_DialEndpoints(t *testing.T) {
  55. ipTestCases := []struct {
  56. ip net.IP
  57. ok bool
  58. }{
  59. {net.IPv4zero, true},
  60. {net.IPv6zero, true},
  61. {nil, false},
  62. {net.IPv4bcast, false},
  63. {net.IPv4allsys, false},
  64. {[]byte{1, 2, 3}, false},
  65. {[]byte{1, 2, 3, 4, 5}, false},
  66. }
  67. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  68. a := makeTransport(t)
  69. endpoints := a.Endpoints()
  70. require.NotEmpty(t, endpoints)
  71. endpoint := endpoints[0]
  72. // Spawn a goroutine to simply accept any connections until closed.
  73. go func() {
  74. for {
  75. conn, err := a.Accept()
  76. if err != nil {
  77. return
  78. }
  79. _ = conn.Close()
  80. }
  81. }()
  82. // Dialing self should work.
  83. conn, err := a.Dial(ctx, endpoint)
  84. require.NoError(t, err)
  85. require.NoError(t, conn.Close())
  86. // Dialing empty endpoint should error.
  87. _, err = a.Dial(ctx, p2p.Endpoint{})
  88. require.Error(t, err)
  89. // Dialing without protocol should error.
  90. noProtocol := endpoint
  91. noProtocol.Protocol = ""
  92. _, err = a.Dial(ctx, noProtocol)
  93. require.Error(t, err)
  94. // Dialing with invalid protocol should error.
  95. fooProtocol := endpoint
  96. fooProtocol.Protocol = "foo"
  97. _, err = a.Dial(ctx, fooProtocol)
  98. require.Error(t, err)
  99. // Tests for networked endpoints (with IP).
  100. if len(endpoint.IP) > 0 {
  101. for _, tc := range ipTestCases {
  102. tc := tc
  103. t.Run(tc.ip.String(), func(t *testing.T) {
  104. e := endpoint
  105. e.IP = tc.ip
  106. conn, err := a.Dial(ctx, e)
  107. if tc.ok {
  108. require.NoError(t, conn.Close())
  109. require.NoError(t, err)
  110. } else {
  111. require.Error(t, err)
  112. }
  113. })
  114. }
  115. // Non-networked endpoints should error.
  116. noIP := endpoint
  117. noIP.IP = nil
  118. noIP.Port = 0
  119. noIP.Path = "foo"
  120. _, err := a.Dial(ctx, noIP)
  121. require.Error(t, err)
  122. } else {
  123. // Tests for non-networked endpoints (no IP).
  124. noPath := endpoint
  125. noPath.Path = ""
  126. _, err = a.Dial(ctx, noPath)
  127. require.Error(t, err)
  128. }
  129. })
  130. }
  131. func TestTransport_Dial(t *testing.T) {
  132. // Most just tests dial failures, happy path is tested widely elsewhere.
  133. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  134. a := makeTransport(t)
  135. b := makeTransport(t)
  136. require.NotEmpty(t, a.Endpoints())
  137. require.NotEmpty(t, b.Endpoints())
  138. aEndpoint := a.Endpoints()[0]
  139. bEndpoint := b.Endpoints()[0]
  140. // Context cancellation should error. We can't test timeouts since we'd
  141. // need a non-responsive endpoint.
  142. cancelCtx, cancel := context.WithCancel(ctx)
  143. cancel()
  144. _, err := a.Dial(cancelCtx, bEndpoint)
  145. require.Error(t, err)
  146. require.Equal(t, err, context.Canceled)
  147. // Unavailable endpoint should error.
  148. err = b.Close()
  149. require.NoError(t, err)
  150. _, err = a.Dial(ctx, bEndpoint)
  151. require.Error(t, err)
  152. // Dialing from a closed transport should still work.
  153. errCh := make(chan error, 1)
  154. go func() {
  155. conn, err := a.Accept()
  156. if err == nil {
  157. _ = conn.Close()
  158. }
  159. errCh <- err
  160. }()
  161. conn, err := b.Dial(ctx, aEndpoint)
  162. require.NoError(t, err)
  163. require.NoError(t, conn.Close())
  164. require.NoError(t, <-errCh)
  165. })
  166. }
  167. func TestTransport_Endpoints(t *testing.T) {
  168. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  169. a := makeTransport(t)
  170. b := makeTransport(t)
  171. // Both transports return valid and different endpoints.
  172. aEndpoints := a.Endpoints()
  173. bEndpoints := b.Endpoints()
  174. require.NotEmpty(t, aEndpoints)
  175. require.NotEmpty(t, bEndpoints)
  176. require.NotEqual(t, aEndpoints, bEndpoints)
  177. for _, endpoint := range append(aEndpoints, bEndpoints...) {
  178. err := endpoint.Validate()
  179. require.NoError(t, err, "invalid endpoint %q", endpoint)
  180. }
  181. // When closed, the transport should no longer return any endpoints.
  182. err := a.Close()
  183. require.NoError(t, err)
  184. require.Empty(t, a.Endpoints())
  185. require.NotEmpty(t, b.Endpoints())
  186. })
  187. }
  188. func TestTransport_Protocols(t *testing.T) {
  189. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  190. a := makeTransport(t)
  191. protocols := a.Protocols()
  192. endpoints := a.Endpoints()
  193. require.NotEmpty(t, protocols)
  194. require.NotEmpty(t, endpoints)
  195. for _, endpoint := range endpoints {
  196. require.Contains(t, protocols, endpoint.Protocol)
  197. }
  198. })
  199. }
  200. func TestTransport_String(t *testing.T) {
  201. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  202. a := makeTransport(t)
  203. require.NotEmpty(t, a.String())
  204. })
  205. }
  206. func TestConnection_Handshake(t *testing.T) {
  207. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  208. a := makeTransport(t)
  209. b := makeTransport(t)
  210. ab, ba := dialAccept(t, a, b)
  211. // A handshake should pass the given keys and NodeInfo.
  212. aKey := ed25519.GenPrivKey()
  213. aInfo := p2p.NodeInfo{
  214. NodeID: p2p.NodeIDFromPubKey(aKey.PubKey()),
  215. ProtocolVersion: p2p.NewProtocolVersion(1, 2, 3),
  216. ListenAddr: "listenaddr",
  217. Network: "network",
  218. Version: "1.2.3",
  219. Channels: bytes.HexBytes([]byte{0xf0, 0x0f}),
  220. Moniker: "moniker",
  221. Other: p2p.NodeInfoOther{
  222. TxIndex: "txindex",
  223. RPCAddress: "rpc.domain.com",
  224. },
  225. }
  226. bKey := ed25519.GenPrivKey()
  227. bInfo := p2p.NodeInfo{NodeID: p2p.NodeIDFromPubKey(bKey.PubKey())}
  228. errCh := make(chan error, 1)
  229. go func() {
  230. // Must use assert due to goroutine.
  231. peerInfo, peerKey, err := ba.Handshake(ctx, bInfo, bKey)
  232. if err == nil {
  233. assert.Equal(t, aInfo, peerInfo)
  234. assert.Equal(t, aKey.PubKey(), peerKey)
  235. }
  236. errCh <- err
  237. }()
  238. peerInfo, peerKey, err := ab.Handshake(ctx, aInfo, aKey)
  239. require.NoError(t, err)
  240. require.Equal(t, bInfo, peerInfo)
  241. require.Equal(t, bKey.PubKey(), peerKey)
  242. require.NoError(t, <-errCh)
  243. })
  244. }
  245. func TestConnection_HandshakeCancel(t *testing.T) {
  246. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  247. a := makeTransport(t)
  248. b := makeTransport(t)
  249. // Handshake should error on context cancellation.
  250. ab, ba := dialAccept(t, a, b)
  251. timeoutCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
  252. cancel()
  253. _, _, err := ab.Handshake(timeoutCtx, p2p.NodeInfo{}, ed25519.GenPrivKey())
  254. require.Error(t, err)
  255. require.Equal(t, context.Canceled, err)
  256. _ = ab.Close()
  257. _ = ba.Close()
  258. // Handshake should error on context timeout.
  259. ab, ba = dialAccept(t, a, b)
  260. timeoutCtx, cancel = context.WithTimeout(ctx, 200*time.Millisecond)
  261. defer cancel()
  262. _, _, err = ab.Handshake(timeoutCtx, p2p.NodeInfo{}, ed25519.GenPrivKey())
  263. require.Error(t, err)
  264. require.Equal(t, context.DeadlineExceeded, err)
  265. _ = ab.Close()
  266. _ = ba.Close()
  267. })
  268. }
  269. func TestConnection_FlushClose(t *testing.T) {
  270. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  271. a := makeTransport(t)
  272. b := makeTransport(t)
  273. ab, _ := dialAcceptHandshake(t, a, b)
  274. // FIXME: FlushClose should be removed (and replaced by separate Flush
  275. // and Close calls if necessary). We can't reliably test it, so we just
  276. // make sure it closes both ends and that it's idempotent.
  277. err := ab.FlushClose()
  278. require.NoError(t, err)
  279. _, _, err = ab.ReceiveMessage()
  280. require.Error(t, err)
  281. require.Equal(t, io.EOF, err)
  282. _, err = ab.SendMessage(chID, []byte("closed"))
  283. require.Error(t, err)
  284. require.Equal(t, io.EOF, err)
  285. err = ab.FlushClose()
  286. require.NoError(t, err)
  287. })
  288. }
  289. func TestConnection_LocalRemoteEndpoint(t *testing.T) {
  290. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  291. a := makeTransport(t)
  292. b := makeTransport(t)
  293. ab, ba := dialAcceptHandshake(t, a, b)
  294. // Local and remote connection endpoints correspond to each other.
  295. require.NotEmpty(t, ab.LocalEndpoint())
  296. require.NotEmpty(t, ba.LocalEndpoint())
  297. require.Equal(t, ab.LocalEndpoint(), ba.RemoteEndpoint())
  298. require.Equal(t, ab.RemoteEndpoint(), ba.LocalEndpoint())
  299. })
  300. }
  301. func TestConnection_SendReceive(t *testing.T) {
  302. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  303. a := makeTransport(t)
  304. b := makeTransport(t)
  305. ab, ba := dialAcceptHandshake(t, a, b)
  306. // Can send and receive a to b.
  307. ok, err := ab.SendMessage(chID, []byte("foo"))
  308. require.NoError(t, err)
  309. require.True(t, ok)
  310. ch, msg, err := ba.ReceiveMessage()
  311. require.NoError(t, err)
  312. require.Equal(t, []byte("foo"), msg)
  313. require.Equal(t, chID, ch)
  314. // Can send and receive b to a.
  315. _, err = ba.SendMessage(chID, []byte("bar"))
  316. require.NoError(t, err)
  317. _, msg, err = ab.ReceiveMessage()
  318. require.NoError(t, err)
  319. require.Equal(t, []byte("bar"), msg)
  320. // TrySendMessage also works.
  321. ok, err = ba.TrySendMessage(chID, []byte("try"))
  322. require.NoError(t, err)
  323. require.True(t, ok)
  324. ch, msg, err = ab.ReceiveMessage()
  325. require.NoError(t, err)
  326. require.Equal(t, []byte("try"), msg)
  327. require.Equal(t, chID, ch)
  328. // Connections should still be active after closing the transports.
  329. err = a.Close()
  330. require.NoError(t, err)
  331. err = b.Close()
  332. require.NoError(t, err)
  333. _, err = ab.SendMessage(chID, []byte("still here"))
  334. require.NoError(t, err)
  335. ch, msg, err = ba.ReceiveMessage()
  336. require.NoError(t, err)
  337. require.Equal(t, chID, ch)
  338. require.Equal(t, []byte("still here"), msg)
  339. // Close one side of the connection. Both sides should then error
  340. // with io.EOF when trying to send or receive.
  341. err = ba.Close()
  342. require.NoError(t, err)
  343. _, _, err = ab.ReceiveMessage()
  344. require.Error(t, err)
  345. require.Equal(t, io.EOF, err)
  346. _, err = ab.SendMessage(chID, []byte("closed"))
  347. require.Error(t, err)
  348. require.Equal(t, io.EOF, err)
  349. _, _, err = ba.ReceiveMessage()
  350. require.Error(t, err)
  351. require.Equal(t, io.EOF, err)
  352. _, err = ba.SendMessage(chID, []byte("closed"))
  353. require.Error(t, err)
  354. require.Equal(t, io.EOF, err)
  355. })
  356. }
  357. func TestConnection_Status(t *testing.T) {
  358. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  359. a := makeTransport(t)
  360. b := makeTransport(t)
  361. ab, _ := dialAcceptHandshake(t, a, b)
  362. // FIXME: This isn't implemented in all transports, so for now we just
  363. // check that it doesn't panic, which isn't really much of a test.
  364. ab.Status()
  365. })
  366. }
  367. func TestConnection_String(t *testing.T) {
  368. withTransports(t, func(t *testing.T, makeTransport transportFactory) {
  369. a := makeTransport(t)
  370. b := makeTransport(t)
  371. ab, _ := dialAccept(t, a, b)
  372. require.NotEmpty(t, ab.String())
  373. })
  374. }
  375. func TestEndpoint_NodeAddress(t *testing.T) {
  376. var (
  377. ip4 = []byte{1, 2, 3, 4}
  378. ip4in6 = net.IPv4(1, 2, 3, 4)
  379. ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
  380. id = p2p.NodeID("00112233445566778899aabbccddeeff00112233")
  381. )
  382. testcases := []struct {
  383. endpoint p2p.Endpoint
  384. expect p2p.NodeAddress
  385. }{
  386. // Valid endpoints.
  387. {
  388. p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"},
  389. p2p.NodeAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"},
  390. },
  391. {
  392. p2p.Endpoint{Protocol: "tcp", IP: ip4in6, Port: 8080, Path: "path"},
  393. p2p.NodeAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"},
  394. },
  395. {
  396. p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "path"},
  397. p2p.NodeAddress{Protocol: "tcp", Hostname: "b10c::1", Port: 8080, Path: "path"},
  398. },
  399. {
  400. p2p.Endpoint{Protocol: "memory", Path: "foo"},
  401. p2p.NodeAddress{Protocol: "memory", Path: "foo"},
  402. },
  403. {
  404. p2p.Endpoint{Protocol: "memory", Path: string(id)},
  405. p2p.NodeAddress{Protocol: "memory", Path: string(id)},
  406. },
  407. // Partial (invalid) endpoints.
  408. {p2p.Endpoint{}, p2p.NodeAddress{}},
  409. {p2p.Endpoint{Protocol: "tcp"}, p2p.NodeAddress{Protocol: "tcp"}},
  410. {p2p.Endpoint{IP: net.IPv4(1, 2, 3, 4)}, p2p.NodeAddress{Hostname: "1.2.3.4"}},
  411. {p2p.Endpoint{Port: 8080}, p2p.NodeAddress{}},
  412. {p2p.Endpoint{Path: "path"}, p2p.NodeAddress{Path: "path"}},
  413. }
  414. for _, tc := range testcases {
  415. tc := tc
  416. t.Run(tc.endpoint.String(), func(t *testing.T) {
  417. // Without NodeID.
  418. expect := tc.expect
  419. require.Equal(t, expect, tc.endpoint.NodeAddress(""))
  420. // With NodeID.
  421. expect.NodeID = id
  422. require.Equal(t, expect, tc.endpoint.NodeAddress(expect.NodeID))
  423. })
  424. }
  425. }
  426. func TestEndpoint_String(t *testing.T) {
  427. var (
  428. ip4 = []byte{1, 2, 3, 4}
  429. ip4in6 = net.IPv4(1, 2, 3, 4)
  430. ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
  431. nodeID = p2p.NodeID("00112233445566778899aabbccddeeff00112233")
  432. )
  433. testcases := []struct {
  434. endpoint p2p.Endpoint
  435. expect string
  436. }{
  437. // Non-networked endpoints.
  438. {p2p.Endpoint{Protocol: "memory", Path: string(nodeID)}, "memory:" + string(nodeID)},
  439. {p2p.Endpoint{Protocol: "file", Path: "foo"}, "file:///foo"},
  440. {p2p.Endpoint{Protocol: "file", Path: "👋"}, "file:///%F0%9F%91%8B"},
  441. // IPv4 endpoints.
  442. {p2p.Endpoint{Protocol: "tcp", IP: ip4}, "tcp://1.2.3.4"},
  443. {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, "tcp://1.2.3.4"},
  444. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080}, "tcp://1.2.3.4:8080"},
  445. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "/path"}, "tcp://1.2.3.4:8080/path"},
  446. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Path: "path/👋"}, "tcp://1.2.3.4/path/%F0%9F%91%8B"},
  447. // IPv6 endpoints.
  448. {p2p.Endpoint{Protocol: "tcp", IP: ip6}, "tcp://b10c::1"},
  449. {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080}, "tcp://[b10c::1]:8080"},
  450. {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "/path"}, "tcp://[b10c::1]:8080/path"},
  451. {p2p.Endpoint{Protocol: "tcp", IP: ip6, Path: "path/👋"}, "tcp://b10c::1/path/%F0%9F%91%8B"},
  452. // Partial (invalid) endpoints.
  453. {p2p.Endpoint{}, ""},
  454. {p2p.Endpoint{Protocol: "tcp"}, "tcp:"},
  455. {p2p.Endpoint{IP: []byte{1, 2, 3, 4}}, "1.2.3.4"},
  456. {p2p.Endpoint{IP: []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}}, "b10c::1"},
  457. {p2p.Endpoint{Port: 8080}, ""},
  458. {p2p.Endpoint{Path: "foo"}, "/foo"},
  459. }
  460. for _, tc := range testcases {
  461. tc := tc
  462. t.Run(tc.expect, func(t *testing.T) {
  463. require.Equal(t, tc.expect, tc.endpoint.String())
  464. })
  465. }
  466. }
  467. func TestEndpoint_Validate(t *testing.T) {
  468. var (
  469. ip4 = []byte{1, 2, 3, 4}
  470. ip4in6 = net.IPv4(1, 2, 3, 4)
  471. ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
  472. )
  473. testcases := []struct {
  474. endpoint p2p.Endpoint
  475. expectValid bool
  476. }{
  477. // Valid endpoints.
  478. {p2p.Endpoint{Protocol: "tcp", IP: ip4}, true},
  479. {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, true},
  480. {p2p.Endpoint{Protocol: "tcp", IP: ip6}, true},
  481. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8008}, true},
  482. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"}, true},
  483. {p2p.Endpoint{Protocol: "memory", Path: "path"}, true},
  484. // Invalid endpoints.
  485. {p2p.Endpoint{}, false},
  486. {p2p.Endpoint{IP: ip4}, false},
  487. {p2p.Endpoint{Protocol: "tcp"}, false},
  488. {p2p.Endpoint{Protocol: "tcp", IP: []byte{1, 2, 3}}, false},
  489. {p2p.Endpoint{Protocol: "tcp", Port: 8080, Path: "path"}, false},
  490. }
  491. for _, tc := range testcases {
  492. tc := tc
  493. t.Run(tc.endpoint.String(), func(t *testing.T) {
  494. err := tc.endpoint.Validate()
  495. if tc.expectValid {
  496. require.NoError(t, err)
  497. } else {
  498. require.Error(t, err)
  499. }
  500. })
  501. }
  502. }
  503. // dialAccept is a helper that dials b from a and returns both sides of the
  504. // connection.
  505. func dialAccept(t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) {
  506. t.Helper()
  507. endpoints := b.Endpoints()
  508. require.NotEmpty(t, endpoints, "peer not listening on any endpoints")
  509. ctx, cancel := context.WithTimeout(ctx, time.Second)
  510. defer cancel()
  511. acceptCh := make(chan p2p.Connection, 1)
  512. errCh := make(chan error, 1)
  513. go func() {
  514. conn, err := b.Accept()
  515. errCh <- err
  516. acceptCh <- conn
  517. }()
  518. dialConn, err := a.Dial(ctx, endpoints[0])
  519. require.NoError(t, err)
  520. acceptConn := <-acceptCh
  521. require.NoError(t, <-errCh)
  522. t.Cleanup(func() {
  523. _ = dialConn.Close()
  524. _ = acceptConn.Close()
  525. })
  526. return dialConn, acceptConn
  527. }
  528. // dialAcceptHandshake is a helper that dials and handshakes b from a and
  529. // returns both sides of the connection.
  530. func dialAcceptHandshake(t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) {
  531. t.Helper()
  532. ab, ba := dialAccept(t, a, b)
  533. ctx, cancel := context.WithTimeout(ctx, time.Second)
  534. defer cancel()
  535. errCh := make(chan error, 1)
  536. go func() {
  537. privKey := ed25519.GenPrivKey()
  538. nodeInfo := p2p.NodeInfo{NodeID: p2p.NodeIDFromPubKey(privKey.PubKey())}
  539. _, _, err := ba.Handshake(ctx, nodeInfo, privKey)
  540. errCh <- err
  541. }()
  542. privKey := ed25519.GenPrivKey()
  543. nodeInfo := p2p.NodeInfo{NodeID: p2p.NodeIDFromPubKey(privKey.PubKey())}
  544. _, _, err := ab.Handshake(ctx, nodeInfo, privKey)
  545. require.NoError(t, err)
  546. timer := time.NewTimer(2 * time.Second)
  547. defer timer.Stop()
  548. select {
  549. case err := <-errCh:
  550. require.NoError(t, err)
  551. case <-timer.C:
  552. require.Fail(t, "handshake timed out")
  553. }
  554. return ab, ba
  555. }