diff --git a/internal/ackhandler/packet_number_generator.go b/internal/ackhandler/packet_number_generator.go index ddd4a5632..56fbf3d80 100644 --- a/internal/ackhandler/packet_number_generator.go +++ b/internal/ackhandler/packet_number_generator.go @@ -5,6 +5,7 @@ import ( "math" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) // The packetNumberGenerator generates the packet number for the next packet @@ -15,6 +16,8 @@ type packetNumberGenerator struct { next protocol.PacketNumber nextToSkip protocol.PacketNumber + + history []protocol.PacketNumber } func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator { @@ -37,6 +40,10 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber { p.next++ if p.next == p.nextToSkip { + if len(p.history)+1 > protocol.MaxTrackedSkippedPackets { + p.history = p.history[1:] + } + p.history = append(p.history, p.next) p.next++ p.generateNewSkip() } @@ -60,3 +67,12 @@ func (p *packetNumberGenerator) getRandomNumber() uint16 { num := uint16(b[0])<<8 + uint16(b[1]) return num } + +func (p *packetNumberGenerator) Validate(ack *wire.AckFrame) bool { + for _, pn := range p.history { + if ack.AcksPacket(pn) { + return false + } + } + return true +} diff --git a/internal/ackhandler/packet_number_generator_test.go b/internal/ackhandler/packet_number_generator_test.go index 73bd914fe..9465a8338 100644 --- a/internal/ackhandler/packet_number_generator_test.go +++ b/internal/ackhandler/packet_number_generator_test.go @@ -4,6 +4,7 @@ import ( "math" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -97,4 +98,45 @@ var _ = Describe("Packet Number Generator", func() { Expect(largest).To(BeNumerically(">", math.MaxUint16-300)) Expect(sum / uint64(rep)).To(BeNumerically("==", uint64(math.MaxUint16/2), 1000)) }) + + It("validates ACK frames", func() { + var skipped []protocol.PacketNumber + var lastPN protocol.PacketNumber + for len(skipped) < 3 { + if png.Peek() > lastPN+1 { + skipped = append(skipped, lastPN+1) + } + lastPN = png.Pop() + } + invalidACK := &wire.AckFrame{ + AckRanges: []wire.AckRange{{Smallest: 1, Largest: lastPN}}, + } + Expect(png.Validate(invalidACK)).To(BeFalse()) + validACK1 := &wire.AckFrame{ + AckRanges: []wire.AckRange{{Smallest: 1, Largest: skipped[0] - 1}}, + } + Expect(png.Validate(validACK1)).To(BeTrue()) + validACK2 := &wire.AckFrame{ + AckRanges: []wire.AckRange{ + {Smallest: 1, Largest: skipped[0] - 1}, + {Smallest: skipped[0] + 1, Largest: skipped[1] - 1}, + {Smallest: skipped[1] + 1, Largest: skipped[2] - 1}, + {Smallest: skipped[2] + 1, Largest: skipped[2] + 100}, + }, + } + Expect(png.Validate(validACK2)).To(BeTrue()) + }) + + It("tracks a maximum number of protocol.MaxTrackedSkippedPackets packets", func() { + var skipped []protocol.PacketNumber + var lastPN protocol.PacketNumber + for len(skipped) < protocol.MaxTrackedSkippedPackets+3 { + if png.Peek() > lastPN+1 { + skipped = append(skipped, lastPN+1) + } + lastPN = png.Pop() + Expect(len(png.history)).To(BeNumerically("<=", protocol.MaxTrackedSkippedPackets)) + } + Expect(len(png.history)).To(Equal(protocol.MaxTrackedSkippedPackets)) + }) }) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 3c134aa00..176f08b5e 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -37,7 +37,6 @@ type sentPacketHandler struct { lastSentHandshakePacketTime time.Time nextPacketSendTime time.Time - skippedPackets []protocol.PacketNumber largestAcked protocol.PacketNumber largestReceivedPacketWithAck protocol.PacketNumber @@ -150,10 +149,6 @@ func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retra func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ { for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { h.logger.Debugf("Skipping packet number %#x", p) - h.skippedPackets = append(h.skippedPackets, p) - if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets { - h.skippedPackets = h.skippedPackets[1:] - } } h.lastSentPacketNumber = packet.PacketNumber @@ -200,7 +195,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe h.largestReceivedPacketWithAck = withPacketNumber h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked) - if h.skippedPacketsAcked(ackFrame) { + if !h.packetNumberGenerator.Validate(ackFrame) { return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number") } @@ -238,8 +233,6 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe return err } h.updateLossDetectionAlarm() - - h.garbageCollectSkippedPackets() return nil } @@ -638,23 +631,3 @@ func (h *sentPacketHandler) computeRTOTimeout() time.Duration { rto <<= h.rtoCount return utils.MinDuration(rto, maxRTOTimeout) } - -func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool { - for _, p := range h.skippedPackets { - if ackFrame.AcksPacket(p) { - return true - } - } - return false -} - -func (h *sentPacketHandler) garbageCollectSkippedPackets() { - lowestUnacked := h.lowestUnacked() - deleteIndex := 0 - for i, p := range h.skippedPackets { - if p < lowestUnacked { - deleteIndex = i + 1 - } - } - h.skippedPackets = h.skippedPackets[deleteIndex:] -} diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 59158aa24..54a4857c3 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -96,7 +96,6 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(2))) expectInPacketHistory([]protocol.PacketNumber{1, 2}) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) - Expect(handler.skippedPackets).To(BeEmpty()) }) It("accepts packet number 0", func() { @@ -106,7 +105,6 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(1))) expectInPacketHistory([]protocol.PacketNumber{0, 1}) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) - Expect(handler.skippedPackets).To(BeEmpty()) }) It("stores the sent time", func() { @@ -128,94 +126,6 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.lastSentRetransmittablePacketTime).To(BeZero()) Expect(handler.bytesInFlight).To(BeZero()) }) - - Context("skipped packet numbers", func() { - It("works with non-consecutive packet numbers", func() { - handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1})) - handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 3})) - Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(3))) - expectInPacketHistory([]protocol.PacketNumber{1, 3}) - Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2})) - }) - - It("works with non-retransmittable packets", func() { - handler.SentPacket(nonRetransmittablePacket(&Packet{PacketNumber: 1})) - handler.SentPacket(nonRetransmittablePacket(&Packet{PacketNumber: 3})) - Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2})) - }) - - It("recognizes multiple skipped packets", func() { - handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1})) - handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 3})) - handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 5})) - Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2, 4})) - }) - - It("recognizes multiple consecutive skipped packets", func() { - handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1})) - handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 4})) - Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2, 3})) - }) - - It("limits the lengths of the skipped packet slice", func() { - for i := protocol.PacketNumber(0); i < protocol.MaxTrackedSkippedPackets+5; i++ { - handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2*i + 1})) - } - Expect(handler.skippedPackets).To(HaveLen(protocol.MaxUndecryptablePackets)) - Expect(handler.skippedPackets[0]).To(Equal(protocol.PacketNumber(10))) - Expect(handler.skippedPackets[protocol.MaxTrackedSkippedPackets-1]).To(Equal(protocol.PacketNumber(10 + 2*(protocol.MaxTrackedSkippedPackets-1)))) - }) - - Context("garbage collection", func() { - It("keeps all packet numbers above the LargestAcked", func() { - handler.skippedPackets = []protocol.PacketNumber{2, 5, 8, 10} - handler.largestAcked = 1 - handler.garbageCollectSkippedPackets() - Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2, 5, 8, 10})) - }) - - It("doesn't keep packet numbers below the LargestAcked", func() { - handler.skippedPackets = []protocol.PacketNumber{1, 5, 8, 10} - handler.largestAcked = 5 - handler.garbageCollectSkippedPackets() - Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{8, 10})) - }) - - It("deletes all packet numbers if LargestAcked is sufficiently high", func() { - handler.skippedPackets = []protocol.PacketNumber{1, 5, 10} - handler.largestAcked = 15 - handler.garbageCollectSkippedPackets() - Expect(handler.skippedPackets).To(BeEmpty()) - }) - }) - - Context("ACK handling", func() { - BeforeEach(func() { - handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 10})) - handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 12})) - }) - - It("rejects ACKs for skipped packets", func() { - ack := &wire.AckFrame{ - AckRanges: []wire.AckRange{{Smallest: 10, Largest: 12}}, - } - err := handler.ReceivedAck(ack, 1337, protocol.Encryption1RTT, time.Now()) - Expect(err).To(MatchError("InvalidAckData: Received an ACK for a skipped packet number")) - }) - - It("accepts an ACK that correctly nacks a skipped packet", func() { - ack := &wire.AckFrame{ - AckRanges: []wire.AckRange{ - {Smallest: 12, Largest: 12}, - {Smallest: 10, Largest: 10}, - }, - } - err := handler.ReceivedAck(ack, 1337, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.largestAcked).ToNot(BeZero()) - }) - }) - }) }) Context("ACK processing", func() {