diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 5a83cc95..ee193d90 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -276,11 +276,13 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En if err != nil || len(ackedPackets) == 0 { return err } - - if err := h.detectAndRemoveLostPackets(rcvTime, encLevel, priorInFlight); err != nil { + lostPackets, err := h.detectAndRemoveLostPackets(rcvTime, encLevel) + if err != nil { return err } - + for _, p := range lostPackets { + h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + } for _, p := range ackedPackets { if p.includedInBytesInFlight { h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) @@ -301,10 +303,7 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu return h.lowestNotConfirmedAcked } -func (h *sentPacketHandler) detectAndRemoveAckedPackets( - ack *wire.AckFrame, - encLevel protocol.EncryptionLevel, -) ([]*Packet, error) { +func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) { pnSpace := h.getPacketNumberSpace(encLevel) var ackedPackets []*Packet ackRangeIndex := 0 @@ -449,11 +448,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { h.alarm = sentTime.Add(h.rttStats.PTO(encLevel == protocol.Encryption1RTT) << h.ptoCount) } -func (h *sentPacketHandler) detectAndRemoveLostPackets( - now time.Time, - encLevel protocol.EncryptionLevel, - priorInFlight protocol.ByteCount, -) error { +func (h *sentPacketHandler) detectAndRemoveLostPackets(now time.Time, encLevel protocol.EncryptionLevel) ([]*Packet, error) { pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.lossTime = time.Time{} @@ -506,7 +501,6 @@ func (h *sentPacketHandler) detectAndRemoveLostPackets( // the bytes in flight need to be reduced no matter if this packet will be retransmitted if p.includedInBytesInFlight { h.bytesInFlight -= p.Length - h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) } pnSpace.history.Remove(p.PacketNumber) if h.traceCallback != nil { @@ -525,7 +519,7 @@ func (h *sentPacketHandler) detectAndRemoveLostPackets( }) } } - return nil + return lostPackets, nil } func (h *sentPacketHandler) OnLossDetectionTimeout() error { @@ -549,7 +543,14 @@ func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) } // Early retransmit or time loss detection - return h.detectAndRemoveLostPackets(time.Now(), encLevel, h.bytesInFlight) + priorInFlight := h.bytesInFlight + lostPackets, err := h.detectAndRemoveLostPackets(time.Now(), encLevel) + if err != nil { + return err + } + for _, p := range lostPackets { + h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + } } // PTO