diff --git a/congestion/cubic_sender.go b/congestion/cubic_sender.go index 5036a9e3..792bf823 100644 --- a/congestion/cubic_sender.go +++ b/congestion/cubic_sender.go @@ -21,6 +21,8 @@ type cubicSender struct { stats connectionStats cubic *Cubic + reno bool + // Track the largest packet that has been sent. largestSentPacketNumber protocol.PacketNumber @@ -55,20 +57,23 @@ type cubicSender struct { // ACK counter for the Reno implementation. congestionWindowCount uint64 - reno bool + initialCongestionWindow protocol.PacketNumber + initialMaxCongestionWindow protocol.PacketNumber } // NewCubicSender makes a new cubic sender -func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, maxCongestionWindow protocol.PacketNumber) SendAlgorithm { +func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.PacketNumber) SendAlgorithm { return &cubicSender{ - rttStats: rttStats, - congestionWindow: initialCongestionWindow, - minCongestionWindow: defaultMinimumCongestionWindow, - slowstartThreshold: maxCongestionWindow, - maxTCPCongestionWindow: maxCongestionWindow, - numConnections: defaultNumConnections, - cubic: NewCubic(clock), - reno: reno, + rttStats: rttStats, + initialCongestionWindow: initialCongestionWindow, + initialMaxCongestionWindow: initialMaxCongestionWindow, + congestionWindow: initialCongestionWindow, + minCongestionWindow: defaultMinimumCongestionWindow, + slowstartThreshold: initialMaxCongestionWindow, + maxTCPCongestionWindow: initialMaxCongestionWindow, + numConnections: defaultNumConnections, + cubic: NewCubic(clock), + reno: reno, } } @@ -272,3 +277,18 @@ func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { c.slowstartThreshold = c.congestionWindow / 2 c.congestionWindow = c.minCongestionWindow } + +// OnConnectionMigration is called when the connection is migrated (?) +func (c *cubicSender) OnConnectionMigration() { + c.hybridSlowStart.Restart() + c.prr = PrrSender{} + c.largestSentPacketNumber = 0 + c.largestAckedPacketNumber = 0 + c.largestSentAtLastCutback = 0 + c.lastCutbackExitedSlowstart = false + c.cubic.Reset() + c.congestionWindowCount = 0 + c.congestionWindow = c.initialCongestionWindow + c.slowstartThreshold = c.initialMaxCongestionWindow + c.maxTCPCongestionWindow = c.initialMaxCongestionWindow +} diff --git a/congestion/cubic_sender_test.go b/congestion/cubic_sender_test.go index 81147a92..7e5800c6 100644 --- a/congestion/cubic_sender_test.go +++ b/congestion/cubic_sender_test.go @@ -464,4 +464,36 @@ var _ = Describe("Cubic Sender", func() { expected_send_window += protocol.DefaultTCPMSS Expect(sender.GetCongestionWindow()).To(Equal(expected_send_window)) }) + + It("reset after connection migration", func() { + Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) + Expect(sender.SlowstartThreshold()).To(Equal(protocol.MaxCongestionWindow)) + + // Starts with slow start. + sender.SetNumEmulatedConnections(1) + const kNumberOfAcks = 10 + for i := 0; i < kNumberOfAcks; i++ { + // Send our full send window. + SendAvailableSendWindow() + AckNPackets(2) + } + SendAvailableSendWindow() + expected_send_window := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * kNumberOfAcks) + Expect(sender.GetCongestionWindow()).To(Equal(expected_send_window)) + + // Loses a packet to exit slow start. + LoseNPackets(1) + + // We should now have fallen out of slow start with a reduced window. Slow + // start threshold is also updated. + expected_send_window = uint64(float32(expected_send_window) * renoBeta) + Expect(sender.GetCongestionWindow()).To(Equal(expected_send_window)) + Expect(sender.SlowstartThreshold()).To(Equal(protocol.PacketNumber(expected_send_window / protocol.DefaultTCPMSS))) + + // Resets cwnd and slow start threshold on connection migrations. + sender.OnConnectionMigration() + Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP)) + Expect(sender.SlowstartThreshold()).To(Equal(protocol.MaxCongestionWindow)) + Expect(sender.HybridSlowStart().Started()).To(BeFalse()) + }) }) diff --git a/congestion/interface.go b/congestion/interface.go index a9cb90d2..1ba80019 100644 --- a/congestion/interface.go +++ b/congestion/interface.go @@ -15,6 +15,7 @@ type SendAlgorithm interface { SetNumEmulatedConnections(n int) OnRetransmissionTimeout(packetsRetransmitted bool) InRecovery() bool + OnConnectionMigration() HybridSlowStart() *HybridSlowStart // only for testing SlowstartThreshold() protocol.PacketNumber // only for testing