From 358b1f23c048dcea15cbf12e0053db7dc3e24ad8 Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Fri, 4 Jun 2021 20:20:36 -0400 Subject: [PATCH] p2p/conn: check for channel id overflow before processing receive msg (backport #6522) (#6528) * p2p/conn: check for channel id overflow before processing receive msg (#6522) Per tendermint spec, each Channel has a globally unique byte id, which is mapped to uint8 in Go. However, the proto PacketMsg.ChannelID field is declared as int32, and when receive the packet, we cast it to a byte without checking for possible overflow. That leads to a malform packet with invalid channel id is sent successfully. To fix it, we just add a check for possible overflow, and return invalid channel id error. Fixed #6521 (cherry picked from commit 1f46a4c90e268def505037a5d42627942f605ef4) --- p2p/conn/connection.go | 9 ++++---- p2p/conn/connection_test.go | 44 +++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/p2p/conn/connection.go b/p2p/conn/connection.go index 65495074f..dfbd76f0a 100644 --- a/p2p/conn/connection.go +++ b/p2p/conn/connection.go @@ -624,8 +624,9 @@ FOR_LOOP: // never block } case *tmp2p.Packet_PacketMsg: - channel, ok := c.channelsIdx[byte(pkt.PacketMsg.ChannelID)] - if !ok || channel == nil { + channelID := byte(pkt.PacketMsg.ChannelID) + channel, ok := c.channelsIdx[channelID] + if pkt.PacketMsg.ChannelID < 0 || pkt.PacketMsg.ChannelID > math.MaxUint8 || !ok || channel == nil { err := fmt.Errorf("unknown channel %X", pkt.PacketMsg.ChannelID) c.Logger.Debug("Connection failed @ recvRoutine", "conn", c, "err", err) c.stopForError(err) @@ -641,9 +642,9 @@ FOR_LOOP: break FOR_LOOP } if msgBytes != nil { - c.Logger.Debug("Received bytes", "chID", pkt.PacketMsg.ChannelID, "msgBytes", fmt.Sprintf("%X", msgBytes)) + c.Logger.Debug("Received bytes", "chID", channelID, "msgBytes", msgBytes) // NOTE: This means the reactor.Receive runs in the same thread as the p2p recv routine - c.onReceive(byte(pkt.PacketMsg.ChannelID), msgBytes) + c.onReceive(channelID, msgBytes) } default: err := fmt.Errorf("unknown message type %v", reflect.TypeOf(packet)) diff --git a/p2p/conn/connection_test.go b/p2p/conn/connection_test.go index 6dadfb486..c41a46c48 100644 --- a/p2p/conn/connection_test.go +++ b/p2p/conn/connection_test.go @@ -587,3 +587,47 @@ func TestConnVectors(t *testing.T) { require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName) } } + +func TestMConnectionChannelOverflow(t *testing.T) { + chOnErr := make(chan struct{}) + chOnRcv := make(chan struct{}) + + mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr) + t.Cleanup(stopAll(t, mconnClient, mconnServer)) + + mconnServer.onReceive = func(chID byte, msgBytes []byte) { + chOnRcv <- struct{}{} + } + + client := mconnClient.conn + protoWriter := protoio.NewDelimitedWriter(client) + + var packet = tmp2p.PacketMsg{ + ChannelID: 0x01, + EOF: true, + Data: []byte(`42`), + } + _, err := protoWriter.WriteMsg(mustWrapPacket(&packet)) + require.NoError(t, err) + assert.True(t, expectSend(chOnRcv)) + + packet.ChannelID = int32(1025) + _, err = protoWriter.WriteMsg(mustWrapPacket(&packet)) + require.NoError(t, err) + assert.False(t, expectSend(chOnRcv)) + +} + +type stopper interface { + Stop() error +} + +func stopAll(t *testing.T, stoppers ...stopper) func() { + return func() { + for _, s := range stoppers { + if err := s.Stop(); err != nil { + t.Log(err) + } + } + } +}