diff --git a/connection.go b/connection.go index 5afc7c6b9..5202ff58c 100644 --- a/connection.go +++ b/connection.go @@ -147,7 +147,7 @@ type connection struct { packer packer mtuDiscoverer mtuDiscoverer // initialized when the transport parameters are received - maxPayloadSizeEstimate atomic.Uint32 + currentMTUEstimate atomic.Uint32 initialStream *cryptoStream handshakeStream *cryptoStream @@ -279,7 +279,7 @@ var newConnection = func( s.tracer, s.logger, ) - s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) + s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) statelessResetToken := statelessResetter.GetStatelessResetToken(srcConnID) params := &wire.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -391,7 +391,7 @@ var newClientConnection = func( s.tracer, s.logger, ) - s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) + s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) oneRTTStream := newCryptoStream() params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -1586,6 +1586,13 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr return err } } + // If one of the acknowledged packets was a Path MTU probe packet, this might have increased the Path MTU estimate. + if s.mtuDiscoverer != nil { + if mtu := s.mtuDiscoverer.CurrentSize(); mtu > protocol.ByteCount(s.currentMTUEstimate.Load()) { + s.currentMTUEstimate.Store(uint32(mtu)) + s.sentPacketHandler.SetMaxDatagramSize(mtu) + } + } return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) } @@ -1851,7 +1858,6 @@ func (s *connection) applyTransportParameters() { s.rttStats, protocol.ByteCount(s.config.InitialPacketSize), maxPacketSize, - s.onMTUIncreased, s.tracer, ) } @@ -2319,11 +2325,6 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) { s.framer.RemoveActiveStream(id) } -func (s *connection) onMTUIncreased(mtu protocol.ByteCount) { - s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(mtu))) - s.sentPacketHandler.SetMaxDatagramSize(mtu) -} - func (s *connection) SendDatagram(p []byte) error { if !s.supportsDatagrams() { return errors.New("datagram support disabled") @@ -2334,7 +2335,7 @@ func (s *connection) SendDatagram(p []byte) error { // Under many circumstances we could send a few more bytes. maxDataLen := min( f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version), - protocol.ByteCount(s.maxPayloadSizeEstimate.Load()), + protocol.ByteCount(s.currentMTUEstimate.Load()), ) if protocol.ByteCount(len(p)) > maxDataLen { return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)} diff --git a/mtu_discoverer.go b/mtu_discoverer.go index ee636a6d1..d2d1ac77a 100644 --- a/mtu_discoverer.go +++ b/mtu_discoverer.go @@ -88,7 +88,6 @@ const ( type mtuFinder struct { lastProbeTime time.Time - mtuIncreased func(protocol.ByteCount) rttStats *utils.RTTStats @@ -107,15 +106,13 @@ var _ mtuDiscoverer = &mtuFinder{} func newMTUDiscoverer( rttStats *utils.RTTStats, start, max protocol.ByteCount, - mtuIncreased func(protocol.ByteCount), tracer *logging.ConnectionTracer, ) *mtuFinder { f := &mtuFinder{ - inFlight: protocol.InvalidByteCount, - min: start, - rttStats: rttStats, - mtuIncreased: mtuIncreased, - tracer: tracer, + inFlight: protocol.InvalidByteCount, + min: start, + rttStats: rttStats, + tracer: tracer, } for i := range f.lost { if i == 0 { @@ -207,7 +204,6 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) { if h.tracer != nil && h.tracer.UpdatedMTU != nil { h.tracer.UpdatedMTU(size, h.done()) } - h.mtuIncreased(size) } func (h *mtuFinderAckHandler) OnLost(wire.Frame) { diff --git a/mtu_discoverer_test.go b/mtu_discoverer_test.go index 6d87b9ecb..415d06387 100644 --- a/mtu_discoverer_test.go +++ b/mtu_discoverer_test.go @@ -17,7 +17,7 @@ func TestMTUDiscovererTiming(t *testing.T) { const rtt = 100 * time.Millisecond var rttStats utils.RTTStats rttStats.UpdateRTT(rtt, 0) - d := newMTUDiscoverer(&rttStats, 1000, 2000, func(s protocol.ByteCount) {}, nil) + d := newMTUDiscoverer(&rttStats, 1000, 2000, nil) now := time.Now() require.False(t, d.ShouldSendProbe(now)) @@ -37,21 +37,20 @@ func TestMTUDiscovererTiming(t *testing.T) { } func TestMTUDiscovererAckAndLoss(t *testing.T) { - var mtu protocol.ByteCount - d := newMTUDiscoverer(&utils.RTTStats{}, 1000, 2000, func(s protocol.ByteCount) { mtu = s }, nil) + d := newMTUDiscoverer(&utils.RTTStats{}, 1000, 2000, nil) // we use an RTT of 0 here, so we don't have to advance the timer on every step now := time.Now() ping, size := d.GetPing(now) require.Equal(t, protocol.ByteCount(1500), size) // the MTU is reduced if the frame is lost ping.Handler.OnLost(ping.Frame) - require.Zero(t, mtu) // no change to the MTU yet + require.Equal(t, protocol.ByteCount(1000), d.CurrentSize()) // no change to the MTU yet require.True(t, d.ShouldSendProbe(now)) ping, size = d.GetPing(now) require.Equal(t, protocol.ByteCount(1250), size) ping.Handler.OnAcked(ping.Frame) - require.Equal(t, protocol.ByteCount(1250), mtu) // the MTU is increased + require.Equal(t, protocol.ByteCount(1250), d.CurrentSize()) // the MTU is increased // Even though the 1500 byte MTU probe packet was lost, we try again with a higher MTU. // This protects against regular (non-MTU-related) packet loss. @@ -59,7 +58,7 @@ func TestMTUDiscovererAckAndLoss(t *testing.T) { ping, size = d.GetPing(now) require.Greater(t, size, protocol.ByteCount(1500)) ping.Handler.OnAcked(ping.Frame) - require.Equal(t, size, mtu) + require.Equal(t, size, d.CurrentSize()) // We continue probing until the MTU is close to the maximum. var steps int @@ -91,13 +90,9 @@ func testMTUDiscovererMTUDiscovery(t *testing.T) { rttStats.UpdateRTT(rtt, 0) maxMTU := protocol.ByteCount(rand.IntN(int(3000-startMTU))) + startMTU + 1 - currentMTU := startMTU var tracedMTU protocol.ByteCount var tracerDone bool - d := newMTUDiscoverer( - &rttStats, - startMTU, maxMTU, - func(s protocol.ByteCount) { currentMTU = s }, + d := newMTUDiscoverer(&rttStats, startMTU, maxMTU, &logging.ConnectionTracer{ UpdatedMTU: func(mtu logging.ByteCount, done bool) { tracedMTU = mtu @@ -122,6 +117,7 @@ func testMTUDiscovererMTUDiscovery(t *testing.T) { } now = now.Add(mtuProbeDelay * rtt) } + currentMTU := d.CurrentSize() diff := realMTU - currentMTU require.GreaterOrEqual(t, diff, protocol.ByteCount(0)) if maxMTU > currentMTU+maxMTU { @@ -151,15 +147,10 @@ func testMTUDiscovererWithRandomLoss(t *testing.T) { require.Equal(t, rtt, rttStats.SmoothedRTT()) maxMTU := protocol.ByteCount(rand.IntN(int(3000-startMTU))) + startMTU + 1 - currentMTU := startMTU var tracedMTU protocol.ByteCount var tracerDone bool - d := newMTUDiscoverer( - rttStats, - startMTU, - maxMTU, - func(s protocol.ByteCount) { currentMTU = s }, + d := newMTUDiscoverer(rttStats, startMTU, maxMTU, &logging.ConnectionTracer{ UpdatedMTU: func(mtu logging.ByteCount, done bool) { tracedMTU = mtu @@ -195,6 +186,7 @@ func testMTUDiscovererWithRandomLoss(t *testing.T) { now = now.Add(mtuProbeDelay * rtt) } + currentMTU := d.CurrentSize() diff := realMTU - currentMTU require.GreaterOrEqual(t, diff, protocol.ByteCount(0)) if maxMTU > currentMTU+maxMTU {