wire: optimize parsing of long header packets (#4589)

This commit is contained in:
Marten Seemann
2024-07-21 15:22:32 -06:00
committed by GitHub
parent bc642d872d
commit 5f8d146836
15 changed files with 138 additions and 160 deletions

View File

@@ -24,7 +24,7 @@ import (
)
type unpacker interface {
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error)
UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
}
@@ -961,7 +961,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header)
return false
}
packet, err := s.unpacker.UnpackLongHeader(hdr, p.rcvTime, p.data, s.version)
packet, err := s.unpacker.UnpackLongHeader(hdr, p.data)
if err != nil {
wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
return false

View File

@@ -752,7 +752,7 @@ var _ = Describe("Connection", func() {
packet := getLongHeaderPacket(hdr, nil)
packet.ecn = protocol.ECNCE
rcvTime := time.Now().Add(-10 * time.Second)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), rcvTime, gomock.Any(), conn.version).Return(&unpackedPacket{
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial,
hdr: &unpackedHdr,
data: []byte{0}, // one PADDING frame
@@ -803,7 +803,7 @@ var _ = Describe("Connection", func() {
})
It("drops a packet when unpacking fails", func() {
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrDecryptionFailed)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
@@ -1004,7 +1004,7 @@ var _ = Describe("Connection", func() {
Expect(srcConnID).ToNot(Equal(hdr2.SrcConnectionID))
// Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(&unpackedPacket{
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.Encryption1RTT,
hdr: hdr1,
data: []byte{0}, // one PADDING frame
@@ -1032,7 +1032,7 @@ var _ = Describe("Connection", func() {
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 1,
}
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrKeysNotYetAvailable)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable)
packet := getLongHeaderPacket(hdr, nil)
tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake, packet.Size())
Expect(conn.handlePacketImpl(packet)).To(BeFalse())
@@ -1074,7 +1074,7 @@ var _ = Describe("Connection", func() {
It("cuts packets to the right length", func() {
hdrLen, packet := getPacketWithLength(srcConnID, 456)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) {
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen + 456 - 3))
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
@@ -1091,7 +1091,7 @@ var _ = Describe("Connection", func() {
It("handles coalesced packets", func() {
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
packet1.ecn = protocol.ECT1
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) {
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
@@ -1103,7 +1103,7 @@ var _ = Describe("Connection", func() {
}, nil
})
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) {
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
@@ -1129,8 +1129,8 @@ var _ = Describe("Connection", func() {
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
gomock.InOrder(
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrKeysNotYetAvailable),
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) {
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable),
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
@@ -1156,7 +1156,7 @@ var _ = Describe("Connection", func() {
wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
Expect(srcConnID).ToNot(Equal(wrongConnID))
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) {
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
@@ -2599,7 +2599,7 @@ var _ = Describe("Client Connection", func() {
It("changes the connection ID when receiving the first packet from the server", func() {
unpacker := NewMockUnpacker(mockCtrl)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) {
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
return &unpackedPacket{
encryptionLevel: protocol.Encryption1RTT,
hdr: &wire.ExtendedHeader{Header: *hdr},
@@ -2655,7 +2655,7 @@ var _ = Describe("Client Connection", func() {
})
Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})))
// now receive a packet with the original source connection ID
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte, _ protocol.Version) (*unpackedPacket, error) {
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ []byte) (*unpackedPacket, error) {
return &unpackedPacket{
hdr: &wire.ExtendedHeader{Header: *hdr},
data: []byte{0},
@@ -3177,7 +3177,7 @@ var _ = Describe("Client Connection", func() {
Expect(hdr2.SrcConnectionID).ToNot(Equal(srcConnID))
// Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(&unpackedPacket{
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial,
hdr: hdr1,
data: []byte{0}, // one PADDING frame

View File

@@ -54,7 +54,7 @@ func Fuzz(data []byte) int {
extHdr = &wire.ExtendedHeader{Header: *hdr}
} else {
var err error
extHdr, err = hdr.ParseExtended(bytes.NewReader(data), version)
extHdr, err = hdr.ParseExtended(data)
if err != nil {
return 0
}

View File

@@ -50,7 +50,7 @@ var _ = Describe("QUIC Proxy", func() {
hdr, data, _, err := wire.ParsePacket(b)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial))
extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1)
extHdr, err := hdr.ParseExtended(data)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
return extHdr.PacketNumber
}

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
@@ -32,66 +31,23 @@ type ExtendedHeader struct {
parsedLen protocol.ByteCount
}
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) {
startLen := b.Len()
func (h *ExtendedHeader) parse(data []byte) (bool /* reserved bits valid */, error) {
// read the (now unencrypted) first byte
var err error
h.typeByte, err = b.ReadByte()
if err != nil {
return false, err
}
if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil {
return false, err
}
reservedBitsValid, err := h.parseLongHeader(b, v)
if err != nil {
return false, err
}
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return reservedBitsValid, err
}
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) {
if err := h.readPacketNumber(b); err != nil {
return false, err
}
if h.typeByte&0xc != 0 {
return false, nil
}
return true, nil
}
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
h.typeByte = data[0]
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
n, err := b.ReadByte()
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen2:
n, err := utils.BigEndian.ReadUint16(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen3:
n, err := utils.BigEndian.ReadUint24(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen4:
n, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
if protocol.ByteCount(len(data)) < h.Header.ParsedLen()+protocol.ByteCount(h.PacketNumberLen) {
return false, io.EOF
}
return nil
pn, err := readPacketNumber(data[h.Header.ParsedLen():], h.PacketNumberLen)
if err != nil {
return true, nil
}
h.PacketNumber = pn
reservedBitsValid := h.typeByte&0xc == 0
h.parsedLen = h.Header.ParsedLen() + protocol.ByteCount(h.PacketNumberLen)
return reservedBitsValid, err
}
// Append appends the Header.

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"log"
"os"
"testing"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
@@ -331,3 +332,33 @@ var _ = Describe("Header", func() {
})
})
})
func BenchmarkParseExtendedHeader(b *testing.B) {
data, err := (&ExtendedHeader{
Header: Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}),
SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}),
Version: protocol.Version1,
Length: 1234,
},
PacketNumber: 0xdecaf,
PacketNumberLen: protocol.PacketNumberLen3,
}).Append(nil, protocol.Version1)
if err != nil {
b.Fatal(err)
}
data = append(data, make([]byte, 1231)...)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
hdr, _, _, err := ParsePacket(data)
if err != nil {
b.Fatal(err)
}
if _, err := hdr.ParseExtended(data); err != nil {
b.Fatal(err)
}
}
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)
@@ -274,9 +275,9 @@ func (h *Header) ParsedLen() protocol.ByteCount {
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) {
func (h *Header) ParseExtended(data []byte) (*ExtendedHeader, error) {
extHdr := h.toExtendedHeader()
reservedBitsValid, err := extHdr.parse(b, ver)
reservedBitsValid, err := extHdr.parse(data)
if err != nil {
return nil, err
}
@@ -294,3 +295,20 @@ func (h *Header) toExtendedHeader() *ExtendedHeader {
func (h *Header) PacketType() string {
return h.Type.String()
}
func readPacketNumber(data []byte, pnLen protocol.PacketNumberLen) (protocol.PacketNumber, error) {
var pn protocol.PacketNumber
switch pnLen {
case protocol.PacketNumberLen1:
pn = protocol.PacketNumber(data[0])
case protocol.PacketNumberLen2:
pn = protocol.PacketNumber(utils.BigEndian.Uint16(data[:2]))
case protocol.PacketNumberLen3:
pn = protocol.PacketNumber(utils.BigEndian.Uint24(data[:3]))
case protocol.PacketNumberLen4:
pn = protocol.PacketNumber(utils.BigEndian.Uint32(data[:4]))
default:
return 0, fmt.Errorf("invalid packet number length: %d", pnLen)
}
return pn, nil
}

View File

@@ -197,12 +197,10 @@ var _ = Describe("Header Parsing", func() {
Expect(hdr.Length).To(Equal(protocol.ByteCount(10)))
Expect(hdr.Version).To(Equal(protocol.Version1))
Expect(rest).To(BeEmpty())
b := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(b, protocol.Version1)
extHdr, err := hdr.ParseExtended(data)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef)))
Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
Expect(b.Len()).To(Equal(6)) // foobar
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef)))
Expect(hdr.ParsedLen()).To(BeEquivalentTo(hdrLen))
Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 4))
})
@@ -287,12 +285,11 @@ var _ = Describe("Header Parsing", func() {
hdr, _, _, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(b, protocol.Version1)
extHdr, err := hdr.ParseExtended(data)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x123)))
Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2))
Expect(b.Len()).To(BeZero())
Expect(extHdr.ParsedLen()).To(BeEquivalentTo(len(data)))
})
It("parses a Retry packet, for QUIC v1", func() {
@@ -367,7 +364,7 @@ var _ = Describe("Header Parsing", func() {
hdr, _, _, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake))
extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1)
extHdr, err := hdr.ParseExtended(data)
Expect(err).To(MatchError(ErrInvalidReservedBits))
Expect(extHdr).ToNot(BeNil())
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1234)))
@@ -394,11 +391,10 @@ var _ = Describe("Header Parsing", func() {
hdrLen := len(data)
data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number
for i := hdrLen; i < len(data); i++ {
data = data[:i]
hdr, _, _, err := ParsePacket(data)
b := data[:i]
hdr, _, _, err := ParsePacket(b)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
_, err = hdr.ParseExtended(b, protocol.Version1)
_, err = hdr.ParseExtended(b)
Expect(err).To(Equal(io.EOF))
}
})
@@ -414,8 +410,7 @@ var _ = Describe("Header Parsing", func() {
data = data[:i]
hdr, _, _, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
_, err = hdr.ParseExtended(b, protocol.Version1)
_, err = hdr.ParseExtended(data)
Expect(err).To(Equal(io.EOF))
}
})

