From 8f4257a8832986178080de35a8b2c15e80a0b570 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 7 Sep 2020 14:19:43 +0700 Subject: [PATCH] delete unacknowledged packets from the packet history after 3 PTOs --- internal/ackhandler/interfaces.go | 1 + internal/ackhandler/sent_packet_handler.go | 75 ++++++++++--------- .../ackhandler/sent_packet_handler_test.go | 20 ++++- internal/ackhandler/sent_packet_history.go | 33 ++++++-- .../ackhandler/sent_packet_history_test.go | 39 +++++++++- 5 files changed, 123 insertions(+), 45 deletions(-) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index db739fb44..42d83d983 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -18,6 +18,7 @@ type Packet struct { SendTime time.Time includedInBytesInFlight bool + declaredLost bool } // SentPacketHandler handles ACKs received for outgoing packets diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index e1bfa4eac..a4e3f4e0a 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -35,9 +35,9 @@ type packetNumberSpace struct { largestSent protocol.PacketNumber } -func newPacketNumberSpace(initialPN protocol.PacketNumber) *packetNumberSpace { +func newPacketNumberSpace(initialPN protocol.PacketNumber, rttStats *utils.RTTStats) *packetNumberSpace { return &packetNumberSpace{ - history: newSentPacketHistory(), + history: newSentPacketHistory(rttStats), pns: newPacketNumberGenerator(initialPN, protocol.SkipPacketAveragePeriodLength), largestSent: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, @@ -109,9 +109,9 @@ func newSentPacketHandler( return &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerAddressValidated: pers == protocol.PerspectiveClient, - initialPackets: newPacketNumberSpace(initialPacketNumber), - handshakePackets: newPacketNumberSpace(0), - appDataPackets: newPacketNumberSpace(0), + initialPackets: newPacketNumberSpace(initialPacketNumber, rttStats), + handshakePackets: newPacketNumberSpace(0, rttStats), + appDataPackets: newPacketNumberSpace(0, rttStats), rttStats: rttStats, congestion: congestion, perspective: pers, @@ -131,6 +131,13 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { h.dropPackets(encLevel) } +func (h *sentPacketHandler) removeFromBytesInFlight(p *Packet) { + if p.includedInBytesInFlight { + h.bytesInFlight -= p.Length + p.includedInBytesInFlight = false + } +} + func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { // The server won't await address validation after the handshake is confirmed. // This applies even if we didn't receive an ACK for a Handshake packet. @@ -141,9 +148,7 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake { pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.history.Iterate(func(p *Packet) (bool, error) { - if p.includedInBytesInFlight { - h.bytesInFlight -= p.Length - } + h.removeFromBytesInFlight(p) return true, nil }) } @@ -160,9 +165,7 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { return false, nil } h.queueFramesForRetransmission(p) - if p.includedInBytesInFlight { - h.bytesInFlight -= p.Length - } + h.removeFromBytesInFlight(p) h.appDataPackets.history.Remove(p.PacketNumber) return true, nil }) @@ -301,7 +304,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } } - lostPackets, err := h.detectAndRemoveLostPackets(rcvTime, encLevel) + lostPackets, err := h.detectLostPackets(rcvTime, encLevel) if err != nil { return err } @@ -309,9 +312,10 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) } for _, p := range ackedPackets { - if p.includedInBytesInFlight { + if p.includedInBytesInFlight && !p.declaredLost { h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) } + h.removeFromBytesInFlight(p) } // Reset the pto_count unless the client is unsure if the server has validated the client's address. @@ -323,6 +327,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } h.numProbesToSend = 0 + pnSpace.history.DeleteOldPackets(rcvTime) h.setLossDetectionTimer() return nil } @@ -385,9 +390,6 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL f.OnAcked(f.Frame) } } - if p.includedInBytesInFlight { - h.bytesInFlight -= p.Length - } if err := pnSpace.history.Remove(p.PacketNumber); err != nil { return nil, err } @@ -500,7 +502,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { } } -func (h *sentPacketHandler) detectAndRemoveLostPackets(now time.Time, encLevel protocol.EncryptionLevel) ([]*Packet, error) { +func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) ([]*Packet, error) { pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.lossTime = time.Time{} @@ -518,6 +520,9 @@ func (h *sentPacketHandler) detectAndRemoveLostPackets(now time.Time, encLevel p if packet.PacketNumber > pnSpace.largestAcked { return false, nil } + if packet.declaredLost { + return true, nil + } if packet.SendTime.Before(lostSendTime) { lostPackets = append(lostPackets, packet) @@ -551,14 +556,10 @@ func (h *sentPacketHandler) detectAndRemoveLostPackets(now time.Time, encLevel p } for _, p := range lostPackets { + p.declaredLost = true h.queueFramesForRetransmission(p) // the bytes in flight need to be reduced no matter if this packet will be retransmitted - if p.includedInBytesInFlight { - h.bytesInFlight -= p.Length - } - if err := pnSpace.history.Remove(p.PacketNumber); err != nil { - return nil, err - } + h.removeFromBytesInFlight(p) if h.traceCallback != nil { frames := make([]wire.Frame, 0, len(p.Frames)) for _, f := range p.Frames { @@ -603,7 +604,7 @@ func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { } // Early retransmit or time loss detection priorInFlight := h.bytesInFlight - lostPackets, err := h.detectAndRemoveLostPackets(time.Now(), encLevel) + lostPackets, err := h.detectLostPackets(time.Now(), encLevel) if err != nil { return err } @@ -740,22 +741,21 @@ func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) return false } h.queueFramesForRetransmission(p) - // TODO: don't remove the packet here + // TODO: don't declare the packet lost here. // Keep track of acknowledged frames instead. - if p.includedInBytesInFlight { - h.bytesInFlight -= p.Length - } - if err := pnSpace.history.Remove(p.PacketNumber); err != nil { - // should never happen. We just got this packet from the history. - panic(err) - } + h.removeFromBytesInFlight(p) + p.declaredLost = true return true } func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) { + if len(p.Frames) == 0 { + panic("no frames") + } for _, f := range p.Frames { f.OnLost(f.Frame) } + p.Frames = nil } func (h *sentPacketHandler) ResetForRetry() error { @@ -765,13 +765,18 @@ func (h *sentPacketHandler) ResetForRetry() error { if firstPacketSendTime.IsZero() { firstPacketSendTime = p.SendTime } + if p.declaredLost { + return true, nil + } h.queueFramesForRetransmission(p) return true, nil }) // All application data packets sent at this point are 0-RTT packets. // In the case of a Retry, we can assume that the server dropped all of them. h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { - h.queueFramesForRetransmission(p) + if !p.declaredLost { + h.queueFramesForRetransmission(p) + } return true, nil }) @@ -787,8 +792,8 @@ func (h *sentPacketHandler) ResetForRetry() error { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } } - h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop()) - h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop()) + h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), h.rttStats) + h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), h.rttStats) oldAlarm := h.alarm h.alarm = time.Time{} if h.tracer != nil { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 4808ad519..1883930fa 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -26,7 +26,7 @@ var _ = Describe("SentPacketHandler", func() { JustBeforeEach(func() { lostPackets = nil - rttStats := &utils.RTTStats{} + rttStats := utils.NewRTTStats() handler = newSentPacketHandler(42, rttStats, perspective, nil, nil, utils.DefaultLogger) streamFrame = wire.StreamFrame{ StreamID: 5, @@ -86,7 +86,14 @@ var _ = Describe("SentPacketHandler", func() { expectInPacketHistory := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { pnSpace := handler.getPacketNumberSpace(encLevel) - ExpectWithOffset(1, pnSpace.history.Len()).To(Equal(len(expected))) + var length int + pnSpace.history.Iterate(func(p *Packet) (bool, error) { + if !p.declaredLost { + length++ + } + return true, nil + }) + ExpectWithOffset(1, length).To(Equal(len(expected))) for _, p := range expected { ExpectWithOffset(2, pnSpace.history.packetMap).To(HaveKey(p)) } @@ -190,7 +197,14 @@ var _ = Describe("SentPacketHandler", func() { Context("acks the right packets", func() { expectInPacketHistoryOrLost := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { pnSpace := handler.getPacketNumberSpace(encLevel) - ExpectWithOffset(1, pnSpace.history.Len()+len(lostPackets)).To(Equal(len(expected))) + var length int + pnSpace.history.Iterate(func(p *Packet) (bool, error) { + if !p.declaredLost { + length++ + } + return true, nil + }) + ExpectWithOffset(1, length+len(lostPackets)).To(Equal(len(expected))) expectedLoop: for _, p := range expected { if _, ok := pnSpace.history.packetMap[p]; ok { diff --git a/internal/ackhandler/sent_packet_history.go b/internal/ackhandler/sent_packet_history.go index 6d792d5d3..c7e37a076 100644 --- a/internal/ackhandler/sent_packet_history.go +++ b/internal/ackhandler/sent_packet_history.go @@ -2,17 +2,21 @@ package ackhandler import ( "fmt" + "time" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) type sentPacketHistory struct { + rttStats *utils.RTTStats packetList *PacketList packetMap map[protocol.PacketNumber]*PacketElement } -func newSentPacketHistory() *sentPacketHistory { +func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory { return &sentPacketHistory{ + rttStats: rttStats, packetList: NewPacketList(), packetMap: make(map[protocol.PacketNumber]*PacketElement), } @@ -40,10 +44,12 @@ func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) err // FirstOutStanding returns the first outstanding packet. func (h *sentPacketHistory) FirstOutstanding() *Packet { - if !h.HasOutstandingPackets() { - return nil + for el := h.packetList.Front(); el != nil; el = el.Next() { + if !el.Value.declaredLost { + return &el.Value + } } - return &h.packetList.Front().Value + return nil } func (h *sentPacketHistory) Len() int { @@ -61,5 +67,22 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { } func (h *sentPacketHistory) HasOutstandingPackets() bool { - return h.packetList.Len() > 0 + return h.FirstOutstanding() != nil +} + +func (h *sentPacketHistory) DeleteOldPackets(now time.Time) { + maxAge := 3 * h.rttStats.PTO(false) + var nextEl *PacketElement + for el := h.packetList.Front(); el != nil; el = nextEl { + nextEl = el.Next() + p := el.Value + if p.SendTime.After(now.Add(-maxAge)) { + break + } + if !p.declaredLost { // should only happen in the case of drastic RTT changes + continue + } + delete(h.packetMap, p.PacketNumber) + h.packetList.Remove(el) + } } diff --git a/internal/ackhandler/sent_packet_history_test.go b/internal/ackhandler/sent_packet_history_test.go index 8f4e616b0..bc214ed1d 100644 --- a/internal/ackhandler/sent_packet_history_test.go +++ b/internal/ackhandler/sent_packet_history_test.go @@ -2,6 +2,9 @@ package ackhandler import ( "errors" + "time" + + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" @@ -9,7 +12,10 @@ import ( ) var _ = Describe("SentPacketHistory", func() { - var hist *sentPacketHistory + var ( + hist *sentPacketHistory + rttStats *utils.RTTStats + ) expectInHistory := func(packetNumbers []protocol.PacketNumber) { ExpectWithOffset(1, hist.packetMap).To(HaveLen(len(packetNumbers))) @@ -26,7 +32,8 @@ var _ = Describe("SentPacketHistory", func() { } BeforeEach(func() { - hist = newSentPacketHistory() + rttStats = utils.NewRTTStats() + hist = newSentPacketHistory(rttStats) }) It("saves sent packets", func() { @@ -155,4 +162,32 @@ var _ = Describe("SentPacketHistory", func() { Expect(hist.HasOutstandingPackets()).To(BeFalse()) }) }) + + Context("deleting old packets", func() { + const pto = 3 * time.Second + + BeforeEach(func() { + rttStats.UpdateRTT(time.Second, 0, time.Time{}) + Expect(rttStats.PTO(false)).To(Equal(pto)) + }) + + It("deletes old packets after 3 PTOs", func() { + now := time.Now() + hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto), declaredLost: true}) + Expect(hist.Len()).To(Equal(1)) + hist.DeleteOldPackets(now.Add(-time.Nanosecond)) + Expect(hist.Len()).To(Equal(1)) + hist.DeleteOldPackets(now) + Expect(hist.Len()).To(BeZero()) + }) + + It("doesn't delete a packet if it hasn't been declared lost yet", func() { + now := time.Now() + hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto), declaredLost: true}) + hist.SentPacket(&Packet{PacketNumber: 11, SendTime: now.Add(-3 * pto), declaredLost: false}) + Expect(hist.Len()).To(Equal(2)) + hist.DeleteOldPackets(now) + Expect(hist.Len()).To(Equal(1)) + }) + }) })