diff --git a/internal/congestion/bandwidth.go b/internal/congestion/bandwidth.go index 1d03abbb..3ad827d2 100644 --- a/internal/congestion/bandwidth.go +++ b/internal/congestion/bandwidth.go @@ -1,7 +1,6 @@ package congestion import ( - "math" "time" "github.com/quic-go/quic-go/internal/protocol" @@ -10,8 +9,6 @@ import ( // Bandwidth of a connection type Bandwidth uint64 -const infBandwidth Bandwidth = math.MaxUint64 - const ( // BitsPerSecond is 1 bit per second BitsPerSecond Bandwidth = 1 diff --git a/internal/congestion/pacer.go b/internal/congestion/pacer.go index 92757eed..7656f529 100644 --- a/internal/congestion/pacer.go +++ b/internal/congestion/pacer.go @@ -1,6 +1,7 @@ package congestion import ( + "math" "time" "github.com/quic-go/quic-go/internal/monotime" @@ -48,8 +49,13 @@ func (p *pacer) Budget(now monotime.Time) protocol.ByteCount { if p.lastSentTime.IsZero() { return p.maxBurstSize() } - budget := p.budgetAtLastSent + (protocol.ByteCount(p.adjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 - if budget < 0 { // protect against overflows + delta := now.Sub(p.lastSentTime) + var added protocol.ByteCount + if delta > 0 { + added = p.timeScaledBandwidth(uint64(delta.Nanoseconds())) + } + budget := p.budgetAtLastSent + added + if added > 0 && budget < p.budgetAtLastSent { budget = protocol.MaxByteCount } return min(p.maxBurstSize(), budget) @@ -57,11 +63,30 @@ func (p *pacer) Budget(now monotime.Time) protocol.ByteCount { func (p *pacer) maxBurstSize() protocol.ByteCount { return max( - protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.adjustedBandwidth())/1e9, + p.timeScaledBandwidth(uint64((protocol.MinPacingDelay + protocol.TimerGranularity).Nanoseconds())), maxBurstSizePackets*p.maxDatagramSize, ) } +// timeScaledBandwidth calculates the number of bytes that may be sent within +// a given time interval (ns nanoseconds), based on the current bandwidth estimate. +// It caps the scaled value to the maximum allowed burst and handles overflows. +func (p *pacer) timeScaledBandwidth(ns uint64) protocol.ByteCount { + bw := p.adjustedBandwidth() + if bw == 0 { + return 0 + } + const nsPerSecond = 1e9 + maxBurst := maxBurstSizePackets * p.maxDatagramSize + var scaled protocol.ByteCount + if ns > math.MaxUint64/bw { + scaled = maxBurst + } else { + scaled = protocol.ByteCount(bw * ns / nsPerSecond) + } + return scaled +} + // TimeUntilSend returns when the next packet should be sent. // It returns zero if a packet can be sent immediately. func (p *pacer) TimeUntilSend() monotime.Time { diff --git a/internal/congestion/pacer_test.go b/internal/congestion/pacer_test.go index ca413548..795cd217 100644 --- a/internal/congestion/pacer_test.go +++ b/internal/congestion/pacer_test.go @@ -7,6 +7,8 @@ import ( "time" "github.com/quic-go/quic-go/internal/monotime" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/stretchr/testify/require" ) @@ -110,10 +112,43 @@ func TestPacerFastPacing(t *testing.T) { } func TestPacerNoOverflows(t *testing.T) { - p := newPacer(func() Bandwidth { return infBandwidth }) + p := newPacer(func() Bandwidth { return math.MaxUint64 }) now := monotime.Now() p.SentPacket(now, initialMaxDatagramSize) for range 100000 { require.NotZero(t, p.Budget(now.Add(time.Duration(rand.Int64N(math.MaxInt64))))) } + + burstCount := 1 + for p.Budget(now) > 0 { + burstCount++ + p.SentPacket(now, initialMaxDatagramSize) + } + require.Equal(t, maxBurstSizePackets, burstCount) + require.Zero(t, p.Budget(now)) + + next := p.TimeUntilSend() + require.Equal(t, next.Sub(now), protocol.MinPacingDelay) + require.Greater(t, p.Budget(next), initialMaxDatagramSize) +} + +func BenchmarkPacer(b *testing.B) { + const bandwidth = 50 * initialMaxDatagramSize // 50 full-size packets per second + p := newPacer(func() Bandwidth { return Bandwidth(bandwidth) * BytesPerSecond * 4 / 5 }) + + now := monotime.Now() + + var i int + for b.Loop() { + i++ + for p.Budget(now) > 0 { + p.SentPacket(now, initialMaxDatagramSize) + } + next := p.TimeUntilSend() + if i%2 == 0 { + now = next + } else { + now = now.Add(100 * time.Millisecond) + } + } }