View File

@@ -2,7 +2,6 @@ package wire
import (
"errors"
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@@ -28,25 +27,15 @@ func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.Packet
}
pos := 1 + connIDLen
var pn protocol.PacketNumber
switch pnLen {
case protocol.PacketNumberLen1:
pn = protocol.PacketNumber(data[pos])
case protocol.PacketNumberLen2:
pn = protocol.PacketNumber(utils.BigEndian.Uint16(data[pos : pos+2]))
case protocol.PacketNumberLen3:
pn = protocol.PacketNumber(utils.BigEndian.Uint24(data[pos : pos+3]))
case protocol.PacketNumberLen4:
pn = protocol.PacketNumber(utils.BigEndian.Uint32(data[pos : pos+4]))
default:
return 0, 0, 0, 0, fmt.Errorf("invalid packet number length: %d", pnLen)
pn, err := readPacketNumber(data[pos:], pnLen)
if err != nil {
return 0, 0, 0, 0, err
}
kp := protocol.KeyPhaseZero
if data[0]&0b100 > 0 {
kp = protocol.KeyPhaseOne
}
var err error
if data[0]&0x18 != 0 {
err = ErrInvalidReservedBits
}

View File

@@ -42,18 +42,18 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder {
}
// UnpackLongHeader mocks base method.
func (m *MockUnpacker) UnpackLongHeader(arg0 *wire.Header, arg1 time.Time, arg2 []byte, arg3 protocol.Version) (*unpackedPacket, error) {
func (m *MockUnpacker) UnpackLongHeader(arg0 *wire.Header, arg1 []byte) (*unpackedPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnpackLongHeader", arg0, arg1, arg2, arg3)
ret := m.ctrl.Call(m, "UnpackLongHeader", arg0, arg1)
ret0, _ := ret[0].(*unpackedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UnpackLongHeader indicates an expected call of UnpackLongHeader.
func (mr *MockUnpackerMockRecorder) UnpackLongHeader(arg0, arg1, arg2, arg3 any) *MockUnpackerUnpackLongHeaderCall {
func (mr *MockUnpackerMockRecorder) UnpackLongHeader(arg0, arg1 any) *MockUnpackerUnpackLongHeaderCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), arg0, arg1, arg2, arg3)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), arg0, arg1)
return &MockUnpackerUnpackLongHeaderCall{Call: call}
}
@@ -69,13 +69,13 @@ func (c *MockUnpackerUnpackLongHeaderCall) Return(arg0 *unpackedPacket, arg1 err
}
// Do rewrite *gomock.Call.Do
func (c *MockUnpackerUnpackLongHeaderCall) Do(f func(*wire.Header, time.Time, []byte, protocol.Version) (*unpackedPacket, error)) *MockUnpackerUnpackLongHeaderCall {
func (c *MockUnpackerUnpackLongHeaderCall) Do(f func(*wire.Header, []byte) (*unpackedPacket, error)) *MockUnpackerUnpackLongHeaderCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockUnpackerUnpackLongHeaderCall) DoAndReturn(f func(*wire.Header, time.Time, []byte, protocol.Version) (*unpackedPacket, error)) *MockUnpackerUnpackLongHeaderCall {
func (c *MockUnpackerUnpackLongHeaderCall) DoAndReturn(f func(*wire.Header, []byte) (*unpackedPacket, error)) *MockUnpackerUnpackLongHeaderCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

View File

@@ -24,7 +24,6 @@ import (
var _ = Describe("Packet packer", func() {
const maxPacketSize protocol.ByteCount = 1357
const version = protocol.Version1
var (
packer *packetPacker
@@ -46,10 +45,9 @@ var _ = Describe("Packet packer", func() {
}
hdr, _, more, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
r := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(r, version)
extHdr, err := hdr.ParseExtended(data)
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() - len(more) + int(extHdr.PacketNumberLen)))
// ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() - len(more) + int(extHdr.PacketNumberLen)))
ExpectWithOffset(1, extHdr.Length+protocol.ByteCount(extHdr.PacketNumberLen)).To(BeNumerically(">=", 4))
data = more
hdrs = append(hdrs, extHdr)
@@ -696,24 +694,21 @@ var _ = Describe("Packet packer", func() {
hdr, _, _, err := wire.ParsePacket(packet.buffer.Data)
Expect(err).ToNot(HaveOccurred())
data := packet.buffer.Data
r := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(r, protocol.Version1)
extHdr, err := hdr.ParseExtended(data)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1))
Expect(r.Len()).To(Equal(4 - 1 /* packet number length */ + sealer.Overhead()))
data = data[extHdr.ParsedLen():]
Expect(data).To(HaveLen(4 - 1 /* packet number length */ + sealer.Overhead()))
// the first bytes of the payload should be a 2 PADDING frames...
firstPayloadByte, err := r.ReadByte()
Expect(err).ToNot(HaveOccurred())
Expect(firstPayloadByte).To(Equal(byte(0)))
secondPayloadByte, err := r.ReadByte()
Expect(err).ToNot(HaveOccurred())
Expect(secondPayloadByte).To(Equal(byte(0)))
Expect(data[0]).To(Equal(byte(0)))
Expect(data[1]).To(Equal(byte(0)))
data = data[2:]
// ... followed by the PING
frameParser := wire.NewFrameParser(false)
l, frame, err := frameParser.ParseNext(data[len(data)-r.Len():], protocol.Encryption1RTT, protocol.Version1)
l, frame, err := frameParser.ParseNext(data, protocol.Encryption1RTT, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{}))
Expect(r.Len() - l).To(Equal(sealer.Overhead()))
Expect(len(data) - l).To(Equal(sealer.Overhead()))
})
It("pads if payload length + packet number length is smaller than 4", func() {
@@ -1213,24 +1208,21 @@ var _ = Describe("Packet packer", func() {
hdr, _, _, err := wire.ParsePacket(packet.buffer.Data)
Expect(err).ToNot(HaveOccurred())
data := packet.buffer.Data
r := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(r, protocol.Version1)
extHdr, err := hdr.ParseExtended(data)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1))
Expect(r.Len()).To(Equal(4 - 1 /* packet number length */ + sealer.Overhead()))
data = data[extHdr.ParsedLen():]
Expect(data).To(HaveLen(4 - 1 /* packet number length */ + sealer.Overhead()))
// the first bytes of the payload should be a 2 PADDING frames...
firstPayloadByte, err := r.ReadByte()
Expect(err).ToNot(HaveOccurred())
Expect(firstPayloadByte).To(Equal(byte(0)))
secondPayloadByte, err := r.ReadByte()
Expect(err).ToNot(HaveOccurred())
Expect(secondPayloadByte).To(Equal(byte(0)))
Expect(data[0]).To(Equal(byte(0)))
Expect(data[1]).To(Equal(byte(0)))
data = data[2:]
// ... followed by the PING
frameParser := wire.NewFrameParser(false)
l, frame, err := frameParser.ParseNext(data[len(data)-r.Len():], protocol.Encryption1RTT, protocol.Version1)
l, frame, err := frameParser.ParseNext(data, protocol.Encryption1RTT, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{}))
Expect(r.Len() - l).To(Equal(sealer.Overhead()))
Expect(len(data) - l).To(Equal(sealer.Overhead()))
})
It("adds retransmissions", func() {

View File

@@ -1,7 +1,6 @@
package quic
import (
"bytes"
"fmt"
"time"
@@ -53,7 +52,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetU
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
// If any other error occurred when parsing the header, the error is of type headerParseError.
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) {
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader
var decrypted []byte
@@ -65,7 +64,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
@@ -75,7 +74,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
@@ -85,7 +84,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
@@ -125,8 +124,8 @@ func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (prot
return pn, pnLen, kp, decrypted, nil
}
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption.
@@ -187,17 +186,15 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int
}
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
extHdr, err := unpackLongHeader(hd, hdr, data, v)
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
extHdr, err := unpackLongHeader(hd, hdr, data)
if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err}
}
return extHdr, err
}
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data)
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
hdrLen := hdr.ParsedLen()
if protocol.ByteCount(len(data)) < hdrLen+4+16 {
//nolint:stylecheck
@@ -214,7 +211,7 @@ func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v proto
data[hdrLen:hdrLen+4],
)
// 3. parse the header (and learn the actual length of the packet number)
extHdr, parseErr := hdr.ParseExtended(r, v)
extHdr, parseErr := hdr.ParseExtended(data)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, parseErr
}

