diff --git a/internal/wire/header.go b/internal/wire/header.go index 29911684..5ec74b16 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -8,7 +8,6 @@ import ( "io" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/quicvarint" ) @@ -139,18 +138,18 @@ type Header struct { parsedLen protocol.ByteCount // how many bytes were read while parsing this header } -// ParsePacket parses a packet. -// If the packet has a long header, the packet is cut according to the length field. -// If we understand the version, the packet is header up unto the packet number. +// ParsePacket parses a long header packet. +// The packet is cut according to the length field. +// If we understand the version, the packet is parsed up unto the packet number. // Otherwise, only the invariant part of the header is parsed. func ParsePacket(data []byte) (*Header, []byte, []byte, error) { if len(data) == 0 || !IsLongHeaderPacket(data[0]) { return nil, nil, nil, errors.New("not a long header packet") } - hdr, err := parseHeader(bytes.NewReader(data)) + hdr, err := parseHeader(data) if err != nil { - if err == ErrUnsupportedVersion { - return hdr, nil, nil, ErrUnsupportedVersion + if errors.Is(err, ErrUnsupportedVersion) { + return hdr, nil, nil, err } return nil, nil, nil, err } @@ -161,55 +160,55 @@ func ParsePacket(data []byte) (*Header, []byte, []byte, error) { return hdr, data[:packetLen], data[packetLen:], nil } -// ParseHeader parses the header. -// For short header packets: up to the packet number. -// For long header packets: +// ParseHeader parses the header: // * if we understand the version: up to the packet number // * if not, only the invariant part of the header -func parseHeader(b *bytes.Reader) (*Header, error) { - startLen := b.Len() - typeByte, err := b.ReadByte() - if err != nil { - return nil, err +func parseHeader(b []byte) (*Header, error) { + if len(b) == 0 { + return nil, io.EOF } + typeByte := b[0] h := &Header{typeByte: typeByte} - err = h.parseLongHeader(b) - h.parsedLen = protocol.ByteCount(startLen - b.Len()) + l, err := h.parseLongHeader(b[1:]) + h.parsedLen = protocol.ByteCount(l) + 1 return h, err } -func (h *Header) parseLongHeader(b *bytes.Reader) error { - v, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return err +func (h *Header) parseLongHeader(b []byte) (int, error) { + startLen := len(b) + if len(b) < 5 { + return 0, io.EOF } - h.Version = protocol.Version(v) + h.Version = protocol.Version(binary.BigEndian.Uint32(b[:4])) if h.Version != 0 && h.typeByte&0x40 == 0 { - return errors.New("not a QUIC packet") + return startLen - len(b), errors.New("not a QUIC packet") } - destConnIDLen, err := b.ReadByte() - if err != nil { - return err + destConnIDLen := int(b[4]) + if destConnIDLen > protocol.MaxConnIDLen { + return startLen - len(b), protocol.ErrInvalidConnectionIDLen } - h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen)) - if err != nil { - return err + b = b[5:] + if len(b) < destConnIDLen+1 { + return startLen - len(b), io.EOF } - srcConnIDLen, err := b.ReadByte() - if err != nil { - return err + h.DestConnectionID = protocol.ParseConnectionID(b[:destConnIDLen]) + srcConnIDLen := int(b[destConnIDLen]) + if srcConnIDLen > protocol.MaxConnIDLen { + return startLen - len(b), protocol.ErrInvalidConnectionIDLen } - h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen)) - if err != nil { - return err + b = b[destConnIDLen+1:] + if len(b) < srcConnIDLen { + return startLen - len(b), io.EOF } + h.SrcConnectionID = protocol.ParseConnectionID(b[:srcConnIDLen]) + b = b[srcConnIDLen:] if h.Version == 0 { // version negotiation packet - return nil + return startLen - len(b), nil } // If we don't understand the version, we have no idea how to interpret the rest of the bytes if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) { - return ErrUnsupportedVersion + return startLen - len(b), ErrUnsupportedVersion } if h.Version == protocol.Version2 { @@ -237,38 +236,35 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { } if h.Type == protocol.PacketTypeRetry { - tokenLen := b.Len() - 16 + tokenLen := len(b) - 16 if tokenLen <= 0 { - return io.EOF + return startLen - len(b), io.EOF } h.Token = make([]byte, tokenLen) - if _, err := io.ReadFull(b, h.Token); err != nil { - return err - } - _, err := b.Seek(16, io.SeekCurrent) - return err + copy(h.Token, b[:tokenLen]) + return startLen - len(b) + tokenLen + 16, nil } if h.Type == protocol.PacketTypeInitial { - tokenLen, err := quicvarint.Read(b) + tokenLen, n, err := quicvarint.Parse(b) if err != nil { - return err + return startLen - len(b), err } - if tokenLen > uint64(b.Len()) { - return io.EOF + b = b[n:] + if tokenLen > uint64(len(b)) { + return startLen - len(b), io.EOF } h.Token = make([]byte, tokenLen) - if _, err := io.ReadFull(b, h.Token); err != nil { - return err - } + copy(h.Token, b[:tokenLen]) + b = b[tokenLen:] } - pl, err := quicvarint.Read(b) + pl, n, err := quicvarint.Parse(b) if err != nil { - return err + return 0, err } h.Length = protocol.ByteCount(pl) - return nil + return startLen - len(b) + n, nil } // ParsedLen returns the number of bytes that were consumed when parsing the header diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index af8d931e..a0d1657f 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -505,3 +505,78 @@ func BenchmarkIs0RTTPacket(b *testing.B) { Is0RTTPacket(packets[i%len(packets)]) } } + +func BenchmarkParseInitial(b *testing.B) { + b.Run("without token", func(b *testing.B) { + benchmarkInitialPacketParsing(b, nil) + }) + b.Run("with token", func(b *testing.B) { + token := make([]byte, 32) + rand.Read(token) + benchmarkInitialPacketParsing(b, token) + }) +} + +func benchmarkInitialPacketParsing(b *testing.B, token []byte) { + hdr := Header{ + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + SrcConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), + Length: 1000, + Token: token, + Version: protocol.Version1, + } + data, err := (&ExtendedHeader{ + Header: hdr, + PacketNumber: 0x1337, + PacketNumberLen: 4, + }).Append(nil, protocol.Version1) + if err != nil { + b.Fatal(err) + } + data = append(data, make([]byte, 1000)...) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + h, _, _, err := ParsePacket(data) + if err != nil { + b.Fatal(err) + } + if h.Type != hdr.Type || h.DestConnectionID != hdr.DestConnectionID || h.SrcConnectionID != hdr.SrcConnectionID || + !bytes.Equal(h.Token, hdr.Token) { + b.Fatalf("headers don't match: %v vs %v", h, hdr) + } + } +} + +func BenchmarkParseRetry(b *testing.B) { + token := make([]byte, 64) + rand.Read(token) + hdr := &ExtendedHeader{ + Header: Header{ + Type: protocol.PacketTypeRetry, + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), + Token: token, + Version: protocol.Version1, + }, + } + data, err := hdr.Append(nil, hdr.Version) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + h, _, _, err := ParsePacket(data) + if err != nil { + b.Fatal(err) + } + if h.Type != hdr.Type || h.DestConnectionID != hdr.DestConnectionID || h.SrcConnectionID != hdr.SrcConnectionID || + !bytes.Equal(h.Token, hdr.Token[:len(hdr.Token)-16]) { + b.Fatalf("headers don't match: %#v vs %#v", h, hdr) + } + } +}