You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

657 lines
18 KiB

  1. package p2p_test
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "testing"
  7. "time"
  8. "github.com/fortytw2/leaktest"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "github.com/tendermint/tendermint/crypto/ed25519"
  12. "github.com/tendermint/tendermint/internal/p2p"
  13. "github.com/tendermint/tendermint/libs/bytes"
  14. "github.com/tendermint/tendermint/types"
  15. )
  16. // transportFactory is used to set up transports for tests.
  17. type transportFactory func(t *testing.T) p2p.Transport
  18. // testTransports is a registry of transport factories for withTransports().
  19. var testTransports = map[string]transportFactory{}
  20. // withTransports is a test helper that runs a test against all transports
  21. // registered in testTransports.
  22. func withTransports(ctx context.Context, t *testing.T, tester func(context.Context, *testing.T, transportFactory)) {
  23. t.Helper()
  24. for name, transportFactory := range testTransports {
  25. transportFactory := transportFactory
  26. t.Run(name, func(t *testing.T) {
  27. t.Cleanup(leaktest.Check(t))
  28. tctx, cancel := context.WithCancel(ctx)
  29. defer cancel()
  30. tester(tctx, t, transportFactory)
  31. })
  32. }
  33. }
  34. func TestTransport_AcceptClose(t *testing.T) {
  35. // Just test accept unblock on close, happy path is tested widely elsewhere.
  36. ctx, cancel := context.WithCancel(context.Background())
  37. defer cancel()
  38. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  39. a := makeTransport(t)
  40. opctx, opcancel := context.WithCancel(ctx)
  41. // In-progress Accept should error on concurrent close.
  42. errCh := make(chan error, 1)
  43. go func() {
  44. time.Sleep(200 * time.Millisecond)
  45. opcancel()
  46. errCh <- a.Close()
  47. }()
  48. _, err := a.Accept(opctx)
  49. require.Error(t, err)
  50. require.Equal(t, io.EOF, err)
  51. require.NoError(t, <-errCh)
  52. // Closed transport should return error immediately.
  53. _, err = a.Accept(opctx)
  54. require.Error(t, err)
  55. require.Equal(t, io.EOF, err)
  56. })
  57. }
  58. func TestTransport_DialEndpoints(t *testing.T) {
  59. ipTestCases := []struct {
  60. ip net.IP
  61. ok bool
  62. }{
  63. {net.IPv4zero, true},
  64. {net.IPv6zero, true},
  65. {nil, false},
  66. {net.IPv4bcast, false},
  67. {net.IPv4allsys, false},
  68. {[]byte{1, 2, 3}, false},
  69. {[]byte{1, 2, 3, 4, 5}, false},
  70. }
  71. ctx, cancel := context.WithCancel(context.Background())
  72. defer cancel()
  73. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  74. a := makeTransport(t)
  75. endpoints := a.Endpoints()
  76. require.NotEmpty(t, endpoints)
  77. endpoint := endpoints[0]
  78. // Spawn a goroutine to simply accept any connections until closed.
  79. go func() {
  80. for {
  81. conn, err := a.Accept(ctx)
  82. if err != nil {
  83. return
  84. }
  85. _ = conn.Close()
  86. }
  87. }()
  88. // Dialing self should work.
  89. conn, err := a.Dial(ctx, endpoint)
  90. require.NoError(t, err)
  91. require.NoError(t, conn.Close())
  92. // Dialing empty endpoint should error.
  93. _, err = a.Dial(ctx, p2p.Endpoint{})
  94. require.Error(t, err)
  95. // Dialing without protocol should error.
  96. noProtocol := endpoint
  97. noProtocol.Protocol = ""
  98. _, err = a.Dial(ctx, noProtocol)
  99. require.Error(t, err)
  100. // Dialing with invalid protocol should error.
  101. fooProtocol := endpoint
  102. fooProtocol.Protocol = "foo"
  103. _, err = a.Dial(ctx, fooProtocol)
  104. require.Error(t, err)
  105. // Tests for networked endpoints (with IP).
  106. if len(endpoint.IP) > 0 && endpoint.Protocol != p2p.MemoryProtocol {
  107. for _, tc := range ipTestCases {
  108. tc := tc
  109. t.Run(tc.ip.String(), func(t *testing.T) {
  110. e := endpoint
  111. e.IP = tc.ip
  112. conn, err := a.Dial(ctx, e)
  113. if tc.ok {
  114. require.NoError(t, conn.Close())
  115. require.NoError(t, err)
  116. } else {
  117. require.Error(t, err, "endpoint=%s", e)
  118. }
  119. })
  120. }
  121. // Non-networked endpoints should error.
  122. noIP := endpoint
  123. noIP.IP = nil
  124. noIP.Port = 0
  125. noIP.Path = "foo"
  126. _, err := a.Dial(ctx, noIP)
  127. require.Error(t, err)
  128. } else {
  129. // Tests for non-networked endpoints (no IP).
  130. noPath := endpoint
  131. noPath.Path = ""
  132. _, err = a.Dial(ctx, noPath)
  133. require.Error(t, err)
  134. }
  135. })
  136. }
  137. func TestTransport_Dial(t *testing.T) {
  138. ctx, cancel := context.WithCancel(context.Background())
  139. defer cancel()
  140. // Most just tests dial failures, happy path is tested widely elsewhere.
  141. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  142. a := makeTransport(t)
  143. b := makeTransport(t)
  144. require.NotEmpty(t, a.Endpoints())
  145. require.NotEmpty(t, b.Endpoints())
  146. aEndpoint := a.Endpoints()[0]
  147. bEndpoint := b.Endpoints()[0]
  148. // Context cancellation should error. We can't test timeouts since we'd
  149. // need a non-responsive endpoint.
  150. cancelCtx, cancel := context.WithCancel(ctx)
  151. cancel()
  152. _, err := a.Dial(cancelCtx, bEndpoint)
  153. require.Error(t, err)
  154. // Unavailable endpoint should error.
  155. err = b.Close()
  156. require.NoError(t, err)
  157. _, err = a.Dial(ctx, bEndpoint)
  158. require.Error(t, err)
  159. // Dialing from a closed transport should still work.
  160. errCh := make(chan error, 1)
  161. go func() {
  162. conn, err := a.Accept(ctx)
  163. if err == nil {
  164. _ = conn.Close()
  165. }
  166. errCh <- err
  167. }()
  168. conn, err := b.Dial(ctx, aEndpoint)
  169. require.NoError(t, err)
  170. require.NoError(t, conn.Close())
  171. require.NoError(t, <-errCh)
  172. })
  173. }
  174. func TestTransport_Endpoints(t *testing.T) {
  175. ctx, cancel := context.WithCancel(context.Background())
  176. defer cancel()
  177. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  178. a := makeTransport(t)
  179. b := makeTransport(t)
  180. // Both transports return valid and different endpoints.
  181. aEndpoints := a.Endpoints()
  182. bEndpoints := b.Endpoints()
  183. require.NotEmpty(t, aEndpoints)
  184. require.NotEmpty(t, bEndpoints)
  185. require.NotEqual(t, aEndpoints, bEndpoints)
  186. for _, endpoint := range append(aEndpoints, bEndpoints...) {
  187. err := endpoint.Validate()
  188. require.NoError(t, err, "invalid endpoint %q", endpoint)
  189. }
  190. // When closed, the transport should no longer return any endpoints.
  191. err := a.Close()
  192. require.NoError(t, err)
  193. require.Empty(t, a.Endpoints())
  194. require.NotEmpty(t, b.Endpoints())
  195. })
  196. }
  197. func TestTransport_Protocols(t *testing.T) {
  198. ctx, cancel := context.WithCancel(context.Background())
  199. defer cancel()
  200. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  201. a := makeTransport(t)
  202. protocols := a.Protocols()
  203. endpoints := a.Endpoints()
  204. require.NotEmpty(t, protocols)
  205. require.NotEmpty(t, endpoints)
  206. for _, endpoint := range endpoints {
  207. require.Contains(t, protocols, endpoint.Protocol)
  208. }
  209. })
  210. }
  211. func TestTransport_String(t *testing.T) {
  212. ctx, cancel := context.WithCancel(context.Background())
  213. defer cancel()
  214. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  215. a := makeTransport(t)
  216. require.NotEmpty(t, a.String())
  217. })
  218. }
  219. func TestConnection_Handshake(t *testing.T) {
  220. ctx, cancel := context.WithCancel(context.Background())
  221. defer cancel()
  222. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  223. a := makeTransport(t)
  224. b := makeTransport(t)
  225. ab, ba := dialAccept(ctx, t, a, b)
  226. // A handshake should pass the given keys and NodeInfo.
  227. aKey := ed25519.GenPrivKey()
  228. aInfo := types.NodeInfo{
  229. NodeID: types.NodeIDFromPubKey(aKey.PubKey()),
  230. ProtocolVersion: types.ProtocolVersion{
  231. P2P: 1,
  232. Block: 2,
  233. App: 3,
  234. },
  235. ListenAddr: "listenaddr",
  236. Network: "network",
  237. Version: "1.2.3",
  238. Channels: bytes.HexBytes([]byte{0xf0, 0x0f}),
  239. Moniker: "moniker",
  240. Other: types.NodeInfoOther{
  241. TxIndex: "txindex",
  242. RPCAddress: "rpc.domain.com",
  243. },
  244. }
  245. bKey := ed25519.GenPrivKey()
  246. bInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(bKey.PubKey())}
  247. errCh := make(chan error, 1)
  248. go func() {
  249. // Must use assert due to goroutine.
  250. peerInfo, peerKey, err := ba.Handshake(ctx, bInfo, bKey)
  251. if err == nil {
  252. assert.Equal(t, aInfo, peerInfo)
  253. assert.Equal(t, aKey.PubKey(), peerKey)
  254. }
  255. select {
  256. case errCh <- err:
  257. case <-ctx.Done():
  258. }
  259. }()
  260. peerInfo, peerKey, err := ab.Handshake(ctx, aInfo, aKey)
  261. require.NoError(t, err)
  262. require.Equal(t, bInfo, peerInfo)
  263. require.Equal(t, bKey.PubKey(), peerKey)
  264. require.NoError(t, <-errCh)
  265. })
  266. }
  267. func TestConnection_HandshakeCancel(t *testing.T) {
  268. ctx, cancel := context.WithCancel(context.Background())
  269. defer cancel()
  270. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  271. a := makeTransport(t)
  272. b := makeTransport(t)
  273. // Handshake should error on context cancellation.
  274. ab, ba := dialAccept(ctx, t, a, b)
  275. timeoutCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
  276. cancel()
  277. _, _, err := ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey())
  278. require.Error(t, err)
  279. require.Equal(t, context.Canceled, err)
  280. _ = ab.Close()
  281. _ = ba.Close()
  282. // Handshake should error on context timeout.
  283. ab, ba = dialAccept(ctx, t, a, b)
  284. timeoutCtx, cancel = context.WithTimeout(ctx, 200*time.Millisecond)
  285. defer cancel()
  286. _, _, err = ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey())
  287. require.Error(t, err)
  288. require.Equal(t, context.DeadlineExceeded, err)
  289. _ = ab.Close()
  290. _ = ba.Close()
  291. })
  292. }
  293. func TestConnection_FlushClose(t *testing.T) {
  294. ctx, cancel := context.WithCancel(context.Background())
  295. defer cancel()
  296. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  297. a := makeTransport(t)
  298. b := makeTransport(t)
  299. ab, _ := dialAcceptHandshake(ctx, t, a, b)
  300. err := ab.Close()
  301. require.NoError(t, err)
  302. _, _, err = ab.ReceiveMessage(ctx)
  303. require.Error(t, err)
  304. require.Equal(t, io.EOF, err)
  305. err = ab.SendMessage(ctx, chID, []byte("closed"))
  306. require.Error(t, err)
  307. })
  308. }
  309. func TestConnection_LocalRemoteEndpoint(t *testing.T) {
  310. ctx, cancel := context.WithCancel(context.Background())
  311. defer cancel()
  312. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  313. a := makeTransport(t)
  314. b := makeTransport(t)
  315. ab, ba := dialAcceptHandshake(ctx, t, a, b)
  316. // Local and remote connection endpoints correspond to each other.
  317. require.NotEmpty(t, ab.LocalEndpoint())
  318. require.NotEmpty(t, ba.LocalEndpoint())
  319. require.Equal(t, ab.LocalEndpoint(), ba.RemoteEndpoint())
  320. require.Equal(t, ab.RemoteEndpoint(), ba.LocalEndpoint())
  321. })
  322. }
  323. func TestConnection_SendReceive(t *testing.T) {
  324. ctx, cancel := context.WithCancel(context.Background())
  325. defer cancel()
  326. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  327. a := makeTransport(t)
  328. b := makeTransport(t)
  329. ab, ba := dialAcceptHandshake(ctx, t, a, b)
  330. // Can send and receive a to b.
  331. err := ab.SendMessage(ctx, chID, []byte("foo"))
  332. require.NoError(t, err)
  333. ch, msg, err := ba.ReceiveMessage(ctx)
  334. require.NoError(t, err)
  335. require.Equal(t, []byte("foo"), msg)
  336. require.Equal(t, chID, ch)
  337. // Can send and receive b to a.
  338. err = ba.SendMessage(ctx, chID, []byte("bar"))
  339. require.NoError(t, err)
  340. _, msg, err = ab.ReceiveMessage(ctx)
  341. require.NoError(t, err)
  342. require.Equal(t, []byte("bar"), msg)
  343. // Connections should still be active after closing the transports.
  344. err = a.Close()
  345. require.NoError(t, err)
  346. err = b.Close()
  347. require.NoError(t, err)
  348. err = ab.SendMessage(ctx, chID, []byte("still here"))
  349. require.NoError(t, err)
  350. ch, msg, err = ba.ReceiveMessage(ctx)
  351. require.NoError(t, err)
  352. require.Equal(t, chID, ch)
  353. require.Equal(t, []byte("still here"), msg)
  354. // Close one side of the connection. Both sides should then error
  355. // with io.EOF when trying to send or receive.
  356. err = ba.Close()
  357. require.NoError(t, err)
  358. _, _, err = ab.ReceiveMessage(ctx)
  359. require.Error(t, err)
  360. require.Equal(t, io.EOF, err)
  361. err = ab.SendMessage(ctx, chID, []byte("closed"))
  362. require.Error(t, err)
  363. require.Equal(t, io.EOF, err)
  364. _, _, err = ba.ReceiveMessage(ctx)
  365. require.Error(t, err)
  366. require.Equal(t, io.EOF, err)
  367. err = ba.SendMessage(ctx, chID, []byte("closed"))
  368. require.Error(t, err)
  369. })
  370. }
  371. func TestConnection_String(t *testing.T) {
  372. ctx, cancel := context.WithCancel(context.Background())
  373. defer cancel()
  374. withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
  375. a := makeTransport(t)
  376. b := makeTransport(t)
  377. ab, _ := dialAccept(ctx, t, a, b)
  378. require.NotEmpty(t, ab.String())
  379. })
  380. }
  381. func TestEndpoint_NodeAddress(t *testing.T) {
  382. var (
  383. ip4 = []byte{1, 2, 3, 4}
  384. ip4in6 = net.IPv4(1, 2, 3, 4)
  385. ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
  386. id = types.NodeID("00112233445566778899aabbccddeeff00112233")
  387. )
  388. testcases := []struct {
  389. endpoint p2p.Endpoint
  390. expect p2p.NodeAddress
  391. }{
  392. // Valid endpoints.
  393. {
  394. p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"},
  395. p2p.NodeAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"},
  396. },
  397. {
  398. p2p.Endpoint{Protocol: "tcp", IP: ip4in6, Port: 8080, Path: "path"},
  399. p2p.NodeAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"},
  400. },
  401. {
  402. p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "path"},
  403. p2p.NodeAddress{Protocol: "tcp", Hostname: "b10c::1", Port: 8080, Path: "path"},
  404. },
  405. {
  406. p2p.Endpoint{Protocol: "memory", Path: "foo"},
  407. p2p.NodeAddress{Protocol: "memory", Path: "foo"},
  408. },
  409. {
  410. p2p.Endpoint{Protocol: "memory", Path: string(id)},
  411. p2p.NodeAddress{Protocol: "memory", Path: string(id)},
  412. },
  413. // Partial (invalid) endpoints.
  414. {p2p.Endpoint{}, p2p.NodeAddress{}},
  415. {p2p.Endpoint{Protocol: "tcp"}, p2p.NodeAddress{Protocol: "tcp"}},
  416. {p2p.Endpoint{IP: net.IPv4(1, 2, 3, 4)}, p2p.NodeAddress{Hostname: "1.2.3.4"}},
  417. {p2p.Endpoint{Port: 8080}, p2p.NodeAddress{}},
  418. {p2p.Endpoint{Path: "path"}, p2p.NodeAddress{Path: "path"}},
  419. }
  420. for _, tc := range testcases {
  421. tc := tc
  422. t.Run(tc.endpoint.String(), func(t *testing.T) {
  423. // Without NodeID.
  424. expect := tc.expect
  425. require.Equal(t, expect, tc.endpoint.NodeAddress(""))
  426. // With NodeID.
  427. expect.NodeID = id
  428. require.Equal(t, expect, tc.endpoint.NodeAddress(expect.NodeID))
  429. })
  430. }
  431. }
  432. func TestEndpoint_String(t *testing.T) {
  433. var (
  434. ip4 = []byte{1, 2, 3, 4}
  435. ip4in6 = net.IPv4(1, 2, 3, 4)
  436. ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
  437. nodeID = types.NodeID("00112233445566778899aabbccddeeff00112233")
  438. )
  439. testcases := []struct {
  440. endpoint p2p.Endpoint
  441. expect string
  442. }{
  443. // Non-networked endpoints.
  444. {p2p.Endpoint{Protocol: "memory", Path: string(nodeID)}, "memory:" + string(nodeID)},
  445. {p2p.Endpoint{Protocol: "file", Path: "foo"}, "file:///foo"},
  446. {p2p.Endpoint{Protocol: "file", Path: "👋"}, "file:///%F0%9F%91%8B"},
  447. // IPv4 endpoints.
  448. {p2p.Endpoint{Protocol: "tcp", IP: ip4}, "tcp://1.2.3.4"},
  449. {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, "tcp://1.2.3.4"},
  450. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080}, "tcp://1.2.3.4:8080"},
  451. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "/path"}, "tcp://1.2.3.4:8080/path"},
  452. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Path: "path/👋"}, "tcp://1.2.3.4/path/%F0%9F%91%8B"},
  453. // IPv6 endpoints.
  454. {p2p.Endpoint{Protocol: "tcp", IP: ip6}, "tcp://b10c::1"},
  455. {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080}, "tcp://[b10c::1]:8080"},
  456. {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "/path"}, "tcp://[b10c::1]:8080/path"},
  457. {p2p.Endpoint{Protocol: "tcp", IP: ip6, Path: "path/👋"}, "tcp://b10c::1/path/%F0%9F%91%8B"},
  458. // Partial (invalid) endpoints.
  459. {p2p.Endpoint{}, ""},
  460. {p2p.Endpoint{Protocol: "tcp"}, "tcp:"},
  461. {p2p.Endpoint{IP: []byte{1, 2, 3, 4}}, "1.2.3.4"},
  462. {p2p.Endpoint{IP: []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}}, "b10c::1"},
  463. {p2p.Endpoint{Port: 8080}, ""},
  464. {p2p.Endpoint{Path: "foo"}, "/foo"},
  465. }
  466. for _, tc := range testcases {
  467. tc := tc
  468. t.Run(tc.expect, func(t *testing.T) {
  469. require.Equal(t, tc.expect, tc.endpoint.String())
  470. })
  471. }
  472. }
  473. func TestEndpoint_Validate(t *testing.T) {
  474. var (
  475. ip4 = []byte{1, 2, 3, 4}
  476. ip4in6 = net.IPv4(1, 2, 3, 4)
  477. ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
  478. )
  479. testcases := []struct {
  480. endpoint p2p.Endpoint
  481. expectValid bool
  482. }{
  483. // Valid endpoints.
  484. {p2p.Endpoint{Protocol: "tcp", IP: ip4}, true},
  485. {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, true},
  486. {p2p.Endpoint{Protocol: "tcp", IP: ip6}, true},
  487. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8008}, true},
  488. {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"}, true},
  489. {p2p.Endpoint{Protocol: "memory", Path: "path"}, true},
  490. // Invalid endpoints.
  491. {p2p.Endpoint{}, false},
  492. {p2p.Endpoint{IP: ip4}, false},
  493. {p2p.Endpoint{Protocol: "tcp"}, false},
  494. {p2p.Endpoint{Protocol: "tcp", IP: []byte{1, 2, 3}}, false},
  495. {p2p.Endpoint{Protocol: "tcp", Port: 8080, Path: "path"}, false},
  496. }
  497. for _, tc := range testcases {
  498. tc := tc
  499. t.Run(tc.endpoint.String(), func(t *testing.T) {
  500. err := tc.endpoint.Validate()
  501. if tc.expectValid {
  502. require.NoError(t, err)
  503. } else {
  504. require.Error(t, err)
  505. }
  506. })
  507. }
  508. }
  509. // dialAccept is a helper that dials b from a and returns both sides of the
  510. // connection.
  511. func dialAccept(ctx context.Context, t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) {
  512. t.Helper()
  513. endpoints := b.Endpoints()
  514. require.NotEmpty(t, endpoints, "peer not listening on any endpoints")
  515. ctx, cancel := context.WithTimeout(ctx, time.Second)
  516. defer cancel()
  517. acceptCh := make(chan p2p.Connection, 1)
  518. errCh := make(chan error, 1)
  519. go func() {
  520. conn, err := b.Accept(ctx)
  521. errCh <- err
  522. acceptCh <- conn
  523. }()
  524. dialConn, err := a.Dial(ctx, endpoints[0])
  525. require.NoError(t, err)
  526. acceptConn := <-acceptCh
  527. require.NoError(t, <-errCh)
  528. t.Cleanup(func() {
  529. _ = dialConn.Close()
  530. _ = acceptConn.Close()
  531. })
  532. return dialConn, acceptConn
  533. }
  534. // dialAcceptHandshake is a helper that dials and handshakes b from a and
  535. // returns both sides of the connection.
  536. func dialAcceptHandshake(ctx context.Context, t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) {
  537. t.Helper()
  538. ab, ba := dialAccept(ctx, t, a, b)
  539. errCh := make(chan error, 1)
  540. go func() {
  541. privKey := ed25519.GenPrivKey()
  542. nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())}
  543. _, _, err := ba.Handshake(ctx, nodeInfo, privKey)
  544. errCh <- err
  545. }()
  546. privKey := ed25519.GenPrivKey()
  547. nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())}
  548. _, _, err := ab.Handshake(ctx, nodeInfo, privKey)
  549. require.NoError(t, err)
  550. timer := time.NewTimer(2 * time.Second)
  551. defer timer.Stop()
  552. select {
  553. case err := <-errCh:
  554. require.NoError(t, err)
  555. case <-timer.C:
  556. require.Fail(t, "handshake timed out")
  557. }
  558. return ab, ba
  559. }