diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 83d75d8b..8dead2a9 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -329,7 +329,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { } // PTO alarm - h.alarm = h.lastSentAckElicitingPacketTime.Add(h.rttStats.PTO() << h.ptoCount) + h.alarm = h.lastSentAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount) } func (h *sentPacketHandler) detectLostPackets( diff --git a/internal/congestion/rtt_stats.go b/internal/congestion/rtt_stats.go index 0b17fc10..9ae42706 100644 --- a/internal/congestion/rtt_stats.go +++ b/internal/congestion/rtt_stats.go @@ -46,13 +46,19 @@ func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT } // MeanDeviation gets the mean deviation func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } +// MaxAckDelay gets the max_ack_delay advertized by the peer func (r *RTTStats) MaxAckDelay() time.Duration { return r.maxAckDelay } -func (r *RTTStats) PTO() time.Duration { +// PTO gets the probe timeout duration. +func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { if r.SmoothedRTT() == 0 { return 2 * defaultInitialRTT } - return r.SmoothedRTT() + utils.MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) + r.MaxAckDelay() + pto := r.SmoothedRTT() + utils.MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) + if includeMaxAckDelay { + pto += r.MaxAckDelay() + } + return pto } // UpdateRTT updates the RTT based on a new sample. diff --git a/internal/congestion/rtt_stats_test.go b/internal/congestion/rtt_stats_test.go index c899dde0..fe722281 100644 --- a/internal/congestion/rtt_stats_test.go +++ b/internal/congestion/rtt_stats_test.go @@ -66,13 +66,14 @@ var _ = Describe("RTT stats", func() { rttStats.UpdateRTT(rtt, 0, time.Time{}) Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) Expect(rttStats.MeanDeviation()).To(Equal(rtt / 2)) - Expect(rttStats.PTO()).To(Equal(rtt + 4*(rtt/2) + maxAckDelay)) + Expect(rttStats.PTO(false)).To(Equal(rtt + 4*(rtt/2))) + Expect(rttStats.PTO(true)).To(Equal(rtt + 4*(rtt/2) + maxAckDelay)) }) It("uses the granularity for computing the PTO for short RTTs", func() { rtt := time.Microsecond rttStats.UpdateRTT(rtt, 0, time.Time{}) - Expect(rttStats.PTO()).To(Equal(rtt + protocol.TimerGranularity)) + Expect(rttStats.PTO(true)).To(Equal(rtt + protocol.TimerGranularity)) }) It("ExpireSmoothedMetrics", func() { diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 2444bc50..5cc58511 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -99,7 +99,7 @@ func (a *updatableAEAD) rollKeys(now time.Time) { a.numRcvdWithCurrentKey = 0 a.numSentWithCurrentKey = 0 a.prevRcvAEAD = a.rcvAEAD - a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO()) + a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO(true)) a.rcvAEAD = a.nextRcvAEAD a.sendAEAD = a.nextSendAEAD diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 51000a44..54d071de 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -151,7 +151,7 @@ var _ = Describe("Updatable AEAD", func() { It("drops keys 3 PTOs after a key update", func() { now := time.Now() rttStats.UpdateRTT(10*time.Millisecond, 0, now) - pto := rttStats.PTO() + pto := rttStats.PTO(true) encrypted01 := client.Seal(nil, msg, 0x42, ad) encrypted02 := client.Seal(nil, msg, 0x43, ad) // receive the first packet with key phase 0