You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

656 lines
18 KiB

  1. package p2p_test
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "testing"
  7. "time"
  8. "github.com/fortytw2/leaktest"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "github.com/tendermint/tendermint/crypto/ed25519"
  12. "github.com/tendermint/tendermint/internal/p2p"
  13. "github.com/tendermint/tendermint/libs/bytes"
  14. "github.com/tendermint/tendermint/types"
  15. )
  16. // transportFactory is used to set up transports for tests.
  17. type transportFactory func(t *testing.T) p2p.Transport
  18. // testTransports is a registry of transport factories for withTransports().
  19. var testTransports = map[string]transportFactory{}
  20. // withTransports is a test helper that runs a test against all transports
  21. // registered in testTransports.
  22. func withTransports(ctx context.Context, t *testing.T, tester func(context.Context, *testing.T, transportFactory)) {
  23. t.Helper()
  24. for name, transportFactory := range testTransports {
  25. transportFactory := transportFactory
  26. t.Run(name, func(t *testing.T) {
  27. t.Cleanup(leaktest.Check(t))
  28. tctx, cancel := context.WithCancel(ctx)
  29. defer cancel()
  30. tester(tctx, t, transportFactory)
  31. })
  32. }
  33. }
  34. func TestTransport_AcceptClose(t *testing.T) {
  35. // Just test accept unblock on close, happy path is tested widely elsewhere.
  36. ctx, cancel := context.WithCancel(context.Background())
  37. defer cancel()
  38. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  39. a := makeTransport(t)
  40. opctx, opcancel := context.WithTimeout(ctx, 200*time.Millisecond)
  41. defer opcancel()
  42. _, err := a.Accept(opctx)
  43. require.Error(t, err)
  44. require.Equal(t, io.EOF, err)
  45. <-opctx.Done()
  46. _ = a.Close()
  47. // Closed transport should return error immediately,
  48. // because the transport is closed. We use the base
  49. // context (ctx) rather than the operation context
  50. // (opctx) because using the later would mean this
  51. // could error because the context was canceled.
  52. _, err = a.Accept(ctx)
  53. require.Error(t, err)
  54. require.Equal(t, io.EOF, err)
  55. })
  56. }
  57. func TestTransport_DialEndpoints(t *testing.T) {
  58. ipTestCases := []struct {
  59. ip net.IP
  60. ok bool
  61. }{
  62. {net.IPv4zero, true},
  63. {net.IPv6zero, true},
  64. {nil, false},
  65. {net.IPv4bcast, false},
  66. {net.IPv4allsys, false},
  67. {[]byte{1, 2, 3}, false},
  68. {[]byte{1, 2, 3, 4, 5}, false},
  69. }
  70. ctx, cancel := context.WithCancel(context.Background())
  71. defer cancel()
  72. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  73. a := makeTransport(t)
  74. endpoints := a.Endpoints()
  75. require.NotEmpty(t, endpoints)
  76. endpoint := endpoints[0]
  77. // Spawn a goroutine to simply accept any connections until closed.
  78. go func() {
  79. for {
  80. conn, err := a.Accept(ctx)
  81. if err != nil {
  82. return
  83. }
  84. _ = conn.Close()
  85. }
  86. }()
  87. // Dialing self should work.
  88. conn, err := a.Dial(ctx, endpoint)
  89. require.NoError(t, err)
  90. require.NoError(t, conn.Close())
  91. // Dialing empty endpoint should error.
  92. _, err = a.Dial(ctx, p2p.Endpoint{})
  93. require.Error(t, err)
  94. // Dialing without protocol should error.
  95. noProtocol := endpoint
  96. noProtocol.Protocol = ""
  97. _, err = a.Dial(ctx, noProtocol)
  98. require.Error(t, err)
  99. // Dialing with invalid protocol should error.
  100. fooProtocol := endpoint
  101. fooProtocol.Protocol = "foo"
  102. _, err = a.Dial(ctx, fooProtocol)
  103. require.Error(t, err)
  104. // Tests for networked endpoints (with IP).
  105. if len(endpoint.IP) > 0 && endpoint.Protocol != p2p.MemoryProtocol {
  106. for _, tc := range ipTestCases {
  107. tc := tc
  108. t.Run(tc.ip.String(), func(t *testing.T) {
  109. e := endpoint
  110. e.IP = tc.ip
  111. conn, err := a.Dial(ctx, e)
  112. if tc.ok {
  113. require.NoError(t, conn.Close())
  114. require.NoError(t, err)
  115. } else {
  116. require.Error(t, err, "endpoint=%s", e)
  117. }
  118. })
  119. }
  120. // Non-networked endpoints should error.
  121. noIP := endpoint
  122. noIP.IP = nil
  123. noIP.Port = 0
  124. noIP.Path = "foo"
  125. _, err := a.Dial(ctx, noIP)
  126. require.Error(t, err)
  127. } else {
  128. // Tests for non-networked endpoints (no IP).
  129. noPath := endpoint
  130. noPath.Path = ""
  131. _, err = a.Dial(ctx, noPath)
  132. require.Error(t, err)
  133. }
  134. })
  135. }
  136. func TestTransport_Dial(t *testing.T) {
  137. ctx, cancel := context.WithCancel(context.Background())
  138. defer cancel()
  139. // Most just tests dial failures, happy path is tested widely elsewhere.
  140. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  141. a := makeTransport(t)
  142. b := makeTransport(t)
  143. require.NotEmpty(t, a.Endpoints())
  144. require.NotEmpty(t, b.Endpoints())
  145. aEndpoint := a.Endpoints()[0]
  146. bEndpoint := b.Endpoints()[0]
  147. // Context cancellation should error. We can't test timeouts since we'd
  148. // need a non-responsive endpoint.
  149. cancelCtx, cancel := context.WithCancel(ctx)
  150. cancel()
  151. _, err := a.Dial(cancelCtx, bEndpoint)
  152. require.Error(t, err)
  153. // Unavailable endpoint should error.
  154. err = b.Close()
  155. require.NoError(t, err)
  156. _, err = a.Dial(ctx, bEndpoint)
  157. require.Error(t, err)
  158. // Dialing from a closed transport should still work.
  159. errCh := make(chan error, 1)
  160. go func() {
  161. conn, err := a.Accept(ctx)
  162. if err == nil {
  163. _ = conn.Close()
  164. }
  165. errCh <- err
  166. }()
  167. conn, err := b.Dial(ctx, aEndpoint)
  168. require.NoError(t, err)
  169. require.NoError(t, conn.Close())
  170. require.NoError(t, <-errCh)
  171. })
  172. }
  173. func TestTransport_Endpoints(t *testing.T) {
  174. ctx, cancel := context.WithCancel(context.Background())
  175. defer cancel()
  176. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  177. a := makeTransport(t)
  178. b := makeTransport(t)
  179. // Both transports return valid and different endpoints.
  180. aEndpoints := a.Endpoints()
  181. bEndpoints := b.Endpoints()
  182. require.NotEmpty(t, aEndpoints)
  183. require.NotEmpty(t, bEndpoints)
  184. require.NotEqual(t, aEndpoints, bEndpoints)
  185. for _, endpoint := range append(aEndpoints, bEndpoints...) {
  186. err := endpoint.Validate()
  187. require.NoError(t, err, "invalid endpoint %q", endpoint)
  188. }
  189. // When closed, the transport should no longer return any endpoints.
  190. err := a.Close()
  191. require.NoError(t, err)
  192. require.Empty(t, a.Endpoints())
  193. require.NotEmpty(t, b.Endpoints())
  194. })
  195. }
  196. func TestTransport_Protocols(t *testing.T) {
  197. ctx, cancel := context.WithCancel(context.Background())
  198. defer cancel()
  199. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  200. a := makeTransport(t)
  201. protocols := a.Protocols()
  202. endpoints := a.Endpoints()
  203. require.NotEmpty(t, protocols)
  204. require.NotEmpty(t, endpoints)
  205. for _, endpoint := range endpoints {
  206. require.Contains(t, protocols, endpoint.Protocol)
  207. }
  208. })
  209. }
  210. func TestTransport_String(t *testing.T) {
  211. ctx, cancel := context.WithCancel(context.Background())
  212. defer cancel()
  213. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  214. a := makeTransport(t)
  215. require.NotEmpty(t, a.String())
  216. })
  217. }
  218. func TestConnection_Handshake(t *testing.T) {
  219. ctx, cancel := context.WithCancel(context.Background())
  220. defer cancel()
  221. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  222. a := makeTransport(t)
  223. b := makeTransport(t)
  224. ab, ba := dialAccept(ctx, t, a, b)
  225. // A handshake should pass the given keys and NodeInfo.
  226. aKey := ed25519.GenPrivKey()
  227. aInfo := types.NodeInfo{
  228. NodeID: types.NodeIDFromPubKey(aKey.PubKey()),
  229. ProtocolVersion: types.ProtocolVersion{
  230. P2P: 1,
  231. Block: 2,
  232. App: 3,
  233. },
  234. ListenAddr: "listenaddr",
  235. Network: "network",
  236. Version: "1.2.3",
  237. Channels: bytes.HexBytes([]byte{0xf0, 0x0f}),
  238. Moniker: "moniker",
  239. Other: types.NodeInfoOther{
  240. TxIndex: "txindex",
  241. RPCAddress: "rpc.domain.com",
  242. },
  243. }
  244. bKey := ed25519.GenPrivKey()
  245. bInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(bKey.PubKey())}
  246. errCh := make(chan error, 1)
  247. go func() {
  248. // Must use assert due to goroutine.
  249. peerInfo, peerKey, err := ba.Handshake(ctx, bInfo, bKey)
  250. if err == nil {
  251. assert.Equal(t, aInfo, peerInfo)
  252. assert.Equal(t, aKey.PubKey(), peerKey)
  253. }
  254. select {
  255. case errCh <- err:
  256. case <-ctx.Done():
  257. }
  258. }()
  259. peerInfo, peerKey, err := ab.Handshake(ctx, aInfo, aKey)
  260. require.NoError(t, err)
  261. require.Equal(t, bInfo, peerInfo)
  262. require.Equal(t, bKey.PubKey(), peerKey)
  263. require.NoError(t, <-errCh)
  264. })
  265. }
  266. func TestConnection_HandshakeCancel(t *testing.T) {
  267. ctx, cancel := context.WithCancel(context.Background())
  268. defer cancel()
  269. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  270. a := makeTransport(t)
  271. b := makeTransport(t)
  272. // Handshake should error on context cancellation.
  273. ab, ba := dialAccept(ctx, t, a, b)
  274. timeoutCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
  275. cancel()
  276. _, _, err := ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey())
  277. require.Error(t, err)
  278. require.Equal(t, context.Canceled, err)
  279. _ = ab.Close()
  280. _ = ba.Close()
  281. // Handshake should error on context timeout.
  282. ab, ba = dialAccept(ctx, t, a, b)
  283. timeoutCtx, cancel = context.WithTimeout(ctx, 200*time.Millisecond)
  284. defer cancel()
  285. _, _, err = ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey())
  286. require.Error(t, err)
  287. require.Equal(t, context.DeadlineExceeded, err)
  288. _ = ab.Close()
  289. _ = ba.Close()
  290. })
  291. }
  292. func TestConnection_FlushClose(t *testing.T) {
  293. ctx, cancel := context.WithCancel(context.Background())
  294. defer cancel()
  295. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  296. a := makeTransport(t)
  297. b := makeTransport(t)
  298. ab, _ := dialAcceptHandshake(ctx, t, a, b)
  299. err := ab.Close()
  300. require.NoError(t, err)
  301. _, _, err = ab.ReceiveMessage(ctx)
  302. require.Error(t, err)
  303. require.Equal(t, io.EOF, err)
  304. err = ab.SendMessage(ctx, chID, []byte("closed"))
  305. require.Error(t, err)
  306. })
  307. }
  308. func TestConnection_LocalRemoteEndpoint(t *testing.T) {
  309. ctx, cancel := context.WithCancel(context.Background())
  310. defer cancel()
  311. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  312. a := makeTransport(t)
  313. b := makeTransport(t)
  314. ab, ba := dialAcceptHandshake(ctx, t, a, b)
  315. // Local and remote connection endpoints correspond to each other.
  316. require.NotEmpty(t, ab.LocalEndpoint())
  317. require.NotEmpty(t, ba.LocalEndpoint())
  318. require.Equal(t, ab.LocalEndpoint(), ba.RemoteEndpoint())
  319. require.Equal(t, ab.RemoteEndpoint(), ba.LocalEndpoint())
  320. })
  321. }
  322. func TestConnection_SendReceive(t *testing.T) {
  323. ctx, cancel := context.WithCancel(context.Background())
  324. defer cancel()
  325. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  326. a := makeTransport(t)
  327. b := makeTransport(t)
  328. ab, ba := dialAcceptHandshake(ctx, t, a, b)
  329. // Can send and receive a to b.
  330. err := ab.SendMessage(ctx, chID, []byte("foo"))
  331. require.NoError(t, err)
  332. ch, msg, err := ba.ReceiveMessage(ctx)
  333. require.NoError(t, err)
  334. require.Equal(t, []byte("foo"), msg)
  335. require.Equal(t, chID, ch)
  336. // Can send and receive b to a.
  337. err = ba.SendMessage(ctx, chID, []byte("bar"))
  338. require.NoError(t, err)
  339. _, msg, err = ab.ReceiveMessage(ctx)
  340. require.NoError(t, err)
  341. require.Equal(t, []byte("bar"), msg)
  342. // Connections should still be active after closing the transports.
  343. err = a.Close()
  344. require.NoError(t, err)
  345. err = b.Close()
  346. require.NoError(t, err)
  347. err = ab.SendMessage(ctx, chID, []byte("still here"))
  348. require.NoError(t, err)
  349. ch, msg, err = ba.ReceiveMessage(ctx)
  350. require.NoError(t, err)
  351. require.Equal(t, chID, ch)
  352. require.Equal(t, []byte("still here"), msg)
  353. // Close one side of the connection. Both sides should then error
  354. // with io.EOF when trying to send or receive.
  355. err = ba.Close()
  356. require.NoError(t, err)
  357. _, _, err = ab.ReceiveMessage(ctx)
  358. require.Error(t, err)
  359. require.Equal(t, io.EOF, err)
  360. err = ab.SendMessage(ctx, chID, []byte("closed"))
  361. require.Error(t, err)
  362. require.Equal(t, io.EOF, err)
  363. _, _, err = ba.ReceiveMessage(ctx)
  364. require.Error(t, err)
  365. require.Equal(t, io.EOF, err)
  366. err = ba.SendMessage(ctx, chID, []byte("closed"))
  367. require.Error(t, err)
  368. })
  369. }
  370. func TestConnection_String(t *testing.T) {
  371. ctx, cancel := context.WithCancel(context.Background())
  372. defer cancel()
  373. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  374. a := makeTransport(t)
  375. b := makeTransport(t)
  376. ab, _ := dialAccept(ctx, t, a, b)
  377. require.NotEmpty(t, ab.String())
  378. })
  379. }
  380. func TestEndpoint_NodeAddress(t *testing.T) {
  381. var (
  382. ip4 = []byte{1, 2, 3, 4}
  383. ip4in6 = net.IPv4(1, 2, 3, 4)
  384. ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
  385. id = types.NodeID("00112233445566778899aabbccddeeff00112233")
  386. )
  387. testcases := []struct {
  388. endpoint p2p.Endpoint
  389. expect p2p.NodeAddress
  390. }{
  391. // Valid endpoints.
  392. {
  393. p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"},
  394. p2p.NodeAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"},
  395. },
  396. {
  397. p2p.Endpoint{Protocol: "tcp", IP: ip4in6, Port: 8080, Path: "path"},
  398. p2p.NodeAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"},
  399. },
  400. {
  401. p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "path"},
  402. p2p.NodeAddress{Protocol: "tcp", Hostname: "b10c::1", Port: 8080, Path: "path"},
  403. },
  404. {
  405. p2p.Endpoint{Protocol: "memory", Path: "foo"},
  406. p2p.NodeAddress{Protocol: "memory", Path: "foo"},
  407. },
  408. {
  409. p2p.Endpoint{Protocol: "memory", Path: string(id)},
  410. p2p.NodeAddress{Protocol: "memory", Path: string(id)},
  411. },
  412. // Partial (invalid) endpoints.
  413. {p2p.Endpoint{}, p2p.NodeAddress{}},
  414. {p2p.Endpoint{Protocol: "tcp"}, p2p.NodeAddress{Protocol: "tcp"}},
  415. {p2p.Endpoint{IP: net.IPv4(1, 2, 3, 4)}, p2p.NodeAddress{Hostname: "1.2.3.4"}},
  416. {p2p.Endpoint{Port: 8080}, p2p.NodeAddress{}},
  417. {p2p.Endpoint{Path: "path"}, p2p.NodeAddress{Path: "path"}},
  418. }
  419. for _, tc := range testcases {
  420. tc := tc
  421. t.Run(tc.endpoint.String(), func(t *testing.T) {
  422. // Without NodeID.
  423. expect := tc.expect
  424. require.Equal(t, expect, tc.endpoint.NodeAddress(""))
  425. // With NodeID.
  426. expect.NodeID = id
  427. require.Equal(t, expect, tc.endpoint.NodeAddress(expect.NodeID))
  428. })
  429. }
  430. }
  431. func TestEndpoint_String(t *testing.T) {
  432. var (
  433. ip4 = []byte{1, 2, 3, 4}
  434. ip4in6 = net.IPv4(1, 2, 3, 4)
  435. ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
  436. nodeID = types.NodeID("00112233445566778899aabbccddeeff00112233")
  437. )
  438. testcases := []struct {
  439. endpoint p2p.Endpoint
  440. expect string
  441. }{
  442. // Non-networked endpoints.
  443. {p2p.Endpoint{Protocol: "memory", Path: string(nodeID)}, "memory:" + string(nodeID)},
  444. {p2p.Endpoint{Protocol: "file", Path: "foo"}, "file:///foo"},
  445. {p2p.Endpoint{Protocol: "file", Path: "👋"}, "file:///%F0%9F%91%8B"},
  446. // IPv4 endpoints.
  447. {p2p.Endpoint{Protocol: "tcp", IP: ip4}, "tcp://1.2.3.4"},
  448. {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, "tcp://1.2.3.4"},
  449. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080}, "tcp://1.2.3.4:8080"},
  450. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "/path"}, "tcp://1.2.3.4:8080/path"},
  451. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Path: "path/👋"}, "tcp://1.2.3.4/path/%F0%9F%91%8B"},
  452. // IPv6 endpoints.
  453. {p2p.Endpoint{Protocol: "tcp", IP: ip6}, "tcp://b10c::1"},
  454. {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080}, "tcp://[b10c::1]:8080"},
  455. {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "/path"}, "tcp://[b10c::1]:8080/path"},
  456. {p2p.Endpoint{Protocol: "tcp", IP: ip6, Path: "path/👋"}, "tcp://b10c::1/path/%F0%9F%91%8B"},
  457. // Partial (invalid) endpoints.
  458. {p2p.Endpoint{}, ""},
  459. {p2p.Endpoint{Protocol: "tcp"}, "tcp:"},
  460. {p2p.Endpoint{IP: []byte{1, 2, 3, 4}}, "1.2.3.4"},
  461. {p2p.Endpoint{IP: []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}}, "b10c::1"},
  462. {p2p.Endpoint{Port: 8080}, ""},
  463. {p2p.Endpoint{Path: "foo"}, "/foo"},
  464. }
  465. for _, tc := range testcases {
  466. tc := tc
  467. t.Run(tc.expect, func(t *testing.T) {
  468. require.Equal(t, tc.expect, tc.endpoint.String())
  469. })
  470. }
  471. }
  472. func TestEndpoint_Validate(t *testing.T) {
  473. var (
  474. ip4 = []byte{1, 2, 3, 4}
  475. ip4in6 = net.IPv4(1, 2, 3, 4)
  476. ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
  477. )
  478. testcases := []struct {
  479. endpoint p2p.Endpoint
  480. expectValid bool
  481. }{
  482. // Valid endpoints.
  483. {p2p.Endpoint{Protocol: "tcp", IP: ip4}, true},
  484. {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, true},
  485. {p2p.Endpoint{Protocol: "tcp", IP: ip6}, true},
  486. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8008}, true},
  487. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"}, true},
  488. {p2p.Endpoint{Protocol: "memory", Path: "path"}, true},
  489. // Invalid endpoints.
  490. {p2p.Endpoint{}, false},
  491. {p2p.Endpoint{IP: ip4}, false},
  492. {p2p.Endpoint{Protocol: "tcp"}, false},
  493. {p2p.Endpoint{Protocol: "tcp", IP: []byte{1, 2, 3}}, false},
  494. {p2p.Endpoint{Protocol: "tcp", Port: 8080, Path: "path"}, false},
  495. }
  496. for _, tc := range testcases {
  497. tc := tc
  498. t.Run(tc.endpoint.String(), func(t *testing.T) {
  499. err := tc.endpoint.Validate()
  500. if tc.expectValid {
  501. require.NoError(t, err)
  502. } else {
  503. require.Error(t, err)
  504. }
  505. })
  506. }
  507. }
  508. // dialAccept is a helper that dials b from a and returns both sides of the
  509. // connection.
  510. func dialAccept(ctx context.Context, t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) {
  511. t.Helper()
  512. endpoints := b.Endpoints()
  513. require.NotEmpty(t, endpoints, "peer not listening on any endpoints")
  514. ctx, cancel := context.WithTimeout(ctx, time.Second)
  515. defer cancel()
  516. acceptCh := make(chan p2p.Connection, 1)
  517. errCh := make(chan error, 1)
  518. go func() {
  519. conn, err := b.Accept(ctx)
  520. errCh <- err
  521. acceptCh <- conn
  522. }()
  523. dialConn, err := a.Dial(ctx, endpoints[0])
  524. require.NoError(t, err)
  525. acceptConn := <-acceptCh
  526. require.NoError(t, <-errCh)
  527. t.Cleanup(func() {
  528. _ = dialConn.Close()
  529. _ = acceptConn.Close()
  530. })
  531. return dialConn, acceptConn
  532. }
  533. // dialAcceptHandshake is a helper that dials and handshakes b from a and
  534. // returns both sides of the connection.
  535. func dialAcceptHandshake(ctx context.Context, t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) {
  536. t.Helper()
  537. ab, ba := dialAccept(ctx, t, a, b)
  538. errCh := make(chan error, 1)
  539. go func() {
  540. privKey := ed25519.GenPrivKey()
  541. nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())}
  542. _, _, err := ba.Handshake(ctx, nodeInfo, privKey)
  543. errCh <- err
  544. }()
  545. privKey := ed25519.GenPrivKey()
  546. nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())}
  547. _, _, err := ab.Handshake(ctx, nodeInfo, privKey)
  548. require.NoError(t, err)
  549. timer := time.NewTimer(2 * time.Second)
  550. defer timer.Stop()
  551. select {
  552. case err := <-errCh:
  553. require.NoError(t, err)
  554. case <-timer.C:
  555. require.Fail(t, "handshake timed out")
  556. }
  557. return ab, ba
  558. }