diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 7e7a1a10c..e807226e8 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -292,10 +292,13 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } - if h.qlogger != nil && h.ptoCount != 0 { - h.qlogger.UpdatedPTOCount(0) + // Reset the pto_count unless the client is unsure if the server has validated the client's address. + if h.peerCompletedAddressValidation { + if h.qlogger != nil && h.ptoCount != 0 { + h.qlogger.UpdatedPTOCount(0) + } + h.ptoCount = 0 } - h.ptoCount = 0 h.numProbesToSend = 0 h.setLossDetectionTimer() diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index bdff4896a..c0e279090 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -601,6 +601,19 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(4 * timeout)) }) + It("reset the PTO count when receiving an ACK", func() { + now := time.Now() + handler.SetHandshakeComplete() + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) + Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + Expect(handler.ptoCount).To(BeEquivalentTo(1)) + Expect(handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.Encryption1RTT, time.Now())).To(Succeed()) + Expect(handler.ptoCount).To(BeZero()) + }) + It("resets the PTO mode and PTO count when a packet number space is dropped", func() { now := time.Now() handler.SentPacket(ackElicitingPacket(&Packet{ @@ -808,6 +821,18 @@ var _ = Describe("SentPacketHandler", func() { Expect(pto).ToNot(BeZero()) Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", time.Now().Add(pto), 10*time.Millisecond)) }) + + It("doesn't reset the PTO count when receiving an ACK", func() { + now := time.Now() + handler.SentPacket(initialPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) + handler.SentPacket(initialPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) + Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOInitial)) + Expect(handler.ptoCount).To(BeEquivalentTo(1)) + Expect(handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionInitial, time.Now())).To(Succeed()) + Expect(handler.ptoCount).To(BeEquivalentTo(1)) + }) }) Context("Packet-based loss detection", func() {