diff --git a/connection.go b/connection.go index b1e9aaee3..0b009fbd6 100644 --- a/connection.go +++ b/connection.go @@ -155,6 +155,8 @@ type connection struct { packer packer mtuDiscoverer mtuDiscoverer // initialized when the handshake completes + maxPayloadSizeEstimate atomic.Uint32 + initialStream cryptoStream handshakeStream cryptoStream oneRTTStream cryptoStream // only set for the server @@ -274,9 +276,10 @@ var newConnection = func( ) s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) s.preSetup() + initialPacketSize := getMaxPacketSize(s.conn.RemoteAddr()) s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( 0, - getMaxPacketSize(s.conn.RemoteAddr()), + initialPacketSize, s.rttStats, clientAddressValidated, s.conn.capabilities().ECN, @@ -284,7 +287,8 @@ var newConnection = func( s.tracer, s.logger, ) - s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) + s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, initialPacketSize, s.onMTUIncreased) + s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(initialPacketSize))) params := &wire.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -383,9 +387,10 @@ var newClientConnection = func( ) s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) s.preSetup() + initialPacketSize := getMaxPacketSize(s.conn.RemoteAddr()) s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( initialPacketNumber, - getMaxPacketSize(s.conn.RemoteAddr()), + initialPacketSize, s.rttStats, false, // has no effect s.conn.capabilities().ECN, @@ -393,7 +398,8 @@ var newClientConnection = func( s.tracer, s.logger, ) - s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) + s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, initialPacketSize, s.onMTUIncreased) + s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(initialPacketSize))) oneRTTStream := newCryptoStream() params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -2352,13 +2358,23 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) { } } +func (s *connection) onMTUIncreased(mtu protocol.ByteCount) { + s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(mtu))) + s.sentPacketHandler.SetMaxDatagramSize(mtu) +} + func (s *connection) SendDatagram(p []byte) error { if !s.supportsDatagrams() { return errors.New("datagram support disabled") } f := &wire.DatagramFrame{DataLenPresent: true} - maxDataLen := f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) + // The payload size estimate is conservative. + // Under many circumstances we could send a few more bytes. + maxDataLen := min( + f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version), + protocol.ByteCount(s.maxPayloadSizeEstimate.Load()), + ) if protocol.ByteCount(len(p)) > maxDataLen { return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)} } @@ -2391,3 +2407,10 @@ func (s *connection) NextConnection() Connection { s.streamsMap.UseResetMaps() return s } + +// estimateMaxPayloadSize estimates the maximum payload size for short header packets. +// It is not very sophisticated: it just subtracts the size of header (assuming the maximum +// connection ID length), and the size of the encryption tag. +func estimateMaxPayloadSize(mtu protocol.ByteCount) protocol.ByteCount { + return mtu - 1 /* type byte */ - 20 /* maximum connection ID length */ - 16 /* tag size */ +} diff --git a/integrationtests/self/dplpmtud_test.go b/integrationtests/self/dplpmtud_test.go index 4b5aae08d..402726322 100644 --- a/integrationtests/self/dplpmtud_test.go +++ b/integrationtests/self/dplpmtud_test.go @@ -17,10 +17,14 @@ import ( var _ = Describe("DPLPMTUD", func() { It("discovers the MTU", func() { - const rtt = 100 * time.Millisecond + rtt := scaleDuration(10 * time.Millisecond) const mtu = 1400 - ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true})) + ln, err := quic.ListenAddr( + "localhost:0", + getTLSConfig(), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}), + ) Expect(err).ToNot(HaveOccurred()) defer ln.Close() go func() { @@ -73,7 +77,7 @@ var _ = Describe("DPLPMTUD", func() { context.Background(), proxy.LocalAddr(), getTLSClientConfig(), - getQuicConfig(nil), + getQuicConfig(&quic.Config{EnableDatagrams: true}), ) Expect(err).ToNot(HaveOccurred()) defer conn.CloseWithError(0, "") @@ -87,15 +91,27 @@ var _ = Describe("DPLPMTUD", func() { Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(PRDataLong)) }() + err = conn.SendDatagram(make([]byte, 2000)) + Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{})) + initialMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize _, err = str.Write(PRDataLong) Expect(err).ToNot(HaveOccurred()) str.Close() Eventually(done, 20*time.Second).Should(BeClosed()) + err = conn.SendDatagram(make([]byte, 2000)) + Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{})) + finalMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize mx.Lock() defer mx.Unlock() fmt.Fprintf(GinkgoWriter, "max client packet size: %d, MTU: %d\n", maxPacketSizeClient, mtu) + fmt.Fprintf(GinkgoWriter, "max datagram size: initial: %d, final: %d\n", initialMaxDatagramSize, finalMaxDatagramSize) fmt.Fprintf(GinkgoWriter, "max server packet size: %d, MTU: %d\n", maxPacketSizeServer, mtu) Expect(maxPacketSizeClient).To(BeNumerically(">=", mtu-25)) + const maxDiff = 40 // this includes the 21 bytes for the short header, 16 bytes for the encryption tag, and framing overhead + Expect(initialMaxDatagramSize).To(BeNumerically(">=", 1252-maxDiff)) + Expect(finalMaxDatagramSize).To(BeNumerically(">=", maxPacketSizeClient-maxDiff)) + // MTU discovery was disabled on the server side + Expect(maxPacketSizeServer).To(Equal(1252)) }) })