diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 0341ae79b..725384014 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -307,13 +307,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } } - lostPackets, err := h.detectLostPackets(rcvTime, encLevel) - if err != nil { + if err := h.detectLostPackets(rcvTime, encLevel); err != nil { return err } - for _, p := range lostPackets { - h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) - } for _, p := range ackedPackets { if p.skippedPacket { return fmt.Errorf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel) @@ -508,7 +504,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { } } -func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) ([]*Packet, error) { +func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error { pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.lossTime = time.Time{} @@ -521,68 +517,64 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E // Packets sent before this time are deemed lost. lostSendTime := now.Add(-lossDelay) - var lostPackets []*Packet - if err := pnSpace.history.Iterate(func(packet *Packet) (bool, error) { - if packet.PacketNumber > pnSpace.largestAcked { + priorInFlight := h.bytesInFlight + return pnSpace.history.Iterate(func(p *Packet) (bool, error) { + if p.PacketNumber > pnSpace.largestAcked { return false, nil } - if packet.declaredLost || packet.skippedPacket { + if p.declaredLost || p.skippedPacket { return true, nil } - if packet.SendTime.Before(lostSendTime) { - lostPackets = append(lostPackets, packet) - if h.tracer != nil { - h.tracer.LostPacket(packet.EncryptionLevel, packet.PacketNumber, logging.PacketLossTimeThreshold) + var packetLost bool + if p.SendTime.Before(lostSendTime) { + packetLost = true + if h.logger.Debug() { + h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) } - } else if pnSpace.largestAcked >= packet.PacketNumber+packetThreshold { - lostPackets = append(lostPackets, packet) if h.tracer != nil { - h.tracer.LostPacket(packet.EncryptionLevel, packet.PacketNumber, logging.PacketLossReorderingThreshold) + h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) + } + } else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold { + packetLost = true + if h.logger.Debug() { + h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) + } + if h.tracer != nil { + h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) } } else if pnSpace.lossTime.IsZero() { // Note: This conditional is only entered once per call - lossTime := packet.SendTime.Add(lossDelay) + lossTime := p.SendTime.Add(lossDelay) if h.logger.Debug() { - h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", packet.PacketNumber, encLevel, lossDelay, lossTime) + h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", p.PacketNumber, encLevel, lossDelay, lossTime) } pnSpace.lossTime = lossTime } - return true, nil - }); err != nil { - return nil, err - } - - if h.logger.Debug() && len(lostPackets) > 0 { - pns := make([]protocol.PacketNumber, len(lostPackets)) - for i, p := range lostPackets { - pns[i] = p.PacketNumber - } - h.logger.Debugf("\tlost packets (%d): %d", len(pns), pns) - } - - for _, p := range lostPackets { - p.declaredLost = true - h.queueFramesForRetransmission(p) - // the bytes in flight need to be reduced no matter if this packet will be retransmitted - h.removeFromBytesInFlight(p) - if h.traceCallback != nil { - frames := make([]wire.Frame, 0, len(p.Frames)) - for _, f := range p.Frames { - frames = append(frames, f.Frame) + if packetLost { + h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + p.declaredLost = true + h.queueFramesForRetransmission(p) + // the bytes in flight need to be reduced no matter if this packet will be retransmitted + h.removeFromBytesInFlight(p) + if h.traceCallback != nil { + frames := make([]wire.Frame, 0, len(p.Frames)) + for _, f := range p.Frames { + frames = append(frames, f.Frame) + } + h.traceCallback(quictrace.Event{ + Time: now, + EventType: quictrace.PacketLost, + EncryptionLevel: p.EncryptionLevel, + PacketNumber: p.PacketNumber, + PacketSize: p.Length, + Frames: frames, + TransportState: h.GetStats(), + }) } - h.traceCallback(quictrace.Event{ - Time: now, - EventType: quictrace.PacketLost, - EncryptionLevel: p.EncryptionLevel, - PacketNumber: p.PacketNumber, - PacketSize: p.Length, - Frames: frames, - TransportState: h.GetStats(), - }) } - } - return lostPackets, nil + return true, nil + }) } func (h *sentPacketHandler) OnLossDetectionTimeout() error { @@ -609,15 +601,7 @@ func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) } // Early retransmit or time loss detection - priorInFlight := h.bytesInFlight - lostPackets, err := h.detectLostPackets(time.Now(), encLevel) - if err != nil { - return err - } - for _, p := range lostPackets { - h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) - } - return nil + return h.detectLostPackets(time.Now(), encLevel) } // PTO