You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

641 lines
17 KiB

  1. package conn
  2. import (
  3. "context"
  4. "encoding/hex"
  5. "net"
  6. "testing"
  7. "time"
  8. "github.com/fortytw2/leaktest"
  9. "github.com/gogo/protobuf/proto"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. "github.com/tendermint/tendermint/internal/libs/protoio"
  13. "github.com/tendermint/tendermint/libs/log"
  14. tmp2p "github.com/tendermint/tendermint/proto/tendermint/p2p"
  15. "github.com/tendermint/tendermint/proto/tendermint/types"
  16. )
  17. const maxPingPongPacketSize = 1024 // bytes
  18. func createTestMConnection(logger log.Logger, conn net.Conn) *MConnection {
  19. return createMConnectionWithCallbacks(logger, conn,
  20. // onRecieve
  21. func(chID ChannelID, msgBytes []byte) {
  22. },
  23. // onError
  24. func(r interface{}) {
  25. })
  26. }
  27. func createMConnectionWithCallbacks(
  28. logger log.Logger,
  29. conn net.Conn,
  30. onReceive func(chID ChannelID, msgBytes []byte),
  31. onError func(r interface{}),
  32. ) *MConnection {
  33. cfg := DefaultMConnConfig()
  34. cfg.PingInterval = 90 * time.Millisecond
  35. cfg.PongTimeout = 45 * time.Millisecond
  36. chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
  37. c := NewMConnectionWithConfig(logger, conn, chDescs, onReceive, onError, cfg)
  38. return c
  39. }
  40. func TestMConnectionSendFlushStop(t *testing.T) {
  41. server, client := NetPipe()
  42. t.Cleanup(closeAll(t, client, server))
  43. ctx, cancel := context.WithCancel(context.Background())
  44. defer cancel()
  45. clientConn := createTestMConnection(log.TestingLogger(), client)
  46. err := clientConn.Start(ctx)
  47. require.Nil(t, err)
  48. t.Cleanup(stopAll(t, clientConn))
  49. msg := []byte("abc")
  50. assert.True(t, clientConn.Send(0x01, msg))
  51. msgLength := 14
  52. // start the reader in a new routine, so we can flush
  53. errCh := make(chan error)
  54. go func() {
  55. msgB := make([]byte, msgLength)
  56. _, err := server.Read(msgB)
  57. if err != nil {
  58. t.Error(err)
  59. return
  60. }
  61. errCh <- err
  62. }()
  63. timer := time.NewTimer(3 * time.Second)
  64. select {
  65. case <-errCh:
  66. case <-timer.C:
  67. t.Error("timed out waiting for msgs to be read")
  68. }
  69. }
  70. func TestMConnectionSend(t *testing.T) {
  71. server, client := NetPipe()
  72. t.Cleanup(closeAll(t, client, server))
  73. ctx, cancel := context.WithCancel(context.Background())
  74. defer cancel()
  75. mconn := createTestMConnection(log.TestingLogger(), client)
  76. err := mconn.Start(ctx)
  77. require.Nil(t, err)
  78. t.Cleanup(stopAll(t, mconn))
  79. msg := []byte("Ant-Man")
  80. assert.True(t, mconn.Send(0x01, msg))
  81. // Note: subsequent Send/TrySend calls could pass because we are reading from
  82. // the send queue in a separate goroutine.
  83. _, err = server.Read(make([]byte, len(msg)))
  84. if err != nil {
  85. t.Error(err)
  86. }
  87. msg = []byte("Spider-Man")
  88. assert.True(t, mconn.Send(0x01, msg))
  89. _, err = server.Read(make([]byte, len(msg)))
  90. if err != nil {
  91. t.Error(err)
  92. }
  93. assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown")
  94. }
  95. func TestMConnectionReceive(t *testing.T) {
  96. server, client := NetPipe()
  97. t.Cleanup(closeAll(t, client, server))
  98. receivedCh := make(chan []byte)
  99. errorsCh := make(chan interface{})
  100. onReceive := func(chID ChannelID, msgBytes []byte) {
  101. receivedCh <- msgBytes
  102. }
  103. onError := func(r interface{}) {
  104. errorsCh <- r
  105. }
  106. logger := log.TestingLogger()
  107. ctx, cancel := context.WithCancel(context.Background())
  108. defer cancel()
  109. mconn1 := createMConnectionWithCallbacks(logger, client, onReceive, onError)
  110. err := mconn1.Start(ctx)
  111. require.Nil(t, err)
  112. t.Cleanup(stopAll(t, mconn1))
  113. mconn2 := createTestMConnection(logger, server)
  114. err = mconn2.Start(ctx)
  115. require.Nil(t, err)
  116. t.Cleanup(stopAll(t, mconn2))
  117. msg := []byte("Cyclops")
  118. assert.True(t, mconn2.Send(0x01, msg))
  119. select {
  120. case receivedBytes := <-receivedCh:
  121. assert.Equal(t, msg, receivedBytes)
  122. case err := <-errorsCh:
  123. t.Fatalf("Expected %s, got %+v", msg, err)
  124. case <-time.After(500 * time.Millisecond):
  125. t.Fatalf("Did not receive %s message in 500ms", msg)
  126. }
  127. }
  128. func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
  129. server, client := net.Pipe()
  130. t.Cleanup(closeAll(t, client, server))
  131. receivedCh := make(chan []byte)
  132. errorsCh := make(chan interface{})
  133. onReceive := func(chID ChannelID, msgBytes []byte) {
  134. receivedCh <- msgBytes
  135. }
  136. onError := func(r interface{}) {
  137. errorsCh <- r
  138. }
  139. ctx, cancel := context.WithCancel(context.Background())
  140. defer cancel()
  141. mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
  142. err := mconn.Start(ctx)
  143. require.Nil(t, err)
  144. t.Cleanup(stopAll(t, mconn))
  145. serverGotPing := make(chan struct{})
  146. go func() {
  147. // read ping
  148. var pkt tmp2p.Packet
  149. _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&pkt)
  150. require.NoError(t, err)
  151. serverGotPing <- struct{}{}
  152. }()
  153. <-serverGotPing
  154. pongTimerExpired := mconn.config.PongTimeout + 200*time.Millisecond
  155. select {
  156. case msgBytes := <-receivedCh:
  157. t.Fatalf("Expected error, but got %v", msgBytes)
  158. case err := <-errorsCh:
  159. assert.NotNil(t, err)
  160. case <-time.After(pongTimerExpired):
  161. t.Fatalf("Expected to receive error after %v", pongTimerExpired)
  162. }
  163. }
  164. func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
  165. server, client := net.Pipe()
  166. t.Cleanup(closeAll(t, client, server))
  167. receivedCh := make(chan []byte)
  168. errorsCh := make(chan interface{})
  169. onReceive := func(chID ChannelID, msgBytes []byte) {
  170. receivedCh <- msgBytes
  171. }
  172. onError := func(r interface{}) {
  173. errorsCh <- r
  174. }
  175. ctx, cancel := context.WithCancel(context.Background())
  176. defer cancel()
  177. mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
  178. err := mconn.Start(ctx)
  179. require.Nil(t, err)
  180. t.Cleanup(stopAll(t, mconn))
  181. // sending 3 pongs in a row (abuse)
  182. protoWriter := protoio.NewDelimitedWriter(server)
  183. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  184. require.NoError(t, err)
  185. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  186. require.NoError(t, err)
  187. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  188. require.NoError(t, err)
  189. serverGotPing := make(chan struct{})
  190. go func() {
  191. // read ping (one byte)
  192. var packet tmp2p.Packet
  193. _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
  194. require.NoError(t, err)
  195. serverGotPing <- struct{}{}
  196. // respond with pong
  197. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  198. require.NoError(t, err)
  199. }()
  200. <-serverGotPing
  201. pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
  202. select {
  203. case msgBytes := <-receivedCh:
  204. t.Fatalf("Expected no data, but got %v", msgBytes)
  205. case err := <-errorsCh:
  206. t.Fatalf("Expected no error, but got %v", err)
  207. case <-time.After(pongTimerExpired):
  208. assert.True(t, mconn.IsRunning())
  209. }
  210. }
  211. func TestMConnectionMultiplePings(t *testing.T) {
  212. server, client := net.Pipe()
  213. t.Cleanup(closeAll(t, client, server))
  214. receivedCh := make(chan []byte)
  215. errorsCh := make(chan interface{})
  216. onReceive := func(chID ChannelID, msgBytes []byte) {
  217. receivedCh <- msgBytes
  218. }
  219. onError := func(r interface{}) {
  220. errorsCh <- r
  221. }
  222. ctx, cancel := context.WithCancel(context.Background())
  223. defer cancel()
  224. mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
  225. err := mconn.Start(ctx)
  226. require.Nil(t, err)
  227. t.Cleanup(stopAll(t, mconn))
  228. // sending 3 pings in a row (abuse)
  229. // see https://github.com/tendermint/tendermint/issues/1190
  230. protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
  231. protoWriter := protoio.NewDelimitedWriter(server)
  232. var pkt tmp2p.Packet
  233. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  234. require.NoError(t, err)
  235. _, err = protoReader.ReadMsg(&pkt)
  236. require.NoError(t, err)
  237. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  238. require.NoError(t, err)
  239. _, err = protoReader.ReadMsg(&pkt)
  240. require.NoError(t, err)
  241. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  242. require.NoError(t, err)
  243. _, err = protoReader.ReadMsg(&pkt)
  244. require.NoError(t, err)
  245. assert.True(t, mconn.IsRunning())
  246. }
  247. func TestMConnectionPingPongs(t *testing.T) {
  248. // check that we are not leaking any go-routines
  249. t.Cleanup(leaktest.CheckTimeout(t, 10*time.Second))
  250. server, client := net.Pipe()
  251. t.Cleanup(closeAll(t, client, server))
  252. receivedCh := make(chan []byte)
  253. errorsCh := make(chan interface{})
  254. onReceive := func(chID ChannelID, msgBytes []byte) {
  255. receivedCh <- msgBytes
  256. }
  257. onError := func(r interface{}) {
  258. errorsCh <- r
  259. }
  260. ctx, cancel := context.WithCancel(context.Background())
  261. defer cancel()
  262. mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
  263. err := mconn.Start(ctx)
  264. require.Nil(t, err)
  265. t.Cleanup(stopAll(t, mconn))
  266. serverGotPing := make(chan struct{})
  267. go func() {
  268. protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
  269. protoWriter := protoio.NewDelimitedWriter(server)
  270. var pkt tmp2p.PacketPing
  271. // read ping
  272. _, err = protoReader.ReadMsg(&pkt)
  273. require.NoError(t, err)
  274. serverGotPing <- struct{}{}
  275. // respond with pong
  276. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  277. require.NoError(t, err)
  278. time.Sleep(mconn.config.PingInterval)
  279. // read ping
  280. _, err = protoReader.ReadMsg(&pkt)
  281. require.NoError(t, err)
  282. serverGotPing <- struct{}{}
  283. // respond with pong
  284. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  285. require.NoError(t, err)
  286. }()
  287. <-serverGotPing
  288. <-serverGotPing
  289. pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2
  290. select {
  291. case msgBytes := <-receivedCh:
  292. t.Fatalf("Expected no data, but got %v", msgBytes)
  293. case err := <-errorsCh:
  294. t.Fatalf("Expected no error, but got %v", err)
  295. case <-time.After(2 * pongTimerExpired):
  296. assert.True(t, mconn.IsRunning())
  297. }
  298. }
  299. func TestMConnectionStopsAndReturnsError(t *testing.T) {
  300. server, client := NetPipe()
  301. t.Cleanup(closeAll(t, client, server))
  302. receivedCh := make(chan []byte)
  303. errorsCh := make(chan interface{})
  304. onReceive := func(chID ChannelID, msgBytes []byte) {
  305. receivedCh <- msgBytes
  306. }
  307. onError := func(r interface{}) {
  308. errorsCh <- r
  309. }
  310. ctx, cancel := context.WithCancel(context.Background())
  311. defer cancel()
  312. mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
  313. err := mconn.Start(ctx)
  314. require.Nil(t, err)
  315. t.Cleanup(stopAll(t, mconn))
  316. if err := client.Close(); err != nil {
  317. t.Error(err)
  318. }
  319. select {
  320. case receivedBytes := <-receivedCh:
  321. t.Fatalf("Expected error, got %v", receivedBytes)
  322. case err := <-errorsCh:
  323. assert.NotNil(t, err)
  324. assert.False(t, mconn.IsRunning())
  325. case <-time.After(500 * time.Millisecond):
  326. t.Fatal("Did not receive error in 500ms")
  327. }
  328. }
  329. func newClientAndServerConnsForReadErrors(
  330. ctx context.Context,
  331. t *testing.T,
  332. chOnErr chan struct{},
  333. ) (*MConnection, *MConnection) {
  334. server, client := NetPipe()
  335. onReceive := func(chID ChannelID, msgBytes []byte) {}
  336. onError := func(r interface{}) {}
  337. // create client conn with two channels
  338. chDescs := []*ChannelDescriptor{
  339. {ID: 0x01, Priority: 1, SendQueueCapacity: 1},
  340. {ID: 0x02, Priority: 1, SendQueueCapacity: 1},
  341. }
  342. logger := log.TestingLogger()
  343. mconnClient := NewMConnection(logger.With("module", "client"), client, chDescs, onReceive, onError)
  344. err := mconnClient.Start(ctx)
  345. require.Nil(t, err)
  346. // create server conn with 1 channel
  347. // it fires on chOnErr when there's an error
  348. serverLogger := logger.With("module", "server")
  349. onError = func(r interface{}) {
  350. chOnErr <- struct{}{}
  351. }
  352. mconnServer := createMConnectionWithCallbacks(serverLogger, server, onReceive, onError)
  353. err = mconnServer.Start(ctx)
  354. require.Nil(t, err)
  355. return mconnClient, mconnServer
  356. }
  357. func expectSend(ch chan struct{}) bool {
  358. after := time.After(time.Second * 5)
  359. select {
  360. case <-ch:
  361. return true
  362. case <-after:
  363. return false
  364. }
  365. }
  366. func TestMConnectionReadErrorBadEncoding(t *testing.T) {
  367. ctx, cancel := context.WithCancel(context.Background())
  368. defer cancel()
  369. chOnErr := make(chan struct{})
  370. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  371. client := mconnClient.conn
  372. // Write it.
  373. _, err := client.Write([]byte{1, 2, 3, 4, 5})
  374. require.NoError(t, err)
  375. assert.True(t, expectSend(chOnErr), "badly encoded msgPacket")
  376. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  377. }
  378. func TestMConnectionReadErrorUnknownChannel(t *testing.T) {
  379. ctx, cancel := context.WithCancel(context.Background())
  380. defer cancel()
  381. chOnErr := make(chan struct{})
  382. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  383. msg := []byte("Ant-Man")
  384. // fail to send msg on channel unknown by client
  385. assert.False(t, mconnClient.Send(0x03, msg))
  386. // send msg on channel unknown by the server.
  387. // should cause an error
  388. assert.True(t, mconnClient.Send(0x02, msg))
  389. assert.True(t, expectSend(chOnErr), "unknown channel")
  390. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  391. }
  392. func TestMConnectionReadErrorLongMessage(t *testing.T) {
  393. chOnErr := make(chan struct{})
  394. chOnRcv := make(chan struct{})
  395. ctx, cancel := context.WithCancel(context.Background())
  396. defer cancel()
  397. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  398. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  399. mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) {
  400. chOnRcv <- struct{}{}
  401. }
  402. client := mconnClient.conn
  403. protoWriter := protoio.NewDelimitedWriter(client)
  404. // send msg thats just right
  405. var packet = tmp2p.PacketMsg{
  406. ChannelID: 0x01,
  407. EOF: true,
  408. Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize),
  409. }
  410. _, err := protoWriter.WriteMsg(mustWrapPacket(&packet))
  411. require.NoError(t, err)
  412. assert.True(t, expectSend(chOnRcv), "msg just right")
  413. // send msg thats too long
  414. packet = tmp2p.PacketMsg{
  415. ChannelID: 0x01,
  416. EOF: true,
  417. Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize+100),
  418. }
  419. _, err = protoWriter.WriteMsg(mustWrapPacket(&packet))
  420. require.Error(t, err)
  421. assert.True(t, expectSend(chOnErr), "msg too long")
  422. }
  423. func TestMConnectionReadErrorUnknownMsgType(t *testing.T) {
  424. ctx, cancel := context.WithCancel(context.Background())
  425. defer cancel()
  426. chOnErr := make(chan struct{})
  427. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  428. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  429. // send msg with unknown msg type
  430. _, err := protoio.NewDelimitedWriter(mconnClient.conn).WriteMsg(&types.Header{ChainID: "x"})
  431. require.NoError(t, err)
  432. assert.True(t, expectSend(chOnErr), "unknown msg type")
  433. }
  434. func TestMConnectionTrySend(t *testing.T) {
  435. server, client := NetPipe()
  436. t.Cleanup(closeAll(t, client, server))
  437. ctx, cancel := context.WithCancel(context.Background())
  438. defer cancel()
  439. mconn := createTestMConnection(log.TestingLogger(), client)
  440. err := mconn.Start(ctx)
  441. require.Nil(t, err)
  442. t.Cleanup(stopAll(t, mconn))
  443. msg := []byte("Semicolon-Woman")
  444. resultCh := make(chan string, 2)
  445. assert.True(t, mconn.Send(0x01, msg))
  446. _, err = server.Read(make([]byte, len(msg)))
  447. require.NoError(t, err)
  448. assert.True(t, mconn.Send(0x01, msg))
  449. go func() {
  450. mconn.Send(0x01, msg)
  451. resultCh <- "TrySend"
  452. }()
  453. assert.False(t, mconn.Send(0x01, msg))
  454. assert.Equal(t, "TrySend", <-resultCh)
  455. }
  456. // nolint:lll //ignore line length for tests
  457. func TestConnVectors(t *testing.T) {
  458. testCases := []struct {
  459. testName string
  460. msg proto.Message
  461. expBytes string
  462. }{
  463. {"PacketPing", &tmp2p.PacketPing{}, "0a00"},
  464. {"PacketPong", &tmp2p.PacketPong{}, "1200"},
  465. {"PacketMsg", &tmp2p.PacketMsg{ChannelID: 1, EOF: false, Data: []byte("data transmitted over the wire")}, "1a2208011a1e64617461207472616e736d6974746564206f766572207468652077697265"},
  466. }
  467. for _, tc := range testCases {
  468. tc := tc
  469. pm := mustWrapPacket(tc.msg)
  470. bz, err := pm.Marshal()
  471. require.NoError(t, err, tc.testName)
  472. require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName)
  473. }
  474. }
  475. func TestMConnectionChannelOverflow(t *testing.T) {
  476. chOnErr := make(chan struct{})
  477. chOnRcv := make(chan struct{})
  478. ctx, cancel := context.WithCancel(context.Background())
  479. defer cancel()
  480. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
  481. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  482. mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) {
  483. chOnRcv <- struct{}{}
  484. }
  485. client := mconnClient.conn
  486. protoWriter := protoio.NewDelimitedWriter(client)
  487. var packet = tmp2p.PacketMsg{
  488. ChannelID: 0x01,
  489. EOF: true,
  490. Data: []byte(`42`),
  491. }
  492. _, err := protoWriter.WriteMsg(mustWrapPacket(&packet))
  493. require.NoError(t, err)
  494. assert.True(t, expectSend(chOnRcv))
  495. packet.ChannelID = int32(1025)
  496. _, err = protoWriter.WriteMsg(mustWrapPacket(&packet))
  497. require.NoError(t, err)
  498. assert.False(t, expectSend(chOnRcv))
  499. }
  500. type stopper interface {
  501. Stop() error
  502. }
  503. func stopAll(t *testing.T, stoppers ...stopper) func() {
  504. return func() {
  505. for _, s := range stoppers {
  506. if err := s.Stop(); err != nil {
  507. t.Log(err)
  508. }
  509. }
  510. }
  511. }
  512. type closer interface {
  513. Close() error
  514. }
  515. func closeAll(t *testing.T, closers ...closer) func() {
  516. return func() {
  517. for _, s := range closers {
  518. if err := s.Close(); err != nil {
  519. t.Log(err)
  520. }
  521. }
  522. }
  523. }