From 192bc8dc2a37e5154dc8265f0ae8a065fa8117be Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 17 Nov 2020 13:08:06 +0700 Subject: [PATCH] account for the size of the header when packing 1-RTT probe packets --- packet_packer.go | 37 ++++++++++++++++++------------------ packet_packer_test.go | 44 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index 84b5cb5c1..858fcd714 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -467,7 +467,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { }, nil } -func (p *packetPacker) maybeGetCryptoPacket(maxSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) { +func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) { var s cryptoStream var hasRetransmission bool //nolint:exhaustive // Initial and Handshake are the only two encryption levels here. @@ -494,19 +494,19 @@ func (p *packetPacker) maybeGetCryptoPacket(maxSize, currentSize protocol.ByteCo if ack != nil { payload.ack = ack payload.length = ack.Length(p.version) - maxSize -= payload.length + maxPacketSize -= payload.length } hdr := p.getLongHeader(encLevel) - maxSize -= hdr.GetLength(p.version) + maxPacketSize -= hdr.GetLength(p.version) if hasRetransmission { for { var f wire.Frame //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s switch encLevel { case protocol.EncryptionInitial: - f = p.retransmissionQueue.GetInitialFrame(maxSize) + f = p.retransmissionQueue.GetInitialFrame(maxPacketSize) case protocol.EncryptionHandshake: - f = p.retransmissionQueue.GetHandshakeFrame(maxSize) + f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize) } if f == nil { break @@ -514,10 +514,10 @@ func (p *packetPacker) maybeGetCryptoPacket(maxSize, currentSize protocol.ByteCo payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) frameLen := f.Length(p.version) payload.length += frameLen - maxSize -= frameLen + maxPacketSize -= frameLen } } else if s.HasData() { - cf := s.PopCryptoFrame(maxSize) + cf := s.PopCryptoFrame(maxPacketSize) payload.frames = []ackhandler.Frame{{Frame: cf}} payload.length += cf.Length(p.version) } @@ -547,18 +547,19 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPacketSize, currentSize protocol } maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) - payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, currentSize, encLevel) + payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, encLevel == protocol.Encryption1RTT && currentSize == 0) return sealer, hdr, payload } -func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) *payload { - payload := p.composeNextPacket(maxPayloadSize, encLevel == protocol.Encryption1RTT && currentSize == 0) +func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol.ByteCount, ackAllowed bool) *payload { + payload := p.composeNextPacket(maxPayloadSize, ackAllowed) // check if we have anything to send - if len(payload.frames) == 0 && payload.ack == nil { - return nil - } - if len(payload.frames) == 0 { // the packet only contains an ACK + if len(payload.frames) == 0 { + if payload.ack == nil { + return nil + } + // the packet only contains an ACK if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { ping := &wire.PingFrame{} payload.frames = append(payload.frames, ackhandler.Frame{Frame: ping}) @@ -642,14 +643,12 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( return nil, err } sealer = oneRTTSealer - payload = p.maybeGetAppDataPacketWithEncLevel(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.Encryption1RTT) - if payload != nil { - hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) - } + hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) + payload = p.maybeGetAppDataPacketWithEncLevel(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), true) default: panic("unknown encryption level") } - if hdr == nil { + if payload == nil { return nil, nil } size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) diff --git a/packet_packer_test.go b/packet_packer_test.go index cf1f38500..49c9f23b9 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -1322,6 +1322,25 @@ var _ = Describe("Packet packer", func() { parsePacket(packet.buffer.Data) }) + It("packs a full size Handshake probe packet", func() { + f := &wire.CryptoFrame{Data: make([]byte, 2000)} + retransmissionQueue.AddHandshake(f) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + handshakeStream.EXPECT().HasData() + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + + packet, err := packer.MaybePackProbePacket(protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) + Expect(packet.length).To(Equal(maxPacketSize)) + parsePacket(packet.buffer.Data) + }) + It("packs a 1-RTT probe packet", func() { f := &wire.StreamFrame{Data: []byte("1-RTT")} retransmissionQueue.AddInitial(f) @@ -1341,8 +1360,33 @@ var _ = Describe("Packet packer", func() { Expect(packet.frames[0].Frame).To(Equal(f)) }) + It("packs a full size 1-RTT probe packet", func() { + f := &wire.StreamFrame{Data: make([]byte, 2000)} + retransmissionQueue.AddInitial(f) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + framer.EXPECT().HasData().Return(true) + expectAppendControlFrames() + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, maxSize protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + sf, split := f.MaybeSplitOffFrame(maxSize, packer.version) + Expect(split).To(BeTrue()) + return append(fs, ackhandler.Frame{Frame: sf}), sf.Length(packer.version) + }) + + packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + Expect(packet.length).To(Equal(maxPacketSize)) + }) + It("returns nil if there's no probe data to send", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) framer.EXPECT().HasData()