You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

680 lines
18 KiB

  1. package conn
  2. import (
  3. "context"
  4. "encoding/hex"
  5. "io"
  6. "net"
  7. "sync"
  8. "testing"
  9. "time"
  10. "github.com/fortytw2/leaktest"
  11. "github.com/gogo/protobuf/proto"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. "github.com/tendermint/tendermint/internal/libs/protoio"
  15. "github.com/tendermint/tendermint/libs/log"
  16. "github.com/tendermint/tendermint/libs/service"
  17. tmp2p "github.com/tendermint/tendermint/proto/tendermint/p2p"
  18. "github.com/tendermint/tendermint/proto/tendermint/types"
  19. )
  20. const maxPingPongPacketSize = 1024 // bytes
  21. func createTestMConnection(logger log.Logger, conn net.Conn) *MConnection {
  22. return createMConnectionWithCallbacks(logger, conn,
  23. // onRecieve
  24. func(ctx context.Context, chID ChannelID, msgBytes []byte) {
  25. },
  26. // onError
  27. func(ctx context.Context, r interface{}) {
  28. })
  29. }
  30. func createMConnectionWithCallbacks(
  31. logger log.Logger,
  32. conn net.Conn,
  33. onReceive func(ctx context.Context, chID ChannelID, msgBytes []byte),
  34. onError func(ctx context.Context, r interface{}),
  35. ) *MConnection {
  36. cfg := DefaultMConnConfig()
  37. cfg.PingInterval = 250 * time.Millisecond
  38. cfg.PongTimeout = 500 * time.Millisecond
  39. chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
  40. c := NewMConnection(logger, conn, chDescs, onReceive, onError, cfg)
  41. return c
  42. }
  43. func TestMConnectionSendFlushStop(t *testing.T) {
  44. server, client := NetPipe()
  45. t.Cleanup(closeAll(t, client, server))
  46. ctx, cancel := context.WithCancel(context.Background())
  47. defer cancel()
  48. clientConn := createTestMConnection(log.NewNopLogger(), client)
  49. err := clientConn.Start(ctx)
  50. require.NoError(t, err)
  51. t.Cleanup(waitAll(clientConn))
  52. msg := []byte("abc")
  53. assert.True(t, clientConn.Send(0x01, msg))
  54. msgLength := 14
  55. // start the reader in a new routine, so we can flush
  56. errCh := make(chan error)
  57. go func() {
  58. msgB := make([]byte, msgLength)
  59. _, err := server.Read(msgB)
  60. if err != nil {
  61. t.Error(err)
  62. return
  63. }
  64. errCh <- err
  65. }()
  66. timer := time.NewTimer(3 * time.Second)
  67. select {
  68. case <-errCh:
  69. case <-timer.C:
  70. t.Error("timed out waiting for msgs to be read")
  71. }
  72. }
  73. func TestMConnectionSend(t *testing.T) {
  74. server, client := NetPipe()
  75. t.Cleanup(closeAll(t, client, server))
  76. ctx, cancel := context.WithCancel(context.Background())
  77. defer cancel()
  78. mconn := createTestMConnection(log.NewNopLogger(), client)
  79. err := mconn.Start(ctx)
  80. require.NoError(t, err)
  81. t.Cleanup(waitAll(mconn))
  82. msg := []byte("Ant-Man")
  83. assert.True(t, mconn.Send(0x01, msg))
  84. // Note: subsequent Send/TrySend calls could pass because we are reading from
  85. // the send queue in a separate goroutine.
  86. _, err = server.Read(make([]byte, len(msg)))
  87. if err != nil {
  88. t.Error(err)
  89. }
  90. msg = []byte("Spider-Man")
  91. assert.True(t, mconn.Send(0x01, msg))
  92. _, err = server.Read(make([]byte, len(msg)))
  93. if err != nil {
  94. t.Error(err)
  95. }
  96. assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown")
  97. }
  98. func TestMConnectionReceive(t *testing.T) {
  99. server, client := NetPipe()
  100. t.Cleanup(closeAll(t, client, server))
  101. receivedCh := make(chan []byte)
  102. errorsCh := make(chan interface{})
  103. onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
  104. select {
  105. case receivedCh <- msgBytes:
  106. case <-ctx.Done():
  107. }
  108. }
  109. onError := func(ctx context.Context, r interface{}) {
  110. select {
  111. case errorsCh <- r:
  112. case <-ctx.Done():
  113. }
  114. }
  115. logger := log.NewNopLogger()
  116. ctx, cancel := context.WithCancel(context.Background())
  117. defer cancel()
  118. mconn1 := createMConnectionWithCallbacks(logger, client, onReceive, onError)
  119. err := mconn1.Start(ctx)
  120. require.NoError(t, err)
  121. t.Cleanup(waitAll(mconn1))
  122. mconn2 := createTestMConnection(logger, server)
  123. err = mconn2.Start(ctx)
  124. require.NoError(t, err)
  125. t.Cleanup(waitAll(mconn2))
  126. msg := []byte("Cyclops")
  127. assert.True(t, mconn2.Send(0x01, msg))
  128. select {
  129. case receivedBytes := <-receivedCh:
  130. assert.Equal(t, msg, receivedBytes)
  131. case err := <-errorsCh:
  132. t.Fatalf("Expected %s, got %+v", msg, err)
  133. case <-time.After(500 * time.Millisecond):
  134. t.Fatalf("Did not receive %s message in 500ms", msg)
  135. }
  136. }
  137. func TestMConnectionWillEventuallyTimeout(t *testing.T) {
  138. server, client := net.Pipe()
  139. t.Cleanup(closeAll(t, client, server))
  140. ctx, cancel := context.WithCancel(context.Background())
  141. defer cancel()
  142. mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, nil, nil)
  143. err := mconn.Start(ctx)
  144. require.NoError(t, err)
  145. t.Cleanup(waitAll(mconn))
  146. require.True(t, mconn.IsRunning())
  147. go func() {
  148. // read the send buffer so that the send receive
  149. // doesn't get blocked.
  150. ticker := time.NewTicker(10 * time.Millisecond)
  151. defer ticker.Stop()
  152. for {
  153. select {
  154. case <-ticker.C:
  155. _, _ = io.ReadAll(server)
  156. case <-ctx.Done():
  157. return
  158. }
  159. }
  160. }()
  161. // wait for the send routine to die because it doesn't
  162. select {
  163. case <-mconn.doneSendRoutine:
  164. require.True(t, time.Since(mconn.getLastMessageAt()) > mconn.config.PongTimeout,
  165. "the connection state reflects that we've passed the pong timeout")
  166. // since we hit the timeout, things should be shutdown
  167. require.False(t, mconn.IsRunning())
  168. case <-time.After(2 * mconn.config.PongTimeout):
  169. t.Fatal("connection did not hit timeout", mconn.config.PongTimeout)
  170. }
  171. }
  172. func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
  173. server, client := net.Pipe()
  174. t.Cleanup(closeAll(t, client, server))
  175. receivedCh := make(chan []byte)
  176. errorsCh := make(chan interface{})
  177. onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
  178. select {
  179. case receivedCh <- msgBytes:
  180. case <-ctx.Done():
  181. }
  182. }
  183. onError := func(ctx context.Context, r interface{}) {
  184. select {
  185. case errorsCh <- r:
  186. case <-ctx.Done():
  187. }
  188. }
  189. ctx, cancel := context.WithCancel(context.Background())
  190. defer cancel()
  191. mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError)
  192. err := mconn.Start(ctx)
  193. require.NoError(t, err)
  194. t.Cleanup(waitAll(mconn))
  195. // sending 3 pongs in a row (abuse)
  196. protoWriter := protoio.NewDelimitedWriter(server)
  197. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  198. require.NoError(t, err)
  199. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  200. require.NoError(t, err)
  201. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  202. require.NoError(t, err)
  203. // read ping (one byte)
  204. var packet tmp2p.Packet
  205. _, err = protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
  206. require.NoError(t, err)
  207. // respond with pong
  208. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  209. require.NoError(t, err)
  210. pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
  211. select {
  212. case msgBytes := <-receivedCh:
  213. t.Fatalf("Expected no data, but got %v", msgBytes)
  214. case err := <-errorsCh:
  215. t.Fatalf("Expected no error, but got %v", err)
  216. case <-time.After(pongTimerExpired):
  217. assert.True(t, mconn.IsRunning())
  218. }
  219. }
  220. func TestMConnectionMultiplePings(t *testing.T) {
  221. server, client := net.Pipe()
  222. t.Cleanup(closeAll(t, client, server))
  223. receivedCh := make(chan []byte)
  224. errorsCh := make(chan interface{})
  225. onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
  226. select {
  227. case receivedCh <- msgBytes:
  228. case <-ctx.Done():
  229. }
  230. }
  231. onError := func(ctx context.Context, r interface{}) {
  232. select {
  233. case errorsCh <- r:
  234. case <-ctx.Done():
  235. }
  236. }
  237. ctx, cancel := context.WithCancel(context.Background())
  238. defer cancel()
  239. mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError)
  240. err := mconn.Start(ctx)
  241. require.NoError(t, err)
  242. t.Cleanup(waitAll(mconn))
  243. // sending 3 pings in a row (abuse)
  244. // see https://github.com/tendermint/tendermint/issues/1190
  245. protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
  246. protoWriter := protoio.NewDelimitedWriter(server)
  247. var pkt tmp2p.Packet
  248. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  249. require.NoError(t, err)
  250. _, err = protoReader.ReadMsg(&pkt)
  251. require.NoError(t, err)
  252. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  253. require.NoError(t, err)
  254. _, err = protoReader.ReadMsg(&pkt)
  255. require.NoError(t, err)
  256. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  257. require.NoError(t, err)
  258. _, err = protoReader.ReadMsg(&pkt)
  259. require.NoError(t, err)
  260. assert.True(t, mconn.IsRunning())
  261. }
  262. func TestMConnectionPingPongs(t *testing.T) {
  263. // check that we are not leaking any go-routines
  264. t.Cleanup(leaktest.CheckTimeout(t, 10*time.Second))
  265. server, client := net.Pipe()
  266. t.Cleanup(closeAll(t, client, server))
  267. receivedCh := make(chan []byte)
  268. errorsCh := make(chan interface{})
  269. onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
  270. select {
  271. case receivedCh <- msgBytes:
  272. case <-ctx.Done():
  273. }
  274. }
  275. onError := func(ctx context.Context, r interface{}) {
  276. select {
  277. case errorsCh <- r:
  278. case <-ctx.Done():
  279. }
  280. }
  281. ctx, cancel := context.WithCancel(context.Background())
  282. defer cancel()
  283. mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError)
  284. err := mconn.Start(ctx)
  285. require.NoError(t, err)
  286. t.Cleanup(waitAll(mconn))
  287. protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
  288. protoWriter := protoio.NewDelimitedWriter(server)
  289. var pkt tmp2p.PacketPing
  290. // read ping
  291. _, err = protoReader.ReadMsg(&pkt)
  292. require.NoError(t, err)
  293. // respond with pong
  294. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  295. require.NoError(t, err)
  296. time.Sleep(mconn.config.PingInterval)
  297. // read ping
  298. _, err = protoReader.ReadMsg(&pkt)
  299. require.NoError(t, err)
  300. // respond with pong
  301. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  302. require.NoError(t, err)
  303. pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 4
  304. select {
  305. case msgBytes := <-receivedCh:
  306. t.Fatalf("Expected no data, but got %v", msgBytes)
  307. case err := <-errorsCh:
  308. t.Fatalf("Expected no error, but got %v", err)
  309. case <-time.After(2 * pongTimerExpired):
  310. assert.True(t, mconn.IsRunning())
  311. }
  312. }
  313. func TestMConnectionStopsAndReturnsError(t *testing.T) {
  314. server, client := NetPipe()
  315. t.Cleanup(closeAll(t, client, server))
  316. receivedCh := make(chan []byte)
  317. errorsCh := make(chan interface{})
  318. onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
  319. select {
  320. case receivedCh <- msgBytes:
  321. case <-ctx.Done():
  322. }
  323. }
  324. onError := func(ctx context.Context, r interface{}) {
  325. select {
  326. case errorsCh <- r:
  327. case <-ctx.Done():
  328. }
  329. }
  330. ctx, cancel := context.WithCancel(context.Background())
  331. defer cancel()
  332. mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError)
  333. err := mconn.Start(ctx)
  334. require.NoError(t, err)
  335. t.Cleanup(waitAll(mconn))
  336. if err := client.Close(); err != nil {
  337. t.Error(err)
  338. }
  339. select {
  340. case receivedBytes := <-receivedCh:
  341. t.Fatalf("Expected error, got %v", receivedBytes)
  342. case err := <-errorsCh:
  343. assert.NotNil(t, err)
  344. assert.False(t, mconn.IsRunning())
  345. case <-time.After(500 * time.Millisecond):
  346. t.Fatal("Did not receive error in 500ms")
  347. }
  348. }
  349. func newClientAndServerConnsForReadErrors(
  350. ctx context.Context,
  351. t *testing.T,
  352. chOnErr chan struct{},
  353. ) (*MConnection, *MConnection) {
  354. server, client := NetPipe()
  355. onReceive := func(context.Context, ChannelID, []byte) {}
  356. onError := func(context.Context, interface{}) {}
  357. // create client conn with two channels
  358. chDescs := []*ChannelDescriptor{
  359. {ID: 0x01, Priority: 1, SendQueueCapacity: 1},
  360. {ID: 0x02, Priority: 1, SendQueueCapacity: 1},
  361. }
  362. logger := log.NewNopLogger()
  363. mconnClient := NewMConnection(logger.With("module", "client"), client, chDescs, onReceive, onError, DefaultMConnConfig())
  364. err := mconnClient.Start(ctx)
  365. require.NoError(t, err)
  366. // create server conn with 1 channel
  367. // it fires on chOnErr when there's an error
  368. serverLogger := logger.With("module", "server")
  369. onError = func(ctx context.Context, r interface{}) {
  370. select {
  371. case <-ctx.Done():
  372. case chOnErr <- struct{}{}:
  373. }
  374. }
  375. mconnServer := createMConnectionWithCallbacks(serverLogger, server, onReceive, onError)
  376. err = mconnServer.Start(ctx)
  377. require.NoError(t, err)
  378. return mconnClient, mconnServer
  379. }
  380. func expectSend(ch chan struct{}) bool {
  381. after := time.After(time.Second * 5)
  382. select {
  383. case <-ch:
  384. return true
  385. case <-after:
  386. return false
  387. }
  388. }
  389. func TestMConnectionReadErrorBadEncoding(t *testing.T) {
  390. ctx, cancel := context.WithCancel(context.Background())
  391. defer cancel()
  392. chOnErr := make(chan struct{})
  393. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  394. client := mconnClient.conn
  395. // Write it.
  396. _, err := client.Write([]byte{1, 2, 3, 4, 5})
  397. require.NoError(t, err)
  398. assert.True(t, expectSend(chOnErr), "badly encoded msgPacket")
  399. t.Cleanup(waitAll(mconnClient, mconnServer))
  400. }
  401. func TestMConnectionReadErrorUnknownChannel(t *testing.T) {
  402. ctx, cancel := context.WithCancel(context.Background())
  403. defer cancel()
  404. chOnErr := make(chan struct{})
  405. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  406. msg := []byte("Ant-Man")
  407. // fail to send msg on channel unknown by client
  408. assert.False(t, mconnClient.Send(0x03, msg))
  409. // send msg on channel unknown by the server.
  410. // should cause an error
  411. assert.True(t, mconnClient.Send(0x02, msg))
  412. assert.True(t, expectSend(chOnErr), "unknown channel")
  413. t.Cleanup(waitAll(mconnClient, mconnServer))
  414. }
  415. func TestMConnectionReadErrorLongMessage(t *testing.T) {
  416. chOnErr := make(chan struct{})
  417. chOnRcv := make(chan struct{})
  418. ctx, cancel := context.WithCancel(context.Background())
  419. defer cancel()
  420. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  421. t.Cleanup(waitAll(mconnClient, mconnServer))
  422. mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) {
  423. select {
  424. case <-ctx.Done():
  425. case chOnRcv <- struct{}{}:
  426. }
  427. }
  428. client := mconnClient.conn
  429. protoWriter := protoio.NewDelimitedWriter(client)
  430. // send msg thats just right
  431. var packet = tmp2p.PacketMsg{
  432. ChannelID: 0x01,
  433. EOF: true,
  434. Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize),
  435. }
  436. _, err := protoWriter.WriteMsg(mustWrapPacket(&packet))
  437. require.NoError(t, err)
  438. assert.True(t, expectSend(chOnRcv), "msg just right")
  439. // send msg thats too long
  440. packet = tmp2p.PacketMsg{
  441. ChannelID: 0x01,
  442. EOF: true,
  443. Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize+100),
  444. }
  445. _, err = protoWriter.WriteMsg(mustWrapPacket(&packet))
  446. require.Error(t, err)
  447. assert.True(t, expectSend(chOnErr), "msg too long")
  448. }
  449. func TestMConnectionReadErrorUnknownMsgType(t *testing.T) {
  450. ctx, cancel := context.WithCancel(context.Background())
  451. defer cancel()
  452. chOnErr := make(chan struct{})
  453. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  454. t.Cleanup(waitAll(mconnClient, mconnServer))
  455. // send msg with unknown msg type
  456. _, err := protoio.NewDelimitedWriter(mconnClient.conn).WriteMsg(&types.Header{ChainID: "x"})
  457. require.NoError(t, err)
  458. assert.True(t, expectSend(chOnErr), "unknown msg type")
  459. }
  460. func TestMConnectionTrySend(t *testing.T) {
  461. server, client := NetPipe()
  462. t.Cleanup(closeAll(t, client, server))
  463. ctx, cancel := context.WithCancel(context.Background())
  464. defer cancel()
  465. mconn := createTestMConnection(log.NewNopLogger(), client)
  466. err := mconn.Start(ctx)
  467. require.NoError(t, err)
  468. t.Cleanup(waitAll(mconn))
  469. msg := []byte("Semicolon-Woman")
  470. resultCh := make(chan string, 2)
  471. assert.True(t, mconn.Send(0x01, msg))
  472. _, err = server.Read(make([]byte, len(msg)))
  473. require.NoError(t, err)
  474. assert.True(t, mconn.Send(0x01, msg))
  475. go func() {
  476. mconn.Send(0x01, msg)
  477. resultCh <- "TrySend"
  478. }()
  479. assert.False(t, mconn.Send(0x01, msg))
  480. assert.Equal(t, "TrySend", <-resultCh)
  481. }
  482. func TestConnVectors(t *testing.T) {
  483. testCases := []struct {
  484. testName string
  485. msg proto.Message
  486. expBytes string
  487. }{
  488. {"PacketPing", &tmp2p.PacketPing{}, "0a00"},
  489. {"PacketPong", &tmp2p.PacketPong{}, "1200"},
  490. {"PacketMsg", &tmp2p.PacketMsg{ChannelID: 1, EOF: false, Data: []byte("data transmitted over the wire")}, "1a2208011a1e64617461207472616e736d6974746564206f766572207468652077697265"},
  491. }
  492. for _, tc := range testCases {
  493. tc := tc
  494. pm := mustWrapPacket(tc.msg)
  495. bz, err := pm.Marshal()
  496. require.NoError(t, err, tc.testName)
  497. require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName)
  498. }
  499. }
  500. func TestMConnectionChannelOverflow(t *testing.T) {
  501. chOnErr := make(chan struct{})
  502. chOnRcv := make(chan struct{})
  503. ctx, cancel := context.WithCancel(context.Background())
  504. defer cancel()
  505. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  506. t.Cleanup(waitAll(mconnClient, mconnServer))
  507. mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) {
  508. select {
  509. case <-ctx.Done():
  510. case chOnRcv <- struct{}{}:
  511. }
  512. }
  513. client := mconnClient.conn
  514. protoWriter := protoio.NewDelimitedWriter(client)
  515. var packet = tmp2p.PacketMsg{
  516. ChannelID: 0x01,
  517. EOF: true,
  518. Data: []byte(`42`),
  519. }
  520. _, err := protoWriter.WriteMsg(mustWrapPacket(&packet))
  521. require.NoError(t, err)
  522. assert.True(t, expectSend(chOnRcv))
  523. packet.ChannelID = int32(1025)
  524. _, err = protoWriter.WriteMsg(mustWrapPacket(&packet))
  525. require.NoError(t, err)
  526. assert.False(t, expectSend(chOnRcv))
  527. }
  528. func waitAll(waiters ...service.Service) func() {
  529. return func() {
  530. switch len(waiters) {
  531. case 0:
  532. return
  533. case 1:
  534. waiters[0].Wait()
  535. return
  536. default:
  537. wg := &sync.WaitGroup{}
  538. for _, w := range waiters {
  539. wg.Add(1)
  540. go func(s service.Service) {
  541. defer wg.Done()
  542. s.Wait()
  543. }(w)
  544. }
  545. wg.Wait()
  546. }
  547. }
  548. }
  549. type closer interface {
  550. Close() error
  551. }
  552. func closeAll(t *testing.T, closers ...closer) func() {
  553. return func() {
  554. for _, s := range closers {
  555. if err := s.Close(); err != nil {
  556. t.Log(err)
  557. }
  558. }
  559. }
  560. }