You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

698 lines
18 KiB

package conn
import (
"context"
"encoding/hex"
"net"
"sync"
"testing"
"time"
"github.com/fortytw2/leaktest"
"github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/internal/libs/protoio"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/libs/service"
tmp2p "github.com/tendermint/tendermint/proto/tendermint/p2p"
"github.com/tendermint/tendermint/proto/tendermint/types"
)
const maxPingPongPacketSize = 1024 // bytes
func createTestMConnection(logger log.Logger, conn net.Conn) *MConnection {
return createMConnectionWithCallbacks(logger, conn,
// onRecieve
func(ctx context.Context, chID ChannelID, msgBytes []byte) {
},
// onError
func(ctx context.Context, r interface{}) {
})
}
func createMConnectionWithCallbacks(
logger log.Logger,
conn net.Conn,
onReceive func(ctx context.Context, chID ChannelID, msgBytes []byte),
onError func(ctx context.Context, r interface{}),
) *MConnection {
cfg := DefaultMConnConfig()
cfg.PingInterval = 90 * time.Millisecond
cfg.PongTimeout = 45 * time.Millisecond
chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
c := NewMConnectionWithConfig(logger, conn, chDescs, onReceive, onError, cfg)
return c
}
func TestMConnectionSendFlushStop(t *testing.T) {
server, client := NetPipe()
t.Cleanup(closeAll(t, client, server))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
clientConn := createTestMConnection(log.TestingLogger(), client)
err := clientConn.Start(ctx)
require.NoError(t, err)
t.Cleanup(waitAll(clientConn))
msg := []byte("abc")
assert.True(t, clientConn.Send(0x01, msg))
msgLength := 14
// start the reader in a new routine, so we can flush
errCh := make(chan error)
go func() {
msgB := make([]byte, msgLength)
_, err := server.Read(msgB)
if err != nil {
t.Error(err)
return
}
errCh <- err
}()
timer := time.NewTimer(3 * time.Second)
select {
case <-errCh:
case <-timer.C:
t.Error("timed out waiting for msgs to be read")
}
}
func TestMConnectionSend(t *testing.T) {
server, client := NetPipe()
t.Cleanup(closeAll(t, client, server))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconn := createTestMConnection(log.TestingLogger(), client)
err := mconn.Start(ctx)
require.NoError(t, err)
t.Cleanup(waitAll(mconn))
msg := []byte("Ant-Man")
assert.True(t, mconn.Send(0x01, msg))
// Note: subsequent Send/TrySend calls could pass because we are reading from
// the send queue in a separate goroutine.
_, err = server.Read(make([]byte, len(msg)))
if err != nil {
t.Error(err)
}
msg = []byte("Spider-Man")
assert.True(t, mconn.Send(0x01, msg))
_, err = server.Read(make([]byte, len(msg)))
if err != nil {
t.Error(err)
}
assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown")
}
func TestMConnectionReceive(t *testing.T) {
server, client := NetPipe()
t.Cleanup(closeAll(t, client, server))
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
logger := log.TestingLogger()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconn1 := createMConnectionWithCallbacks(logger, client, onReceive, onError)
err := mconn1.Start(ctx)
require.NoError(t, err)
t.Cleanup(waitAll(mconn1))
mconn2 := createTestMConnection(logger, server)
err = mconn2.Start(ctx)
require.NoError(t, err)
t.Cleanup(waitAll(mconn2))
msg := []byte("Cyclops")
assert.True(t, mconn2.Send(0x01, msg))
select {
case receivedBytes := <-receivedCh:
assert.Equal(t, msg, receivedBytes)
case err := <-errorsCh:
t.Fatalf("Expected %s, got %+v", msg, err)
case <-time.After(500 * time.Millisecond):
t.Fatalf("Did not receive %s message in 500ms", msg)
}
}
func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
server, client := net.Pipe()
t.Cleanup(closeAll(t, client, server))
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
err := mconn.Start(ctx)
require.NoError(t, err)
t.Cleanup(waitAll(mconn))
serverGotPing := make(chan struct{})
go func() {
// read ping
var pkt tmp2p.Packet
_, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&pkt)
require.NoError(t, err)
serverGotPing <- struct{}{}
}()
<-serverGotPing
pongTimerExpired := mconn.config.PongTimeout + 200*time.Millisecond
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected error, but got %v", msgBytes)
case err := <-errorsCh:
assert.NotNil(t, err)
case <-time.After(pongTimerExpired):
t.Fatalf("Expected to receive error after %v", pongTimerExpired)
}
}
func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
server, client := net.Pipe()
t.Cleanup(closeAll(t, client, server))
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
err := mconn.Start(ctx)
require.NoError(t, err)
t.Cleanup(waitAll(mconn))
// sending 3 pongs in a row (abuse)
protoWriter := protoio.NewDelimitedWriter(server)
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
serverGotPing := make(chan struct{})
go func() {
// read ping (one byte)
var packet tmp2p.Packet
_, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
require.NoError(t, err)
serverGotPing <- struct{}{}
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
}()
<-serverGotPing
pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected no data, but got %v", msgBytes)
case err := <-errorsCh:
t.Fatalf("Expected no error, but got %v", err)
case <-time.After(pongTimerExpired):
assert.True(t, mconn.IsRunning())
}
}
func TestMConnectionMultiplePings(t *testing.T) {
server, client := net.Pipe()
t.Cleanup(closeAll(t, client, server))
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
err := mconn.Start(ctx)
require.NoError(t, err)
t.Cleanup(waitAll(mconn))
// sending 3 pings in a row (abuse)
// see https://github.com/tendermint/tendermint/issues/1190
protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
protoWriter := protoio.NewDelimitedWriter(server)
var pkt tmp2p.Packet
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
require.NoError(t, err)
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
require.NoError(t, err)
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPing{}))
require.NoError(t, err)
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
assert.True(t, mconn.IsRunning())
}
func TestMConnectionPingPongs(t *testing.T) {
// check that we are not leaking any go-routines
t.Cleanup(leaktest.CheckTimeout(t, 10*time.Second))
server, client := net.Pipe()
t.Cleanup(closeAll(t, client, server))
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
err := mconn.Start(ctx)
require.NoError(t, err)
t.Cleanup(waitAll(mconn))
serverGotPing := make(chan struct{})
go func() {
protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
protoWriter := protoio.NewDelimitedWriter(server)
var pkt tmp2p.PacketPing
// read ping
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
serverGotPing <- struct{}{}
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
time.Sleep(mconn.config.PingInterval)
// read ping
_, err = protoReader.ReadMsg(&pkt)
require.NoError(t, err)
serverGotPing <- struct{}{}
// respond with pong
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
require.NoError(t, err)
}()
<-serverGotPing
<-serverGotPing
pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected no data, but got %v", msgBytes)
case err := <-errorsCh:
t.Fatalf("Expected no error, but got %v", err)
case <-time.After(2 * pongTimerExpired):
assert.True(t, mconn.IsRunning())
}
}
func TestMConnectionStopsAndReturnsError(t *testing.T) {
server, client := NetPipe()
t.Cleanup(closeAll(t, client, server))
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case receivedCh <- msgBytes:
case <-ctx.Done():
}
}
onError := func(ctx context.Context, r interface{}) {
select {
case errorsCh <- r:
case <-ctx.Done():
}
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
err := mconn.Start(ctx)
require.NoError(t, err)
t.Cleanup(waitAll(mconn))
if err := client.Close(); err != nil {
t.Error(err)
}
select {
case receivedBytes := <-receivedCh:
t.Fatalf("Expected error, got %v", receivedBytes)
case err := <-errorsCh:
assert.NotNil(t, err)
assert.False(t, mconn.IsRunning())
case <-time.After(500 * time.Millisecond):
t.Fatal("Did not receive error in 500ms")
}
}
func newClientAndServerConnsForReadErrors(
ctx context.Context,
t *testing.T,
chOnErr chan struct{},
) (*MConnection, *MConnection) {
server, client := NetPipe()
onReceive := func(context.Context, ChannelID, []byte) {}
onError := func(context.Context, interface{}) {}
// create client conn with two channels
chDescs := []*ChannelDescriptor{
{ID: 0x01, Priority: 1, SendQueueCapacity: 1},
{ID: 0x02, Priority: 1, SendQueueCapacity: 1},
}
logger := log.TestingLogger()
mconnClient := NewMConnection(logger.With("module", "client"), client, chDescs, onReceive, onError)
err := mconnClient.Start(ctx)
require.NoError(t, err)
// create server conn with 1 channel
// it fires on chOnErr when there's an error
serverLogger := logger.With("module", "server")
onError = func(ctx context.Context, r interface{}) {
select {
case <-ctx.Done():
case chOnErr <- struct{}{}:
}
}
mconnServer := createMConnectionWithCallbacks(serverLogger, server, onReceive, onError)
err = mconnServer.Start(ctx)
require.NoError(t, err)
return mconnClient, mconnServer
}
func expectSend(ch chan struct{}) bool {
after := time.After(time.Second * 5)
select {
case <-ch:
return true
case <-after:
return false
}
}
func TestMConnectionReadErrorBadEncoding(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
chOnErr := make(chan struct{})
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
client := mconnClient.conn
// Write it.
_, err := client.Write([]byte{1, 2, 3, 4, 5})
require.NoError(t, err)
assert.True(t, expectSend(chOnErr), "badly encoded msgPacket")
t.Cleanup(waitAll(mconnClient, mconnServer))
}
func TestMConnectionReadErrorUnknownChannel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
chOnErr := make(chan struct{})
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
msg := []byte("Ant-Man")
// fail to send msg on channel unknown by client
assert.False(t, mconnClient.Send(0x03, msg))
// send msg on channel unknown by the server.
// should cause an error
assert.True(t, mconnClient.Send(0x02, msg))
assert.True(t, expectSend(chOnErr), "unknown channel")
t.Cleanup(waitAll(mconnClient, mconnServer))
}
func TestMConnectionReadErrorLongMessage(t *testing.T) {
chOnErr := make(chan struct{})
chOnRcv := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
t.Cleanup(waitAll(mconnClient, mconnServer))
mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case <-ctx.Done():
case chOnRcv <- struct{}{}:
}
}
client := mconnClient.conn
protoWriter := protoio.NewDelimitedWriter(client)
// send msg thats just right
var packet = tmp2p.PacketMsg{
ChannelID: 0x01,
EOF: true,
Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize),
}
_, err := protoWriter.WriteMsg(mustWrapPacket(&packet))
require.NoError(t, err)
assert.True(t, expectSend(chOnRcv), "msg just right")
// send msg thats too long
packet = tmp2p.PacketMsg{
ChannelID: 0x01,
EOF: true,
Data: make([]byte, mconnClient.config.MaxPacketMsgPayloadSize+100),
}
_, err = protoWriter.WriteMsg(mustWrapPacket(&packet))
require.Error(t, err)
assert.True(t, expectSend(chOnErr), "msg too long")
}
func TestMConnectionReadErrorUnknownMsgType(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
chOnErr := make(chan struct{})
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
t.Cleanup(waitAll(mconnClient, mconnServer))
// send msg with unknown msg type
_, err := protoio.NewDelimitedWriter(mconnClient.conn).WriteMsg(&types.Header{ChainID: "x"})
require.NoError(t, err)
assert.True(t, expectSend(chOnErr), "unknown msg type")
}
func TestMConnectionTrySend(t *testing.T) {
server, client := NetPipe()
t.Cleanup(closeAll(t, client, server))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconn := createTestMConnection(log.TestingLogger(), client)
err := mconn.Start(ctx)
require.NoError(t, err)
t.Cleanup(waitAll(mconn))
msg := []byte("Semicolon-Woman")
resultCh := make(chan string, 2)
assert.True(t, mconn.Send(0x01, msg))
_, err = server.Read(make([]byte, len(msg)))
require.NoError(t, err)
assert.True(t, mconn.Send(0x01, msg))
go func() {
mconn.Send(0x01, msg)
resultCh <- "TrySend"
}()
assert.False(t, mconn.Send(0x01, msg))
assert.Equal(t, "TrySend", <-resultCh)
}
func TestConnVectors(t *testing.T) {
testCases := []struct {
testName string
msg proto.Message
expBytes string
}{
{"PacketPing", &tmp2p.PacketPing{}, "0a00"},
{"PacketPong", &tmp2p.PacketPong{}, "1200"},
{"PacketMsg", &tmp2p.PacketMsg{ChannelID: 1, EOF: false, Data: []byte("data transmitted over the wire")}, "1a2208011a1e64617461207472616e736d6974746564206f766572207468652077697265"},
}
for _, tc := range testCases {
tc := tc
pm := mustWrapPacket(tc.msg)
bz, err := pm.Marshal()
require.NoError(t, err, tc.testName)
require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName)
}
}
func TestMConnectionChannelOverflow(t *testing.T) {
chOnErr := make(chan struct{})
chOnRcv := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr)
t.Cleanup(waitAll(mconnClient, mconnServer))
mconnServer.onReceive = func(ctx context.Context, chID ChannelID, msgBytes []byte) {
select {
case <-ctx.Done():
case chOnRcv <- struct{}{}:
}
}
client := mconnClient.conn
protoWriter := protoio.NewDelimitedWriter(client)
var packet = tmp2p.PacketMsg{
ChannelID: 0x01,
EOF: true,
Data: []byte(`42`),
}
_, err := protoWriter.WriteMsg(mustWrapPacket(&packet))
require.NoError(t, err)
assert.True(t, expectSend(chOnRcv))
packet.ChannelID = int32(1025)
_, err = protoWriter.WriteMsg(mustWrapPacket(&packet))
require.NoError(t, err)
assert.False(t, expectSend(chOnRcv))
}
func waitAll(waiters ...service.Service) func() {
return func() {
switch len(waiters) {
case 0:
return
case 1:
waiters[0].Wait()
return
default:
wg := &sync.WaitGroup{}
for _, w := range waiters {
wg.Add(1)
go func(s service.Service) {
defer wg.Done()
s.Wait()
}(w)
}
wg.Wait()
}
}
}
type closer interface {
Close() error
}
func closeAll(t *testing.T, closers ...closer) func() {
return func() {
for _, s := range closers {
if err := s.Close(); err != nil {
t.Log(err)
}
}
}
}