From 1f34236948b722c3cfd2b37bbe49f2de7142ace3 Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Tue, 7 Jul 2015 18:35:21 -0700 Subject: [PATCH] Limit binary data to 21MB --- binary/binary.go | 6 ++++++ binary/byteslice.go | 9 +++++++++ binary/reflect.go | 4 ++++ binary/string.go | 5 +++++ p2p/connection.go | 30 ++++++++++++++++++++---------- 5 files changed, 44 insertions(+), 10 deletions(-) diff --git a/binary/binary.go b/binary/binary.go index 1a5656937..a613767e7 100644 --- a/binary/binary.go +++ b/binary/binary.go @@ -2,10 +2,16 @@ package binary import ( "encoding/json" + "errors" "io" "reflect" ) +// TODO document and maybe make it configurable. +const MaxBinaryReadSize = 21 * 1024 * 1024 + +var ErrMaxBinaryReadSizeReached = errors.New("Error: max binary read size reached") + func ReadBinary(o interface{}, r io.Reader, n *int64, err *error) interface{} { rv, rt := reflect.ValueOf(o), reflect.TypeOf(o) if rv.Kind() == reflect.Ptr { diff --git a/binary/byteslice.go b/binary/byteslice.go index 2e93ab938..205a502c7 100644 --- a/binary/byteslice.go +++ b/binary/byteslice.go @@ -19,6 +19,10 @@ func ReadByteSlice(r io.Reader, n *int64, err *error) []byte { if *err != nil { return nil } + if MaxBinaryReadSize < *n+int64(length) { + *err = ErrMaxBinaryReadSizeReached + return nil + } var buf, tmpBuf []byte // read one ByteSliceChunk at a time and append @@ -50,6 +54,11 @@ func ReadByteSlices(r io.Reader, n *int64, err *error) [][]byte { if *err != nil { return nil } + if MaxBinaryReadSize < *n+int64(length) { + *err = ErrMaxBinaryReadSizeReached + return nil + } + bzz := make([][]byte, length) for i := 0; i < length; i++ { bz := ReadByteSlice(r, n, err) diff --git a/binary/reflect.go b/binary/reflect.go index 0ef267dc4..55e60d239 100644 --- a/binary/reflect.go +++ b/binary/reflect.go @@ -273,6 +273,10 @@ func readReflectBinary(rv reflect.Value, rt reflect.Type, opts Options, r io.Rea if *err != nil { return } + if MaxBinaryReadSize < *n { + *err = ErrMaxBinaryReadSizeReached + return + } } sliceRv = reflect.AppendSlice(sliceRv, tmpSliceRv) } diff --git a/binary/string.go b/binary/string.go index fb0bfc7d8..d05744b04 100644 --- a/binary/string.go +++ b/binary/string.go @@ -14,6 +14,11 @@ func ReadString(r io.Reader, n *int64, err *error) string { if *err != nil { return "" } + if MaxBinaryReadSize < *n+int64(length) { + *err = ErrMaxBinaryReadSizeReached + return "" + } + buf := make([]byte, length) ReadFull(buf, r, n, err) return string(buf) diff --git a/p2p/connection.go b/p2p/connection.go index aa4b4394c..a8021c777 100644 --- a/p2p/connection.go +++ b/p2p/connection.go @@ -403,13 +403,13 @@ FOR_LOOP: // do nothing log.Debug("Receive Pong") case packetTypeMsg: - pkt, n, err := msgPacket{}, new(int64), new(error) - binary.ReadBinaryPtr(&pkt, c.bufReader, n, err) - c.recvMonitor.Update(int(*n)) - if *err != nil { + pkt, n, err := msgPacket{}, int64(0), error(nil) + binary.ReadBinaryPtr(&pkt, c.bufReader, &n, &err) + c.recvMonitor.Update(int(n)) + if err != nil { if atomic.LoadUint32(&c.stopped) != 1 { - log.Warn("Connection failed @ recvRoutine", "connection", c, "error", *err) - c.stopForError(*err) + log.Warn("Connection failed @ recvRoutine", "connection", c, "error", err) + c.stopForError(err) } break FOR_LOOP } @@ -417,7 +417,14 @@ FOR_LOOP: if !ok || channel == nil { panic(Fmt("Unknown channel %X", pkt.ChannelId)) } - msgBytes := channel.recvMsgPacket(pkt) + msgBytes, err := channel.recvMsgPacket(pkt) + if err != nil { + if atomic.LoadUint32(&c.stopped) != 1 { + log.Warn("Connection failed @ recvRoutine", "connection", c, "error", err) + c.stopForError(err) + } + break FOR_LOOP + } if msgBytes != nil { log.Debug("Received bytes", "chId", pkt.ChannelId, "msgBytes", msgBytes) c.onReceive(pkt.ChannelId, msgBytes) @@ -569,15 +576,18 @@ func (ch *Channel) writeMsgPacketTo(w io.Writer) (n int64, err error) { // Handles incoming msgPackets. Returns a msg bytes if msg is complete. // Not goroutine-safe -func (ch *Channel) recvMsgPacket(packet msgPacket) []byte { +func (ch *Channel) recvMsgPacket(packet msgPacket) ([]byte, error) { log.Debug("Read Msg Packet", "conn", ch.conn, "packet", packet) + if binary.MaxBinaryReadSize < len(ch.recving)+len(packet.Bytes) { + return nil, binary.ErrMaxBinaryReadSizeReached + } ch.recving = append(ch.recving, packet.Bytes...) if packet.EOF == byte(0x01) { msgBytes := ch.recving ch.recving = make([]byte, 0, defaultRecvBufferCapacity) - return msgBytes + return msgBytes, nil } - return nil + return nil, nil } // Call this periodically to update stats for throttling purposes.