From 3eb8a134cb3e54c086e349a167945d937437e4e8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 1 Jun 2025 15:52:47 +0800 Subject: [PATCH] congestion: migrate tests away from Ginkgo (#5193) --- internal/congestion/bandwidth_test.go | 12 +- internal/congestion/congestion_suite_test.go | 13 - internal/congestion/cubic_sender_test.go | 1031 +++++++++-------- internal/congestion/cubic_test.go | 395 +++---- internal/congestion/hybrid_slow_start_test.go | 109 +- 5 files changed, 784 insertions(+), 776 deletions(-) delete mode 100644 internal/congestion/congestion_suite_test.go diff --git a/internal/congestion/bandwidth_test.go b/internal/congestion/bandwidth_test.go index c832be14..2545d582 100644 --- a/internal/congestion/bandwidth_test.go +++ b/internal/congestion/bandwidth_test.go @@ -1,14 +1,12 @@ package congestion import ( + "testing" "time" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) -var _ = Describe("Bandwidth", func() { - It("converts from time delta", func() { - Expect(BandwidthFromDelta(1, time.Millisecond)).To(Equal(1000 * BytesPerSecond)) - }) -}) +func TestBandwidthFromDelta(t *testing.T) { + require.Equal(t, 1000*BytesPerSecond, BandwidthFromDelta(1, time.Millisecond)) +} diff --git a/internal/congestion/congestion_suite_test.go b/internal/congestion/congestion_suite_test.go deleted file mode 100644 index 447360f1..00000000 --- a/internal/congestion/congestion_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package congestion - -import ( - "testing" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -func TestCongestion(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Congestion Suite") -} diff --git a/internal/congestion/cubic_sender_test.go b/internal/congestion/cubic_sender_test.go index 08655fbf..82866d28 100644 --- a/internal/congestion/cubic_sender_test.go +++ b/internal/congestion/cubic_sender_test.go @@ -1,13 +1,14 @@ package congestion import ( + "fmt" + "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) const ( @@ -27,501 +28,563 @@ func (c *mockClock) Advance(d time.Duration) { const MaxCongestionWindow = 200 * maxDatagramSize -var _ = Describe("Cubic Sender", func() { - var ( - sender *cubicSender - clock mockClock - bytesInFlight protocol.ByteCount - packetNumber protocol.PacketNumber - ackedPacketNumber protocol.PacketNumber - rttStats utils.RTTStats - ) +type testCubicSender struct { + sender *cubicSender + clock *mockClock + rttStats *utils.RTTStats + bytesInFlight protocol.ByteCount + packetNumber protocol.PacketNumber + ackedPacketNumber protocol.PacketNumber +} - BeforeEach(func() { - bytesInFlight = 0 - packetNumber = 1 - ackedPacketNumber = 0 - clock = mockClock{} - rttStats = utils.RTTStats{} - sender = newCubicSender( +func newTestCubicSender(cubic bool) *testCubicSender { + clock := mockClock{} + rttStats := utils.RTTStats{} + return &testCubicSender{ + clock: &clock, + rttStats: &rttStats, + packetNumber: 1, + sender: newCubicSender( &clock, &rttStats, - true, /*reno*/ + !cubic, protocol.InitialPacketSize, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow, nil, - ) - }) + ), + } +} - SendAvailableSendWindowLen := func(packetLength protocol.ByteCount) int { - var packetsSent int - for sender.CanSend(bytesInFlight) { - sender.OnPacketSent(clock.Now(), bytesInFlight, packetNumber, packetLength, true) - packetNumber++ - packetsSent++ - bytesInFlight += packetLength - } - return packetsSent +func (s *testCubicSender) SendAvailableSendWindowLen(packetLength protocol.ByteCount) int { + var packetsSent int + for s.sender.CanSend(s.bytesInFlight) { + s.sender.OnPacketSent(s.clock.Now(), s.bytesInFlight, s.packetNumber, packetLength, true) + s.packetNumber++ + packetsSent++ + s.bytesInFlight += packetLength + } + return packetsSent +} + +func (s *testCubicSender) AckNPackets(n int) { + s.rttStats.UpdateRTT(60*time.Millisecond, 0) + s.sender.MaybeExitSlowStart() + for range n { + s.ackedPacketNumber++ + s.sender.OnPacketAcked(s.ackedPacketNumber, maxDatagramSize, s.bytesInFlight, s.clock.Now()) + } + s.bytesInFlight -= protocol.ByteCount(n) * maxDatagramSize + s.clock.Advance(time.Millisecond) +} + +func (s *testCubicSender) LoseNPacketsLen(n int, packetLength protocol.ByteCount) { + for range n { + s.ackedPacketNumber++ + s.sender.OnCongestionEvent(s.ackedPacketNumber, packetLength, s.bytesInFlight) + } + s.bytesInFlight -= protocol.ByteCount(n) * packetLength +} + +func (s *testCubicSender) LosePacket(number protocol.PacketNumber) { + s.sender.OnCongestionEvent(number, maxDatagramSize, s.bytesInFlight) + s.bytesInFlight -= maxDatagramSize +} + +func (s *testCubicSender) SendAvailableSendWindow() int { + return s.SendAvailableSendWindowLen(maxDatagramSize) +} + +func (s *testCubicSender) LoseNPackets(n int) { + s.LoseNPacketsLen(n, maxDatagramSize) +} + +func TestCubicSenderStartup(t *testing.T) { + sender := newTestCubicSender(false) + + // At startup make sure we are at the default. + require.Equal(t, defaultWindowTCP, sender.sender.GetCongestionWindow()) + + // Make sure we can send. + require.Zero(t, sender.sender.TimeUntilSend(0)) + require.True(t, sender.sender.CanSend(sender.bytesInFlight)) + + // And that window is un-affected. + require.Equal(t, defaultWindowTCP, sender.sender.GetCongestionWindow()) + + // Fill the send window with data, then verify that we can't send. + sender.SendAvailableSendWindow() + require.False(t, sender.sender.CanSend(sender.bytesInFlight)) +} + +func TestCubicSenderPacing(t *testing.T) { + sender := newTestCubicSender(false) + + // Set up RTT and advance clock + sender.rttStats.UpdateRTT(10*time.Millisecond, 0) + sender.clock.Advance(time.Hour) + + // Fill the send window with data, then verify that we can't send. + sender.SendAvailableSendWindow() + sender.AckNPackets(1) + + // Check that we can't send immediately due to pacing + delay := sender.sender.TimeUntilSend(sender.bytesInFlight) + require.NotZero(t, delay) + require.Less(t, delay.Sub(time.Time(*sender.clock)), time.Hour) +} + +func TestCubicSenderApplicationLimitedSlowStart(t *testing.T) { + sender := newTestCubicSender(false) + + // At startup make sure we can send. + require.True(t, sender.sender.CanSend(0)) + require.Zero(t, sender.sender.TimeUntilSend(0)) + + // Send exactly 10 packets and ensure the CWND ends at 14 packets. + const numberOfAcks = 5 + sender.SendAvailableSendWindow() + for range numberOfAcks { + sender.AckNPackets(2) } - // Normal is that TCP acks every other segment. - AckNPackets := func(n int) { - rttStats.UpdateRTT(60*time.Millisecond, 0) + bytesToSend := sender.sender.GetCongestionWindow() + // It's expected 2 acks will arrive when the bytes_in_flight are greater than + // half the CWND. + require.Equal(t, defaultWindowTCP+maxDatagramSize*2*2, bytesToSend) +} + +func TestCubicSenderExponentialSlowStart(t *testing.T) { + sender := newTestCubicSender(false) + + // At startup make sure we can send. + require.True(t, sender.sender.CanSend(0)) + require.Zero(t, sender.sender.TimeUntilSend(0)) + require.Equal(t, infBandwidth, sender.sender.BandwidthEstimate()) + + const numberOfAcks = 20 + for range numberOfAcks { + // Send our full send window. + sender.SendAvailableSendWindow() + sender.AckNPackets(2) + } + + cwnd := sender.sender.GetCongestionWindow() + require.Equal(t, defaultWindowTCP+maxDatagramSize*2*numberOfAcks, cwnd) + require.Equal(t, BandwidthFromDelta(cwnd, sender.rttStats.SmoothedRTT()), sender.sender.BandwidthEstimate()) +} + +func TestCubicSenderSlowStartPacketLoss(t *testing.T) { + sender := newTestCubicSender(false) + + const numberOfAcks = 10 + for range numberOfAcks { + // Send our full send window. + sender.SendAvailableSendWindow() + sender.AckNPackets(2) + } + sender.SendAvailableSendWindow() + expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + // Lose a packet to exit slow start. + sender.LoseNPackets(1) + packetsInRecoveryWindow := expectedSendWindow / maxDatagramSize + + // We should now have fallen out of slow start with a reduced window. + expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + // Recovery phase. We need to ack every packet in the recovery window before + // we exit recovery. + numberOfPacketsInWindow := expectedSendWindow / maxDatagramSize + sender.AckNPackets(int(packetsInRecoveryWindow)) + sender.SendAvailableSendWindow() + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + // We need to ack an entire window before we increase CWND by 1. + fmt.Println(numberOfPacketsInWindow) + sender.AckNPackets(int(numberOfPacketsInWindow) - 2) + sender.SendAvailableSendWindow() + fmt.Println(sender.clock.Now()) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + // Next ack should increase cwnd by 1. + sender.AckNPackets(1) + expectedSendWindow += maxDatagramSize + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + // Now RTO and ensure slow start gets reset. + require.True(t, sender.sender.hybridSlowStart.Started()) + sender.sender.OnRetransmissionTimeout(true) + require.False(t, sender.sender.hybridSlowStart.Started()) +} + +func TestCubicSenderSlowStartPacketLossPRR(t *testing.T) { + sender := newTestCubicSender(false) + + // Test based on the first example in RFC6937. + // Ack 10 packets in 5 acks to raise the CWND to 20, as in the example. + const numberOfAcks = 5 + for range numberOfAcks { + // Send our full send window. + sender.SendAvailableSendWindow() + sender.AckNPackets(2) + } + sender.SendAvailableSendWindow() + expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + sender.LoseNPackets(1) + + // We should now have fallen out of slow start with a reduced window. + sendWindowBeforeLoss := expectedSendWindow + expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + // Testing TCP proportional rate reduction. + // We should send packets paced over the received acks for the remaining + // outstanding packets. The number of packets before we exit recovery is the + // original CWND minus the packet that has been lost and the one which + // triggered the loss. + remainingPacketsInRecovery := sendWindowBeforeLoss/maxDatagramSize - 2 + + for i := protocol.ByteCount(0); i < remainingPacketsInRecovery; i++ { + sender.AckNPackets(1) + sender.SendAvailableSendWindow() + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + } + + // We need to ack another window before we increase CWND by 1. + numberOfPacketsInWindow := expectedSendWindow / maxDatagramSize + for range numberOfPacketsInWindow { + sender.AckNPackets(1) + require.Equal(t, 1, sender.SendAvailableSendWindow()) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + } + + sender.AckNPackets(1) + expectedSendWindow += maxDatagramSize + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) +} + +func TestCubicSenderSlowStartBurstPacketLossPRR(t *testing.T) { + sender := newTestCubicSender(false) + + // Test based on the second example in RFC6937, though we also implement + // forward acknowledgements, so the first two incoming acks will trigger + // PRR immediately. + // Ack 20 packets in 10 acks to raise the CWND to 30. + const numberOfAcks = 10 + for range numberOfAcks { + // Send our full send window. + sender.SendAvailableSendWindow() + sender.AckNPackets(2) + } + sender.SendAvailableSendWindow() + expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + // Lose one more than the congestion window reduction, so that after loss, + // bytes_in_flight is lesser than the congestion window. + sendWindowAfterLoss := protocol.ByteCount(renoBeta * float32(expectedSendWindow)) + numPacketsToLose := (expectedSendWindow-sendWindowAfterLoss)/maxDatagramSize + 1 + sender.LoseNPackets(int(numPacketsToLose)) + // Immediately after the loss, ensure at least one packet can be sent. + // Losses without subsequent acks can occur with timer based loss detection. + require.True(t, sender.sender.CanSend(sender.bytesInFlight)) + sender.AckNPackets(1) + + // We should now have fallen out of slow start with a reduced window. + expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + // Only 2 packets should be allowed to be sent, per PRR-SSRB + require.Equal(t, 2, sender.SendAvailableSendWindow()) + + // Ack the next packet, which triggers another loss. + sender.LoseNPackets(1) + sender.AckNPackets(1) + + // Send 2 packets to simulate PRR-SSRB. + require.Equal(t, 2, sender.SendAvailableSendWindow()) + + // Ack the next packet, which triggers another loss. + sender.LoseNPackets(1) + sender.AckNPackets(1) + + // Send 2 packets to simulate PRR-SSRB. + require.Equal(t, 2, sender.SendAvailableSendWindow()) + + // Exit recovery and return to sending at the new rate. + for range numberOfAcks { + sender.AckNPackets(1) + require.Equal(t, 1, sender.SendAvailableSendWindow()) + } +} + +func TestCubicSenderRTOCongestionWindow(t *testing.T) { + sender := newTestCubicSender(false) + + require.Equal(t, defaultWindowTCP, sender.sender.GetCongestionWindow()) + require.Equal(t, protocol.MaxByteCount, sender.sender.slowStartThreshold) + + // Expect the window to decrease to the minimum once the RTO fires + // and slow start threshold to be set to 1/2 of the CWND. + sender.sender.OnRetransmissionTimeout(true) + require.Equal(t, 2*maxDatagramSize, sender.sender.GetCongestionWindow()) + require.Equal(t, 5*maxDatagramSize, sender.sender.slowStartThreshold) +} + +func TestCubicSenderTCPCubicResetEpochOnQuiescence(t *testing.T) { + sender := newTestCubicSender(true) + + const maxCongestionWindow = 50 + const maxCongestionWindowBytes = maxCongestionWindow * maxDatagramSize + + numSent := sender.SendAvailableSendWindow() + + // Make sure we fall out of slow start. + savedCwnd := sender.sender.GetCongestionWindow() + sender.LoseNPackets(1) + require.Greater(t, savedCwnd, sender.sender.GetCongestionWindow()) + + // Ack the rest of the outstanding packets to get out of recovery. + for i := 1; i < numSent; i++ { + sender.AckNPackets(1) + } + require.Zero(t, sender.bytesInFlight) + + // Send a new window of data and ack all; cubic growth should occur. + savedCwnd = sender.sender.GetCongestionWindow() + numSent = sender.SendAvailableSendWindow() + for range numSent { + sender.AckNPackets(1) + } + require.Less(t, savedCwnd, sender.sender.GetCongestionWindow()) + require.Greater(t, maxCongestionWindowBytes, sender.sender.GetCongestionWindow()) + require.Zero(t, sender.bytesInFlight) + + // Quiescent time of 100 seconds + sender.clock.Advance(100 * time.Second) + + // Send new window of data and ack one packet. Cubic epoch should have + // been reset; ensure cwnd increase is not dramatic. + savedCwnd = sender.sender.GetCongestionWindow() + sender.SendAvailableSendWindow() + sender.AckNPackets(1) + require.InDelta(t, float64(savedCwnd), float64(sender.sender.GetCongestionWindow()), float64(maxDatagramSize)) + require.Greater(t, maxCongestionWindowBytes, sender.sender.GetCongestionWindow()) +} + +func TestCubicSenderMultipleLossesInOneWindow(t *testing.T) { + sender := newTestCubicSender(false) + + sender.SendAvailableSendWindow() + initialWindow := sender.sender.GetCongestionWindow() + sender.LosePacket(sender.ackedPacketNumber + 1) + postLossWindow := sender.sender.GetCongestionWindow() + require.True(t, initialWindow > postLossWindow) + sender.LosePacket(sender.ackedPacketNumber + 3) + require.Equal(t, postLossWindow, sender.sender.GetCongestionWindow()) + sender.LosePacket(sender.packetNumber - 1) + require.Equal(t, postLossWindow, sender.sender.GetCongestionWindow()) + + // Lose a later packet and ensure the window decreases. + sender.LosePacket(sender.packetNumber) + require.True(t, postLossWindow > sender.sender.GetCongestionWindow()) +} + +func TestCubicSender1ConnectionCongestionAvoidanceAtEndOfRecovery(t *testing.T) { + sender := newTestCubicSender(false) + + // Ack 10 packets in 5 acks to raise the CWND to 20. + const numberOfAcks = 5 + for range numberOfAcks { + // Send our full send window. + sender.SendAvailableSendWindow() + sender.AckNPackets(2) + } + sender.SendAvailableSendWindow() + expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + sender.LoseNPackets(1) + + // We should now have fallen out of slow start with a reduced window. + expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + // No congestion window growth should occur in recovery phase, i.e., until the + // currently outstanding 20 packets are acked. + for range 10 { + // Send our full send window. + sender.SendAvailableSendWindow() + require.True(t, sender.sender.InRecovery()) + sender.AckNPackets(2) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + } + require.False(t, sender.sender.InRecovery()) + + // Out of recovery now. Congestion window should not grow during RTT. + for i := protocol.ByteCount(0); i < expectedSendWindow/maxDatagramSize-2; i += 2 { + // Send our full send window. + sender.SendAvailableSendWindow() + sender.AckNPackets(2) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + } + + // Next ack should cause congestion window to grow by 1MSS. + sender.SendAvailableSendWindow() + sender.AckNPackets(2) + expectedSendWindow += maxDatagramSize + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) +} + +func TestCubicSenderNoPRR(t *testing.T) { + sender := newTestCubicSender(false) + + sender.SendAvailableSendWindow() + sender.LoseNPackets(9) + sender.AckNPackets(1) + + require.Equal(t, protocol.ByteCount(renoBeta*float32(defaultWindowTCP)), sender.sender.GetCongestionWindow()) + windowInPackets := int(renoBeta * float32(defaultWindowTCP) / float32(maxDatagramSize)) + numSent := sender.SendAvailableSendWindow() + require.Equal(t, windowInPackets, numSent) +} + +func TestCubicSenderResetAfterConnectionMigration(t *testing.T) { + sender := newTestCubicSender(false) + + require.Equal(t, defaultWindowTCP, sender.sender.GetCongestionWindow()) + require.Equal(t, protocol.MaxByteCount, sender.sender.slowStartThreshold) + + // Starts with slow start. + const numberOfAcks = 10 + for range numberOfAcks { + // Send our full send window. + sender.SendAvailableSendWindow() + sender.AckNPackets(2) + } + sender.SendAvailableSendWindow() + expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + + // Loses a packet to exit slow start. + sender.LoseNPackets(1) + + // We should now have fallen out of slow start with a reduced window. Slow + // start threshold is also updated. + expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) + require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) + require.Equal(t, expectedSendWindow, sender.sender.slowStartThreshold) + + // Resets cwnd and slow start threshold on connection migrations. + sender.sender.OnConnectionMigration() + require.Equal(t, defaultWindowTCP, sender.sender.GetCongestionWindow()) + require.Equal(t, MaxCongestionWindow, sender.sender.slowStartThreshold) + require.False(t, sender.sender.hybridSlowStart.Started()) +} + +func TestCubicSenderSlowStartsUpToMaximumCongestionWindow(t *testing.T) { + clock := mockClock{} + rttStats := utils.RTTStats{} + const initialMaxCongestionWindow = protocol.MaxCongestionWindowPackets * initialMaxDatagramSize + sender := newCubicSender( + &clock, + &rttStats, + true, + protocol.InitialPacketSize, + initialCongestionWindowPackets*maxDatagramSize, + initialMaxCongestionWindow, + nil, + ) + + for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { sender.MaybeExitSlowStart() - for i := 0; i < n; i++ { - ackedPacketNumber++ - sender.OnPacketAcked(ackedPacketNumber, maxDatagramSize, bytesInFlight, clock.Now()) - } - bytesInFlight -= protocol.ByteCount(n) * maxDatagramSize - clock.Advance(time.Millisecond) + sender.OnPacketAcked(protocol.PacketNumber(i), 1350, sender.GetCongestionWindow(), clock.Now()) + } + require.Equal(t, initialMaxCongestionWindow, sender.GetCongestionWindow()) +} + +func TestCubicSenderMaximumPacketSizeReduction(t *testing.T) { + sender := newTestCubicSender(false) + require.Panics(t, func() { sender.sender.SetMaxDatagramSize(initialMaxDatagramSize - 1) }) +} + +func TestCubicSenderSlowStartsPacketSizeIncrease(t *testing.T) { + clock := mockClock{} + rttStats := utils.RTTStats{} + const initialMaxCongestionWindow = protocol.MaxCongestionWindowPackets * initialMaxDatagramSize + sender := newCubicSender( + &clock, + &rttStats, + true, + protocol.InitialPacketSize, + initialCongestionWindowPackets*maxDatagramSize, + initialMaxCongestionWindow, + nil, + ) + const packetSize = initialMaxDatagramSize + 100 + sender.SetMaxDatagramSize(packetSize) + for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { + sender.OnPacketAcked(protocol.PacketNumber(i), packetSize, sender.GetCongestionWindow(), clock.Now()) + } + const maxCwnd = protocol.MaxCongestionWindowPackets * packetSize + require.True(t, sender.GetCongestionWindow() > maxCwnd) + require.True(t, sender.GetCongestionWindow() <= maxCwnd+packetSize) +} + +func TestCubicSenderLimitCwndIncreaseInCongestionAvoidance(t *testing.T) { + // Enable Cubic. + clock := mockClock{} + rttStats := utils.RTTStats{} + sender := newCubicSender( + &clock, + &rttStats, + false, + protocol.InitialPacketSize, + initialCongestionWindowPackets*maxDatagramSize, + MaxCongestionWindow, + nil, + ) + testSender := &testCubicSender{ + sender: sender, + clock: &clock, + rttStats: &rttStats, } - LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) { - for i := 0; i < n; i++ { - ackedPacketNumber++ - sender.OnCongestionEvent(ackedPacketNumber, packetLength, bytesInFlight) - } - bytesInFlight -= protocol.ByteCount(n) * packetLength + numSent := testSender.SendAvailableSendWindow() + + // Make sure we fall out of slow start. + savedCwnd := sender.GetCongestionWindow() + testSender.LoseNPackets(1) + require.Greater(t, savedCwnd, sender.GetCongestionWindow()) + + // Ack the rest of the outstanding packets to get out of recovery. + for i := 1; i < numSent; i++ { + testSender.AckNPackets(1) } + require.Equal(t, protocol.ByteCount(0), testSender.bytesInFlight) - // Does not increment acked_packet_number_. - LosePacket := func(number protocol.PacketNumber) { - sender.OnCongestionEvent(number, maxDatagramSize, bytesInFlight) - bytesInFlight -= maxDatagramSize + savedCwnd = sender.GetCongestionWindow() + testSender.SendAvailableSendWindow() + + // Ack packets until the CWND increases. + for sender.GetCongestionWindow() == savedCwnd { + testSender.AckNPackets(1) + testSender.SendAvailableSendWindow() } - - SendAvailableSendWindow := func() int { return SendAvailableSendWindowLen(maxDatagramSize) } - LoseNPackets := func(n int) { LoseNPacketsLen(n, maxDatagramSize) } - - It("has the right values at startup", func() { - // At startup make sure we are at the default. - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - // Make sure we can send. - Expect(sender.TimeUntilSend(0)).To(BeZero()) - Expect(sender.CanSend(bytesInFlight)).To(BeTrue()) - // And that window is un-affected. - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - - // Fill the send window with data, then verify that we can't send. - SendAvailableSendWindow() - Expect(sender.CanSend(bytesInFlight)).To(BeFalse()) - }) - - It("paces", func() { - rttStats.UpdateRTT(10*time.Millisecond, 0) - clock.Advance(time.Hour) - // Fill the send window with data, then verify that we can't send. - SendAvailableSendWindow() - AckNPackets(1) - delay := sender.TimeUntilSend(bytesInFlight) - Expect(delay).ToNot(BeZero()) - Expect(delay.Sub(clock.Now())).To(BeNumerically("<", time.Hour)) - }) - - It("application limited slow start", func() { - // Send exactly 10 packets and ensure the CWND ends at 14 packets. - const numberOfAcks = 5 - // At startup make sure we can send. - Expect(sender.CanSend(0)).To(BeTrue()) - Expect(sender.TimeUntilSend(0)).To(BeZero()) - - SendAvailableSendWindow() - for i := 0; i < numberOfAcks; i++ { - AckNPackets(2) - } - bytesToSend := sender.GetCongestionWindow() - // It's expected 2 acks will arrive when the bytes_in_flight are greater than - // half the CWND. - Expect(bytesToSend).To(Equal(defaultWindowTCP + maxDatagramSize*2*2)) - }) - - It("exponential slow start", func() { - const numberOfAcks = 20 - // At startup make sure we can send. - Expect(sender.CanSend(0)).To(BeTrue()) - Expect(sender.TimeUntilSend(0)).To(BeZero()) - Expect(sender.BandwidthEstimate()).To(Equal(infBandwidth)) - // Make sure we can send. - Expect(sender.TimeUntilSend(0)).To(BeZero()) - - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - cwnd := sender.GetCongestionWindow() - Expect(cwnd).To(Equal(defaultWindowTCP + maxDatagramSize*2*numberOfAcks)) - Expect(sender.BandwidthEstimate()).To(Equal(BandwidthFromDelta(cwnd, rttStats.SmoothedRTT()))) - }) - - It("slow start packet loss", func() { - const numberOfAcks = 10 - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - SendAvailableSendWindow() - expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Lose a packet to exit slow start. - LoseNPackets(1) - packetsInRecoveryWindow := expectedSendWindow / maxDatagramSize - - // We should now have fallen out of slow start with a reduced window. - expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Recovery phase. We need to ack every packet in the recovery window before - // we exit recovery. - numberOfPacketsInWindow := expectedSendWindow / maxDatagramSize - AckNPackets(int(packetsInRecoveryWindow)) - SendAvailableSendWindow() - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // We need to ack an entire window before we increase CWND by 1. - AckNPackets(int(numberOfPacketsInWindow) - 2) - SendAvailableSendWindow() - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Next ack should increase cwnd by 1. - AckNPackets(1) - expectedSendWindow += maxDatagramSize - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Now RTO and ensure slow start gets reset. - Expect(sender.hybridSlowStart.Started()).To(BeTrue()) - sender.OnRetransmissionTimeout(true) - Expect(sender.hybridSlowStart.Started()).To(BeFalse()) - }) - - It("slow start packet loss PRR", func() { - // Test based on the first example in RFC6937. - // Ack 10 packets in 5 acks to raise the CWND to 20, as in the example. - const numberOfAcks = 5 - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - SendAvailableSendWindow() - expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - LoseNPackets(1) - - // We should now have fallen out of slow start with a reduced window. - sendWindowBeforeLoss := expectedSendWindow - expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Testing TCP proportional rate reduction. - // We should send packets paced over the received acks for the remaining - // outstanding packets. The number of packets before we exit recovery is the - // original CWND minus the packet that has been lost and the one which - // triggered the loss. - remainingPacketsInRecovery := sendWindowBeforeLoss/maxDatagramSize - 2 - - for i := protocol.ByteCount(0); i < remainingPacketsInRecovery; i++ { - AckNPackets(1) - SendAvailableSendWindow() - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - } - - // We need to ack another window before we increase CWND by 1. - numberOfPacketsInWindow := expectedSendWindow / maxDatagramSize - for i := protocol.ByteCount(0); i < numberOfPacketsInWindow; i++ { - AckNPackets(1) - Expect(SendAvailableSendWindow()).To(Equal(1)) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - } - - AckNPackets(1) - expectedSendWindow += maxDatagramSize - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - }) - - It("slow start burst packet loss PRR", func() { - // Test based on the second example in RFC6937, though we also implement - // forward acknowledgements, so the first two incoming acks will trigger - // PRR immediately. - // Ack 20 packets in 10 acks to raise the CWND to 30. - const numberOfAcks = 10 - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - SendAvailableSendWindow() - expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Lose one more than the congestion window reduction, so that after loss, - // bytes_in_flight is lesser than the congestion window. - sendWindowAfterLoss := protocol.ByteCount(renoBeta * float32(expectedSendWindow)) - numPacketsToLose := (expectedSendWindow-sendWindowAfterLoss)/maxDatagramSize + 1 - LoseNPackets(int(numPacketsToLose)) - // Immediately after the loss, ensure at least one packet can be sent. - // Losses without subsequent acks can occur with timer based loss detection. - Expect(sender.CanSend(bytesInFlight)).To(BeTrue()) - AckNPackets(1) - - // We should now have fallen out of slow start with a reduced window. - expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Only 2 packets should be allowed to be sent, per PRR-SSRB - Expect(SendAvailableSendWindow()).To(Equal(2)) - - // Ack the next packet, which triggers another loss. - LoseNPackets(1) - AckNPackets(1) - - // Send 2 packets to simulate PRR-SSRB. - Expect(SendAvailableSendWindow()).To(Equal(2)) - - // Ack the next packet, which triggers another loss. - LoseNPackets(1) - AckNPackets(1) - - // Send 2 packets to simulate PRR-SSRB. - Expect(SendAvailableSendWindow()).To(Equal(2)) - - // Exit recovery and return to sending at the new rate. - for i := 0; i < numberOfAcks; i++ { - AckNPackets(1) - Expect(SendAvailableSendWindow()).To(Equal(1)) - } - }) - - It("RTO congestion window", func() { - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - Expect(sender.slowStartThreshold).To(Equal(protocol.MaxByteCount)) - - // Expect the window to decrease to the minimum once the RTO fires - // and slow start threshold to be set to 1/2 of the CWND. - sender.OnRetransmissionTimeout(true) - Expect(sender.GetCongestionWindow()).To(Equal(2 * maxDatagramSize)) - Expect(sender.slowStartThreshold).To(Equal(5 * maxDatagramSize)) - }) - - It("RTO congestion window no retransmission", func() { - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - - // Expect the window to remain unchanged if the RTO fires but no - // packets are retransmitted. - sender.OnRetransmissionTimeout(false) - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - }) - - It("tcp cubic reset epoch on quiescence", func() { - const maxCongestionWindow = 50 - const maxCongestionWindowBytes = maxCongestionWindow * maxDatagramSize - sender = newCubicSender(&clock, &rttStats, false, protocol.InitialPacketSize, initialCongestionWindowPackets*maxDatagramSize, maxCongestionWindowBytes, nil) - - numSent := SendAvailableSendWindow() - - // Make sure we fall out of slow start. - savedCwnd := sender.GetCongestionWindow() - LoseNPackets(1) - Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow())) - - // Ack the rest of the outstanding packets to get out of recovery. - for i := 1; i < numSent; i++ { - AckNPackets(1) - } - Expect(bytesInFlight).To(BeZero()) - - // Send a new window of data and ack all; cubic growth should occur. - savedCwnd = sender.GetCongestionWindow() - numSent = SendAvailableSendWindow() - for i := 0; i < numSent; i++ { - AckNPackets(1) - } - Expect(savedCwnd).To(BeNumerically("<", sender.GetCongestionWindow())) - Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow())) - Expect(bytesInFlight).To(BeZero()) - - // Quiescent time of 100 seconds - clock.Advance(100 * time.Second) - - // Send new window of data and ack one packet. Cubic epoch should have - // been reset; ensure cwnd increase is not dramatic. - savedCwnd = sender.GetCongestionWindow() - SendAvailableSendWindow() - AckNPackets(1) - Expect(savedCwnd).To(BeNumerically("~", sender.GetCongestionWindow(), maxDatagramSize)) - Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow())) - }) - - It("multiple losses in one window", func() { - SendAvailableSendWindow() - initialWindow := sender.GetCongestionWindow() - LosePacket(ackedPacketNumber + 1) - postLossWindow := sender.GetCongestionWindow() - Expect(initialWindow).To(BeNumerically(">", postLossWindow)) - LosePacket(ackedPacketNumber + 3) - Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow)) - LosePacket(packetNumber - 1) - Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow)) - - // Lose a later packet and ensure the window decreases. - LosePacket(packetNumber) - Expect(postLossWindow).To(BeNumerically(">", sender.GetCongestionWindow())) - }) - - It("1 connection congestion avoidance at end of recovery", func() { - // Ack 10 packets in 5 acks to raise the CWND to 20. - const numberOfAcks = 5 - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - SendAvailableSendWindow() - expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - LoseNPackets(1) - - // We should now have fallen out of slow start with a reduced window. - expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // No congestion window growth should occur in recovery phase, i.e., until the - // currently outstanding 20 packets are acked. - for i := 0; i < 10; i++ { - // Send our full send window. - SendAvailableSendWindow() - Expect(sender.InRecovery()).To(BeTrue()) - AckNPackets(2) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - } - Expect(sender.InRecovery()).To(BeFalse()) - - // Out of recovery now. Congestion window should not grow during RTT. - for i := protocol.ByteCount(0); i < expectedSendWindow/maxDatagramSize-2; i += 2 { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - } - - // Next ack should cause congestion window to grow by 1MSS. - SendAvailableSendWindow() - AckNPackets(2) - expectedSendWindow += maxDatagramSize - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - }) - - It("no PRR", func() { - SendAvailableSendWindow() - LoseNPackets(9) - AckNPackets(1) - - Expect(sender.GetCongestionWindow()).To(Equal(protocol.ByteCount(renoBeta * float32(defaultWindowTCP)))) - windowInPackets := renoBeta * float32(defaultWindowTCP) / float32(maxDatagramSize) - numSent := SendAvailableSendWindow() - Expect(numSent).To(BeEquivalentTo(windowInPackets)) - }) - - It("reset after connection migration", func() { - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - Expect(sender.slowStartThreshold).To(Equal(protocol.MaxByteCount)) - - // Starts with slow start. - const numberOfAcks = 10 - for i := 0; i < numberOfAcks; i++ { - // Send our full send window. - SendAvailableSendWindow() - AckNPackets(2) - } - SendAvailableSendWindow() - expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - - // Loses a packet to exit slow start. - LoseNPackets(1) - - // We should now have fallen out of slow start with a reduced window. Slow - // start threshold is also updated. - expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) - Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow)) - Expect(sender.slowStartThreshold).To(Equal(expectedSendWindow)) - - // Resets cwnd and slow start threshold on connection migrations. - sender.OnConnectionMigration() - Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) - Expect(sender.slowStartThreshold).To(Equal(MaxCongestionWindow)) - Expect(sender.hybridSlowStart.Started()).To(BeFalse()) - }) - - It("slow starts up to the maximum congestion window", func() { - const initialMaxCongestionWindow = protocol.MaxCongestionWindowPackets * initialMaxDatagramSize - sender = newCubicSender(&clock, &rttStats, true, protocol.InitialPacketSize, initialCongestionWindowPackets*maxDatagramSize, initialMaxCongestionWindow, nil) - - for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { - sender.MaybeExitSlowStart() - sender.OnPacketAcked(protocol.PacketNumber(i), 1350, sender.GetCongestionWindow(), clock.Now()) - } - Expect(sender.GetCongestionWindow()).To(Equal(initialMaxCongestionWindow)) - }) - - It("doesn't allow reductions of the maximum packet size", func() { - Expect(func() { sender.SetMaxDatagramSize(initialMaxDatagramSize - 1) }).To(Panic()) - }) - - It("slow starts up to maximum congestion window, if larger packets are sent", func() { - const initialMaxCongestionWindow = protocol.MaxCongestionWindowPackets * initialMaxDatagramSize - sender = newCubicSender(&clock, &rttStats, true, protocol.InitialPacketSize, initialCongestionWindowPackets*maxDatagramSize, initialMaxCongestionWindow, nil) - const packetSize = initialMaxDatagramSize + 100 - sender.SetMaxDatagramSize(packetSize) - for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { - sender.OnPacketAcked(protocol.PacketNumber(i), packetSize, sender.GetCongestionWindow(), clock.Now()) - } - const maxCwnd = protocol.MaxCongestionWindowPackets * packetSize - Expect(sender.GetCongestionWindow()).To(And( - BeNumerically(">", maxCwnd), - BeNumerically("<=", maxCwnd+packetSize), - )) - }) - - It("limit cwnd increase in congestion avoidance", func() { - // Enable Cubic. - sender = newCubicSender(&clock, &rttStats, false, protocol.InitialPacketSize, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow, nil) - numSent := SendAvailableSendWindow() - - // Make sure we fall out of slow start. - savedCwnd := sender.GetCongestionWindow() - LoseNPackets(1) - Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow())) - - // Ack the rest of the outstanding packets to get out of recovery. - for i := 1; i < numSent; i++ { - AckNPackets(1) - } - Expect(bytesInFlight).To(BeZero()) - - savedCwnd = sender.GetCongestionWindow() - SendAvailableSendWindow() - - // Ack packets until the CWND increases. - for sender.GetCongestionWindow() == savedCwnd { - AckNPackets(1) - SendAvailableSendWindow() - } - // Bytes in flight may be larger than the CWND if the CWND isn't an exact - // multiple of the packet sizes being sent. - Expect(bytesInFlight).To(BeNumerically(">=", sender.GetCongestionWindow())) - savedCwnd = sender.GetCongestionWindow() - - // Advance time 2 seconds waiting for an ack. - clock.Advance(2 * time.Second) - - // Ack two packets. The CWND should increase by only one packet. - AckNPackets(2) - Expect(sender.GetCongestionWindow()).To(Equal(savedCwnd + maxDatagramSize)) - }) -}) + // Bytes in flight may be larger than the CWND if the CWND isn't an exact + // multiple of the packet sizes being sent. + require.GreaterOrEqual(t, testSender.bytesInFlight, sender.GetCongestionWindow()) + savedCwnd = sender.GetCongestionWindow() + + // Advance time 2 seconds waiting for an ack. + clock.Advance(2 * time.Second) + + // Ack two packets. The CWND should increase by only one packet. + testSender.AckNPackets(2) + require.Equal(t, savedCwnd+maxDatagramSize, sender.GetCongestionWindow()) +} diff --git a/internal/congestion/cubic_test.go b/internal/congestion/cubic_test.go index c1ad621d..181c94a1 100644 --- a/internal/congestion/cubic_test.go +++ b/internal/congestion/cubic_test.go @@ -2,12 +2,11 @@ package congestion import ( "math" + "testing" "time" "github.com/quic-go/quic-go/internal/protocol" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) const ( @@ -18,223 +17,189 @@ const ( maxCubicTimeInterval = 30 * time.Millisecond ) -var _ = Describe("Cubic", func() { - var ( - clock mockClock - cubic *Cubic - ) +func renoCwnd(currentCwnd protocol.ByteCount) protocol.ByteCount { + return currentCwnd + protocol.ByteCount(float32(maxDatagramSize)*nConnectionAlpha*float32(maxDatagramSize)/float32(currentCwnd)) +} - BeforeEach(func() { - clock = mockClock{} - cubic = NewCubic(&clock) - cubic.SetNumConnections(int(numConnections)) - }) +func cubicConvexCwnd(initialCwnd protocol.ByteCount, rtt, elapsedTime time.Duration) protocol.ByteCount { + offset := protocol.ByteCount((elapsedTime+rtt)/time.Microsecond) << 10 / 1000000 + deltaCongestionWindow := 410 * offset * offset * offset * maxDatagramSize >> 40 + return initialCwnd + deltaCongestionWindow +} - renoCwnd := func(currentCwnd protocol.ByteCount) protocol.ByteCount { - return currentCwnd + protocol.ByteCount(float32(maxDatagramSize)*nConnectionAlpha*float32(maxDatagramSize)/float32(currentCwnd)) - } +func TestCubicAboveOriginWithTighterBounds(t *testing.T) { + clock := mockClock{} + cubic := NewCubic(&clock) + cubic.SetNumConnections(int(numConnections)) - cubicConvexCwnd := func(initialCwnd protocol.ByteCount, rtt, elapsedTime time.Duration) protocol.ByteCount { - offset := protocol.ByteCount((elapsedTime+rtt)/time.Microsecond) << 10 / 1000000 - deltaCongestionWindow := 410 * offset * offset * offset * maxDatagramSize >> 40 - return initialCwnd + deltaCongestionWindow - } + // Convex growth. + const rttMin = 100 * time.Millisecond + const rttMinS = float32(rttMin/time.Millisecond) / 1000.0 + currentCwnd := 10 * maxDatagramSize + initialCwnd := currentCwnd - It("works above origin (with tighter bounds)", func() { - // Convex growth. - const rttMin = 100 * time.Millisecond - const rttMinS = float32(rttMin/time.Millisecond) / 1000.0 - currentCwnd := 10 * maxDatagramSize - initialCwnd := currentCwnd + clock.Advance(time.Millisecond) + initialTime := clock.Now() + expectedFirstCwnd := renoCwnd(currentCwnd) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, initialTime) + require.Equal(t, expectedFirstCwnd, currentCwnd) - clock.Advance(time.Millisecond) - initialTime := clock.Now() - expectedFirstCwnd := renoCwnd(currentCwnd) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, initialTime) - Expect(expectedFirstCwnd).To(Equal(currentCwnd)) + // Normal TCP phase. + // The maximum number of expected reno RTTs can be calculated by + // finding the point where the cubic curve and the reno curve meet. + maxRenoRtts := int(math.Sqrt(float64(nConnectionAlpha/(0.4*rttMinS*rttMinS*rttMinS))) - 2) + for range maxRenoRtts { + numAcksThisEpoch := int(float32(currentCwnd/maxDatagramSize) / nConnectionAlpha) - // Normal TCP phase. - // The maximum number of expected reno RTTs can be calculated by - // finding the point where the cubic curve and the reno curve meet. - maxRenoRtts := int(math.Sqrt(float64(nConnectionAlpha/(0.4*rttMinS*rttMinS*rttMinS))) - 2) - for i := 0; i < maxRenoRtts; i++ { - // Alternatively, we expect it to increase by one, every time we - // receive current_cwnd/Alpha acks back. (This is another way of - // saying we expect cwnd to increase by approximately Alpha once - // we receive current_cwnd number ofacks back). - numAcksThisEpoch := int(float32(currentCwnd/maxDatagramSize) / nConnectionAlpha) - - initialCwndThisEpoch := currentCwnd - for n := 0; n < numAcksThisEpoch; n++ { - // Call once per ACK. - expectedNextCwnd := renoCwnd(currentCwnd) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - Expect(currentCwnd).To(Equal(expectedNextCwnd)) - } - // Our byte-wise Reno implementation is an estimate. We expect - // the cwnd to increase by approximately one MSS every - // cwnd/kDefaultTCPMSS/Alpha acks, but it may be off by as much as - // half a packet for smaller values of current_cwnd. - cwndChangeThisEpoch := currentCwnd - initialCwndThisEpoch - Expect(cwndChangeThisEpoch).To(BeNumerically("~", maxDatagramSize, maxDatagramSize/2)) - clock.Advance(100 * time.Millisecond) - } - - for i := 0; i < 54; i++ { - maxAcksThisEpoch := currentCwnd / maxDatagramSize - interval := time.Duration(100*1000/maxAcksThisEpoch) * time.Microsecond - for n := 0; n < int(maxAcksThisEpoch); n++ { - clock.Advance(interval) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) - // If we allow per-ack updates, every update is a small cubic update. - Expect(currentCwnd).To(Equal(expectedCwnd)) - } - } - expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - Expect(currentCwnd).To(Equal(expectedCwnd)) - }) - - It("works above the origin with fine grained cubing", func() { - // Start the test with an artificially large cwnd to prevent Reno - // from over-taking cubic. - currentCwnd := 1000 * maxDatagramSize - initialCwnd := currentCwnd - rttMin := 100 * time.Millisecond - clock.Advance(time.Millisecond) - initialTime := clock.Now() - - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - clock.Advance(600 * time.Millisecond) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - - // We expect the algorithm to perform only non-zero, fine-grained cubic - // increases on every ack in this case. - for i := 0; i < 100; i++ { - clock.Advance(10 * time.Millisecond) - expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) - nextCwnd := cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - // Make sure we are performing cubic increases. - Expect(nextCwnd).To(Equal(expectedCwnd)) - // Make sure that these are non-zero, less-than-packet sized increases. - Expect(nextCwnd).To(BeNumerically(">", currentCwnd)) - cwndDelta := nextCwnd - currentCwnd - Expect(maxDatagramSize / 10).To(BeNumerically(">", cwndDelta)) - currentCwnd = nextCwnd - } - }) - - It("handles per ack updates", func() { - // Start the test with a large cwnd and RTT, to force the first - // increase to be a cubic increase. - initialCwndPackets := 150 - currentCwnd := protocol.ByteCount(initialCwndPackets) * maxDatagramSize - rttMin := 350 * time.Millisecond - - // Initialize the epoch - clock.Advance(time.Millisecond) - // Keep track of the growth of the reno-equivalent cwnd. - rCwnd := renoCwnd(currentCwnd) - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - initialCwnd := currentCwnd - - // Simulate the return of cwnd packets in less than - // MaxCubicInterval() time. - maxAcks := int(float32(initialCwndPackets) / nConnectionAlpha) - interval := maxCubicTimeInterval / time.Duration(maxAcks+1) - - // In this scenario, the first increase is dictated by the cubic - // equation, but it is less than one byte, so the cwnd doesn't - // change. Normally, without per-ack increases, any cwnd plateau - // will cause the cwnd to be pinned for MaxCubicTimeInterval(). If - // we enable per-ack updates, the cwnd will continue to grow, - // regardless of the temporary plateau. - clock.Advance(interval) - rCwnd = renoCwnd(rCwnd) - Expect(cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())).To(Equal(currentCwnd)) - for i := 1; i < maxAcks; i++ { - clock.Advance(interval) - nextCwnd := cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - rCwnd = renoCwnd(rCwnd) - // The window shoud increase on every ack. - Expect(nextCwnd).To(BeNumerically(">", currentCwnd)) - Expect(nextCwnd).To(Equal(rCwnd)) - currentCwnd = nextCwnd - } - - // After all the acks are returned from the epoch, we expect the - // cwnd to have increased by nearly one packet. (Not exactly one - // packet, because our byte-wise Reno algorithm is always a slight - // under-estimation). Without per-ack updates, the current_cwnd - // would otherwise be unchanged. - minimumExpectedIncrease := maxDatagramSize * 9 / 10 - Expect(currentCwnd).To(BeNumerically(">", initialCwnd+minimumExpectedIncrease)) - }) - - It("handles loss events", func() { - rttMin := 100 * time.Millisecond - currentCwnd := 422 * maxDatagramSize - expectedCwnd := renoCwnd(currentCwnd) - // Initialize the state. - clock.Advance(time.Millisecond) - Expect(cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd)) - - // On the first loss, the last max congestion window is set to the - // congestion window before the loss. - preLossCwnd := currentCwnd - Expect(cubic.lastMaxCongestionWindow).To(BeZero()) - expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) - Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) - Expect(cubic.lastMaxCongestionWindow).To(Equal(preLossCwnd)) - currentCwnd = expectedCwnd - - // On the second loss, the current congestion window has not yet - // reached the last max congestion window. The last max congestion - // window will be reduced by an additional backoff factor to allow - // for competition. - preLossCwnd = currentCwnd - expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) - Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) - currentCwnd = expectedCwnd - Expect(preLossCwnd).To(BeNumerically(">", cubic.lastMaxCongestionWindow)) - expectedLastMax := protocol.ByteCount(float32(preLossCwnd) * nConnectionBetaLastMax) - Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax)) - Expect(expectedCwnd).To(BeNumerically("<", cubic.lastMaxCongestionWindow)) - // Simulate an increase, and check that we are below the origin. - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - Expect(cubic.lastMaxCongestionWindow).To(BeNumerically(">", currentCwnd)) - - // On the final loss, simulate the condition where the congestion - // window had a chance to grow nearly to the last congestion window. - currentCwnd = cubic.lastMaxCongestionWindow - 1 - preLossCwnd = currentCwnd - expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) - Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) - expectedLastMax = preLossCwnd - Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax)) - }) - - It("works below origin", func() { - // Concave growth. - rttMin := 100 * time.Millisecond - currentCwnd := 422 * maxDatagramSize - expectedCwnd := renoCwnd(currentCwnd) - // Initialize the state. - clock.Advance(time.Millisecond) - Expect(cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd)) - - expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) - Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd)) - currentCwnd = expectedCwnd - // First update after loss to initialize the epoch. - currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) - // Cubic phase. - for i := 0; i < 40; i++ { - clock.Advance(100 * time.Millisecond) + initialCwndThisEpoch := currentCwnd + for range numAcksThisEpoch { + // Call once per ACK. + expectedNextCwnd := renoCwnd(currentCwnd) currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + require.Equal(t, expectedNextCwnd, currentCwnd) } - expectedCwnd = 553632 * maxDatagramSize / 1460 - Expect(currentCwnd).To(Equal(expectedCwnd)) - }) -}) + cwndChangeThisEpoch := currentCwnd - initialCwndThisEpoch + require.InDelta(t, float64(maxDatagramSize), float64(cwndChangeThisEpoch), float64(maxDatagramSize)/2) + clock.Advance(100 * time.Millisecond) + } + + for range 54 { + maxAcksThisEpoch := currentCwnd / maxDatagramSize + interval := time.Duration(100*1000/maxAcksThisEpoch) * time.Microsecond + for range int(maxAcksThisEpoch) { + clock.Advance(interval) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) + require.Equal(t, expectedCwnd, currentCwnd) + } + } + expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + require.Equal(t, expectedCwnd, currentCwnd) +} + +func TestCubicAboveOriginWithFineGrainedCubing(t *testing.T) { + clock := mockClock{} + cubic := NewCubic(&clock) + cubic.SetNumConnections(int(numConnections)) + + currentCwnd := 1000 * maxDatagramSize + initialCwnd := currentCwnd + rttMin := 100 * time.Millisecond + clock.Advance(time.Millisecond) + initialTime := clock.Now() + + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + clock.Advance(600 * time.Millisecond) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + + for i := 0; i < 100; i++ { + clock.Advance(10 * time.Millisecond) + expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) + nextCwnd := cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + require.Equal(t, expectedCwnd, nextCwnd) + require.Greater(t, nextCwnd, currentCwnd) + cwndDelta := nextCwnd - currentCwnd + require.Less(t, cwndDelta, maxDatagramSize/10) + currentCwnd = nextCwnd + } +} + +func TestCubicHandlesPerAckUpdates(t *testing.T) { + clock := mockClock{} + cubic := NewCubic(&clock) + cubic.SetNumConnections(int(numConnections)) + + initialCwndPackets := 150 + currentCwnd := protocol.ByteCount(initialCwndPackets) * maxDatagramSize + rttMin := 350 * time.Millisecond + + clock.Advance(time.Millisecond) + rCwnd := renoCwnd(currentCwnd) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + initialCwnd := currentCwnd + + maxAcks := int(float32(initialCwndPackets) / nConnectionAlpha) + interval := maxCubicTimeInterval / time.Duration(maxAcks+1) + + clock.Advance(interval) + rCwnd = renoCwnd(rCwnd) + require.Equal(t, currentCwnd, cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())) + + for range maxAcks - 1 { + clock.Advance(interval) + nextCwnd := cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + rCwnd = renoCwnd(rCwnd) + require.Greater(t, nextCwnd, currentCwnd) + require.Equal(t, rCwnd, nextCwnd) + currentCwnd = nextCwnd + } + + minimumExpectedIncrease := maxDatagramSize * 9 / 10 + require.Greater(t, currentCwnd, initialCwnd+minimumExpectedIncrease) +} + +func TestCubicHandlesLossEvents(t *testing.T) { + clock := mockClock{} + cubic := NewCubic(&clock) + cubic.SetNumConnections(int(numConnections)) + + rttMin := 100 * time.Millisecond + currentCwnd := 422 * maxDatagramSize + expectedCwnd := renoCwnd(currentCwnd) + + clock.Advance(time.Millisecond) + require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())) + + preLossCwnd := currentCwnd + require.Zero(t, cubic.lastMaxCongestionWindow) + expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) + require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterPacketLoss(currentCwnd)) + require.Equal(t, preLossCwnd, cubic.lastMaxCongestionWindow) + currentCwnd = expectedCwnd + + preLossCwnd = currentCwnd + expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) + require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterPacketLoss(currentCwnd)) + currentCwnd = expectedCwnd + require.Greater(t, preLossCwnd, cubic.lastMaxCongestionWindow) + expectedLastMax := protocol.ByteCount(float32(preLossCwnd) * nConnectionBetaLastMax) + require.Equal(t, expectedLastMax, cubic.lastMaxCongestionWindow) + require.Less(t, expectedCwnd, cubic.lastMaxCongestionWindow) + + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + require.Greater(t, cubic.lastMaxCongestionWindow, currentCwnd) + + currentCwnd = cubic.lastMaxCongestionWindow - 1 + preLossCwnd = currentCwnd + expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) + require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterPacketLoss(currentCwnd)) + expectedLastMax = preLossCwnd + require.Equal(t, expectedLastMax, cubic.lastMaxCongestionWindow) +} + +func TestCubicBelowOrigin(t *testing.T) { + clock := mockClock{} + cubic := NewCubic(&clock) + cubic.SetNumConnections(int(numConnections)) + + rttMin := 100 * time.Millisecond + currentCwnd := 422 * maxDatagramSize + expectedCwnd := renoCwnd(currentCwnd) + + clock.Advance(time.Millisecond) + require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())) + + expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) + require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterPacketLoss(currentCwnd)) + currentCwnd = expectedCwnd + + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + + for range 40 { + clock.Advance(100 * time.Millisecond) + currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) + } + expectedCwnd = 553632 * maxDatagramSize / 1460 + require.Equal(t, expectedCwnd, currentCwnd) +} diff --git a/internal/congestion/hybrid_slow_start_test.go b/internal/congestion/hybrid_slow_start_test.go index 5d951799..c325529b 100644 --- a/internal/congestion/hybrid_slow_start_test.go +++ b/internal/congestion/hybrid_slow_start_test.go @@ -1,73 +1,68 @@ package congestion import ( + "testing" "time" "github.com/quic-go/quic-go/internal/protocol" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) -var _ = Describe("Hybrid slow start", func() { - var slowStart HybridSlowStart +func TestHybridSlowStartSimpleCase(t *testing.T) { + slowStart := HybridSlowStart{} - BeforeEach(func() { - slowStart = HybridSlowStart{} - }) + packetNumber := protocol.PacketNumber(1) + endPacketNumber := protocol.PacketNumber(3) + slowStart.StartReceiveRound(endPacketNumber) - It("works in a simple case", func() { - packetNumber := protocol.PacketNumber(1) - endPacketNumber := protocol.PacketNumber(3) - slowStart.StartReceiveRound(endPacketNumber) + packetNumber++ + require.False(t, slowStart.IsEndOfRound(packetNumber)) + // Test duplicates. + require.False(t, slowStart.IsEndOfRound(packetNumber)) + + packetNumber++ + require.False(t, slowStart.IsEndOfRound(packetNumber)) + packetNumber++ + require.True(t, slowStart.IsEndOfRound(packetNumber)) + + // Test without a new registered end_packet_number; + packetNumber++ + require.True(t, slowStart.IsEndOfRound(packetNumber)) + + endPacketNumber = 20 + slowStart.StartReceiveRound(endPacketNumber) + for packetNumber < endPacketNumber { packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) + require.False(t, slowStart.IsEndOfRound(packetNumber)) + } + packetNumber++ + require.True(t, slowStart.IsEndOfRound(packetNumber)) +} - // Test duplicates. - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) +func TestHybridSlowStartWithDelay(t *testing.T) { + slowStart := HybridSlowStart{} + const rtt = 60 * time.Millisecond + // We expect to detect the increase at +1/8 of the RTT; hence at a typical + // RTT of 60ms the detection will happen at 67.5 ms. + const hybridStartMinSamples = 8 // Number of acks required to trigger. - packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) - packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue()) + endPacketNumber := protocol.PacketNumber(1) + endPacketNumber++ + slowStart.StartReceiveRound(endPacketNumber) - // Test without a new registered end_packet_number; - packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue()) - - endPacketNumber = 20 - slowStart.StartReceiveRound(endPacketNumber) - for packetNumber < endPacketNumber { - packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse()) - } - packetNumber++ - Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue()) - }) - - It("works with delay", func() { - rtt := 60 * time.Millisecond - // We expect to detect the increase at +1/8 of the RTT; hence at a typical - // RTT of 60ms the detection will happen at 67.5 ms. - const hybridStartMinSamples = 8 // Number of acks required to trigger. - - endPacketNumber := protocol.PacketNumber(1) - endPacketNumber++ - slowStart.StartReceiveRound(endPacketNumber) - - // Will not trigger since our lowest RTT in our burst is the same as the long - // term RTT provided. - for n := 0; n < hybridStartMinSamples; n++ { - Expect(slowStart.ShouldExitSlowStart(rtt+time.Duration(n)*time.Millisecond, rtt, 100)).To(BeFalse()) - } - endPacketNumber++ - slowStart.StartReceiveRound(endPacketNumber) - for n := 1; n < hybridStartMinSamples; n++ { - Expect(slowStart.ShouldExitSlowStart(rtt+(time.Duration(n)+10)*time.Millisecond, rtt, 100)).To(BeFalse()) - } - // Expect to trigger since all packets in this burst was above the long term - // RTT provided. - Expect(slowStart.ShouldExitSlowStart(rtt+10*time.Millisecond, rtt, 100)).To(BeTrue()) - }) -}) + // Will not trigger since our lowest RTT in our burst is the same as the long + // term RTT provided. + for n := 0; n < hybridStartMinSamples; n++ { + require.False(t, slowStart.ShouldExitSlowStart(rtt+time.Duration(n)*time.Millisecond, rtt, 100)) + } + endPacketNumber++ + slowStart.StartReceiveRound(endPacketNumber) + for n := 1; n < hybridStartMinSamples; n++ { + require.False(t, slowStart.ShouldExitSlowStart(rtt+(time.Duration(n)+10)*time.Millisecond, rtt, 100)) + } + // Expect to trigger since all packets in this burst was above the long term + // RTT provided. + require.True(t, slowStart.ShouldExitSlowStart(rtt+10*time.Millisecond, rtt, 100)) +}