forked from quic-go/quic-go
wire: refactor header parsing to use quicvarint.Parse (#4481)
* wire: add benchmark tests for initial and retry header parsing * wire: refactor header parsing to use quicvarint.Parse * wire: simplify tracking of parsed length for Long Header parsing
This commit is contained in:
@@ -8,7 +8,6 @@ import (
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
@@ -139,18 +138,18 @@ type Header struct {
|
||||
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
|
||||
}
|
||||
|
||||
// ParsePacket parses a packet.
|
||||
// If the packet has a long header, the packet is cut according to the length field.
|
||||
// If we understand the version, the packet is header up unto the packet number.
|
||||
// ParsePacket parses a long header packet.
|
||||
// The packet is cut according to the length field.
|
||||
// If we understand the version, the packet is parsed up unto the packet number.
|
||||
// Otherwise, only the invariant part of the header is parsed.
|
||||
func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
|
||||
if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
|
||||
return nil, nil, nil, errors.New("not a long header packet")
|
||||
}
|
||||
hdr, err := parseHeader(bytes.NewReader(data))
|
||||
hdr, err := parseHeader(data)
|
||||
if err != nil {
|
||||
if err == ErrUnsupportedVersion {
|
||||
return hdr, nil, nil, ErrUnsupportedVersion
|
||||
if errors.Is(err, ErrUnsupportedVersion) {
|
||||
return hdr, nil, nil, err
|
||||
}
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -161,55 +160,55 @@ func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
|
||||
return hdr, data[:packetLen], data[packetLen:], nil
|
||||
}
|
||||
|
||||
// ParseHeader parses the header.
|
||||
// For short header packets: up to the packet number.
|
||||
// For long header packets:
|
||||
// ParseHeader parses the header:
|
||||
// * if we understand the version: up to the packet number
|
||||
// * if not, only the invariant part of the header
|
||||
func parseHeader(b *bytes.Reader) (*Header, error) {
|
||||
startLen := b.Len()
|
||||
typeByte, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func parseHeader(b []byte) (*Header, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, io.EOF
|
||||
}
|
||||
typeByte := b[0]
|
||||
|
||||
h := &Header{typeByte: typeByte}
|
||||
err = h.parseLongHeader(b)
|
||||
h.parsedLen = protocol.ByteCount(startLen - b.Len())
|
||||
l, err := h.parseLongHeader(b[1:])
|
||||
h.parsedLen = protocol.ByteCount(l) + 1
|
||||
return h, err
|
||||
}
|
||||
|
||||
func (h *Header) parseLongHeader(b *bytes.Reader) error {
|
||||
v, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil {
|
||||
return err
|
||||
func (h *Header) parseLongHeader(b []byte) (int, error) {
|
||||
startLen := len(b)
|
||||
if len(b) < 5 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
h.Version = protocol.Version(v)
|
||||
h.Version = protocol.Version(binary.BigEndian.Uint32(b[:4]))
|
||||
if h.Version != 0 && h.typeByte&0x40 == 0 {
|
||||
return errors.New("not a QUIC packet")
|
||||
return startLen - len(b), errors.New("not a QUIC packet")
|
||||
}
|
||||
destConnIDLen, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
destConnIDLen := int(b[4])
|
||||
if destConnIDLen > protocol.MaxConnIDLen {
|
||||
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
|
||||
}
|
||||
h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
|
||||
if err != nil {
|
||||
return err
|
||||
b = b[5:]
|
||||
if len(b) < destConnIDLen+1 {
|
||||
return startLen - len(b), io.EOF
|
||||
}
|
||||
srcConnIDLen, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
h.DestConnectionID = protocol.ParseConnectionID(b[:destConnIDLen])
|
||||
srcConnIDLen := int(b[destConnIDLen])
|
||||
if srcConnIDLen > protocol.MaxConnIDLen {
|
||||
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
|
||||
}
|
||||
h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
|
||||
if err != nil {
|
||||
return err
|
||||
b = b[destConnIDLen+1:]
|
||||
if len(b) < srcConnIDLen {
|
||||
return startLen - len(b), io.EOF
|
||||
}
|
||||
h.SrcConnectionID = protocol.ParseConnectionID(b[:srcConnIDLen])
|
||||
b = b[srcConnIDLen:]
|
||||
if h.Version == 0 { // version negotiation packet
|
||||
return nil
|
||||
return startLen - len(b), nil
|
||||
}
|
||||
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
|
||||
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
|
||||
return ErrUnsupportedVersion
|
||||
return startLen - len(b), ErrUnsupportedVersion
|
||||
}
|
||||
|
||||
if h.Version == protocol.Version2 {
|
||||
@@ -237,38 +236,35 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error {
|
||||
}
|
||||
|
||||
if h.Type == protocol.PacketTypeRetry {
|
||||
tokenLen := b.Len() - 16
|
||||
tokenLen := len(b) - 16
|
||||
if tokenLen <= 0 {
|
||||
return io.EOF
|
||||
return startLen - len(b), io.EOF
|
||||
}
|
||||
h.Token = make([]byte, tokenLen)
|
||||
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := b.Seek(16, io.SeekCurrent)
|
||||
return err
|
||||
copy(h.Token, b[:tokenLen])
|
||||
return startLen - len(b) + tokenLen + 16, nil
|
||||
}
|
||||
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
tokenLen, err := quicvarint.Read(b)
|
||||
tokenLen, n, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return startLen - len(b), err
|
||||
}
|
||||
if tokenLen > uint64(b.Len()) {
|
||||
return io.EOF
|
||||
b = b[n:]
|
||||
if tokenLen > uint64(len(b)) {
|
||||
return startLen - len(b), io.EOF
|
||||
}
|
||||
h.Token = make([]byte, tokenLen)
|
||||
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
copy(h.Token, b[:tokenLen])
|
||||
b = b[tokenLen:]
|
||||
}
|
||||
|
||||
pl, err := quicvarint.Read(b)
|
||||
pl, n, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
h.Length = protocol.ByteCount(pl)
|
||||
return nil
|
||||
return startLen - len(b) + n, nil
|
||||
}
|
||||
|
||||
// ParsedLen returns the number of bytes that were consumed when parsing the header
|
||||
|
||||
@@ -505,3 +505,78 @@ func BenchmarkIs0RTTPacket(b *testing.B) {
|
||||
Is0RTTPacket(packets[i%len(packets)])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseInitial(b *testing.B) {
|
||||
b.Run("without token", func(b *testing.B) {
|
||||
benchmarkInitialPacketParsing(b, nil)
|
||||
})
|
||||
b.Run("with token", func(b *testing.B) {
|
||||
token := make([]byte, 32)
|
||||
rand.Read(token)
|
||||
benchmarkInitialPacketParsing(b, token)
|
||||
})
|
||||
}
|
||||
|
||||
func benchmarkInitialPacketParsing(b *testing.B, token []byte) {
|
||||
hdr := Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
||||
SrcConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
|
||||
Length: 1000,
|
||||
Token: token,
|
||||
Version: protocol.Version1,
|
||||
}
|
||||
data, err := (&ExtendedHeader{
|
||||
Header: hdr,
|
||||
PacketNumber: 0x1337,
|
||||
PacketNumberLen: 4,
|
||||
}).Append(nil, protocol.Version1)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
data = append(data, make([]byte, 1000)...)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
h, _, _, err := ParsePacket(data)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if h.Type != hdr.Type || h.DestConnectionID != hdr.DestConnectionID || h.SrcConnectionID != hdr.SrcConnectionID ||
|
||||
!bytes.Equal(h.Token, hdr.Token) {
|
||||
b.Fatalf("headers don't match: %v vs %v", h, hdr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseRetry(b *testing.B) {
|
||||
token := make([]byte, 64)
|
||||
rand.Read(token)
|
||||
hdr := &ExtendedHeader{
|
||||
Header: Header{
|
||||
Type: protocol.PacketTypeRetry,
|
||||
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
|
||||
Token: token,
|
||||
Version: protocol.Version1,
|
||||
},
|
||||
}
|
||||
data, err := hdr.Append(nil, hdr.Version)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
h, _, _, err := ParsePacket(data)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if h.Type != hdr.Type || h.DestConnectionID != hdr.DestConnectionID || h.SrcConnectionID != hdr.SrcConnectionID ||
|
||||
!bytes.Equal(h.Token, hdr.Token[:len(hdr.Token)-16]) {
|
||||
b.Fatalf("headers don't match: %#v vs %#v", h, hdr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user