wire: refactor header parsing to use quicvarint.Parse (#4481)

* wire: add benchmark tests for initial and retry header parsing

* wire: refactor header parsing to use quicvarint.Parse

* wire: simplify tracking of parsed length for Long Header parsing
This commit is contained in:
Marten Seemann
2024-05-05 20:48:06 +08:00
committed by GitHub
parent f12ee48617
commit 347a4afc51
2 changed files with 125 additions and 54 deletions

View File

@@ -8,7 +8,6 @@ import (
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)
@@ -139,18 +138,18 @@ type Header struct {
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
}
// ParsePacket parses a packet.
// If the packet has a long header, the packet is cut according to the length field.
// If we understand the version, the packet is header up unto the packet number.
// ParsePacket parses a long header packet.
// The packet is cut according to the length field.
// If we understand the version, the packet is parsed up unto the packet number.
// Otherwise, only the invariant part of the header is parsed.
func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
return nil, nil, nil, errors.New("not a long header packet")
}
hdr, err := parseHeader(bytes.NewReader(data))
hdr, err := parseHeader(data)
if err != nil {
if err == ErrUnsupportedVersion {
return hdr, nil, nil, ErrUnsupportedVersion
if errors.Is(err, ErrUnsupportedVersion) {
return hdr, nil, nil, err
}
return nil, nil, nil, err
}
@@ -161,55 +160,55 @@ func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
return hdr, data[:packetLen], data[packetLen:], nil
}
// ParseHeader parses the header.
// For short header packets: up to the packet number.
// For long header packets:
// ParseHeader parses the header:
// * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func parseHeader(b *bytes.Reader) (*Header, error) {
startLen := b.Len()
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
func parseHeader(b []byte) (*Header, error) {
if len(b) == 0 {
return nil, io.EOF
}
typeByte := b[0]
h := &Header{typeByte: typeByte}
err = h.parseLongHeader(b)
h.parsedLen = protocol.ByteCount(startLen - b.Len())
l, err := h.parseLongHeader(b[1:])
h.parsedLen = protocol.ByteCount(l) + 1
return h, err
}
func (h *Header) parseLongHeader(b *bytes.Reader) error {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
func (h *Header) parseLongHeader(b []byte) (int, error) {
startLen := len(b)
if len(b) < 5 {
return 0, io.EOF
}
h.Version = protocol.Version(v)
h.Version = protocol.Version(binary.BigEndian.Uint32(b[:4]))
if h.Version != 0 && h.typeByte&0x40 == 0 {
return errors.New("not a QUIC packet")
return startLen - len(b), errors.New("not a QUIC packet")
}
destConnIDLen, err := b.ReadByte()
if err != nil {
return err
destConnIDLen := int(b[4])
if destConnIDLen > protocol.MaxConnIDLen {
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
}
h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
if err != nil {
return err
b = b[5:]
if len(b) < destConnIDLen+1 {
return startLen - len(b), io.EOF
}
srcConnIDLen, err := b.ReadByte()
if err != nil {
return err
h.DestConnectionID = protocol.ParseConnectionID(b[:destConnIDLen])
srcConnIDLen := int(b[destConnIDLen])
if srcConnIDLen > protocol.MaxConnIDLen {
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
}
h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
if err != nil {
return err
b = b[destConnIDLen+1:]
if len(b) < srcConnIDLen {
return startLen - len(b), io.EOF
}
h.SrcConnectionID = protocol.ParseConnectionID(b[:srcConnIDLen])
b = b[srcConnIDLen:]
if h.Version == 0 { // version negotiation packet
return nil
return startLen - len(b), nil
}
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
return ErrUnsupportedVersion
return startLen - len(b), ErrUnsupportedVersion
}
if h.Version == protocol.Version2 {
@@ -237,38 +236,35 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error {
}
if h.Type == protocol.PacketTypeRetry {
tokenLen := b.Len() - 16
tokenLen := len(b) - 16
if tokenLen <= 0 {
return io.EOF
return startLen - len(b), io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
_, err := b.Seek(16, io.SeekCurrent)
return err
copy(h.Token, b[:tokenLen])
return startLen - len(b) + tokenLen + 16, nil
}
if h.Type == protocol.PacketTypeInitial {
tokenLen, err := quicvarint.Read(b)
tokenLen, n, err := quicvarint.Parse(b)
if err != nil {
return err
return startLen - len(b), err
}
if tokenLen > uint64(b.Len()) {
return io.EOF
b = b[n:]
if tokenLen > uint64(len(b)) {
return startLen - len(b), io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
copy(h.Token, b[:tokenLen])
b = b[tokenLen:]
}
pl, err := quicvarint.Read(b)
pl, n, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, err
}
h.Length = protocol.ByteCount(pl)
return nil
return startLen - len(b) + n, nil
}
// ParsedLen returns the number of bytes that were consumed when parsing the header

View File

@@ -505,3 +505,78 @@ func BenchmarkIs0RTTPacket(b *testing.B) {
Is0RTTPacket(packets[i%len(packets)])
}
}
func BenchmarkParseInitial(b *testing.B) {
b.Run("without token", func(b *testing.B) {
benchmarkInitialPacketParsing(b, nil)
})
b.Run("with token", func(b *testing.B) {
token := make([]byte, 32)
rand.Read(token)
benchmarkInitialPacketParsing(b, token)
})
}
func benchmarkInitialPacketParsing(b *testing.B, token []byte) {
hdr := Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
SrcConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
Length: 1000,
Token: token,
Version: protocol.Version1,
}
data, err := (&ExtendedHeader{
Header: hdr,
PacketNumber: 0x1337,
PacketNumberLen: 4,
}).Append(nil, protocol.Version1)
if err != nil {
b.Fatal(err)
}
data = append(data, make([]byte, 1000)...)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
h, _, _, err := ParsePacket(data)
if err != nil {
b.Fatal(err)
}
if h.Type != hdr.Type || h.DestConnectionID != hdr.DestConnectionID || h.SrcConnectionID != hdr.SrcConnectionID ||
!bytes.Equal(h.Token, hdr.Token) {
b.Fatalf("headers don't match: %v vs %v", h, hdr)
}
}
}
func BenchmarkParseRetry(b *testing.B) {
token := make([]byte, 64)
rand.Read(token)
hdr := &ExtendedHeader{
Header: Header{
Type: protocol.PacketTypeRetry,
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
Token: token,
Version: protocol.Version1,
},
}
data, err := hdr.Append(nil, hdr.Version)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
h, _, _, err := ParsePacket(data)
if err != nil {
b.Fatal(err)
}
if h.Type != hdr.Type || h.DestConnectionID != hdr.DestConnectionID || h.SrcConnectionID != hdr.SrcConnectionID ||
!bytes.Equal(h.Token, hdr.Token[:len(hdr.Token)-16]) {
b.Fatalf("headers don't match: %#v vs %#v", h, hdr)
}
}
}