From 37f1a3fdda61e7759fb4c92f273e4fd7ac4808ee Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 1 Apr 2020 13:37:17 +0700 Subject: [PATCH] simplify removing of acked packets from packet history --- internal/ackhandler/sent_packet_handler.go | 47 +++++++++++----------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 3ace110cd..e9a216ce8 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -271,19 +271,13 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } + priorInFlight := h.bytesInFlight ackedPackets, err := h.determineNewlyAckedPackets(ack, encLevel) if err != nil || len(ackedPackets) == 0 { return err } - priorInFlight := h.bytesInFlight for _, p := range ackedPackets { - if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { - h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1) - } - if err := h.onPacketAcked(p); err != nil { - return err - } if p.includedInBytesInFlight { h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) } @@ -352,6 +346,28 @@ func (h *sentPacketHandler) determineNewlyAckedPackets( } h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns) } + + for _, p := range ackedPackets { + if packet := pnSpace.history.GetPacket(p.PacketNumber); packet == nil { + continue + } + if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { + h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1) + } + + for _, f := range p.Frames { + if f.OnAcked != nil { + f.OnAcked(f.Frame) + } + } + if p.includedInBytesInFlight { + h.bytesInFlight -= p.Length + } + if err := pnSpace.history.Remove(p.PacketNumber); err != nil { + return nil, err + } + } + return ackedPackets, err } @@ -563,23 +579,6 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { return h.alarm } -func (h *sentPacketHandler) onPacketAcked(p *Packet) error { - pnSpace := h.getPacketNumberSpace(p.EncryptionLevel) - if packet := pnSpace.history.GetPacket(p.PacketNumber); packet == nil { - return nil - } - - for _, f := range p.Frames { - if f.OnAcked != nil { - f.OnAcked(f.Frame) - } - } - if p.includedInBytesInFlight { - h.bytesInFlight -= p.Length - } - return pnSpace.history.Remove(p.PacketNumber) -} - func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { pnSpace := h.getPacketNumberSpace(encLevel)