diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 8fe5fec7..22c5e6c1 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -38,6 +38,7 @@ type SentPacketHandler interface { TimeUntilSend() time.Time // HasPacingBudget says if the pacer allows sending of a (full size) packet at this moment. HasPacingBudget() bool + SetMaxDatagramSize(count protocol.ByteCount) // only to be called once the handshake is complete QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index d4dd6498..199f560d 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -707,6 +707,10 @@ func (h *sentPacketHandler) HasPacingBudget() bool { return h.congestion.HasPacingBudget() } +func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) { + h.congestion.SetMaxDatagramSize(s) +} + func (h *sentPacketHandler) isAmplificationLimited() bool { if h.peerAddressValidated { return false diff --git a/internal/congestion/cubic_sender.go b/internal/congestion/cubic_sender.go index 20a75af2..868a264c 100644 --- a/internal/congestion/cubic_sender.go +++ b/internal/congestion/cubic_sender.go @@ -11,12 +11,13 @@ import ( const ( // maxDatagramSize is the default maximum packet size used in the Linux TCP implementation. // Used in QUIC for congestion window computations in bytes. - maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4) - maxBurstBytes = 3 * maxDatagramSize - renoBeta = 0.7 // Reno backoff factor. - maxCongestionWindow = protocol.MaxCongestionWindowPackets * maxDatagramSize - minCongestionWindow = 2 * maxDatagramSize - initialCongestionWindow = 32 * maxDatagramSize + maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4) + initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4) + maxBurstPackets = 3 + renoBeta = 0.7 // Reno backoff factor. + initialMaxCongestionWindow = protocol.MaxCongestionWindowPackets * initialMaxDatagramSize + minCongestionWindowPackets = 2 + initialCongestionWindow = 32 * initialMaxDatagramSize ) type cubicSender struct { @@ -44,12 +45,6 @@ type cubicSender struct { // Congestion window in packets. congestionWindow protocol.ByteCount - // Minimum congestion window in packets. - minCongestionWindow protocol.ByteCount - - // Maximum congestion window. - maxCongestionWindow protocol.ByteCount - // Slow start congestion window in bytes, aka ssthresh. slowStartThreshold protocol.ByteCount @@ -59,6 +54,8 @@ type cubicSender struct { initialCongestionWindow protocol.ByteCount initialMaxCongestionWindow protocol.ByteCount + maxDatagramSize protocol.ByteCount + lastState logging.CongestionState tracer logging.ConnectionTracer } @@ -70,7 +67,7 @@ var ( // NewCubicSender makes a new cubic sender func NewCubicSender(clock Clock, rttStats *utils.RTTStats, reno bool, tracer logging.ConnectionTracer) *cubicSender { - return newCubicSender(clock, rttStats, reno, initialCongestionWindow, maxCongestionWindow, tracer) + return newCubicSender(clock, rttStats, reno, initialCongestionWindow, initialMaxCongestionWindow, tracer) } func newCubicSender(clock Clock, rttStats *utils.RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount, tracer logging.ConnectionTracer) *cubicSender { @@ -82,13 +79,12 @@ func newCubicSender(clock Clock, rttStats *utils.RTTStats, reno bool, initialCon initialCongestionWindow: initialCongestionWindow, initialMaxCongestionWindow: initialMaxCongestionWindow, congestionWindow: initialCongestionWindow, - minCongestionWindow: minCongestionWindow, slowStartThreshold: protocol.MaxByteCount, - maxCongestionWindow: initialMaxCongestionWindow, cubic: NewCubic(clock), clock: clock, reno: reno, tracer: tracer, + maxDatagramSize: initialMaxDatagramSize, } c.pacer = newPacer(c.BandwidthEstimate) if c.tracer != nil { @@ -104,12 +100,20 @@ func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time { } func (c *cubicSender) HasPacingBudget() bool { - return c.pacer.Budget(c.clock.Now()) >= maxDatagramSize + return c.pacer.Budget(c.clock.Now()) >= c.maxDatagramSize +} + +func (c *cubicSender) maxCongestionWindow() protocol.ByteCount { + return c.maxDatagramSize * protocol.MaxCongestionWindowPackets +} + +func (c *cubicSender) minCongestionWindow() protocol.ByteCount { + return c.maxDatagramSize * minCongestionWindowPackets } func (c *cubicSender) OnPacketSent( sentTime time.Time, - bytesInFlight protocol.ByteCount, + _ protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool, @@ -139,7 +143,8 @@ func (c *cubicSender) GetCongestionWindow() protocol.ByteCount { } func (c *cubicSender) MaybeExitSlowStart() { - if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/maxDatagramSize) { + if c.InSlowStart() && + c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) { // exit slow start c.slowStartThreshold = c.congestionWindow c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) @@ -162,11 +167,7 @@ func (c *cubicSender) OnPacketAcked( } } -func (c *cubicSender) OnPacketLost( - packetNumber protocol.PacketNumber, - lostBytes protocol.ByteCount, - priorInFlight protocol.ByteCount, -) { +func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets // already sent should be treated as a single loss event, since it's expected. if packetNumber <= c.largestSentAtLastCutback { @@ -180,8 +181,8 @@ func (c *cubicSender) OnPacketLost( } else { c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) } - if c.congestionWindow < c.minCongestionWindow { - c.congestionWindow = c.minCongestionWindow + if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd { + c.congestionWindow = minCwnd } c.slowStartThreshold = c.congestionWindow c.largestSentAtLastCutback = c.largestSentPacketNumber @@ -205,12 +206,12 @@ func (c *cubicSender) maybeIncreaseCwnd( c.maybeTraceStateChange(logging.CongestionStateApplicationLimited) return } - if c.congestionWindow >= c.maxCongestionWindow { + if c.congestionWindow >= c.maxCongestionWindow() { return } if c.InSlowStart() { // TCP slow start, exponential growth, increase by one for each ACK. - c.congestionWindow += maxDatagramSize + c.congestionWindow += c.maxDatagramSize c.maybeTraceStateChange(logging.CongestionStateSlowStart) return } @@ -219,12 +220,12 @@ func (c *cubicSender) maybeIncreaseCwnd( if c.reno { // Classic Reno congestion avoidance. c.numAckedPackets++ - if c.numAckedPackets >= uint64(c.congestionWindow/maxDatagramSize) { - c.congestionWindow += maxDatagramSize + if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) { + c.congestionWindow += c.maxDatagramSize c.numAckedPackets = 0 } } else { - c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) + c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) } } @@ -235,7 +236,7 @@ func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool { } availableBytes := congestionWindow - bytesInFlight slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2 - return slowStartLimited || availableBytes <= maxBurstBytes + return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize } // BandwidthEstimate returns the current bandwidth estimate @@ -257,7 +258,7 @@ func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { c.hybridSlowStart.Restart() c.cubic.Reset() c.slowStartThreshold = c.congestionWindow / 2 - c.congestionWindow = c.minCongestionWindow + c.congestionWindow = c.minCongestionWindow() } // OnConnectionMigration is called when the connection is migrated (?) @@ -271,7 +272,6 @@ func (c *cubicSender) OnConnectionMigration() { c.numAckedPackets = 0 c.congestionWindow = c.initialCongestionWindow c.slowStartThreshold = c.initialMaxCongestionWindow - c.maxCongestionWindow = c.initialMaxCongestionWindow } func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { @@ -281,3 +281,14 @@ func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { c.tracer.UpdatedCongestionState(new) c.lastState = new } + +func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) { + if s < c.maxDatagramSize { + panic("congestion BUG: decreased max datagram size") + } + cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow() + c.maxDatagramSize = s + if cwndIsMinCwnd { + c.congestionWindow = c.minCongestionWindow() + } +} diff --git a/internal/congestion/cubic_sender_test.go b/internal/congestion/cubic_sender_test.go index 6836ea99..8560feab 100644 --- a/internal/congestion/cubic_sender_test.go +++ b/internal/congestion/cubic_sender_test.go @@ -46,7 +46,7 @@ var _ = Describe("Cubic Sender", func() { }) SendAvailableSendWindowLen := func(packetLength protocol.ByteCount) int { - packetsSent := 0 + var packetsSent int for sender.CanSend(bytesInFlight) { sender.OnPacketSent(clock.Now(), bytesInFlight, packetNumber, packetLength, true) packetNumber++ @@ -449,15 +449,32 @@ var _ = Describe("Cubic Sender", func() { Expect(sender.hybridSlowStart.Started()).To(BeFalse()) }) - It("default max cwnd", func() { - sender = newCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*maxDatagramSize, maxCongestionWindow, nil) + It("slow starts up to the maximum congestion window", func() { + sender = newCubicSender(&clock, rttStats, true, initialCongestionWindowPackets*maxDatagramSize, initialMaxCongestionWindow, nil) - defaultMaxCongestionWindowPackets := maxCongestionWindow / maxDatagramSize - for i := 1; i < int(defaultMaxCongestionWindowPackets); i++ { + for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { sender.MaybeExitSlowStart() sender.OnPacketAcked(protocol.PacketNumber(i), 1350, sender.GetCongestionWindow(), clock.Now()) } - Expect(sender.GetCongestionWindow()).To(Equal(maxCongestionWindow)) + Expect(sender.GetCongestionWindow()).To(Equal(initialMaxCongestionWindow)) + }) + + It("doesn't allow reductions of the maximum packet size", func() { + Expect(func() { sender.SetMaxDatagramSize(initialMaxDatagramSize - 1) }).To(Panic()) + }) + + It("slow starts up to maximum congestion window, if larger packets are sent", func() { + sender = newCubicSender(&clock, rttStats, true, initialCongestionWindowPackets*maxDatagramSize, initialMaxCongestionWindow, nil) + const packetSize = initialMaxDatagramSize + 100 + sender.SetMaxDatagramSize(packetSize) + for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { + sender.OnPacketAcked(protocol.PacketNumber(i), packetSize, sender.GetCongestionWindow(), clock.Now()) + } + const maxCwnd = protocol.MaxCongestionWindowPackets * packetSize + Expect(sender.GetCongestionWindow()).To(And( + BeNumerically(">", maxCwnd), + BeNumerically("<=", maxCwnd+packetSize), + )) }) It("limit cwnd increase in congestion avoidance", func() { diff --git a/internal/congestion/interface.go b/internal/congestion/interface.go index 94ee88d9..5157383f 100644 --- a/internal/congestion/interface.go +++ b/internal/congestion/interface.go @@ -16,6 +16,7 @@ type SendAlgorithm interface { OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time) OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) OnRetransmissionTimeout(packetsRetransmitted bool) + SetMaxDatagramSize(protocol.ByteCount) } // A SendAlgorithmWithDebugInfos is a SendAlgorithm that exposes some debug infos diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index df1ef044..c4c9ef7b 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -212,6 +212,18 @@ func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeConfirmed() *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeConfirmed)) } +// SetMaxDatagramSize mocks base method. +func (m *MockSentPacketHandler) SetMaxDatagramSize(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetMaxDatagramSize", arg0) +} + +// SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. +func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSentPacketHandler)(nil).SetMaxDatagramSize), arg0) +} + // TimeUntilSend mocks base method. func (m *MockSentPacketHandler) TimeUntilSend() time.Time { m.ctrl.T.Helper() diff --git a/internal/mocks/congestion.go b/internal/mocks/congestion.go index d9267db5..1b9aef3f 100644 --- a/internal/mocks/congestion.go +++ b/internal/mocks/congestion.go @@ -165,6 +165,18 @@ func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnRetransmissionTimeout), arg0) } +// SetMaxDatagramSize mocks base method. +func (m *MockSendAlgorithmWithDebugInfos) SetMaxDatagramSize(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetMaxDatagramSize", arg0) +} + +// SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).SetMaxDatagramSize), arg0) +} + // TimeUntilSend mocks base method. func (m *MockSendAlgorithmWithDebugInfos) TimeUntilSend(arg0 protocol.ByteCount) time.Time { m.ctrl.T.Helper()