wire: use quicvarint.Parse when parsing frames (#4484)

* wire: add benchmarks for the frame parser

* wire: use quicvarint.Parse when parsing frames

* wire: always use io.EOF for too short frames
This commit is contained in:
Marten Seemann
2024-05-05 19:28:28 +08:00
committed by GitHub
parent 1514095afb
commit f12ee48617
38 changed files with 572 additions and 453 deletions

View File

@@ -1,9 +1,9 @@
package wire
import (
"bytes"
"errors"
"fmt"
"io"
"reflect"
"github.com/quic-go/quic-go/internal/protocol"
@@ -38,8 +38,6 @@ const (
// The FrameParser parses QUIC frames, one by one.
type FrameParser struct {
r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them
ackDelayExponent uint8
supportsDatagrams bool
@@ -51,7 +49,6 @@ type FrameParser struct {
// NewFrameParser creates a new frame parser.
func NewFrameParser(supportsDatagrams bool) *FrameParser {
return &FrameParser{
r: *bytes.NewReader(nil),
supportsDatagrams: supportsDatagrams,
ackFrame: &AckFrame{},
}
@@ -60,45 +57,46 @@ func NewFrameParser(supportsDatagrams bool) *FrameParser {
// ParseNext parses the next frame.
// It skips PADDING frames.
func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) {
startLen := len(data)
p.r.Reset(data)
frame, err := p.parseNext(&p.r, encLevel, v)
n := startLen - p.r.Len()
p.r.Reset(nil)
return n, frame, err
frame, l, err := p.parseNext(data, encLevel, v)
return l, frame, err
}
func (p *FrameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
for r.Len() != 0 {
typ, err := quicvarint.Read(r)
func (p *FrameParser) parseNext(b []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) {
var parsed int
for len(b) != 0 {
typ, l, err := quicvarint.Parse(b)
parsed += l
if err != nil {
return nil, &qerr.TransportError{
return nil, parsed, &qerr.TransportError{
ErrorCode: qerr.FrameEncodingError,
ErrorMessage: err.Error(),
}
}
b = b[l:]
if typ == 0x0 { // skip PADDING frames
continue
}
f, err := p.parseFrame(r, typ, encLevel, v)
f, l, err := p.parseFrame(b, typ, encLevel, v)
parsed += l
if err != nil {
return nil, &qerr.TransportError{
return nil, parsed, &qerr.TransportError{
FrameType: typ,
ErrorCode: qerr.FrameEncodingError,
ErrorMessage: err.Error(),
}
}
return f, nil
return f, parsed, nil
}
return nil, nil
return nil, parsed, nil
}
func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
func (p *FrameParser) parseFrame(b []byte, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) {
var frame Frame
var err error
var l int
if typ&0xf8 == 0x8 {
frame, err = parseStreamFrame(r, typ, v)
frame, l, err = parseStreamFrame(b, typ, v)
} else {
switch typ {
case pingFrameType:
@@ -109,43 +107,43 @@ func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
ackDelayExponent = protocol.DefaultAckDelayExponent
}
p.ackFrame.Reset()
err = parseAckFrame(p.ackFrame, r, typ, ackDelayExponent, v)
l, err = parseAckFrame(p.ackFrame, b, typ, ackDelayExponent, v)
frame = p.ackFrame
case resetStreamFrameType:
frame, err = parseResetStreamFrame(r, v)
frame, l, err = parseResetStreamFrame(b, v)
case stopSendingFrameType:
frame, err = parseStopSendingFrame(r, v)
frame, l, err = parseStopSendingFrame(b, v)
case cryptoFrameType:
frame, err = parseCryptoFrame(r, v)
frame, l, err = parseCryptoFrame(b, v)
case newTokenFrameType:
frame, err = parseNewTokenFrame(r, v)
frame, l, err = parseNewTokenFrame(b, v)
case maxDataFrameType:
frame, err = parseMaxDataFrame(r, v)
frame, l, err = parseMaxDataFrame(b, v)
case maxStreamDataFrameType:
frame, err = parseMaxStreamDataFrame(r, v)
frame, l, err = parseMaxStreamDataFrame(b, v)
case bidiMaxStreamsFrameType, uniMaxStreamsFrameType:
frame, err = parseMaxStreamsFrame(r, typ, v)
frame, l, err = parseMaxStreamsFrame(b, typ, v)
case dataBlockedFrameType:
frame, err = parseDataBlockedFrame(r, v)
frame, l, err = parseDataBlockedFrame(b, v)
case streamDataBlockedFrameType:
frame, err = parseStreamDataBlockedFrame(r, v)
frame, l, err = parseStreamDataBlockedFrame(b, v)
case bidiStreamBlockedFrameType, uniStreamBlockedFrameType:
frame, err = parseStreamsBlockedFrame(r, typ, v)
frame, l, err = parseStreamsBlockedFrame(b, typ, v)
case newConnectionIDFrameType:
frame, err = parseNewConnectionIDFrame(r, v)
frame, l, err = parseNewConnectionIDFrame(b, v)
case retireConnectionIDFrameType:
frame, err = parseRetireConnectionIDFrame(r, v)
frame, l, err = parseRetireConnectionIDFrame(b, v)
case pathChallengeFrameType:
frame, err = parsePathChallengeFrame(r, v)
frame, l, err = parsePathChallengeFrame(b, v)
case pathResponseFrameType:
frame, err = parsePathResponseFrame(r, v)
frame, l, err = parsePathResponseFrame(b, v)
case connectionCloseFrameType, applicationCloseFrameType:
frame, err = parseConnectionCloseFrame(r, typ, v)
frame, l, err = parseConnectionCloseFrame(b, typ, v)
case handshakeDoneFrameType:
frame = &HandshakeDoneFrame{}
case 0x30, 0x31:
if p.supportsDatagrams {
frame, err = parseDatagramFrame(r, typ, v)
frame, l, err = parseDatagramFrame(b, typ, v)
break
}
fallthrough
@@ -154,12 +152,12 @@ func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
}
}
if err != nil {
return nil, err
return nil, 0, err
}
if !p.isAllowedAtEncLevel(frame, encLevel) {
return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel)
return nil, l, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel)
}
return frame, nil
return frame, l, nil
}
func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
@@ -190,3 +188,10 @@ func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionL
func (p *FrameParser) SetAckDelayExponent(exp uint8) {
p.ackDelayExponent = exp
}
func replaceUnexpectedEOF(e error) error {
if e == io.ErrUnexpectedEOF {
return io.EOF
}
return e
}