implement a buffer pool for STREAM frames

This commit is contained in:
Marten Seemann
2019-09-04 16:45:39 +07:00
parent 326ec9e16e
commit 5ea33cd31e
70 changed files with 193 additions and 48 deletions

View File

@@ -81,6 +81,11 @@ const MaxNonAckElicitingAcks = 19
// prevents DoS attacks against the streamFrameSorter
const MaxStreamFrameSorterGaps = 1000
// MinStreamFrameBufferSize is the minimum data length of a received STREAM frame
// that we use the buffer for. This protects against a DoS where an attacker would send us
// very small STREAM frames to consume a lot of memory.
const MinStreamFrameBufferSize = 128
// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams.
// This limits the size of the ClientHello and Certificates that can be received.
const MaxCryptoStreamOffset = 16 * (1 << 10)

33
internal/wire/pool.go Normal file
View File

@@ -0,0 +1,33 @@
package wire
import (
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var pool sync.Pool
func init() {
pool.New = func() interface{} {
return &StreamFrame{
Data: make([]byte, 0, protocol.MaxReceivePacketSize),
fromPool: true,
}
}
}
func getStreamFrame() *StreamFrame {
f := pool.Get().(*StreamFrame)
return f
}
func putStreamFrame(f *StreamFrame) {
if !f.fromPool {
return
}
if protocol.ByteCount(cap(f.Data)) != protocol.MaxReceivePacketSize {
panic("wire.PutStreamFrame called with packet of wrong size!")
}
pool.Put(f)
}

View File

@@ -0,0 +1,24 @@
package wire
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Pool", func() {
It("gets and puts STREAM frames", func() {
f := getStreamFrame()
putStreamFrame(f)
})
It("panics when putting a STREAM frame with a wrong capacity", func() {
f := getStreamFrame()
f.Data = []byte("foobar")
Expect(func() { putStreamFrame(f) }).To(Panic())
})
It("accepts STREAM frames not from the buffer, but ignores them", func() {
f := &StreamFrame{Data: []byte("foobar")}
putStreamFrame(f)
})
})

View File

@@ -17,6 +17,8 @@ type StreamFrame struct {
DataLenPresent bool
Offset protocol.ByteCount
Data []byte
fromPool bool
}
func parseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamFrame, error) {
@@ -26,45 +28,53 @@ func parseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF
}
hasOffset := typeByte&0x4 > 0
frame := &StreamFrame{
FinBit: typeByte&0x1 > 0,
DataLenPresent: typeByte&0x2 > 0,
}
fin := typeByte&0x1 > 0
hasDataLen := typeByte&0x2 > 0
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
frame.StreamID = protocol.StreamID(streamID)
var offset uint64
if hasOffset {
offset, err := utils.ReadVarInt(r)
offset, err = utils.ReadVarInt(r)
if err != nil {
return nil, err
}
frame.Offset = protocol.ByteCount(offset)
}
var dataLen uint64
if frame.DataLenPresent {
if hasDataLen {
var err error
dataLen, err = utils.ReadVarInt(r)
if err != nil {
return nil, err
}
// shortcut to prevent the unnecessary allocation of dataLen bytes
// if the dataLen is larger than the remaining length of the packet
// reading the packet contents would result in EOF when attempting to READ
if dataLen > uint64(r.Len()) {
return nil, io.EOF
}
} else {
// The rest of the packet is data
dataLen = uint64(r.Len())
}
var frame *StreamFrame
if dataLen < protocol.MinStreamFrameBufferSize {
frame = &StreamFrame{Data: make([]byte, dataLen)}
} else {
frame = getStreamFrame()
// The STREAM frame can't be larger than the StreamFrame we obtained from the buffer,
// since those StreamFrames have a buffer length of the maximum packet size.
if dataLen > uint64(cap(frame.Data)) {
return nil, io.EOF
}
frame.Data = frame.Data[:dataLen]
}
frame.StreamID = protocol.StreamID(streamID)
frame.Offset = protocol.ByteCount(offset)
frame.FinBit = fin
frame.DataLenPresent = hasDataLen
if dataLen != 0 {
frame.Data = make([]byte, dataLen)
if _, err := io.ReadFull(r, frame.Data); err != nil {
// this should never happen, since we already checked the dataLen earlier
return nil, err
}
}
@@ -156,16 +166,25 @@ func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version pro
if n == 0 {
return nil, true
}
newFrame := &StreamFrame{
FinBit: false,
StreamID: f.StreamID,
Offset: f.Offset,
Data: f.Data[:n],
DataLenPresent: f.DataLenPresent,
}
f.Data = f.Data[n:]
new := getStreamFrame()
new.StreamID = f.StreamID
new.Offset = f.Offset
new.FinBit = false
new.DataLenPresent = f.DataLenPresent
// swap the data slices
new.Data, f.Data = f.Data, new.Data
new.fromPool, f.fromPool = f.fromPool, new.fromPool
f.Data = f.Data[:protocol.ByteCount(len(new.Data))-n]
copy(f.Data, new.Data[n:])
new.Data = new.Data[:n]
f.Offset += n
return newFrame, true
return new, true
}
func (f *StreamFrame) PutBack() {
putStreamFrame(f)
}

