diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 0f151110..c6325a3e 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -116,6 +116,16 @@ func newSentPacketHandler( } func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { + if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionInitial { + // This function is called when the crypto setup seals a Handshake packet. + // If this Handshake packet is coalesced behind an Initial packet, we would drop the Initial packet number space + // before SentPacket() was called for that Initial packet. + return + } + h.dropPackets(encLevel) +} + +func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { // remove outstanding packets from bytes_in_flight if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake { pnSpace := h.getPacketNumberSpace(encLevel) @@ -153,6 +163,10 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { } func (h *sentPacketHandler) SentPacket(packet *Packet) { + // For the client, drop the Initial packet number space when the first Handshake packet is sent. + if h.perspective == protocol.PerspectiveClient && packet.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { + h.dropPackets(protocol.EncryptionInitial) + } isAckEliciting := h.sentPacketImpl(packet) if isAckEliciting { h.getPacketNumberSpace(packet.EncryptionLevel).history.SentPacket(packet) diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 3cbe6190..ab2d64e0 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -767,9 +767,8 @@ var _ = Describe("SentPacketHandler", func() { protocol.EncryptionInitial, time.Now(), )).To(Succeed()) - handler.DropPackets(protocol.EncryptionInitial) // Initial keys are dropped when a Handshake packet is received. - handler.SentPacket(handshakePacketNonAckEliciting(&Packet{PacketNumber: 1})) + handler.SentPacket(handshakePacketNonAckEliciting(&Packet{PacketNumber: 1})) // also drops Initial packets Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) @@ -853,7 +852,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.ReceivedAck(ack, protocol.EncryptionHandshake, time.Now())).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet")) }) - It("deletes Initial packets", func() { + It("deletes Initial packets, as a server", func() { for i := protocol.PacketNumber(0); i < 6; i++ { handler.SentPacket(ackElicitingPacket(&Packet{ PacketNumber: i, @@ -874,6 +873,37 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) }) + Context("deleting Initials", func() { + BeforeEach(func() { perspective = protocol.PerspectiveClient }) + + It("deletes Initials, as a client", func() { + for i := protocol.PacketNumber(0); i < 6; i++ { + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: i, + EncryptionLevel: protocol.EncryptionInitial, + })) + } + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) + handler.DropPackets(protocol.EncryptionInitial) + // DropPackets should be ignored for clients and the Initial packet number space. + // It has to be possible to send another Initial packets after this function was called. + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 10, + EncryptionLevel: protocol.EncryptionInitial, + })) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(7))) + // Sending a Handshake packet triggers dropping of Initials. + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 1, + EncryptionLevel: protocol.EncryptionHandshake, + })) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1))) + Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission + Expect(handler.initialPackets).To(BeNil()) + Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) + }) + }) + It("deletes Handshake packets", func() { for i := protocol.PacketNumber(0); i < 6; i++ { handler.SentPacket(ackElicitingPacket(&Packet{ diff --git a/packet_packer.go b/packet_packer.go index 2e0c3b2b..b52b43e0 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -264,6 +264,15 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke return p.writeSinglePacket(hdr, payload, encLevel, sealer) } +func (p *packetPacker) padPacket(buffer *packetBuffer) { + if dataLen := len(buffer.Data); dataLen < protocol.MinInitialPacketSize { + buffer.Data = buffer.Data[:protocol.MinInitialPacketSize] + for n := dataLen; n < protocol.MinInitialPacketSize; n++ { + buffer.Data[n] = 0 + } + } +} + // PackCoalescedPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. @@ -278,6 +287,11 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { buffer.Release() return nil, nil } + + if p.perspective == protocol.PerspectiveClient && packet.packets[0].header.Type == protocol.PacketTypeInitial { + p.padPacket(buffer) + } + return packet, nil } @@ -507,6 +521,9 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( if err != nil { return nil, err } + if p.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionInitial { + p.padPacket(buffer) + } return &packedPacket{ buffer: buffer, packetContents: contents, @@ -620,13 +637,7 @@ func (p *packetPacker) appendPacket( var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) if encLevel != protocol.Encryption1RTT { - if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { - headerLen := header.GetLength(p.version) - header.Length = pnLen + protocol.MinInitialPacketSize - headerLen - paddingLen = protocol.ByteCount(protocol.MinInitialPacketSize-sealer.Overhead()) - headerLen - payload.length - } else { - header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length - } + header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length } else if payload.length < 4-pnLen { paddingLen = 4 - pnLen - payload.length } diff --git a/packet_packer_test.go b/packet_packer_test.go index 1456252a..b7124d5a 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -197,7 +197,7 @@ var _ = Describe("Packet packer", func() { sealer.EXPECT().Overhead().Return(7).AnyTimes() sealer.EXPECT().EncryptHeader(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte { - return append(src, bytes.Repeat([]byte{0}, sealer.Overhead())...) + return append(src, bytes.Repeat([]byte{'s'}, sealer.Overhead())...) }).AnyTimes() return sealer } @@ -683,6 +683,43 @@ var _ = Describe("Packet packer", func() { Expect(rest).To(BeEmpty()) }) + It("packs a coalesced packet with Initial / 0-RTT, and pads it", func() { + packer.perspective = protocol.PerspectiveClient + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get0RTTSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + // don't EXPECT any calls for a Handshake ACK frame + initialStream.EXPECT().HasData().Return(true).Times(2) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} + }) + expectAppendControlFrames() + expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.buffer.Data).To(HaveLen(protocol.MinInitialPacketSize)) + Expect(p.packets).To(HaveLen(2)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + hdr, _, rest, err := wire.ParsePacket(p.buffer.Data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) + hdr, _, rest, err = wire.ParsePacket(rest, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(rest).To(Equal(make([]byte, len(rest)))) + }) + It("packs a coalesced packet with Handshake / 1-RTT", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24)) @@ -824,7 +861,6 @@ var _ = Describe("Packet packer", func() { Expect(p.packets[0].frames).To(HaveLen(1)) cf := p.packets[0].frames[0].Frame.(*wire.CryptoFrame) Expect(cf.Data).To(Equal([]byte("foobar"))) - checkLength(p.buffer.Data) }) It("adds an ACK frame", func() { @@ -866,9 +902,29 @@ var _ = Describe("Packet packer", func() { Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(packet.frames).To(HaveLen(1)) Expect(packet.frames[0].Frame).To(Equal(f)) + Expect(packet.buffer.Len()).To(BeNumerically("<", protocol.MinInitialPacketSize)) checkLength(packet.buffer.Data) }) + It("packs an Initial probe packet and pads it, for the client", func() { + packer.perspective = protocol.PerspectiveClient + f := &wire.CryptoFrame{Data: []byte("Initial")} + retransmissionQueue.AddInitial(f) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + initialStream.EXPECT().HasData() + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) + + packet, err := packer.MaybePackProbePacket(protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(packet.buffer.Len()).To(BeEquivalentTo(protocol.MinInitialPacketSize)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(Equal(f)) + }) + It("packs a Handshake probe packet", func() { f := &wire.CryptoFrame{Data: []byte("Handshake")} retransmissionQueue.AddHandshake(f)