package conn import ( "context" "encoding/hex" "io" "net" "sync" "testing" "time" "github.com/fortytw2/leaktest" "github.com/gogo/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/internal/libs/protoio" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/service" tmp2p "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/proto/tendermint/types" ) const maxPingPongPacketSize = 1024 // bytes func createTestMConnection(logger log.Logger, conn net.Conn) *MConnection { return createMConnectionWithCallbacks(logger, conn, // onRecieve func(ctx context.Context, chID ChannelID, msgBytes []byte) { }, // onError func(ctx context.Context, r interface{}) { }) } func createMConnectionWithCallbacks( logger log.Logger, conn net.Conn, onReceive func(ctx context.Context, chID ChannelID, msgBytes []byte), onError func(ctx context.Context, r interface{}), ) *MConnection { cfg := DefaultMConnConfig() cfg.PingInterval = 250 * time.Millisecond cfg.PongTimeout = 500 * time.Millisecond chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} c := NewMConnection(logger, conn, chDescs, onReceive, onError, cfg) return c } func TestMConnectionSendFlushStop(t *testing.T) { server, client := NetPipe() t.Cleanup(closeAll(t, client, server)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() clientConn := createTestMConnection(log.TestingLogger(), client) err := clientConn.Start(ctx) require.NoError(t, err) t.Cleanup(waitAll(clientConn)) msg := []byte("abc") assert.True(t, clientConn.Send(0x01, msg)) msgLength := 14 // start the reader in a new routine, so we can flush errCh := make(chan error) go func() { msgB := make([]byte, msgLength) _, err := server.Read(msgB) if err != nil { t.Error(err) return } errCh <- err }() timer := time.NewTimer(3 * time.Second) select { case <-errCh: case <-timer.C: t.Error("timed out waiting for msgs to be read") } } func TestMConnectionSend(t *testing.T) { server, client := NetPipe() t.Cleanup(closeAll(t, client, server)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() mconn := createTestMConnection(log.TestingLogger(), client) err := mconn.Start(ctx) require.NoError(t, err) t.Cleanup(waitAll(mconn)) msg := []byte("Ant-Man") assert.True(t, mconn.Send(0x01, msg)) // Note: subsequent Send/TrySend calls could pass because we are reading from // the send queue in a separate goroutine. _, err = server.Read(make([]byte, len(msg))) if err != nil { t.Error(err) } msg = []byte("Spider-Man") assert.True(t, mconn.Send(0x01, msg)) _, err = server.Read(make([]byte, len(msg))) if err != nil { t.Error(err) } assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown") } func TestMConnectionReceive(t *testing.T) { server, client := NetPipe() t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) errorsCh := make(chan interface{}) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } onError := func(ctx context.Context, r interface{}) { select { case errorsCh <- r: case <-ctx.Done(): } } logger := log.TestingLogger() ctx, cancel := context.WithCancel(context.Background()) defer cancel() mconn1 := createMConnectionWithCallbacks(logger, client, onReceive, onError) err := mconn1.Start(ctx) require.NoError(t, err) t.Cleanup(waitAll(mconn1)) mconn2 := createTestMConnection(logger, server) err = mconn2.Start(ctx) require.NoError(t, err) t.Cleanup(waitAll(mconn2)) msg := []byte("Cyclops") assert.True(t, mconn2.Send(0x01, msg)) select { case receivedBytes := <-receivedCh: assert.Equal(t, msg, receivedBytes) case err := <-errorsCh: t.Fatalf("Expected %s, got %+v", msg, err) case <-time.After(500 * time.Millisecond): t.Fatalf("Did not receive %s message in 500ms", msg) } } func TestMConnectionWillEventuallyTimeout(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, nil, nil) err := mconn.Start(ctx) require.NoError(t, err) t.Cleanup(waitAll(mconn)) require.True(t, mconn.IsRunning()) go func() { // read the send buffer so that the send receive // doesn't get blocked. ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for { select { case <-ticker.C: _, _ = io.ReadAll(server) case <-ctx.Done(): return } } }() // wait for the send routine to die because it doesn't select { case <-mconn.doneSendRoutine: require.True(t, time.Since(mconn.getLastMessageAt()) > mconn.config.PongTimeout, "the connection state reflects that we've passed the pong timeout") // since we hit the timeout, things should be shutdown require.False(t, mconn.IsRunning()) case <-time.After(2 * mconn.config.PongTimeout): t.Fatal("connection did not hit timeout", mconn.config.PongTimeout) } } func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) errorsCh := make(chan interface{}) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } onError := func(ctx context.Context, r interface{}) { select { case errorsCh <- r: case <-ctx.Done(): } } ctx, cancel := context.WithCancel(context.Background()) defer cancel() mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) err := mconn.Start(ctx) require.NoError(t, err) t.Cleanup(waitAll(mconn)) // sending 3 pongs in a row (abuse) protoWriter := protoio.NewDelimitedWriter(server) _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) require.NoError(t, err) _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) require.NoError(t, err) _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) require.NoError(t, err) // read ping (one byte) var packet tmp2p.Packet _, err = protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet) require.NoError(t, err) // respond with pong _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) require.NoError(t, err) pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond select { case msgBytes := <-receivedCh: t.Fatalf("Expected no data, but got %v", msgBytes) case err := <-errorsCh: t.Fatalf("Expected no error, but got %v", err) case <-time.After(pongTimerExpired): assert.True(t, mconn.IsRunning()) } } func TestMConnectionMultiplePings(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) errorsCh := make(chan interface{}) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } onError := func(ctx context.Context, r interface{}) { select { case errorsCh <- r: case <-ctx.Done(): } } ctx, cancel := context.WithCancel(context.Background()) defer cancel() mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) err := mconn.Start(ctx) require.NoError(t, err) t.Cleanup(waitAll(mconn)) // sending 3 pings in a row (abuse) // see https://github.com/tendermint/tendermint/issues/1190 protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize) protoWriter := protoio.NewDelimitedWriter(server) var pkt tmp2p.Packet _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{})) require.NoError(t, err) _, err = protoReader.ReadMsg(&pkt) require.NoError(t, err) _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{})) require.NoError(t, err) _, err = protoReader.ReadMsg(&pkt) require.NoError(t, err) _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{})) require.NoError(t, err) _, err = protoReader.ReadMsg(&pkt) require.NoError(t, err) assert.True(t, mconn.IsRunning()) } func TestMConnectionPingPongs(t *testing.T) { // check that we are not leaking any go-routines t.Cleanup(leaktest.CheckTimeout(t, 10*time.Second)) server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) errorsCh := make(chan interface{}) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } onError := func(ctx context.Context, r interface{}) { select { case errorsCh <- r: case <-ctx.Done(): } } ctx, cancel := context.WithCancel(context.Background()) defer cancel() mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) err := mconn.Start(ctx) require.NoError(t, err) t.Cleanup(waitAll(mconn)) protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize) protoWriter := protoio.NewDelimitedWriter(server) var pkt tmp2p.PacketPing // read ping _, err = protoReader.ReadMsg(&pkt) require.NoError(t, err) // respond with pong _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) require.NoError(t, err) time.Sleep(mconn.config.PingInterval) // read ping _, err = protoReader.ReadMsg(&pkt) require.NoError(t, err) // respond with pong _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) require.NoError(t, err) pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 4 select { case msgBytes := <-receivedCh: t.Fatalf("Expected no data, but got %v", msgBytes) case err := <-errorsCh: t.Fatalf("Expected no error, but got %v", err) case <-time.After(2 * pongTimerExpired): assert.True(t, mconn.IsRunning()) } } func TestMConnectionStopsAndReturnsError(t *testing.T) { server, client := NetPipe() t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) errorsCh := make(chan interface{}) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } onError := func(ctx context.Context, r interface{}) { select { case errorsCh <- r: case <-ctx.Done(): } } ctx, cancel := context.WithCancel(context.Background()) defer cancel() mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) err := mconn.Start(ctx) require.NoError(t, err) t.Cleanup(waitAll(mconn)) if err := client.Close(); err != nil { t.Error(err) } select { case receivedBytes := <-receivedCh: t.Fatalf("Expected error, got %v", receivedBytes) case err := <-errorsCh: assert.NotNil(t, err) assert.False(t, mconn.IsRunning()) case <-time.After(500 * time.Millisecond): t.Fatal("Did not receive error in 500ms") } } func newClientAndServerConnsForReadErrors( ctx context.Context, t *testing.T, chOnErr chan struct{}, ) (*MConnection, *MConnection) { server, client := NetPipe() onReceive := func(context.Context, ChannelID, []byte) {} onError := func(context.Context, interface{}) {} // create client conn with two channels chDescs := []*ChannelDescriptor{ {ID: 0x01, Priority: 1, SendQueueCapacity: 1}, {ID: 0x02, Priority: 1, SendQueueCapacity: 1}, } logger := log.TestingLogger() mconnClient := NewMConnection(logger.With("module", "client"), client, chDescs, onReceive, onError, DefaultMConnConfig()) err := mconnClient.Start(ctx) require.NoError(t, err) // create server conn with 1 channel // it fires on chOnErr when there's an error serverLogger := logger.With("module", "server") onError = func(ctx context.Context, r interface{}) { select { case <-ctx.Done(): case chOnErr <- struct{}{}: } } mconnServer := createMConnectionWithCallbacks(serverLogger, server, onReceive, onError) err = mconnServer.Start(ctx) require.NoError(t, err) return mconnClient, mconnServer } func expectSend(ch chan struct{}) bool { after := time.After(time.Second * 5) select { case <-ch: return true case <-after: return false } } func TestMConnectionReadErrorBadEncoding(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() chOnErr := make(chan struct{}) mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) client := mconnClient.conn // Write it. _, err := client.Write([]byte{1, 2, 3, 4, 5}) require.NoError(t, err) assert.True(t, expectSend(chOnErr), "badly encoded msgPacket") t.Cleanup(waitAll(mconnClient, mconnServer)) } func TestMConnectionReadErrorUnknownChannel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() chOnErr := make(chan struct{}) mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) msg := []byte("Ant-Man") // fail to send msg on channel unknown by client assert.False(t, mconnClient.Send(0x03, msg)) // send msg on channel unknown by the server. // should cause an error assert.True(t, mconnClient.Send(0x02, msg)) assert.True(t, expectSend(chOnErr), "unknown channel") t.Cleanup(waitAll(mconnClient, mconnServer)) } func TestMConnectionReadErrorLongMessage(t *testing.T) { chOnErr := make(chan struct{}) chOnRcv := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) t.Cleanup(waitAll(mconnClient, mconnServer)) mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case <-ctx.Done(): case chOnRcv <- struct{}{}: } } client := mconnClient.conn protoWriter := protoio.NewDelimitedWriter(client) // send msg thats just right var packet = tmp2p.PacketMsg{ ChannelID: 0x01, EOF: true, Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize), } _, err := protoWriter.WriteMsg(mustWrapPacket(&packet)) require.NoError(t, err) assert.True(t, expectSend(chOnRcv), "msg just right") // send msg thats too long packet = tmp2p.PacketMsg{ ChannelID: 0x01, EOF: true, Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize+100), } _, err = protoWriter.WriteMsg(mustWrapPacket(&packet)) require.Error(t, err) assert.True(t, expectSend(chOnErr), "msg too long") } func TestMConnectionReadErrorUnknownMsgType(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() chOnErr := make(chan struct{}) mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) t.Cleanup(waitAll(mconnClient, mconnServer)) // send msg with unknown msg type _, err := protoio.NewDelimitedWriter(mconnClient.conn).WriteMsg(&types.Header{ChainID: "x"}) require.NoError(t, err) assert.True(t, expectSend(chOnErr), "unknown msg type") } func TestMConnectionTrySend(t *testing.T) { server, client := NetPipe() t.Cleanup(closeAll(t, client, server)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() mconn := createTestMConnection(log.TestingLogger(), client) err := mconn.Start(ctx) require.NoError(t, err) t.Cleanup(waitAll(mconn)) msg := []byte("Semicolon-Woman") resultCh := make(chan string, 2) assert.True(t, mconn.Send(0x01, msg)) _, err = server.Read(make([]byte, len(msg))) require.NoError(t, err) assert.True(t, mconn.Send(0x01, msg)) go func() { mconn.Send(0x01, msg) resultCh <- "TrySend" }() assert.False(t, mconn.Send(0x01, msg)) assert.Equal(t, "TrySend", <-resultCh) } func TestConnVectors(t *testing.T) { testCases := []struct { testName string msg proto.Message expBytes string }{ {"PacketPing", &tmp2p.PacketPing{}, "0a00"}, {"PacketPong", &tmp2p.PacketPong{}, "1200"}, {"PacketMsg", &tmp2p.PacketMsg{ChannelID: 1, EOF: false, Data: []byte("data transmitted over the wire")}, "1a2208011a1e64617461207472616e736d6974746564206f766572207468652077697265"}, } for _, tc := range testCases { tc := tc pm := mustWrapPacket(tc.msg) bz, err := pm.Marshal() require.NoError(t, err, tc.testName) require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName) } } func TestMConnectionChannelOverflow(t *testing.T) { chOnErr := make(chan struct{}) chOnRcv := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) t.Cleanup(waitAll(mconnClient, mconnServer)) mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case <-ctx.Done(): case chOnRcv <- struct{}{}: } } client := mconnClient.conn protoWriter := protoio.NewDelimitedWriter(client) var packet = tmp2p.PacketMsg{ ChannelID: 0x01, EOF: true, Data: []byte(`42`), } _, err := protoWriter.WriteMsg(mustWrapPacket(&packet)) require.NoError(t, err) assert.True(t, expectSend(chOnRcv)) packet.ChannelID = int32(1025) _, err = protoWriter.WriteMsg(mustWrapPacket(&packet)) require.NoError(t, err) assert.False(t, expectSend(chOnRcv)) } func waitAll(waiters ...service.Service) func() { return func() { switch len(waiters) { case 0: return case 1: waiters[0].Wait() return default: wg := &sync.WaitGroup{} for _, w := range waiters { wg.Add(1) go func(s service.Service) { defer wg.Done() s.Wait() }(w) } wg.Wait() } } } type closer interface { Close() error } func closeAll(t *testing.T, closers ...closer) func() { return func() { for _, s := range closers { if err := s.Close(); err != nil { t.Log(err) } } } }