From 416d88990bab1bf3cffe049c6ee670c1116640da Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 26 Nov 2020 13:31:26 +0700 Subject: [PATCH] only skip packet numbers in the application data packet number space --- .../ackhandler/packet_number_generator.go | 46 ++++++++++++++----- .../packet_number_generator_test.go | 21 +++++++-- internal/ackhandler/sent_packet_handler.go | 24 ++++++---- 3 files changed, 67 insertions(+), 24 deletions(-) diff --git a/internal/ackhandler/packet_number_generator.go b/internal/ackhandler/packet_number_generator.go index f203892af..84ba25322 100644 --- a/internal/ackhandler/packet_number_generator.go +++ b/internal/ackhandler/packet_number_generator.go @@ -8,10 +8,35 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" ) -// The packetNumberGenerator generates the packet number for the next packet +type packetNumberGenerator interface { + Peek() protocol.PacketNumber + Pop() protocol.PacketNumber +} + +type sequentialPacketNumberGenerator struct { + next protocol.PacketNumber +} + +var _ packetNumberGenerator = &sequentialPacketNumberGenerator{} + +func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator { + return &sequentialPacketNumberGenerator{next: initial} +} + +func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber { + return p.next +} + +func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber { + next := p.next + p.next++ + return next +} + +// The skippingPacketNumberGenerator generates the packet number for the next packet // it randomly skips a packet number every averagePeriod packets (on average). // It is guaranteed to never skip two consecutive packet numbers. -type packetNumberGenerator struct { +type skippingPacketNumberGenerator struct { rand *mrand.Rand averagePeriod protocol.PacketNumber @@ -19,10 +44,12 @@ type packetNumberGenerator struct { nextToSkip protocol.PacketNumber } -func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator { +var _ packetNumberGenerator = &skippingPacketNumberGenerator{} + +func newSkippingPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) packetNumberGenerator { b := make([]byte, 8) rand.Read(b) // it's not the end of the world if we don't get perfect random here - g := &packetNumberGenerator{ + g := &skippingPacketNumberGenerator{ rand: mrand.New(mrand.NewSource(int64(binary.LittleEndian.Uint64(b)))), next: initial, averagePeriod: averagePeriod, @@ -31,16 +58,13 @@ func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *pac return g } -func (p *packetNumberGenerator) Peek() protocol.PacketNumber { +func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber { return p.next } -func (p *packetNumberGenerator) Pop() protocol.PacketNumber { +func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber { next := p.next - - // generate a new packet number for the next packet - p.next++ - + p.next++ // generate a new packet number for the next packet if p.next == p.nextToSkip { p.next++ p.generateNewSkip() @@ -48,7 +72,7 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber { return next } -func (p *packetNumberGenerator) generateNewSkip() { +func (p *skippingPacketNumberGenerator) generateNewSkip() { // make sure that there are never two consecutive packet numbers that are skipped p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rand.Int31n(int32(2*p.averagePeriod))) } diff --git a/internal/ackhandler/packet_number_generator_test.go b/internal/ackhandler/packet_number_generator_test.go index a862a9c45..4321284f5 100644 --- a/internal/ackhandler/packet_number_generator_test.go +++ b/internal/ackhandler/packet_number_generator_test.go @@ -6,16 +6,29 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Packet Number Generator", func() { - var png *packetNumberGenerator +var _ = Describe("Sequential Packet Number Generator", func() { + It("generates sequential packet numbers", func() { + const initialPN protocol.PacketNumber = 123 + png := newSequentialPacketNumberGenerator(initialPN) + + for i := initialPN; i < initialPN+1000; i++ { + Expect(png.Peek()).To(Equal(i)) + Expect(png.Peek()).To(Equal(i)) + Expect(png.Pop()).To(Equal(i)) + } + }) +}) + +var _ = Describe("Skipping Packet Number Generator", func() { + var png *skippingPacketNumberGenerator const initialPN protocol.PacketNumber = 8 BeforeEach(func() { - png = newPacketNumberGenerator(initialPN, 100) + png = newSkippingPacketNumberGenerator(initialPN, 100).(*skippingPacketNumberGenerator) }) It("can be initialized to return any first packet number", func() { - png = newPacketNumberGenerator(12345, 100) + png = newSkippingPacketNumberGenerator(12345, 100).(*skippingPacketNumberGenerator) Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345))) }) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 2d791a1fb..c9dd12800 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -25,7 +25,7 @@ const ( type packetNumberSpace struct { history *sentPacketHistory - pns *packetNumberGenerator + pns packetNumberGenerator lossTime time.Time lastAckElicitingPacketTime time.Time @@ -34,10 +34,16 @@ type packetNumberSpace struct { largestSent protocol.PacketNumber } -func newPacketNumberSpace(initialPN protocol.PacketNumber, rttStats *utils.RTTStats) *packetNumberSpace { +func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStats *utils.RTTStats) *packetNumberSpace { + var pns packetNumberGenerator + if skipPNs { + pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketAveragePeriodLength) + } else { + pns = newSequentialPacketNumberGenerator(initialPN) + } return &packetNumberSpace{ history: newSentPacketHistory(rttStats), - pns: newPacketNumberGenerator(initialPN, protocol.SkipPacketAveragePeriodLength), + pns: pns, largestSent: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, } @@ -94,7 +100,7 @@ var ( ) func newSentPacketHandler( - initialPacketNumber protocol.PacketNumber, + initialPN protocol.PacketNumber, rttStats *utils.RTTStats, pers protocol.Perspective, tracer logging.ConnectionTracer, @@ -110,9 +116,9 @@ func newSentPacketHandler( return &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerAddressValidated: pers == protocol.PerspectiveClient, - initialPackets: newPacketNumberSpace(initialPacketNumber, rttStats), - handshakePackets: newPacketNumberSpace(0, rttStats), - appDataPackets: newPacketNumberSpace(0, rttStats), + initialPackets: newPacketNumberSpace(initialPN, false, rttStats), + handshakePackets: newPacketNumberSpace(0, false, rttStats), + appDataPackets: newPacketNumberSpace(0, true, rttStats), rttStats: rttStats, congestion: congestion, perspective: pers, @@ -765,8 +771,8 @@ func (h *sentPacketHandler) ResetForRetry() error { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } } - h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), h.rttStats) - h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), h.rttStats) + h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), false, h.rttStats) + h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), true, h.rttStats) oldAlarm := h.alarm h.alarm = time.Time{} if h.tracer != nil {