diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 3944b630..beb5c2ec 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -411,8 +411,7 @@ func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.Encryptio lossTime = h.handshakePackets.lossTime encLevel = protocol.EncryptionHandshake } - if h.handshakeComplete && - (lossTime.IsZero() || (!h.appDataPackets.lossTime.IsZero() && h.appDataPackets.lossTime.Before(lossTime))) { + if lossTime.IsZero() || (!h.appDataPackets.lossTime.IsZero() && h.appDataPackets.lossTime.Before(lossTime)) { lossTime = h.appDataPackets.lossTime encLevel = protocol.Encryption1RTT } @@ -420,24 +419,41 @@ func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.Encryptio } // same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime -func (h *sentPacketHandler) getEarliestSentTimeAndSpace() (time.Time, protocol.EncryptionLevel) { - var encLevel protocol.EncryptionLevel - var sentTime time.Time +func (h *sentPacketHandler) getPTOTimeAndSpace() (time.Time, protocol.EncryptionLevel) { + if !h.hasOutstandingPackets() { + t := time.Now().Add(h.rttStats.PTO(false) << h.ptoCount) + if h.initialPackets != nil { + return t, protocol.EncryptionInitial + } + return t, protocol.EncryptionHandshake + } + + var ( + encLevel protocol.EncryptionLevel + pto time.Time + ) if h.initialPackets != nil { - sentTime = h.initialPackets.lastAckElicitingPacketTime encLevel = protocol.EncryptionInitial + if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() { + pto = t.Add(h.rttStats.PTO(false) << h.ptoCount) + } } - if h.handshakePackets != nil && (sentTime.IsZero() || (!h.handshakePackets.lastAckElicitingPacketTime.IsZero() && h.handshakePackets.lastAckElicitingPacketTime.Before(sentTime))) { - sentTime = h.handshakePackets.lastAckElicitingPacketTime - encLevel = protocol.EncryptionHandshake + if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() { + t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(false) << h.ptoCount) + if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { + pto = t + encLevel = protocol.EncryptionHandshake + } } - if h.handshakeComplete && - (sentTime.IsZero() || (!h.appDataPackets.lastAckElicitingPacketTime.IsZero() && h.appDataPackets.lastAckElicitingPacketTime.Before(sentTime))) { - sentTime = h.appDataPackets.lastAckElicitingPacketTime - encLevel = protocol.Encryption1RTT + if h.handshakeComplete && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { + t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount) + if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { + pto = t + encLevel = protocol.Encryption1RTT + } } - return sentTime, encLevel + return pto, encLevel } func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { @@ -480,14 +496,8 @@ func (h *sentPacketHandler) setLossDetectionTimer() { } // PTO alarm - sentTime, encLevel := h.getEarliestSentTimeAndSpace() - if sentTime.IsZero() { - if h.peerCompletedAddressValidation { - panic("didn't expect sentTime to be zero") - } - sentTime = time.Now() - } - h.alarm = sentTime.Add(h.rttStats.PTO(encLevel == protocol.Encryption1RTT) << h.ptoCount) + ptoTime, encLevel := h.getPTOTimeAndSpace() + h.alarm = ptoTime if h.qlogger != nil && h.alarm != oldAlarm { h.qlogger.SetLossTimer(qlog.TimerTypePTO, encLevel, h.alarm) } @@ -607,7 +617,7 @@ func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { } // PTO - _, encLevel = h.getEarliestSentTimeAndSpace() + _, encLevel = h.getPTOTimeAndSpace() if h.logger.Debug() { h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) }