Browse Source

Limit binary data to 21MB

pull/102/head
Jae Kwon 10 years ago
parent
commit
1f34236948
5 changed files with 44 additions and 10 deletions
  1. +6
    -0
      binary/binary.go
  2. +9
    -0
      binary/byteslice.go
  3. +4
    -0
      binary/reflect.go
  4. +5
    -0
      binary/string.go
  5. +20
    -10
      p2p/connection.go

+ 6
- 0
binary/binary.go View File

@ -2,10 +2,16 @@ package binary
import ( import (
"encoding/json" "encoding/json"
"errors"
"io" "io"
"reflect" "reflect"
) )
// TODO document and maybe make it configurable.
const MaxBinaryReadSize = 21 * 1024 * 1024
var ErrMaxBinaryReadSizeReached = errors.New("Error: max binary read size reached")
func ReadBinary(o interface{}, r io.Reader, n *int64, err *error) interface{} { func ReadBinary(o interface{}, r io.Reader, n *int64, err *error) interface{} {
rv, rt := reflect.ValueOf(o), reflect.TypeOf(o) rv, rt := reflect.ValueOf(o), reflect.TypeOf(o)
if rv.Kind() == reflect.Ptr { if rv.Kind() == reflect.Ptr {


+ 9
- 0
binary/byteslice.go View File

@ -19,6 +19,10 @@ func ReadByteSlice(r io.Reader, n *int64, err *error) []byte {
if *err != nil { if *err != nil {
return nil return nil
} }
if MaxBinaryReadSize < *n+int64(length) {
*err = ErrMaxBinaryReadSizeReached
return nil
}
var buf, tmpBuf []byte var buf, tmpBuf []byte
// read one ByteSliceChunk at a time and append // read one ByteSliceChunk at a time and append
@ -50,6 +54,11 @@ func ReadByteSlices(r io.Reader, n *int64, err *error) [][]byte {
if *err != nil { if *err != nil {
return nil return nil
} }
if MaxBinaryReadSize < *n+int64(length) {
*err = ErrMaxBinaryReadSizeReached
return nil
}
bzz := make([][]byte, length) bzz := make([][]byte, length)
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
bz := ReadByteSlice(r, n, err) bz := ReadByteSlice(r, n, err)


+ 4
- 0
binary/reflect.go View File

@ -273,6 +273,10 @@ func readReflectBinary(rv reflect.Value, rt reflect.Type, opts Options, r io.Rea
if *err != nil { if *err != nil {
return return
} }
if MaxBinaryReadSize < *n {
*err = ErrMaxBinaryReadSizeReached
return
}
} }
sliceRv = reflect.AppendSlice(sliceRv, tmpSliceRv) sliceRv = reflect.AppendSlice(sliceRv, tmpSliceRv)
} }


+ 5
- 0
binary/string.go View File

@ -14,6 +14,11 @@ func ReadString(r io.Reader, n *int64, err *error) string {
if *err != nil { if *err != nil {
return "" return ""
} }
if MaxBinaryReadSize < *n+int64(length) {
*err = ErrMaxBinaryReadSizeReached
return ""
}
buf := make([]byte, length) buf := make([]byte, length)
ReadFull(buf, r, n, err) ReadFull(buf, r, n, err)
return string(buf) return string(buf)


+ 20
- 10
p2p/connection.go View File

@ -403,13 +403,13 @@ FOR_LOOP:
// do nothing // do nothing
log.Debug("Receive Pong") log.Debug("Receive Pong")
case packetTypeMsg: case packetTypeMsg:
pkt, n, err := msgPacket{}, new(int64), new(error)
binary.ReadBinaryPtr(&pkt, c.bufReader, n, err)
c.recvMonitor.Update(int(*n))
if *err != nil {
pkt, n, err := msgPacket{}, int64(0), error(nil)
binary.ReadBinaryPtr(&pkt, c.bufReader, &n, &err)
c.recvMonitor.Update(int(n))
if err != nil {
if atomic.LoadUint32(&c.stopped) != 1 { if atomic.LoadUint32(&c.stopped) != 1 {
log.Warn("Connection failed @ recvRoutine", "connection", c, "error", *err)
c.stopForError(*err)
log.Warn("Connection failed @ recvRoutine", "connection", c, "error", err)
c.stopForError(err)
} }
break FOR_LOOP break FOR_LOOP
} }
@ -417,7 +417,14 @@ FOR_LOOP:
if !ok || channel == nil { if !ok || channel == nil {
panic(Fmt("Unknown channel %X", pkt.ChannelId)) panic(Fmt("Unknown channel %X", pkt.ChannelId))
} }
msgBytes := channel.recvMsgPacket(pkt)
msgBytes, err := channel.recvMsgPacket(pkt)
if err != nil {
if atomic.LoadUint32(&c.stopped) != 1 {
log.Warn("Connection failed @ recvRoutine", "connection", c, "error", err)
c.stopForError(err)
}
break FOR_LOOP
}
if msgBytes != nil { if msgBytes != nil {
log.Debug("Received bytes", "chId", pkt.ChannelId, "msgBytes", msgBytes) log.Debug("Received bytes", "chId", pkt.ChannelId, "msgBytes", msgBytes)
c.onReceive(pkt.ChannelId, msgBytes) c.onReceive(pkt.ChannelId, msgBytes)
@ -569,15 +576,18 @@ func (ch *Channel) writeMsgPacketTo(w io.Writer) (n int64, err error) {
// Handles incoming msgPackets. Returns a msg bytes if msg is complete. // Handles incoming msgPackets. Returns a msg bytes if msg is complete.
// Not goroutine-safe // Not goroutine-safe
func (ch *Channel) recvMsgPacket(packet msgPacket) []byte {
func (ch *Channel) recvMsgPacket(packet msgPacket) ([]byte, error) {
log.Debug("Read Msg Packet", "conn", ch.conn, "packet", packet) log.Debug("Read Msg Packet", "conn", ch.conn, "packet", packet)
if binary.MaxBinaryReadSize < len(ch.recving)+len(packet.Bytes) {
return nil, binary.ErrMaxBinaryReadSizeReached
}
ch.recving = append(ch.recving, packet.Bytes...) ch.recving = append(ch.recving, packet.Bytes...)
if packet.EOF == byte(0x01) { if packet.EOF == byte(0x01) {
msgBytes := ch.recving msgBytes := ch.recving
ch.recving = make([]byte, 0, defaultRecvBufferCapacity) ch.recving = make([]byte, 0, defaultRecvBufferCapacity)
return msgBytes
return msgBytes, nil
} }
return nil
return nil, nil
} }
// Call this periodically to update stats for throttling purposes. // Call this periodically to update stats for throttling purposes.


Loading…
Cancel
Save