You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

582 lines
15 KiB

  1. package conn
  2. import (
  3. "encoding/hex"
  4. "net"
  5. "testing"
  6. "time"
  7. "github.com/fortytw2/leaktest"
  8. "github.com/gogo/protobuf/proto"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "github.com/tendermint/tendermint/internal/libs/protoio"
  12. "github.com/tendermint/tendermint/libs/log"
  13. tmp2p "github.com/tendermint/tendermint/proto/tendermint/p2p"
  14. "github.com/tendermint/tendermint/proto/tendermint/types"
  15. )
  16. const maxPingPongPacketSize = 1024 // bytes
  17. func createTestMConnection(conn net.Conn) *MConnection {
  18. onReceive := func(chID byte, msgBytes []byte) {
  19. }
  20. onError := func(r interface{}) {
  21. }
  22. c := createMConnectionWithCallbacks(conn, onReceive, onError)
  23. c.SetLogger(log.TestingLogger())
  24. return c
  25. }
  26. func createMConnectionWithCallbacks(
  27. conn net.Conn,
  28. onReceive func(chID byte, msgBytes []byte),
  29. onError func(r interface{}),
  30. ) *MConnection {
  31. cfg := DefaultMConnConfig()
  32. cfg.PingInterval = 90 * time.Millisecond
  33. cfg.PongTimeout = 45 * time.Millisecond
  34. chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
  35. c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg)
  36. c.SetLogger(log.TestingLogger())
  37. return c
  38. }
  39. func TestMConnectionSendFlushStop(t *testing.T) {
  40. server, client := NetPipe()
  41. t.Cleanup(closeAll(t, client, server))
  42. clientConn := createTestMConnection(client)
  43. err := clientConn.Start()
  44. require.Nil(t, err)
  45. t.Cleanup(stopAll(t, clientConn))
  46. msg := []byte("abc")
  47. assert.True(t, clientConn.Send(0x01, msg))
  48. msgLength := 14
  49. // start the reader in a new routine, so we can flush
  50. errCh := make(chan error)
  51. go func() {
  52. msgB := make([]byte, msgLength)
  53. _, err := server.Read(msgB)
  54. if err != nil {
  55. t.Error(err)
  56. return
  57. }
  58. errCh <- err
  59. }()
  60. // stop the conn - it should flush all conns
  61. clientConn.FlushStop()
  62. timer := time.NewTimer(3 * time.Second)
  63. select {
  64. case <-errCh:
  65. case <-timer.C:
  66. t.Error("timed out waiting for msgs to be read")
  67. }
  68. }
  69. func TestMConnectionSend(t *testing.T) {
  70. server, client := NetPipe()
  71. t.Cleanup(closeAll(t, client, server))
  72. mconn := createTestMConnection(client)
  73. err := mconn.Start()
  74. require.Nil(t, err)
  75. t.Cleanup(stopAll(t, mconn))
  76. msg := []byte("Ant-Man")
  77. assert.True(t, mconn.Send(0x01, msg))
  78. // Note: subsequent Send/TrySend calls could pass because we are reading from
  79. // the send queue in a separate goroutine.
  80. _, err = server.Read(make([]byte, len(msg)))
  81. if err != nil {
  82. t.Error(err)
  83. }
  84. assert.True(t, mconn.CanSend(0x01))
  85. msg = []byte("Spider-Man")
  86. assert.True(t, mconn.TrySend(0x01, msg))
  87. _, err = server.Read(make([]byte, len(msg)))
  88. if err != nil {
  89. t.Error(err)
  90. }
  91. assert.False(t, mconn.CanSend(0x05), "CanSend should return false because channel is unknown")
  92. assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown")
  93. }
  94. func TestMConnectionReceive(t *testing.T) {
  95. server, client := NetPipe()
  96. t.Cleanup(closeAll(t, client, server))
  97. receivedCh := make(chan []byte)
  98. errorsCh := make(chan interface{})
  99. onReceive := func(chID byte, msgBytes []byte) {
  100. receivedCh <- msgBytes
  101. }
  102. onError := func(r interface{}) {
  103. errorsCh <- r
  104. }
  105. mconn1 := createMConnectionWithCallbacks(client, onReceive, onError)
  106. err := mconn1.Start()
  107. require.Nil(t, err)
  108. t.Cleanup(stopAll(t, mconn1))
  109. mconn2 := createTestMConnection(server)
  110. err = mconn2.Start()
  111. require.Nil(t, err)
  112. t.Cleanup(stopAll(t, mconn2))
  113. msg := []byte("Cyclops")
  114. assert.True(t, mconn2.Send(0x01, msg))
  115. select {
  116. case receivedBytes := <-receivedCh:
  117. assert.Equal(t, msg, receivedBytes)
  118. case err := <-errorsCh:
  119. t.Fatalf("Expected %s, got %+v", msg, err)
  120. case <-time.After(500 * time.Millisecond):
  121. t.Fatalf("Did not receive %s message in 500ms", msg)
  122. }
  123. }
  124. func TestMConnectionStatus(t *testing.T) {
  125. server, client := NetPipe()
  126. t.Cleanup(closeAll(t, client, server))
  127. mconn := createTestMConnection(client)
  128. err := mconn.Start()
  129. require.Nil(t, err)
  130. t.Cleanup(stopAll(t, mconn))
  131. status := mconn.Status()
  132. assert.NotNil(t, status)
  133. assert.Zero(t, status.Channels[0].SendQueueSize)
  134. }
  135. func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
  136. server, client := net.Pipe()
  137. t.Cleanup(closeAll(t, client, server))
  138. receivedCh := make(chan []byte)
  139. errorsCh := make(chan interface{})
  140. onReceive := func(chID byte, msgBytes []byte) {
  141. receivedCh <- msgBytes
  142. }
  143. onError := func(r interface{}) {
  144. errorsCh <- r
  145. }
  146. mconn := createMConnectionWithCallbacks(client, onReceive, onError)
  147. err := mconn.Start()
  148. require.Nil(t, err)
  149. t.Cleanup(stopAll(t, mconn))
  150. serverGotPing := make(chan struct{})
  151. go func() {
  152. // read ping
  153. var pkt tmp2p.Packet
  154. _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&pkt)
  155. require.NoError(t, err)
  156. serverGotPing <- struct{}{}
  157. }()
  158. <-serverGotPing
  159. pongTimerExpired := mconn.config.PongTimeout + 200*time.Millisecond
  160. select {
  161. case msgBytes := <-receivedCh:
  162. t.Fatalf("Expected error, but got %v", msgBytes)
  163. case err := <-errorsCh:
  164. assert.NotNil(t, err)
  165. case <-time.After(pongTimerExpired):
  166. t.Fatalf("Expected to receive error after %v", pongTimerExpired)
  167. }
  168. }
  169. func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
  170. server, client := net.Pipe()
  171. t.Cleanup(closeAll(t, client, server))
  172. receivedCh := make(chan []byte)
  173. errorsCh := make(chan interface{})
  174. onReceive := func(chID byte, msgBytes []byte) {
  175. receivedCh <- msgBytes
  176. }
  177. onError := func(r interface{}) {
  178. errorsCh <- r
  179. }
  180. mconn := createMConnectionWithCallbacks(client, onReceive, onError)
  181. err := mconn.Start()
  182. require.Nil(t, err)
  183. t.Cleanup(stopAll(t, mconn))
  184. // sending 3 pongs in a row (abuse)
  185. protoWriter := protoio.NewDelimitedWriter(server)
  186. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  187. require.NoError(t, err)
  188. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  189. require.NoError(t, err)
  190. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  191. require.NoError(t, err)
  192. serverGotPing := make(chan struct{})
  193. go func() {
  194. // read ping (one byte)
  195. var packet tmp2p.Packet
  196. _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
  197. require.NoError(t, err)
  198. serverGotPing <- struct{}{}
  199. // respond with pong
  200. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  201. require.NoError(t, err)
  202. }()
  203. <-serverGotPing
  204. pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
  205. select {
  206. case msgBytes := <-receivedCh:
  207. t.Fatalf("Expected no data, but got %v", msgBytes)
  208. case err := <-errorsCh:
  209. t.Fatalf("Expected no error, but got %v", err)
  210. case <-time.After(pongTimerExpired):
  211. assert.True(t, mconn.IsRunning())
  212. }
  213. }
  214. func TestMConnectionMultiplePings(t *testing.T) {
  215. server, client := net.Pipe()
  216. t.Cleanup(closeAll(t, client, server))
  217. receivedCh := make(chan []byte)
  218. errorsCh := make(chan interface{})
  219. onReceive := func(chID byte, msgBytes []byte) {
  220. receivedCh <- msgBytes
  221. }
  222. onError := func(r interface{}) {
  223. errorsCh <- r
  224. }
  225. mconn := createMConnectionWithCallbacks(client, onReceive, onError)
  226. err := mconn.Start()
  227. require.Nil(t, err)
  228. t.Cleanup(stopAll(t, mconn))
  229. // sending 3 pings in a row (abuse)
  230. // see https://github.com/tendermint/tendermint/issues/1190
  231. protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
  232. protoWriter := protoio.NewDelimitedWriter(server)
  233. var pkt tmp2p.Packet
  234. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  235. require.NoError(t, err)
  236. _, err = protoReader.ReadMsg(&pkt)
  237. require.NoError(t, err)
  238. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  239. require.NoError(t, err)
  240. _, err = protoReader.ReadMsg(&pkt)
  241. require.NoError(t, err)
  242. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
  243. require.NoError(t, err)
  244. _, err = protoReader.ReadMsg(&pkt)
  245. require.NoError(t, err)
  246. assert.True(t, mconn.IsRunning())
  247. }
  248. func TestMConnectionPingPongs(t *testing.T) {
  249. // check that we are not leaking any go-routines
  250. t.Cleanup(leaktest.CheckTimeout(t, 10*time.Second))
  251. server, client := net.Pipe()
  252. t.Cleanup(closeAll(t, client, server))
  253. receivedCh := make(chan []byte)
  254. errorsCh := make(chan interface{})
  255. onReceive := func(chID byte, msgBytes []byte) {
  256. receivedCh <- msgBytes
  257. }
  258. onError := func(r interface{}) {
  259. errorsCh <- r
  260. }
  261. mconn := createMConnectionWithCallbacks(client, onReceive, onError)
  262. err := mconn.Start()
  263. require.Nil(t, err)
  264. t.Cleanup(stopAll(t, mconn))
  265. serverGotPing := make(chan struct{})
  266. go func() {
  267. protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
  268. protoWriter := protoio.NewDelimitedWriter(server)
  269. var pkt tmp2p.PacketPing
  270. // read ping
  271. _, err = protoReader.ReadMsg(&pkt)
  272. require.NoError(t, err)
  273. serverGotPing <- struct{}{}
  274. // respond with pong
  275. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  276. require.NoError(t, err)
  277. time.Sleep(mconn.config.PingInterval)
  278. // read ping
  279. _, err = protoReader.ReadMsg(&pkt)
  280. require.NoError(t, err)
  281. serverGotPing <- struct{}{}
  282. // respond with pong
  283. _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
  284. require.NoError(t, err)
  285. }()
  286. <-serverGotPing
  287. <-serverGotPing
  288. pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2
  289. select {
  290. case msgBytes := <-receivedCh:
  291. t.Fatalf("Expected no data, but got %v", msgBytes)
  292. case err := <-errorsCh:
  293. t.Fatalf("Expected no error, but got %v", err)
  294. case <-time.After(2 * pongTimerExpired):
  295. assert.True(t, mconn.IsRunning())
  296. }
  297. }
  298. func TestMConnectionStopsAndReturnsError(t *testing.T) {
  299. server, client := NetPipe()
  300. t.Cleanup(closeAll(t, client, server))
  301. receivedCh := make(chan []byte)
  302. errorsCh := make(chan interface{})
  303. onReceive := func(chID byte, msgBytes []byte) {
  304. receivedCh <- msgBytes
  305. }
  306. onError := func(r interface{}) {
  307. errorsCh <- r
  308. }
  309. mconn := createMConnectionWithCallbacks(client, onReceive, onError)
  310. err := mconn.Start()
  311. require.Nil(t, err)
  312. t.Cleanup(stopAll(t, mconn))
  313. if err := client.Close(); err != nil {
  314. t.Error(err)
  315. }
  316. select {
  317. case receivedBytes := <-receivedCh:
  318. t.Fatalf("Expected error, got %v", receivedBytes)
  319. case err := <-errorsCh:
  320. assert.NotNil(t, err)
  321. assert.False(t, mconn.IsRunning())
  322. case <-time.After(500 * time.Millisecond):
  323. t.Fatal("Did not receive error in 500ms")
  324. }
  325. }
  326. func newClientAndServerConnsForReadErrors(t *testing.T, chOnErr chan struct{}) (*MConnection, *MConnection) {
  327. server, client := NetPipe()
  328. onReceive := func(chID byte, msgBytes []byte) {}
  329. onError := func(r interface{}) {}
  330. // create client conn with two channels
  331. chDescs := []*ChannelDescriptor{
  332. {ID: 0x01, Priority: 1, SendQueueCapacity: 1},
  333. {ID: 0x02, Priority: 1, SendQueueCapacity: 1},
  334. }
  335. mconnClient := NewMConnection(client, chDescs, onReceive, onError)
  336. mconnClient.SetLogger(log.TestingLogger().With("module", "client"))
  337. err := mconnClient.Start()
  338. require.Nil(t, err)
  339. // create server conn with 1 channel
  340. // it fires on chOnErr when there's an error
  341. serverLogger := log.TestingLogger().With("module", "server")
  342. onError = func(r interface{}) {
  343. chOnErr <- struct{}{}
  344. }
  345. mconnServer := createMConnectionWithCallbacks(server, onReceive, onError)
  346. mconnServer.SetLogger(serverLogger)
  347. err = mconnServer.Start()
  348. require.Nil(t, err)
  349. return mconnClient, mconnServer
  350. }
  351. func expectSend(ch chan struct{}) bool {
  352. after := time.After(time.Second * 5)
  353. select {
  354. case <-ch:
  355. return true
  356. case <-after:
  357. return false
  358. }
  359. }
  360. func TestMConnectionReadErrorBadEncoding(t *testing.T) {
  361. chOnErr := make(chan struct{})
  362. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
  363. client := mconnClient.conn
  364. // Write it.
  365. _, err := client.Write([]byte{1, 2, 3, 4, 5})
  366. require.NoError(t, err)
  367. assert.True(t, expectSend(chOnErr), "badly encoded msgPacket")
  368. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  369. }
  370. func TestMConnectionReadErrorUnknownChannel(t *testing.T) {
  371. chOnErr := make(chan struct{})
  372. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
  373. msg := []byte("Ant-Man")
  374. // fail to send msg on channel unknown by client
  375. assert.False(t, mconnClient.Send(0x03, msg))
  376. // send msg on channel unknown by the server.
  377. // should cause an error
  378. assert.True(t, mconnClient.Send(0x02, msg))
  379. assert.True(t, expectSend(chOnErr), "unknown channel")
  380. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  381. }
  382. func TestMConnectionReadErrorLongMessage(t *testing.T) {
  383. chOnErr := make(chan struct{})
  384. chOnRcv := make(chan struct{})
  385. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
  386. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  387. mconnServer.onReceive = func(chID byte, msgBytes []byte) {
  388. chOnRcv <- struct{}{}
  389. }
  390. client := mconnClient.conn
  391. protoWriter := protoio.NewDelimitedWriter(client)
  392. // send msg thats just right
  393. var packet = tmp2p.PacketMsg{
  394. ChannelID: 0x01,
  395. EOF: true,
  396. Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize),
  397. }
  398. _, err := protoWriter.WriteMsg(mustWrapPacket(&packet))
  399. require.NoError(t, err)
  400. assert.True(t, expectSend(chOnRcv), "msg just right")
  401. // send msg thats too long
  402. packet = tmp2p.PacketMsg{
  403. ChannelID: 0x01,
  404. EOF: true,
  405. Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize+100),
  406. }
  407. _, err = protoWriter.WriteMsg(mustWrapPacket(&packet))
  408. require.Error(t, err)
  409. assert.True(t, expectSend(chOnErr), "msg too long")
  410. }
  411. func TestMConnectionReadErrorUnknownMsgType(t *testing.T) {
  412. chOnErr := make(chan struct{})
  413. mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
  414. t.Cleanup(stopAll(t, mconnClient, mconnServer))
  415. // send msg with unknown msg type
  416. _, err := protoio.NewDelimitedWriter(mconnClient.conn).WriteMsg(&types.Header{ChainID: "x"})
  417. require.NoError(t, err)
  418. assert.True(t, expectSend(chOnErr), "unknown msg type")
  419. }
  420. func TestMConnectionTrySend(t *testing.T) {
  421. server, client := NetPipe()
  422. t.Cleanup(closeAll(t, client, server))
  423. mconn := createTestMConnection(client)
  424. err := mconn.Start()
  425. require.Nil(t, err)
  426. t.Cleanup(stopAll(t, mconn))
  427. msg := []byte("Semicolon-Woman")
  428. resultCh := make(chan string, 2)
  429. assert.True(t, mconn.TrySend(0x01, msg))
  430. _, err = server.Read(make([]byte, len(msg)))
  431. require.NoError(t, err)
  432. assert.True(t, mconn.CanSend(0x01))
  433. assert.True(t, mconn.TrySend(0x01, msg))
  434. assert.False(t, mconn.CanSend(0x01))
  435. go func() {
  436. mconn.TrySend(0x01, msg)
  437. resultCh <- "TrySend"
  438. }()
  439. assert.False(t, mconn.CanSend(0x01))
  440. assert.False(t, mconn.TrySend(0x01, msg))
  441. assert.Equal(t, "TrySend", <-resultCh)
  442. }
  443. // nolint:lll //ignore line length for tests
  444. func TestConnVectors(t *testing.T) {
  445. testCases := []struct {
  446. testName string
  447. msg proto.Message
  448. expBytes string
  449. }{
  450. {"PacketPing", &tmp2p.PacketPing{}, "0a00"},
  451. {"PacketPong", &tmp2p.PacketPong{}, "1200"},
  452. {"PacketMsg", &tmp2p.PacketMsg{ChannelID: 1, EOF: false, Data: []byte("data transmitted over the wire")}, "1a2208011a1e64617461207472616e736d6974746564206f766572207468652077697265"},
  453. }
  454. for _, tc := range testCases {
  455. tc := tc
  456. pm := mustWrapPacket(tc.msg)
  457. bz, err := pm.Marshal()
  458. require.NoError(t, err, tc.testName)
  459. require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName)
  460. }
  461. }
  462. type stopper interface {
  463. Stop() error
  464. }
  465. func stopAll(t *testing.T, stoppers ...stopper) func() {
  466. return func() {
  467. for _, s := range stoppers {
  468. if err := s.Stop(); err != nil {
  469. t.Log(err)
  470. }
  471. }
  472. }
  473. }
  474. type closer interface {
  475. Close() error
  476. }
  477. func closeAll(t *testing.T, closers ...closer) func() {
  478. return func() {
  479. for _, s := range closers {
  480. if err := s.Close(); err != nil {
  481. t.Log(err)
  482. }
  483. }
  484. }
  485. }