You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

652 lines
17 KiB

  1. package conn
  2. import (
  3. "context"
  4. "encoding/hex"
  5. "net"
  6. "sync"
  7. "testing"
  8. "time"
  9. "github.com/fortytw2/leaktest"
  10. "github.com/gogo/protobuf/proto"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. "github.com/tendermint/tendermint/internal/libs/protoio"
  14. "github.com/tendermint/tendermint/libs/log"
  15. "github.com/tendermint/tendermint/libs/service"
  16. tmp2p "github.com/tendermint/tendermint/proto/tendermint/p2p"
  17. "github.com/tendermint/tendermint/proto/tendermint/types"
  18. )
  19. const maxPingPongPacketSize = 1024 // bytes
  20. func createTestMConnection(logger log.Logger, conn net.Conn) *MConnection {
  21. return createMConnectionWithCallbacks(logger, conn,
  22. // onRecieve
  23. func(chID ChannelID, msgBytes []byte) {
  24. },
  25. // onError
  26. func(r interface{}) {
  27. })
  28. }
  29. func createMConnectionWithCallbacks(
  30. logger log.Logger,
  31. conn net.Conn,
  32. onReceive func(chID ChannelID, msgBytes []byte),
  33. onError func(r interface{}),
  34. ) *MConnection {
  35. cfg := DefaultMConnConfig()
  36. cfg.PingInterval = 90 * time.Millisecond
  37. cfg.PongTimeout = 45 * time.Millisecond
  38. chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
  39. c := NewMConnectionWithConfig(logger, conn, chDescs, onReceive, onError, cfg)
  40. return c
  41. }
  42. func TestMConnectionSendFlushStop(t *testing.T) {
  43. server, client := NetPipe()
  44. t.Cleanup(closeAll(t, client, server))
  45. ctx, cancel := context.WithCancel(context.Background())
  46. defer cancel()
  47. clientConn := createTestMConnection(log.TestingLogger(), client)
  48. err := clientConn.Start(ctx)
  49. require.Nil(t, err)
  50. t.Cleanup(waitAll(clientConn))
  51. msg := []byte("abc")
  52. assert.True(t, clientConn.Send(0x01, msg))
  53. msgLength := 14
  54. // start the reader in a new routine, so we can flush
  55. errCh := make(chan error)
  56. go func() {
  57. msgB := make([]byte, msgLength)
  58. _, err := server.Read(msgB)
  59. if err != nil {
  60. t.Error(err)
  61. return
  62. }
  63. errCh <- err
  64. }()
  65. timer := time.NewTimer(3 * time.Second)
  66. select {
  67. case <-errCh:
  68. case <-timer.C:
  69. t.Error("timed out waiting for msgs to be read")
  70. }
  71. }
  72. func TestMConnectionSend(t *testing.T) {
  73. server, client := NetPipe()
  74. t.Cleanup(closeAll(t, client, server))
  75. ctx, cancel := context.WithCancel(context.Background())
  76. defer cancel()
  77. mconn := createTestMConnection(log.TestingLogger(), client)
  78. err := mconn.Start(ctx)
  79. require.Nil(t, err)
  80. t.Cleanup(waitAll(mconn))
  81. msg := []byte("Ant-Man")
  82. assert.True(t, mconn.Send(0x01, msg))
  83. // Note: subsequent Send/TrySend calls could pass because we are reading from
  84. // the send queue in a separate goroutine.
  85. _, err = server.Read(make([]byte, len(msg)))
  86. if err != nil {
  87. t.Error(err)
  88. }
  89. msg = []byte("Spider-Man")
  90. assert.True(t, mconn.Send(0x01, msg))
  91. _, err = server.Read(make([]byte, len(msg)))
  92. if err != nil {
  93. t.Error(err)
  94. }
  95. assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown")
  96. }
  97. func TestMConnectionReceive(t *testing.T) {
  98. server, client := NetPipe()
  99. t.Cleanup(closeAll(t, client, server))
  100. receivedCh := make(chan []byte)
  101. errorsCh := make(chan interface{})
  102. onReceive := func(chID ChannelID, msgBytes []byte) {
  103. receivedCh <- msgBytes
  104. }
  105. onError := func(r interface{}) {
  106. errorsCh <- r
  107. }
  108. logger := log.TestingLogger()
  109. ctx, cancel := context.WithCancel(context.Background())
  110. defer cancel()
  111. mconn1 := createMConnectionWithCallbacks(logger, client, onReceive, onError)
  112. err := mconn1.Start(ctx)
  113. require.Nil(t, err)
  114. t.Cleanup(waitAll(mconn1))
  115. mconn2 := createTestMConnection(logger, server)
  116. err = mconn2.Start(ctx)
  117. require.Nil(t, err)
  118. t.Cleanup(waitAll(mconn2))
  119. msg := []byte("Cyclops")
  120. assert.True(t, mconn2.Send(0x01, msg))
  121. select {
  122. case receivedBytes := <-receivedCh:
  123. assert.Equal(t, msg, receivedBytes)
  124. case err := <-errorsCh:
  125. t.Fatalf("Expected %s, got %+v", msg, err)
  126. case <-time.After(500 * time.Millisecond):
  127. t.Fatalf("Did not receive %s message in 500ms", msg)
  128. }
  129. }
  130. func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
  131. server, client := net.Pipe()
  132. t.Cleanup(closeAll(t, client, server))
  133. receivedCh := make(chan []byte)
  134. errorsCh := make(chan interface{})
  135. onReceive := func(chID ChannelID, msgBytes []byte) {
  136. receivedCh <- msgBytes
  137. }
  138. onError := func(r interface{}) {
  139. errorsCh <- r
  140. }
  141. ctx, cancel := context.WithCancel(context.Background())
  142. defer cancel()
  143. mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
  144. err := mconn.Start(ctx)
  145. require.Nil(t, err)
  146. t.Cleanup(waitAll(mconn))
  147. serverGotPing := make(chan struct{})
  148. go func() {
  149. // read ping
  150. var pkt tmp2p.Packet
  151. _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&pkt)
  152. require.NoError(t, err)
  153. serverGotPing <- struct{}{}
  154. }()
  155. <-serverGotPing
  156. pongTimerExpired := mconn.config.PongTimeout + 200*time.Millisecond
  157. select {
  158. case msgBytes := <-receivedCh:
  159. t.Fatalf("Expected error, but got %v", msgBytes)
  160. case err := <-errorsCh:
  161. assert.NotNil(t, err)
  162. case <-time.After(pongTimerExpired):
  163. t.Fatalf("Expected to receive error after %v", pongTimerExpired)
  164. }
  165. }
  166. func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
  167. server, client := net.Pipe()
  168. t.Cleanup(closeAll(t, client, server))
  169. receivedCh := make(chan []byte)
  170. errorsCh := make(chan interface{})
  171. onReceive := func(chID ChannelID, msgBytes []byte) {
  172. receivedCh <- msgBytes
  173. }
  174. onError := func(r interface{}) {
  175. errorsCh <- r
  176. }
  177. ctx, cancel := context.WithCancel(context.Background())
  178. defer cancel()
  179. mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
  180. err := mconn.Start(ctx)
  181. require.Nil(t, err)
  182. t.Cleanup(waitAll(mconn))
  183. // sending 3 pongs in a row (abuse)
  184. protoWriter := protoio.NewDelimitedWriter(server)
  185. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  186. require.NoError(t, err)
  187. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  188. require.NoError(t, err)
  189. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  190. require.NoError(t, err)
  191. serverGotPing := make(chan struct{})
  192. go func() {
  193. // read ping (one byte)
  194. var packet tmp2p.Packet
  195. _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
  196. require.NoError(t, err)
  197. serverGotPing <- struct{}{}
  198. // respond with pong
  199. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  200. require.NoError(t, err)
  201. }()
  202. <-serverGotPing
  203. pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
  204. select {
  205. case msgBytes := <-receivedCh:
  206. t.Fatalf("Expected no data, but got %v", msgBytes)
  207. case err := <-errorsCh:
  208. t.Fatalf("Expected no error, but got %v", err)
  209. case <-time.After(pongTimerExpired):
  210. assert.True(t, mconn.IsRunning())
  211. }
  212. }
  213. func TestMConnectionMultiplePings(t *testing.T) {
  214. server, client := net.Pipe()
  215. t.Cleanup(closeAll(t, client, server))
  216. receivedCh := make(chan []byte)
  217. errorsCh := make(chan interface{})
  218. onReceive := func(chID ChannelID, msgBytes []byte) {
  219. receivedCh <- msgBytes
  220. }
  221. onError := func(r interface{}) {
  222. errorsCh <- r
  223. }
  224. ctx, cancel := context.WithCancel(context.Background())
  225. defer cancel()
  226. mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
  227. err := mconn.Start(ctx)
  228. require.Nil(t, err)
  229. t.Cleanup(waitAll(mconn))
  230. // sending 3 pings in a row (abuse)
  231. // see https://github.com/tendermint/tendermint/issues/1190
  232. protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
  233. protoWriter := protoio.NewDelimitedWriter(server)
  234. var pkt tmp2p.Packet
  235. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  236. require.NoError(t, err)
  237. _, err = protoReader.ReadMsg(&pkt)
  238. require.NoError(t, err)
  239. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  240. require.NoError(t, err)
  241. _, err = protoReader.ReadMsg(&pkt)
  242. require.NoError(t, err)
  243. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  244. require.NoError(t, err)
  245. _, err = protoReader.ReadMsg(&pkt)
  246. require.NoError(t, err)
  247. assert.True(t, mconn.IsRunning())
  248. }
  249. func TestMConnectionPingPongs(t *testing.T) {
  250. // check that we are not leaking any go-routines
  251. t.Cleanup(leaktest.CheckTimeout(t, 10*time.Second))
  252. server, client := net.Pipe()
  253. t.Cleanup(closeAll(t, client, server))
  254. receivedCh := make(chan []byte)
  255. errorsCh := make(chan interface{})
  256. onReceive := func(chID ChannelID, msgBytes []byte) {
  257. receivedCh <- msgBytes
  258. }
  259. onError := func(r interface{}) {
  260. errorsCh <- r
  261. }
  262. ctx, cancel := context.WithCancel(context.Background())
  263. defer cancel()
  264. mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
  265. err := mconn.Start(ctx)
  266. require.Nil(t, err)
  267. t.Cleanup(waitAll(mconn))
  268. serverGotPing := make(chan struct{})
  269. go func() {
  270. protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
  271. protoWriter := protoio.NewDelimitedWriter(server)
  272. var pkt tmp2p.PacketPing
  273. // read ping
  274. _, err = protoReader.ReadMsg(&pkt)
  275. require.NoError(t, err)
  276. serverGotPing <- struct{}{}
  277. // respond with pong
  278. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  279. require.NoError(t, err)
  280. time.Sleep(mconn.config.PingInterval)
  281. // read ping
  282. _, err = protoReader.ReadMsg(&pkt)
  283. require.NoError(t, err)
  284. serverGotPing <- struct{}{}
  285. // respond with pong
  286. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  287. require.NoError(t, err)
  288. }()
  289. <-serverGotPing
  290. <-serverGotPing
  291. pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2
  292. select {
  293. case msgBytes := <-receivedCh:
  294. t.Fatalf("Expected no data, but got %v", msgBytes)
  295. case err := <-errorsCh:
  296. t.Fatalf("Expected no error, but got %v", err)
  297. case <-time.After(2 * pongTimerExpired):
  298. assert.True(t, mconn.IsRunning())
  299. }
  300. }
  301. func TestMConnectionStopsAndReturnsError(t *testing.T) {
  302. server, client := NetPipe()
  303. t.Cleanup(closeAll(t, client, server))
  304. receivedCh := make(chan []byte)
  305. errorsCh := make(chan interface{})
  306. onReceive := func(chID ChannelID, msgBytes []byte) {
  307. receivedCh <- msgBytes
  308. }
  309. onError := func(r interface{}) {
  310. errorsCh <- r
  311. }
  312. ctx, cancel := context.WithCancel(context.Background())
  313. defer cancel()
  314. mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
  315. err := mconn.Start(ctx)
  316. require.Nil(t, err)
  317. t.Cleanup(waitAll(mconn))
  318. if err := client.Close(); err != nil {
  319. t.Error(err)
  320. }
  321. select {
  322. case receivedBytes := <-receivedCh:
  323. t.Fatalf("Expected error, got %v", receivedBytes)
  324. case err := <-errorsCh:
  325. assert.NotNil(t, err)
  326. assert.False(t, mconn.IsRunning())
  327. case <-time.After(500 * time.Millisecond):
  328. t.Fatal("Did not receive error in 500ms")
  329. }
  330. }
  331. func newClientAndServerConnsForReadErrors(
  332. ctx context.Context,
  333. t *testing.T,
  334. chOnErr chan struct{},
  335. ) (*MConnection, *MConnection) {
  336. server, client := NetPipe()
  337. onReceive := func(chID ChannelID, msgBytes []byte) {}
  338. onError := func(r interface{}) {}
  339. // create client conn with two channels
  340. chDescs := []*ChannelDescriptor{
  341. {ID: 0x01, Priority: 1, SendQueueCapacity: 1},
  342. {ID: 0x02, Priority: 1, SendQueueCapacity: 1},
  343. }
  344. logger := log.TestingLogger()
  345. mconnClient := NewMConnection(logger.With("module", "client"), client, chDescs, onReceive, onError)
  346. err := mconnClient.Start(ctx)
  347. require.Nil(t, err)
  348. // create server conn with 1 channel
  349. // it fires on chOnErr when there's an error
  350. serverLogger := logger.With("module", "server")
  351. onError = func(r interface{}) {
  352. chOnErr <- struct{}{}
  353. }
  354. mconnServer := createMConnectionWithCallbacks(serverLogger, server, onReceive, onError)
  355. err = mconnServer.Start(ctx)
  356. require.Nil(t, err)
  357. return mconnClient, mconnServer
  358. }
  359. func expectSend(ch chan struct{}) bool {
  360. after := time.After(time.Second * 5)
  361. select {
  362. case <-ch:
  363. return true
  364. case <-after:
  365. return false
  366. }
  367. }
  368. func TestMConnectionReadErrorBadEncoding(t *testing.T) {
  369. ctx, cancel := context.WithCancel(context.Background())
  370. defer cancel()
  371. chOnErr := make(chan struct{})
  372. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  373. client := mconnClient.conn
  374. // Write it.
  375. _, err := client.Write([]byte{1, 2, 3, 4, 5})
  376. require.NoError(t, err)
  377. assert.True(t, expectSend(chOnErr), "badly encoded msgPacket")
  378. t.Cleanup(waitAll(mconnClient, mconnServer))
  379. }
  380. func TestMConnectionReadErrorUnknownChannel(t *testing.T) {
  381. ctx, cancel := context.WithCancel(context.Background())
  382. defer cancel()
  383. chOnErr := make(chan struct{})
  384. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  385. msg := []byte("Ant-Man")
  386. // fail to send msg on channel unknown by client
  387. assert.False(t, mconnClient.Send(0x03, msg))
  388. // send msg on channel unknown by the server.
  389. // should cause an error
  390. assert.True(t, mconnClient.Send(0x02, msg))
  391. assert.True(t, expectSend(chOnErr), "unknown channel")
  392. t.Cleanup(waitAll(mconnClient, mconnServer))
  393. }
  394. func TestMConnectionReadErrorLongMessage(t *testing.T) {
  395. chOnErr := make(chan struct{})
  396. chOnRcv := make(chan struct{})
  397. ctx, cancel := context.WithCancel(context.Background())
  398. defer cancel()
  399. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  400. t.Cleanup(waitAll(mconnClient, mconnServer))
  401. mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) {
  402. chOnRcv <- struct{}{}
  403. }
  404. client := mconnClient.conn
  405. protoWriter := protoio.NewDelimitedWriter(client)
  406. // send msg thats just right
  407. var packet = tmp2p.PacketMsg{
  408. ChannelID: 0x01,
  409. EOF: true,
  410. Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize),
  411. }
  412. _, err := protoWriter.WriteMsg(mustWrapPacket(&packet))
  413. require.NoError(t, err)
  414. assert.True(t, expectSend(chOnRcv), "msg just right")
  415. // send msg thats too long
  416. packet = tmp2p.PacketMsg{
  417. ChannelID: 0x01,
  418. EOF: true,
  419. Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize+100),
  420. }
  421. _, err = protoWriter.WriteMsg(mustWrapPacket(&packet))
  422. require.Error(t, err)
  423. assert.True(t, expectSend(chOnErr), "msg too long")
  424. }
  425. func TestMConnectionReadErrorUnknownMsgType(t *testing.T) {
  426. ctx, cancel := context.WithCancel(context.Background())
  427. defer cancel()
  428. chOnErr := make(chan struct{})
  429. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  430. t.Cleanup(waitAll(mconnClient, mconnServer))
  431. // send msg with unknown msg type
  432. _, err := protoio.NewDelimitedWriter(mconnClient.conn).WriteMsg(&types.Header{ChainID: "x"})
  433. require.NoError(t, err)
  434. assert.True(t, expectSend(chOnErr), "unknown msg type")
  435. }
  436. func TestMConnectionTrySend(t *testing.T) {
  437. server, client := NetPipe()
  438. t.Cleanup(closeAll(t, client, server))
  439. ctx, cancel := context.WithCancel(context.Background())
  440. defer cancel()
  441. mconn := createTestMConnection(log.TestingLogger(), client)
  442. err := mconn.Start(ctx)
  443. require.Nil(t, err)
  444. t.Cleanup(waitAll(mconn))
  445. msg := []byte("Semicolon-Woman")
  446. resultCh := make(chan string, 2)
  447. assert.True(t, mconn.Send(0x01, msg))
  448. _, err = server.Read(make([]byte, len(msg)))
  449. require.NoError(t, err)
  450. assert.True(t, mconn.Send(0x01, msg))
  451. go func() {
  452. mconn.Send(0x01, msg)
  453. resultCh <- "TrySend"
  454. }()
  455. assert.False(t, mconn.Send(0x01, msg))
  456. assert.Equal(t, "TrySend", <-resultCh)
  457. }
  458. func TestConnVectors(t *testing.T) {
  459. testCases := []struct {
  460. testName string
  461. msg proto.Message
  462. expBytes string
  463. }{
  464. {"PacketPing", &tmp2p.PacketPing{}, "0a00"},
  465. {"PacketPong", &tmp2p.PacketPong{}, "1200"},
  466. {"PacketMsg", &tmp2p.PacketMsg{ChannelID: 1, EOF: false, Data: []byte("data transmitted over the wire")}, "1a2208011a1e64617461207472616e736d6974746564206f766572207468652077697265"},
  467. }
  468. for _, tc := range testCases {
  469. tc := tc
  470. pm := mustWrapPacket(tc.msg)
  471. bz, err := pm.Marshal()
  472. require.NoError(t, err, tc.testName)
  473. require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName)
  474. }
  475. }
  476. func TestMConnectionChannelOverflow(t *testing.T) {
  477. chOnErr := make(chan struct{})
  478. chOnRcv := make(chan struct{})
  479. ctx, cancel := context.WithCancel(context.Background())
  480. defer cancel()
  481. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  482. t.Cleanup(waitAll(mconnClient, mconnServer))
  483. mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) {
  484. chOnRcv <- struct{}{}
  485. }
  486. client := mconnClient.conn
  487. protoWriter := protoio.NewDelimitedWriter(client)
  488. var packet = tmp2p.PacketMsg{
  489. ChannelID: 0x01,
  490. EOF: true,
  491. Data: []byte(`42`),
  492. }
  493. _, err := protoWriter.WriteMsg(mustWrapPacket(&packet))
  494. require.NoError(t, err)
  495. assert.True(t, expectSend(chOnRcv))
  496. packet.ChannelID = int32(1025)
  497. _, err = protoWriter.WriteMsg(mustWrapPacket(&packet))
  498. require.NoError(t, err)
  499. assert.False(t, expectSend(chOnRcv))
  500. }
  501. func waitAll(waiters ...service.Service) func() {
  502. return func() {
  503. switch len(waiters) {
  504. case 0:
  505. return
  506. case 1:
  507. waiters[0].Wait()
  508. return
  509. default:
  510. wg := &sync.WaitGroup{}
  511. for _, w := range waiters {
  512. wg.Add(1)
  513. go func(s service.Service) {
  514. defer wg.Done()
  515. s.Wait()
  516. }(w)
  517. }
  518. wg.Wait()
  519. }
  520. }
  521. }
  522. type closer interface {
  523. Close() error
  524. }
  525. func closeAll(t *testing.T, closers ...closer) func() {
  526. return func() {
  527. for _, s := range closers {
  528. if err := s.Close(); err != nil {
  529. t.Log(err)
  530. }
  531. }
  532. }
  533. }