implement a function to pack MTU probe packets

This commit is contained in:
Marten Seemann
2021-01-25 16:41:28 +08:00
parent 91a314258d
commit fb5a45ac53
2 changed files with 55 additions and 24 deletions

View File

@@ -279,9 +279,9 @@ func (p *packetPacker) PackConnectionClose(quicErr *qerr.QuicError) (*coalescedP
}
var paddingLen protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
paddingLen = p.paddingLen(payloads[i].frames, size)
paddingLen = p.initialPaddingLen(payloads[i].frames, size)
}
c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i])
c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], false)
if err != nil {
return nil, err
}
@@ -335,9 +335,8 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke
return p.writeSinglePacket(hdr, payload, encLevel, sealer)
}
// only works for Initial packets
// The size is the expected size of the packet, if no padding was applied.
func (p *packetPacker) paddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount {
// size is the expected size of the packet, if no padding was applied.
func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount {
// For the server, only ack-eliciting Initial packets need to be padded.
if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) {
return 0
@@ -421,22 +420,22 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) {
packets: make([]*packetContents, 0, numPackets),
}
if initialPayload != nil {
padding := p.paddingLen(initialPayload.frames, size)
cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer)
padding := p.initialPaddingLen(initialPayload.frames, size)
cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, false)
if err != nil {
return nil, err
}
packet.packets = append(packet.packets, cont)
}
if handshakePayload != nil {
cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer)
cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, false)
if err != nil {
return nil, err
}
packet.packets = append(packet.packets, cont)
}
if appDataPayload != nil {
cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer)
cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer, false)
if err != nil {
return nil, err
}
@@ -457,7 +456,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
if hdr.IsLongHeader {
encLevel = protocol.Encryption0RTT
}
cont, err := p.appendPacket(buffer, hdr, payload, 0, encLevel, sealer)
cont, err := p.appendPacket(buffer, hdr, payload, 0, encLevel, sealer, false)
if err != nil {
return nil, err
}
@@ -670,10 +669,10 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (
size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead())
var padding protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
padding = p.paddingLen(payload.frames, size)
padding = p.initialPaddingLen(payload.frames, size)
}
buffer := getPacketBuffer()
cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer)
cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer, false)
if err != nil {
return nil, err
}
@@ -683,6 +682,28 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (
}, nil
}
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) {
payload := &payload{
frames: []ackhandler.Frame{ping},
length: ping.Length(p.version),
}
buffer := getPacketBuffer()
sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
}
hdr := p.getShortHeader(sealer.KeyPhase())
padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead())
contents, err := p.appendPacket(buffer, hdr, payload, padding, protocol.Encryption1RTT, sealer, true)
if err != nil {
return nil, err
}
return &packedPacket{
buffer: buffer,
packetContents: contents,
}, nil
}
func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) {
switch encLevel {
case protocol.EncryptionInitial:
@@ -762,9 +783,9 @@ func (p *packetPacker) writeSinglePacket(
buffer := getPacketBuffer()
var paddingLen protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
paddingLen = p.paddingLen(payload.frames, hdr.GetLength(p.version)+payload.length+protocol.ByteCount(sealer.Overhead()))
paddingLen = p.initialPaddingLen(payload.frames, hdr.GetLength(p.version)+payload.length+protocol.ByteCount(sealer.Overhead()))
}
contents, err := p.appendPacket(buffer, hdr, payload, paddingLen, encLevel, sealer)
contents, err := p.appendPacket(buffer, hdr, payload, paddingLen, encLevel, sealer, false)
if err != nil {
return nil, err
}
@@ -774,14 +795,7 @@ func (p *packetPacker) writeSinglePacket(
}, nil
}
func (p *packetPacker) appendPacket(
buffer *packetBuffer,
header *wire.ExtendedHeader,
payload *payload,
padding protocol.ByteCount, // add padding such that the packet has this length. 0 for no padding.
encLevel protocol.EncryptionLevel,
sealer sealer,
) (*packetContents, error) {
func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, isMTUProbePacket bool) (*packetContents, error) {
var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(header.PacketNumberLen)
if payload.length < 4-pnLen {
@@ -816,8 +830,10 @@ func (p *packetPacker) appendPacket(
if payloadSize := protocol.ByteCount(buf.Len()-payloadOffset) - paddingLen; payloadSize != payload.length {
return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize)
}
if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize {
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
if !isMTUProbePacket {
if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize {
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
}
}
raw := buffer.Data

View File

@@ -1439,6 +1439,21 @@ var _ = Describe("Packet packer", func() {
Expect(err).ToNot(HaveOccurred())
Expect(packet).To(BeNil())
})
It("packs an MTU probe packet", func() {
sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil)
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43))
ping := ackhandler.Frame{Frame: &wire.PingFrame{}}
const probePacketSize = maxPacketSize + 42
p, err := packer.PackMTUProbePacket(ping, probePacketSize)
Expect(err).ToNot(HaveOccurred())
Expect(p.length).To(BeEquivalentTo(probePacketSize))
Expect(p.header.IsLongHeader).To(BeFalse())
Expect(p.header.PacketNumber).To(Equal(protocol.PacketNumber(0x43)))
Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT))
Expect(p.buffer.Data).To(HaveLen(int(probePacketSize)))
})
})
})
})