diff --git a/packet_packer.go b/packet_packer.go index 6e248220f..004454729 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -279,9 +279,9 @@ func (p *packetPacker) PackConnectionClose(quicErr *qerr.QuicError) (*coalescedP } var paddingLen protocol.ByteCount if encLevel == protocol.EncryptionInitial { - paddingLen = p.paddingLen(payloads[i].frames, size) + paddingLen = p.initialPaddingLen(payloads[i].frames, size) } - c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i]) + c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], false) if err != nil { return nil, err } @@ -335,9 +335,8 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke return p.writeSinglePacket(hdr, payload, encLevel, sealer) } -// only works for Initial packets -// The size is the expected size of the packet, if no padding was applied. -func (p *packetPacker) paddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { +// size is the expected size of the packet, if no padding was applied. +func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { // For the server, only ack-eliciting Initial packets need to be padded. if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) { return 0 @@ -421,22 +420,22 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { packets: make([]*packetContents, 0, numPackets), } if initialPayload != nil { - padding := p.paddingLen(initialPayload.frames, size) - cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer) + padding := p.initialPaddingLen(initialPayload.frames, size) + cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, false) if err != nil { return nil, err } packet.packets = append(packet.packets, cont) } if handshakePayload != nil { - cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer) + cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, false) if err != nil { return nil, err } packet.packets = append(packet.packets, cont) } if appDataPayload != nil { - cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer) + cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer, false) if err != nil { return nil, err } @@ -457,7 +456,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { if hdr.IsLongHeader { encLevel = protocol.Encryption0RTT } - cont, err := p.appendPacket(buffer, hdr, payload, 0, encLevel, sealer) + cont, err := p.appendPacket(buffer, hdr, payload, 0, encLevel, sealer, false) if err != nil { return nil, err } @@ -670,10 +669,10 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) var padding protocol.ByteCount if encLevel == protocol.EncryptionInitial { - padding = p.paddingLen(payload.frames, size) + padding = p.initialPaddingLen(payload.frames, size) } buffer := getPacketBuffer() - cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer) + cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer, false) if err != nil { return nil, err } @@ -683,6 +682,28 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( }, nil } +func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) { + payload := &payload{ + frames: []ackhandler.Frame{ping}, + length: ping.Length(p.version), + } + buffer := getPacketBuffer() + sealer, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, err + } + hdr := p.getShortHeader(sealer.KeyPhase()) + padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead()) + contents, err := p.appendPacket(buffer, hdr, payload, padding, protocol.Encryption1RTT, sealer, true) + if err != nil { + return nil, err + } + return &packedPacket{ + buffer: buffer, + packetContents: contents, + }, nil +} + func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) { switch encLevel { case protocol.EncryptionInitial: @@ -762,9 +783,9 @@ func (p *packetPacker) writeSinglePacket( buffer := getPacketBuffer() var paddingLen protocol.ByteCount if encLevel == protocol.EncryptionInitial { - paddingLen = p.paddingLen(payload.frames, hdr.GetLength(p.version)+payload.length+protocol.ByteCount(sealer.Overhead())) + paddingLen = p.initialPaddingLen(payload.frames, hdr.GetLength(p.version)+payload.length+protocol.ByteCount(sealer.Overhead())) } - contents, err := p.appendPacket(buffer, hdr, payload, paddingLen, encLevel, sealer) + contents, err := p.appendPacket(buffer, hdr, payload, paddingLen, encLevel, sealer, false) if err != nil { return nil, err } @@ -774,14 +795,7 @@ func (p *packetPacker) writeSinglePacket( }, nil } -func (p *packetPacker) appendPacket( - buffer *packetBuffer, - header *wire.ExtendedHeader, - payload *payload, - padding protocol.ByteCount, // add padding such that the packet has this length. 0 for no padding. - encLevel protocol.EncryptionLevel, - sealer sealer, -) (*packetContents, error) { +func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, isMTUProbePacket bool) (*packetContents, error) { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) if payload.length < 4-pnLen { @@ -816,8 +830,10 @@ func (p *packetPacker) appendPacket( if payloadSize := protocol.ByteCount(buf.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) } - if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize { - return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + if !isMTUProbePacket { + if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize { + return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + } } raw := buffer.Data diff --git a/packet_packer_test.go b/packet_packer_test.go index 516d2087c..847480c7d 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -1439,6 +1439,21 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(packet).To(BeNil()) }) + + It("packs an MTU probe packet", func() { + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43)) + ping := ackhandler.Frame{Frame: &wire.PingFrame{}} + const probePacketSize = maxPacketSize + 42 + p, err := packer.PackMTUProbePacket(ping, probePacketSize) + Expect(err).ToNot(HaveOccurred()) + Expect(p.length).To(BeEquivalentTo(probePacketSize)) + Expect(p.header.IsLongHeader).To(BeFalse()) + Expect(p.header.PacketNumber).To(Equal(protocol.PacketNumber(0x43))) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(p.buffer.Data).To(HaveLen(int(probePacketSize))) + }) }) }) })