From 2873125c150042e34f7c564bb221c4d80865856d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 4 Sep 2022 10:44:37 +0300 Subject: [PATCH] move packing of ACK-only short header packets to composeNextPacket --- packet_packer.go | 44 +++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index 8e4aa735..c4716fea 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -321,9 +321,10 @@ func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) } func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { + var pay *payload var encLevel protocol.EncryptionLevel - var ack *wire.AckFrame if !handshakeConfirmed { + var ack *wire.AckFrame ack = p.acks.GetAckFrame(protocol.EncryptionInitial, true) if ack != nil { encLevel = protocol.EncryptionInitial @@ -333,24 +334,27 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke encLevel = protocol.EncryptionHandshake } } + + if ack != nil { + pay = &payload{ + ack: ack, + length: ack.Length(p.version), + } + } } - if ack == nil { - ack = p.acks.GetAckFrame(protocol.Encryption1RTT, true) - if ack == nil { + if pay == nil { + pay = p.composeNextPacket(p.maxPacketSize, true, true) + if pay == nil { return nil, nil } encLevel = protocol.Encryption1RTT } - payload := &payload{ - ack: ack, - length: ack.Length(p.version), - } sealer, hdr, err := p.getSealerAndHeader(encLevel) if err != nil { return nil, err } - return p.writeSinglePacket(hdr, payload, encLevel, sealer) + return p.writeSinglePacket(hdr, pay, encLevel, sealer) } // size is the expected size of the packet, if no padding was applied. @@ -426,7 +430,7 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { case protocol.Encryption0RTT: appDataHdr, appDataPayload = p.maybeGetAppDataPacketFor0RTT(appDataSealer, maxPacketSize-size) case protocol.Encryption1RTT: - appDataHdr, appDataPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, maxPacketSize-size, size) + appDataHdr, appDataPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, maxPacketSize-size, size == 0) } if appDataHdr != nil && appDataPayload != nil { size += p.packetLength(appDataHdr, appDataPayload) + protocol.ByteCount(appDataSealer.Overhead()) @@ -476,7 +480,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { if err != nil { return nil, err } - hdr, payload := p.maybeGetShortHeaderPacket(sealer, p.maxPacketSize, 0) + hdr, payload := p.maybeGetShortHeaderPacket(sealer, p.maxPacketSize, true) if payload == nil { return nil, nil } @@ -559,15 +563,15 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize return hdr, payload } -func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, maxPacketSize, currentSize protocol.ByteCount) (*wire.ExtendedHeader, *payload) { +func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, maxPacketSize protocol.ByteCount, ackAllowed bool) (*wire.ExtendedHeader, *payload) { hdr := p.getShortHeader(sealer.KeyPhase()) maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) - payload := p.maybeGetAppDataPacket(maxPayloadSize, currentSize == 0) + payload := p.maybeGetAppDataPacket(maxPayloadSize, ackAllowed) return hdr, payload } func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, ackAllowed bool) *payload { - payload := p.composeNextPacket(maxPayloadSize, ackAllowed) + payload := p.composeNextPacket(maxPayloadSize, false, ackAllowed) // check if we have anything to send if len(payload.frames) == 0 { @@ -590,7 +594,17 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, return payload } -func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload { +func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool) *payload { + if onlyAck { + if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil { + payload := &payload{} + payload.ack = ack + payload.length += ack.Length(p.version) + return payload + } + return nil + } + payload := &payload{frames: make([]ackhandler.Frame, 0, 1)} hasData := p.framer.HasData()