reject sending of DATAGRAM frames that exceed the current MTU (#4497)

* reject sending of datagrams that exceed the current MTU

* check datagram size in PMTUD integration test
This commit is contained in:
Marten Seemann
2024-05-07 15:06:00 +08:00
committed by GitHub
parent f1b473d925
commit 66f968b9ff
2 changed files with 47 additions and 8 deletions

View File

@@ -155,6 +155,8 @@ type connection struct {
packer packer packer packer
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes mtuDiscoverer mtuDiscoverer // initialized when the handshake completes
maxPayloadSizeEstimate atomic.Uint32
initialStream cryptoStream initialStream cryptoStream
handshakeStream cryptoStream handshakeStream cryptoStream
oneRTTStream cryptoStream // only set for the server oneRTTStream cryptoStream // only set for the server
@@ -274,9 +276,10 @@ var newConnection = func(
) )
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.preSetup() s.preSetup()
initialPacketSize := getMaxPacketSize(s.conn.RemoteAddr())
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0, 0,
getMaxPacketSize(s.conn.RemoteAddr()), initialPacketSize,
s.rttStats, s.rttStats,
clientAddressValidated, clientAddressValidated,
s.conn.capabilities().ECN, s.conn.capabilities().ECN,
@@ -284,7 +287,8 @@ var newConnection = func(
s.tracer, s.tracer,
s.logger, s.logger,
) )
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, initialPacketSize, s.onMTUIncreased)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(initialPacketSize)))
params := &wire.TransportParameters{ params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@@ -383,9 +387,10 @@ var newClientConnection = func(
) )
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.preSetup() s.preSetup()
initialPacketSize := getMaxPacketSize(s.conn.RemoteAddr())
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
initialPacketNumber, initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()), initialPacketSize,
s.rttStats, s.rttStats,
false, // has no effect false, // has no effect
s.conn.capabilities().ECN, s.conn.capabilities().ECN,
@@ -393,7 +398,8 @@ var newClientConnection = func(
s.tracer, s.tracer,
s.logger, s.logger,
) )
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, initialPacketSize, s.onMTUIncreased)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(initialPacketSize)))
oneRTTStream := newCryptoStream() oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{ params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@@ -2352,13 +2358,23 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) {
} }
} }
func (s *connection) onMTUIncreased(mtu protocol.ByteCount) {
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(mtu)))
s.sentPacketHandler.SetMaxDatagramSize(mtu)
}
func (s *connection) SendDatagram(p []byte) error { func (s *connection) SendDatagram(p []byte) error {
if !s.supportsDatagrams() { if !s.supportsDatagrams() {
return errors.New("datagram support disabled") return errors.New("datagram support disabled")
} }
f := &wire.DatagramFrame{DataLenPresent: true} f := &wire.DatagramFrame{DataLenPresent: true}
maxDataLen := f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) // The payload size estimate is conservative.
// Under many circumstances we could send a few more bytes.
maxDataLen := min(
f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version),
protocol.ByteCount(s.maxPayloadSizeEstimate.Load()),
)
if protocol.ByteCount(len(p)) > maxDataLen { if protocol.ByteCount(len(p)) > maxDataLen {
return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)} return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)}
} }
@@ -2391,3 +2407,10 @@ func (s *connection) NextConnection() Connection {
s.streamsMap.UseResetMaps() s.streamsMap.UseResetMaps()
return s return s
} }
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.
// It is not very sophisticated: it just subtracts the size of header (assuming the maximum
// connection ID length), and the size of the encryption tag.
func estimateMaxPayloadSize(mtu protocol.ByteCount) protocol.ByteCount {
return mtu - 1 /* type byte */ - 20 /* maximum connection ID length */ - 16 /* tag size */
}

View File

@@ -17,10 +17,14 @@ import (
var _ = Describe("DPLPMTUD", func() { var _ = Describe("DPLPMTUD", func() {
It("discovers the MTU", func() { It("discovers the MTU", func() {
const rtt = 100 * time.Millisecond rtt := scaleDuration(10 * time.Millisecond)
const mtu = 1400 const mtu = 1400
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true})) ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer ln.Close() defer ln.Close()
go func() { go func() {
@@ -73,7 +77,7 @@ var _ = Describe("DPLPMTUD", func() {
context.Background(), context.Background(),
proxy.LocalAddr(), proxy.LocalAddr(),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfig(nil), getQuicConfig(&quic.Config{EnableDatagrams: true}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "") defer conn.CloseWithError(0, "")
@@ -87,15 +91,27 @@ var _ = Describe("DPLPMTUD", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRDataLong)) Expect(data).To(Equal(PRDataLong))
}() }()
err = conn.SendDatagram(make([]byte, 2000))
Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{}))
initialMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize
_, err = str.Write(PRDataLong) _, err = str.Write(PRDataLong)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str.Close() str.Close()
Eventually(done, 20*time.Second).Should(BeClosed()) Eventually(done, 20*time.Second).Should(BeClosed())
err = conn.SendDatagram(make([]byte, 2000))
Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{}))
finalMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize
mx.Lock() mx.Lock()
defer mx.Unlock() defer mx.Unlock()
fmt.Fprintf(GinkgoWriter, "max client packet size: %d, MTU: %d\n", maxPacketSizeClient, mtu) fmt.Fprintf(GinkgoWriter, "max client packet size: %d, MTU: %d\n", maxPacketSizeClient, mtu)
fmt.Fprintf(GinkgoWriter, "max datagram size: initial: %d, final: %d\n", initialMaxDatagramSize, finalMaxDatagramSize)
fmt.Fprintf(GinkgoWriter, "max server packet size: %d, MTU: %d\n", maxPacketSizeServer, mtu) fmt.Fprintf(GinkgoWriter, "max server packet size: %d, MTU: %d\n", maxPacketSizeServer, mtu)
Expect(maxPacketSizeClient).To(BeNumerically(">=", mtu-25)) Expect(maxPacketSizeClient).To(BeNumerically(">=", mtu-25))
const maxDiff = 40 // this includes the 21 bytes for the short header, 16 bytes for the encryption tag, and framing overhead
Expect(initialMaxDatagramSize).To(BeNumerically(">=", 1252-maxDiff))
Expect(finalMaxDatagramSize).To(BeNumerically(">=", maxPacketSizeClient-maxDiff))
// MTU discovery was disabled on the server side
Expect(maxPacketSizeServer).To(Equal(1252))
}) })
}) })