View File

@@ -61,7 +61,7 @@ var _ = Describe("Packet Unpacker", func() {
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), data, protocol.Version1)
_, err := unpacker.UnpackLongHeader(hdr, data)
Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
var headerErr *headerParseError
Expect(errors.As(err, &headerErr)).To(BeTrue())
@@ -98,7 +98,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
packet, err := unpacker.UnpackLongHeader(hdr, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
Expect(packet.data).To(Equal([]byte("decrypted")))
@@ -123,7 +123,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
packet, err := unpacker.UnpackLongHeader(hdr, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT))
Expect(packet.data).To(Equal([]byte("decrypted")))
@@ -172,7 +172,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte(""), nil),
)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
_, err := unpacker.UnpackLongHeader(hdr, append(hdrRaw, payload...))
Expect(err).To(MatchError(&qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet",
@@ -214,7 +214,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded}
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
_, err := unpacker.UnpackLongHeader(hdr, append(hdrRaw, payload...))
Expect(err).To(MatchError(unpackErr))
})
@@ -235,7 +235,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
_, err := unpacker.UnpackLongHeader(hdr, append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
})
@@ -268,7 +268,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
_, err := unpacker.UnpackLongHeader(hdr, append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
})
@@ -322,7 +322,7 @@ var _ = Describe("Packet Unpacker", func() {
for i := 1; i <= 100; i++ {
data = append(data, uint8(i))
}
packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), data, protocol.Version1)
packet, err := unpacker.UnpackLongHeader(hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x7331)))
})

View File

@@ -808,7 +808,7 @@ func (s *baseServer) maybeSendInvalidToken(p rejectedPacket) {
hdr := p.hdr
sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
data := p.data[:hdr.ParsedLen()+hdr.Length]
extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
extHdr, err := unpackLongHeader(opener, hdr, data)
// Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted.
if err != nil {

View File

@@ -91,7 +91,7 @@ var _ = Describe("Server", func() {
Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
extHdr, err := unpackLongHeader(opener, replyHdr, b)
Expect(err).ToNot(HaveOccurred())
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
Expect(err).ToNot(HaveOccurred())