View File

@@ -2,6 +2,7 @@ package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
@@ -78,6 +79,16 @@ var _ = Describe("STREAM frame", func() {
Expect(err).To(MatchError("FRAME_ENCODING_ERROR: stream data overflows maximum offset"))
})
It("rejects frames that claim to be longer than the packet size", func() {
data := []byte{0x8 ^ 0x2}
data = append(data, encodeVarInt(0x12345)...) // stream ID
data = append(data, encodeVarInt(uint64(protocol.MaxReceivePacketSize)+1)...) // data length
data = append(data, make([]byte, protocol.MaxReceivePacketSize+1)...)
r := bytes.NewReader(data)
_, err := parseStreamFrame(r, versionIETFFrames)
Expect(err).To(Equal(io.EOF))
})
It("errors on EOFs", func() {
data := []byte{0x8 ^ 0x4 ^ 0x2}
data = append(data, encodeVarInt(0x12345)...) // stream ID
@@ -93,6 +104,40 @@ var _ = Describe("STREAM frame", func() {
})
})
Context("using the buffer", func() {
It("uses the buffer for long STREAM frames", func() {
data := []byte{0x8}
data = append(data, encodeVarInt(0x12345)...) // stream ID
data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)...)
r := bytes.NewReader(data)
frame, err := parseStreamFrame(r, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345)))
Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)))
Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize))
Expect(frame.FinBit).To(BeFalse())
Expect(frame.fromPool).To(BeTrue())
Expect(r.Len()).To(BeZero())
Expect(frame.PutBack).ToNot(Panic())
})
It("doesn't use the buffer for short STREAM frames", func() {
data := []byte{0x8}
data = append(data, encodeVarInt(0x12345)...) // stream ID
data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)...)
r := bytes.NewReader(data)
frame, err := parseStreamFrame(r, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345)))
Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)))
Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize - 1))
Expect(frame.FinBit).To(BeFalse())
Expect(frame.fromPool).To(BeFalse())
Expect(r.Len()).To(BeZero())
Expect(frame.PutBack).ToNot(Panic())
})
})
Context("when writing", func() {
It("writes a frame without offset", func() {
f := &StreamFrame{
@@ -294,6 +339,7 @@ var _ = Describe("STREAM frame", func() {
frame, needsSplit = f.MaybeSplitOffFrame(f.Length(versionIETFFrames)-1, versionIETFFrames)
Expect(needsSplit).To(BeTrue())
Expect(frame.DataLen()).To(BeEquivalentTo(99))
f.PutBack()
})
It("keeps the data len", func() {
@@ -353,6 +399,7 @@ var _ = Describe("STREAM frame", func() {
Expect(f).To(BeNil())
}
for i := minFrameSize; i < size; i++ {
f.fromPool = false
f.Data = make([]byte, size)
f, needsSplit := f.MaybeSplitOffFrame(i, versionIETFFrames)
Expect(needsSplit).To(BeTrue())
@@ -376,6 +423,7 @@ var _ = Describe("STREAM frame", func() {
}
var frameOneByteTooSmallCounter int
for i := minFrameSize; i < size; i++ {
f.fromPool = false
f.Data = make([]byte, size)
newFrame, needsSplit := f.MaybeSplitOffFrame(i, versionIETFFrames)
Expect(needsSplit).To(BeTrue())