You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

590 lines
16 KiB

  1. package conn
  2. import (
  3. "encoding/hex"
  4. "net"
  5. "testing"
  6. "time"
  7. "github.com/fortytw2/leaktest"
  8. "github.com/gogo/protobuf/proto"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "github.com/tendermint/tendermint/internal/libs/protoio"
  12. "github.com/tendermint/tendermint/libs/log"
  13. tmp2p "github.com/tendermint/tendermint/proto/tendermint/p2p"
  14. "github.com/tendermint/tendermint/proto/tendermint/types"
  15. )
  16. const maxPingPongPacketSize = 1024 // bytes
  17. func createTestMConnection(conn net.Conn) *MConnection {
  18. onReceive := func(chID ChannelID, msgBytes []byte) {
  19. }
  20. onError := func(r interface{}) {
  21. }
  22. c := createMConnectionWithCallbacks(conn, onReceive, onError)
  23. c.SetLogger(log.TestingLogger())
  24. return c
  25. }
  26. func createMConnectionWithCallbacks(
  27. conn net.Conn,
  28. onReceive func(chID ChannelID, msgBytes []byte),
  29. onError func(r interface{}),
  30. ) *MConnection {
  31. cfg := DefaultMConnConfig()
  32. cfg.PingInterval = 90 * time.Millisecond
  33. cfg.PongTimeout = 45 * time.Millisecond
  34. chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
  35. c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg)
  36. c.SetLogger(log.TestingLogger())
  37. return c
  38. }
  39. func TestMConnectionSendFlushStop(t *testing.T) {
  40. server, client := NetPipe()
  41. t.Cleanup(closeAll(t, client, server))
  42. clientConn := createTestMConnection(client)
  43. err := clientConn.Start()
  44. require.Nil(t, err)
  45. t.Cleanup(stopAll(t, clientConn))
  46. msg := []byte("abc")
  47. assert.True(t, clientConn.Send(0x01, msg))
  48. msgLength := 14
  49. // start the reader in a new routine, so we can flush
  50. errCh := make(chan error)
  51. go func() {
  52. msgB := make([]byte, msgLength)
  53. _, err := server.Read(msgB)
  54. if err != nil {
  55. t.Error(err)
  56. return
  57. }
  58. errCh <- err
  59. }()
  60. timer := time.NewTimer(3 * time.Second)
  61. select {
  62. case <-errCh:
  63. case <-timer.C:
  64. t.Error("timed out waiting for msgs to be read")
  65. }
  66. }
  67. func TestMConnectionSend(t *testing.T) {
  68. server, client := NetPipe()
  69. t.Cleanup(closeAll(t, client, server))
  70. mconn := createTestMConnection(client)
  71. err := mconn.Start()
  72. require.Nil(t, err)
  73. t.Cleanup(stopAll(t, mconn))
  74. msg := []byte("Ant-Man")
  75. assert.True(t, mconn.Send(0x01, msg))
  76. // Note: subsequent Send/TrySend calls could pass because we are reading from
  77. // the send queue in a separate goroutine.
  78. _, err = server.Read(make([]byte, len(msg)))
  79. if err != nil {
  80. t.Error(err)
  81. }
  82. msg = []byte("Spider-Man")
  83. assert.True(t, mconn.Send(0x01, msg))
  84. _, err = server.Read(make([]byte, len(msg)))
  85. if err != nil {
  86. t.Error(err)
  87. }
  88. assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown")
  89. }
  90. func TestMConnectionReceive(t *testing.T) {
  91. server, client := NetPipe()
  92. t.Cleanup(closeAll(t, client, server))
  93. receivedCh := make(chan []byte)
  94. errorsCh := make(chan interface{})
  95. onReceive := func(chID ChannelID, msgBytes []byte) {
  96. receivedCh <- msgBytes
  97. }
  98. onError := func(r interface{}) {
  99. errorsCh <- r
  100. }
  101. mconn1 := createMConnectionWithCallbacks(client, onReceive, onError)
  102. err := mconn1.Start()
  103. require.Nil(t, err)
  104. t.Cleanup(stopAll(t, mconn1))
  105. mconn2 := createTestMConnection(server)
  106. err = mconn2.Start()
  107. require.Nil(t, err)
  108. t.Cleanup(stopAll(t, mconn2))
  109. msg := []byte("Cyclops")
  110. assert.True(t, mconn2.Send(0x01, msg))
  111. select {
  112. case receivedBytes := <-receivedCh:
  113. assert.Equal(t, msg, receivedBytes)
  114. case err := <-errorsCh:
  115. t.Fatalf("Expected %s, got %+v", msg, err)
  116. case <-time.After(500 * time.Millisecond):
  117. t.Fatalf("Did not receive %s message in 500ms", msg)
  118. }
  119. }
  120. func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
  121. server, client := net.Pipe()
  122. t.Cleanup(closeAll(t, client, server))
  123. receivedCh := make(chan []byte)
  124. errorsCh := make(chan interface{})
  125. onReceive := func(chID ChannelID, msgBytes []byte) {
  126. receivedCh <- msgBytes
  127. }
  128. onError := func(r interface{}) {
  129. errorsCh <- r
  130. }
  131. mconn := createMConnectionWithCallbacks(client, onReceive, onError)
  132. err := mconn.Start()
  133. require.Nil(t, err)
  134. t.Cleanup(stopAll(t, mconn))
  135. serverGotPing := make(chan struct{})
  136. go func() {
  137. // read ping
  138. var pkt tmp2p.Packet
  139. _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&pkt)
  140. require.NoError(t, err)
  141. serverGotPing <- struct{}{}
  142. }()
  143. <-serverGotPing
  144. pongTimerExpired := mconn.config.PongTimeout + 200*time.Millisecond
  145. select {
  146. case msgBytes := <-receivedCh:
  147. t.Fatalf("Expected error, but got %v", msgBytes)
  148. case err := <-errorsCh:
  149. assert.NotNil(t, err)
  150. case <-time.After(pongTimerExpired):
  151. t.Fatalf("Expected to receive error after %v", pongTimerExpired)
  152. }
  153. }
  154. func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
  155. server, client := net.Pipe()
  156. t.Cleanup(closeAll(t, client, server))
  157. receivedCh := make(chan []byte)
  158. errorsCh := make(chan interface{})
  159. onReceive := func(chID ChannelID, msgBytes []byte) {
  160. receivedCh <- msgBytes
  161. }
  162. onError := func(r interface{}) {
  163. errorsCh <- r
  164. }
  165. mconn := createMConnectionWithCallbacks(client, onReceive, onError)
  166. err := mconn.Start()
  167. require.Nil(t, err)
  168. t.Cleanup(stopAll(t, mconn))
  169. // sending 3 pongs in a row (abuse)
  170. protoWriter := protoio.NewDelimitedWriter(server)
  171. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  172. require.NoError(t, err)
  173. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  174. require.NoError(t, err)
  175. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  176. require.NoError(t, err)
  177. serverGotPing := make(chan struct{})
  178. go func() {
  179. // read ping (one byte)
  180. var packet tmp2p.Packet
  181. _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
  182. require.NoError(t, err)
  183. serverGotPing <- struct{}{}
  184. // respond with pong
  185. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  186. require.NoError(t, err)
  187. }()
  188. <-serverGotPing
  189. pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
  190. select {
  191. case msgBytes := <-receivedCh:
  192. t.Fatalf("Expected no data, but got %v", msgBytes)
  193. case err := <-errorsCh:
  194. t.Fatalf("Expected no error, but got %v", err)
  195. case <-time.After(pongTimerExpired):
  196. assert.True(t, mconn.IsRunning())
  197. }
  198. }
  199. func TestMConnectionMultiplePings(t *testing.T) {
  200. server, client := net.Pipe()
  201. t.Cleanup(closeAll(t, client, server))
  202. receivedCh := make(chan []byte)
  203. errorsCh := make(chan interface{})
  204. onReceive := func(chID ChannelID, msgBytes []byte) {
  205. receivedCh <- msgBytes
  206. }
  207. onError := func(r interface{}) {
  208. errorsCh <- r
  209. }
  210. mconn := createMConnectionWithCallbacks(client, onReceive, onError)
  211. err := mconn.Start()
  212. require.Nil(t, err)
  213. t.Cleanup(stopAll(t, mconn))
  214. // sending 3 pings in a row (abuse)
  215. // see https://github.com/tendermint/tendermint/issues/1190
  216. protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
  217. protoWriter := protoio.NewDelimitedWriter(server)
  218. var pkt tmp2p.Packet
  219. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  220. require.NoError(t, err)
  221. _, err = protoReader.ReadMsg(&pkt)
  222. require.NoError(t, err)
  223. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  224. require.NoError(t, err)
  225. _, err = protoReader.ReadMsg(&pkt)
  226. require.NoError(t, err)
  227. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  228. require.NoError(t, err)
  229. _, err = protoReader.ReadMsg(&pkt)
  230. require.NoError(t, err)
  231. assert.True(t, mconn.IsRunning())
  232. }
  233. func TestMConnectionPingPongs(t *testing.T) {
  234. // check that we are not leaking any go-routines
  235. t.Cleanup(leaktest.CheckTimeout(t, 10*time.Second))
  236. server, client := net.Pipe()
  237. t.Cleanup(closeAll(t, client, server))
  238. receivedCh := make(chan []byte)
  239. errorsCh := make(chan interface{})
  240. onReceive := func(chID ChannelID, msgBytes []byte) {
  241. receivedCh <- msgBytes
  242. }
  243. onError := func(r interface{}) {
  244. errorsCh <- r
  245. }
  246. mconn := createMConnectionWithCallbacks(client, onReceive, onError)
  247. err := mconn.Start()
  248. require.Nil(t, err)
  249. t.Cleanup(stopAll(t, mconn))
  250. serverGotPing := make(chan struct{})
  251. go func() {
  252. protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
  253. protoWriter := protoio.NewDelimitedWriter(server)
  254. var pkt tmp2p.PacketPing
  255. // read ping
  256. _, err = protoReader.ReadMsg(&pkt)
  257. require.NoError(t, err)
  258. serverGotPing <- struct{}{}
  259. // respond with pong
  260. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  261. require.NoError(t, err)
  262. time.Sleep(mconn.config.PingInterval)
  263. // read ping
  264. _, err = protoReader.ReadMsg(&pkt)
  265. require.NoError(t, err)
  266. serverGotPing <- struct{}{}
  267. // respond with pong
  268. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  269. require.NoError(t, err)
  270. }()
  271. <-serverGotPing
  272. <-serverGotPing
  273. pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2
  274. select {
  275. case msgBytes := <-receivedCh:
  276. t.Fatalf("Expected no data, but got %v", msgBytes)
  277. case err := <-errorsCh:
  278. t.Fatalf("Expected no error, but got %v", err)
  279. case <-time.After(2 * pongTimerExpired):
  280. assert.True(t, mconn.IsRunning())
  281. }
  282. }
  283. func TestMConnectionStopsAndReturnsError(t *testing.T) {
  284. server, client := NetPipe()
  285. t.Cleanup(closeAll(t, client, server))
  286. receivedCh := make(chan []byte)
  287. errorsCh := make(chan interface{})
  288. onReceive := func(chID ChannelID, msgBytes []byte) {
  289. receivedCh <- msgBytes
  290. }
  291. onError := func(r interface{}) {
  292. errorsCh <- r
  293. }
  294. mconn := createMConnectionWithCallbacks(client, onReceive, onError)
  295. err := mconn.Start()
  296. require.Nil(t, err)
  297. t.Cleanup(stopAll(t, mconn))
  298. if err := client.Close(); err != nil {
  299. t.Error(err)
  300. }
  301. select {
  302. case receivedBytes := <-receivedCh:
  303. t.Fatalf("Expected error, got %v", receivedBytes)
  304. case err := <-errorsCh:
  305. assert.NotNil(t, err)
  306. assert.False(t, mconn.IsRunning())
  307. case <-time.After(500 * time.Millisecond):
  308. t.Fatal("Did not receive error in 500ms")
  309. }
  310. }
  311. func newClientAndServerConnsForReadErrors(t *testing.T, chOnErr chan struct{}) (*MConnection, *MConnection) {
  312. server, client := NetPipe()
  313. onReceive := func(chID ChannelID, msgBytes []byte) {}
  314. onError := func(r interface{}) {}
  315. // create client conn with two channels
  316. chDescs := []*ChannelDescriptor{
  317. {ID: 0x01, Priority: 1, SendQueueCapacity: 1},
  318. {ID: 0x02, Priority: 1, SendQueueCapacity: 1},
  319. }
  320. mconnClient := NewMConnection(client, chDescs, onReceive, onError)
  321. mconnClient.SetLogger(log.TestingLogger().With("module", "client"))
  322. err := mconnClient.Start()
  323. require.Nil(t, err)
  324. // create server conn with 1 channel
  325. // it fires on chOnErr when there's an error
  326. serverLogger := log.TestingLogger().With("module", "server")
  327. onError = func(r interface{}) {
  328. chOnErr <- struct{}{}
  329. }
  330. mconnServer := createMConnectionWithCallbacks(server, onReceive, onError)
  331. mconnServer.SetLogger(serverLogger)
  332. err = mconnServer.Start()
  333. require.Nil(t, err)
  334. return mconnClient, mconnServer
  335. }
  336. func expectSend(ch chan struct{}) bool {
  337. after := time.After(time.Second * 5)
  338. select {
  339. case <-ch:
  340. return true
  341. case <-after:
  342. return false
  343. }
  344. }
  345. func TestMConnectionReadErrorBadEncoding(t *testing.T) {
  346. chOnErr := make(chan struct{})
  347. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
  348. client := mconnClient.conn
  349. // Write it.
  350. _, err := client.Write([]byte{1, 2, 3, 4, 5})
  351. require.NoError(t, err)
  352. assert.True(t, expectSend(chOnErr), "badly encoded msgPacket")
  353. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  354. }
  355. func TestMConnectionReadErrorUnknownChannel(t *testing.T) {
  356. chOnErr := make(chan struct{})
  357. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
  358. msg := []byte("Ant-Man")
  359. // fail to send msg on channel unknown by client
  360. assert.False(t, mconnClient.Send(0x03, msg))
  361. // send msg on channel unknown by the server.
  362. // should cause an error
  363. assert.True(t, mconnClient.Send(0x02, msg))
  364. assert.True(t, expectSend(chOnErr), "unknown channel")
  365. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  366. }
  367. func TestMConnectionReadErrorLongMessage(t *testing.T) {
  368. chOnErr := make(chan struct{})
  369. chOnRcv := make(chan struct{})
  370. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
  371. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  372. mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) {
  373. chOnRcv <- struct{}{}
  374. }
  375. client := mconnClient.conn
  376. protoWriter := protoio.NewDelimitedWriter(client)
  377. // send msg thats just right
  378. var packet = tmp2p.PacketMsg{
  379. ChannelID: 0x01,
  380. EOF: true,
  381. Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize),
  382. }
  383. _, err := protoWriter.WriteMsg(mustWrapPacket(&packet))
  384. require.NoError(t, err)
  385. assert.True(t, expectSend(chOnRcv), "msg just right")
  386. // send msg thats too long
  387. packet = tmp2p.PacketMsg{
  388. ChannelID: 0x01,
  389. EOF: true,
  390. Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize+100),
  391. }
  392. _, err = protoWriter.WriteMsg(mustWrapPacket(&packet))
  393. require.Error(t, err)
  394. assert.True(t, expectSend(chOnErr), "msg too long")
  395. }
  396. func TestMConnectionReadErrorUnknownMsgType(t *testing.T) {
  397. chOnErr := make(chan struct{})
  398. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
  399. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  400. // send msg with unknown msg type
  401. _, err := protoio.NewDelimitedWriter(mconnClient.conn).WriteMsg(&types.Header{ChainID: "x"})
  402. require.NoError(t, err)
  403. assert.True(t, expectSend(chOnErr), "unknown msg type")
  404. }
  405. func TestMConnectionTrySend(t *testing.T) {
  406. server, client := NetPipe()
  407. t.Cleanup(closeAll(t, client, server))
  408. mconn := createTestMConnection(client)
  409. err := mconn.Start()
  410. require.Nil(t, err)
  411. t.Cleanup(stopAll(t, mconn))
  412. msg := []byte("Semicolon-Woman")
  413. resultCh := make(chan string, 2)
  414. assert.True(t, mconn.Send(0x01, msg))
  415. _, err = server.Read(make([]byte, len(msg)))
  416. require.NoError(t, err)
  417. assert.True(t, mconn.Send(0x01, msg))
  418. go func() {
  419. mconn.Send(0x01, msg)
  420. resultCh <- "TrySend"
  421. }()
  422. assert.False(t, mconn.Send(0x01, msg))
  423. assert.Equal(t, "TrySend", <-resultCh)
  424. }
  425. // nolint:lll //ignore line length for tests
  426. func TestConnVectors(t *testing.T) {
  427. testCases := []struct {
  428. testName string
  429. msg proto.Message
  430. expBytes string
  431. }{
  432. {"PacketPing", &tmp2p.PacketPing{}, "0a00"},
  433. {"PacketPong", &tmp2p.PacketPong{}, "1200"},
  434. {"PacketMsg", &tmp2p.PacketMsg{ChannelID: 1, EOF: false, Data: []byte("data transmitted over the wire")}, "1a2208011a1e64617461207472616e736d6974746564206f766572207468652077697265"},
  435. }
  436. for _, tc := range testCases {
  437. tc := tc
  438. pm := mustWrapPacket(tc.msg)
  439. bz, err := pm.Marshal()
  440. require.NoError(t, err, tc.testName)
  441. require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName)
  442. }
  443. }
  444. func TestMConnectionChannelOverflow(t *testing.T) {
  445. chOnErr := make(chan struct{})
  446. chOnRcv := make(chan struct{})
  447. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
  448. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  449. mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) {
  450. chOnRcv <- struct{}{}
  451. }
  452. client := mconnClient.conn
  453. protoWriter := protoio.NewDelimitedWriter(client)
  454. var packet = tmp2p.PacketMsg{
  455. ChannelID: 0x01,
  456. EOF: true,
  457. Data: []byte(`42`),
  458. }
  459. _, err := protoWriter.WriteMsg(mustWrapPacket(&packet))
  460. require.NoError(t, err)
  461. assert.True(t, expectSend(chOnRcv))
  462. packet.ChannelID = int32(1025)
  463. _, err = protoWriter.WriteMsg(mustWrapPacket(&packet))
  464. require.NoError(t, err)
  465. assert.False(t, expectSend(chOnRcv))
  466. }
  467. type stopper interface {
  468. Stop() error
  469. }
  470. func stopAll(t *testing.T, stoppers ...stopper) func() {
  471. return func() {
  472. for _, s := range stoppers {
  473. if err := s.Stop(); err != nil {
  474. t.Log(err)
  475. }
  476. }
  477. }
  478. }
  479. type closer interface {
  480. Close() error
  481. }
  482. func closeAll(t *testing.T, closers ...closer) func() {
  483. return func() {
  484. for _, s := range closers {
  485. if err := s.Close(); err != nil {
  486. t.Log(err)
  487. }
  488. }
  489. }
  490. }