From 7bc2ba6b8113baed66134689046f5d8b4427b31f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 4 Sep 2022 11:27:04 +0300 Subject: [PATCH] simplify packing of long header ACK-only packets --- packet_packer.go | 100 +++++++++++++++++++---------------------------- 1 file changed, 40 insertions(+), 60 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index c4716fea0..6010411b0 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -322,23 +322,26 @@ func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { var pay *payload - var encLevel protocol.EncryptionLevel + var sealer sealer + var hdr *wire.ExtendedHeader + encLevel := protocol.EncryptionInitial if !handshakeConfirmed { - var ack *wire.AckFrame - ack = p.acks.GetAckFrame(protocol.EncryptionInitial, true) - if ack != nil { - encLevel = protocol.EncryptionInitial - } else { - ack = p.acks.GetAckFrame(protocol.EncryptionHandshake, true) - if ack != nil { - encLevel = protocol.EncryptionHandshake + hdr, pay = p.maybeGetCryptoPacket(p.maxPacketSize, protocol.EncryptionInitial, true, true) + if pay != nil { + var err error + sealer, err = p.cryptoSetup.GetInitialSealer() + if err != nil { + return nil, err } - } - - if ack != nil { - pay = &payload{ - ack: ack, - length: ack.Length(p.version), + } else { + encLevel = protocol.EncryptionHandshake + hdr, pay = p.maybeGetCryptoPacket(p.maxPacketSize, protocol.EncryptionHandshake, true, true) + if pay != nil { + var err error + sealer, err = p.cryptoSetup.GetHandshakeSealer() + if err != nil { + return nil, err + } } } } @@ -348,12 +351,14 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke return nil, nil } encLevel = protocol.Encryption1RTT + s, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, err + } + hdr = p.getShortHeader(s.KeyPhase()) + sealer = s } - sealer, hdr, err := p.getSealerAndHeader(encLevel) - if err != nil { - return nil, err - } return p.writeSinglePacket(hdr, pay, encLevel, sealer) } @@ -387,7 +392,7 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { } var size protocol.ByteCount if initialSealer != nil { - initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), size, protocol.EncryptionInitial) + initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, false, size == 0) if initialPayload != nil { size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead()) numPackets++ @@ -403,7 +408,7 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { return nil, err } if handshakeSealer != nil { - handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), size, protocol.EncryptionHandshake) + handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, false, size == 0) if handshakePayload != nil { s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead()) size += s @@ -495,7 +500,17 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { }, nil } -func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) { +func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) { + if onlyAck { + if ack := p.acks.GetAckFrame(encLevel, true); ack != nil { + var payload payload + payload.ack = ack + payload.length = ack.Length(p.version) + return p.getLongHeader(encLevel), &payload + } + return nil, nil + } + var s cryptoStream var hasRetransmission bool //nolint:exhaustive // Initial and Handshake are the only two encryption levels here. @@ -510,7 +525,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol. hasData := s.HasData() var ack *wire.AckFrame - if encLevel == protocol.EncryptionInitial || currentSize == 0 { + if ackAllowed { ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData) } if !hasData && !hasRetransmission && ack == nil { @@ -677,14 +692,14 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( if err != nil { return nil, err } - hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionInitial) + hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true) case protocol.EncryptionHandshake: var err error sealer, err = p.cryptoSetup.GetHandshakeSealer() if err != nil { return nil, err } - hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionHandshake) + hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true) case protocol.Encryption1RTT: oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { @@ -738,41 +753,6 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B }, nil } -func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) { - switch encLevel { - case protocol.EncryptionInitial: - sealer, err := p.cryptoSetup.GetInitialSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.EncryptionInitial) - return sealer, hdr, nil - case protocol.Encryption0RTT: - sealer, err := p.cryptoSetup.Get0RTTSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.Encryption0RTT) - return sealer, hdr, nil - case protocol.EncryptionHandshake: - sealer, err := p.cryptoSetup.GetHandshakeSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.EncryptionHandshake) - return sealer, hdr, nil - case protocol.Encryption1RTT: - sealer, err := p.cryptoSetup.Get1RTTSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getShortHeader(sealer.KeyPhase()) - return sealer, hdr, nil - default: - return nil, nil, fmt.Errorf("unexpected encryption level: %s", encLevel) - } -} - func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) hdr := &wire.ExtendedHeader{}