forked from quic-go/quic-go
reject sending of DATAGRAM frames that exceed the current MTU (#4497)
* reject sending of datagrams that exceed the current MTU * check datagram size in PMTUD integration test
This commit is contained in:
@@ -155,6 +155,8 @@ type connection struct {
|
|||||||
packer packer
|
packer packer
|
||||||
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes
|
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes
|
||||||
|
|
||||||
|
maxPayloadSizeEstimate atomic.Uint32
|
||||||
|
|
||||||
initialStream cryptoStream
|
initialStream cryptoStream
|
||||||
handshakeStream cryptoStream
|
handshakeStream cryptoStream
|
||||||
oneRTTStream cryptoStream // only set for the server
|
oneRTTStream cryptoStream // only set for the server
|
||||||
@@ -274,9 +276,10 @@ var newConnection = func(
|
|||||||
)
|
)
|
||||||
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
|
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
|
||||||
s.preSetup()
|
s.preSetup()
|
||||||
|
initialPacketSize := getMaxPacketSize(s.conn.RemoteAddr())
|
||||||
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
|
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
|
||||||
0,
|
0,
|
||||||
getMaxPacketSize(s.conn.RemoteAddr()),
|
initialPacketSize,
|
||||||
s.rttStats,
|
s.rttStats,
|
||||||
clientAddressValidated,
|
clientAddressValidated,
|
||||||
s.conn.capabilities().ECN,
|
s.conn.capabilities().ECN,
|
||||||
@@ -284,7 +287,8 @@ var newConnection = func(
|
|||||||
s.tracer,
|
s.tracer,
|
||||||
s.logger,
|
s.logger,
|
||||||
)
|
)
|
||||||
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
|
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, initialPacketSize, s.onMTUIncreased)
|
||||||
|
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(initialPacketSize)))
|
||||||
params := &wire.TransportParameters{
|
params := &wire.TransportParameters{
|
||||||
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
||||||
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
||||||
@@ -383,9 +387,10 @@ var newClientConnection = func(
|
|||||||
)
|
)
|
||||||
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
|
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
|
||||||
s.preSetup()
|
s.preSetup()
|
||||||
|
initialPacketSize := getMaxPacketSize(s.conn.RemoteAddr())
|
||||||
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
|
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
|
||||||
initialPacketNumber,
|
initialPacketNumber,
|
||||||
getMaxPacketSize(s.conn.RemoteAddr()),
|
initialPacketSize,
|
||||||
s.rttStats,
|
s.rttStats,
|
||||||
false, // has no effect
|
false, // has no effect
|
||||||
s.conn.capabilities().ECN,
|
s.conn.capabilities().ECN,
|
||||||
@@ -393,7 +398,8 @@ var newClientConnection = func(
|
|||||||
s.tracer,
|
s.tracer,
|
||||||
s.logger,
|
s.logger,
|
||||||
)
|
)
|
||||||
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
|
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, initialPacketSize, s.onMTUIncreased)
|
||||||
|
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(initialPacketSize)))
|
||||||
oneRTTStream := newCryptoStream()
|
oneRTTStream := newCryptoStream()
|
||||||
params := &wire.TransportParameters{
|
params := &wire.TransportParameters{
|
||||||
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
||||||
@@ -2352,13 +2358,23 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *connection) onMTUIncreased(mtu protocol.ByteCount) {
|
||||||
|
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(mtu)))
|
||||||
|
s.sentPacketHandler.SetMaxDatagramSize(mtu)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *connection) SendDatagram(p []byte) error {
|
func (s *connection) SendDatagram(p []byte) error {
|
||||||
if !s.supportsDatagrams() {
|
if !s.supportsDatagrams() {
|
||||||
return errors.New("datagram support disabled")
|
return errors.New("datagram support disabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
f := &wire.DatagramFrame{DataLenPresent: true}
|
f := &wire.DatagramFrame{DataLenPresent: true}
|
||||||
maxDataLen := f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version)
|
// The payload size estimate is conservative.
|
||||||
|
// Under many circumstances we could send a few more bytes.
|
||||||
|
maxDataLen := min(
|
||||||
|
f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version),
|
||||||
|
protocol.ByteCount(s.maxPayloadSizeEstimate.Load()),
|
||||||
|
)
|
||||||
if protocol.ByteCount(len(p)) > maxDataLen {
|
if protocol.ByteCount(len(p)) > maxDataLen {
|
||||||
return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)}
|
return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)}
|
||||||
}
|
}
|
||||||
@@ -2391,3 +2407,10 @@ func (s *connection) NextConnection() Connection {
|
|||||||
s.streamsMap.UseResetMaps()
|
s.streamsMap.UseResetMaps()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.
|
||||||
|
// It is not very sophisticated: it just subtracts the size of header (assuming the maximum
|
||||||
|
// connection ID length), and the size of the encryption tag.
|
||||||
|
func estimateMaxPayloadSize(mtu protocol.ByteCount) protocol.ByteCount {
|
||||||
|
return mtu - 1 /* type byte */ - 20 /* maximum connection ID length */ - 16 /* tag size */
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,10 +17,14 @@ import (
|
|||||||
|
|
||||||
var _ = Describe("DPLPMTUD", func() {
|
var _ = Describe("DPLPMTUD", func() {
|
||||||
It("discovers the MTU", func() {
|
It("discovers the MTU", func() {
|
||||||
const rtt = 100 * time.Millisecond
|
rtt := scaleDuration(10 * time.Millisecond)
|
||||||
const mtu = 1400
|
const mtu = 1400
|
||||||
|
|
||||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}))
|
ln, err := quic.ListenAddr(
|
||||||
|
"localhost:0",
|
||||||
|
getTLSConfig(),
|
||||||
|
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
|
||||||
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
go func() {
|
go func() {
|
||||||
@@ -73,7 +77,7 @@ var _ = Describe("DPLPMTUD", func() {
|
|||||||
context.Background(),
|
context.Background(),
|
||||||
proxy.LocalAddr(),
|
proxy.LocalAddr(),
|
||||||
getTLSClientConfig(),
|
getTLSClientConfig(),
|
||||||
getQuicConfig(nil),
|
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer conn.CloseWithError(0, "")
|
defer conn.CloseWithError(0, "")
|
||||||
@@ -87,15 +91,27 @@ var _ = Describe("DPLPMTUD", func() {
|
|||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(data).To(Equal(PRDataLong))
|
Expect(data).To(Equal(PRDataLong))
|
||||||
}()
|
}()
|
||||||
|
err = conn.SendDatagram(make([]byte, 2000))
|
||||||
|
Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{}))
|
||||||
|
initialMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize
|
||||||
_, err = str.Write(PRDataLong)
|
_, err = str.Write(PRDataLong)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str.Close()
|
str.Close()
|
||||||
Eventually(done, 20*time.Second).Should(BeClosed())
|
Eventually(done, 20*time.Second).Should(BeClosed())
|
||||||
|
err = conn.SendDatagram(make([]byte, 2000))
|
||||||
|
Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{}))
|
||||||
|
finalMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize
|
||||||
|
|
||||||
mx.Lock()
|
mx.Lock()
|
||||||
defer mx.Unlock()
|
defer mx.Unlock()
|
||||||
fmt.Fprintf(GinkgoWriter, "max client packet size: %d, MTU: %d\n", maxPacketSizeClient, mtu)
|
fmt.Fprintf(GinkgoWriter, "max client packet size: %d, MTU: %d\n", maxPacketSizeClient, mtu)
|
||||||
|
fmt.Fprintf(GinkgoWriter, "max datagram size: initial: %d, final: %d\n", initialMaxDatagramSize, finalMaxDatagramSize)
|
||||||
fmt.Fprintf(GinkgoWriter, "max server packet size: %d, MTU: %d\n", maxPacketSizeServer, mtu)
|
fmt.Fprintf(GinkgoWriter, "max server packet size: %d, MTU: %d\n", maxPacketSizeServer, mtu)
|
||||||
Expect(maxPacketSizeClient).To(BeNumerically(">=", mtu-25))
|
Expect(maxPacketSizeClient).To(BeNumerically(">=", mtu-25))
|
||||||
|
const maxDiff = 40 // this includes the 21 bytes for the short header, 16 bytes for the encryption tag, and framing overhead
|
||||||
|
Expect(initialMaxDatagramSize).To(BeNumerically(">=", 1252-maxDiff))
|
||||||
|
Expect(finalMaxDatagramSize).To(BeNumerically(">=", maxPacketSizeClient-maxDiff))
|
||||||
|
// MTU discovery was disabled on the server side
|
||||||
|
Expect(maxPacketSizeServer).To(Equal(1252))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user