diff --git a/connection.go b/connection.go index aaf48a3fe..0404fb195 100644 --- a/connection.go +++ b/connection.go @@ -1783,6 +1783,7 @@ func (s *connection) applyTransportParameters() { protocol.ByteCount(s.config.InitialPacketSize), maxPacketSize, s.onMTUIncreased, + s.tracer, ) } diff --git a/integrationtests/self/mtu_test.go b/integrationtests/self/mtu_test.go index 686a236c4..64de6ab16 100644 --- a/integrationtests/self/mtu_test.go +++ b/integrationtests/self/mtu_test.go @@ -11,6 +11,7 @@ import ( "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -46,7 +47,8 @@ var _ = Describe("DPLPMTUD", func() { }() var mx sync.Mutex - var maxPacketSizeClient, maxPacketSizeServer int + var maxPacketSizeServer int + var clientPacketSizes []int serverPort := ln.Addr().(*net.UDPAddr).Port proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), @@ -59,9 +61,7 @@ var _ = Describe("DPLPMTUD", func() { defer mx.Unlock() switch dir { case quicproxy.DirectionIncoming: - if len(packet) > maxPacketSizeClient { - maxPacketSizeClient = len(packet) - } + clientPacketSizes = append(clientPacketSizes, len(packet)) case quicproxy.DirectionOutgoing: if len(packet) > maxPacketSizeServer { maxPacketSizeServer = len(packet) @@ -80,6 +80,12 @@ var _ = Describe("DPLPMTUD", func() { defer udpConn.Close() tr := &quic.Transport{Conn: udpConn} defer tr.Close() + var mtus []logging.ByteCount + mtuTracer := &logging.ConnectionTracer{ + UpdatedMTU: func(mtu logging.ByteCount, _ bool) { + mtus = append(mtus, mtu) + }, + } conn, err := tr.Dial( context.Background(), proxy.LocalAddr(), @@ -87,6 +93,9 @@ var _ = Describe("DPLPMTUD", func() { getQuicConfig(&quic.Config{ InitialPacketSize: protocol.MinInitialPacketSize, EnableDatagrams: true, + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return mtuTracer + }, }), ) Expect(err).ToNot(HaveOccurred()) @@ -114,6 +123,8 @@ var _ = Describe("DPLPMTUD", func() { mx.Lock() defer mx.Unlock() + Expect(mtus).ToNot(BeEmpty()) + maxPacketSizeClient := int(mtus[len(mtus)-1]) fmt.Fprintf(GinkgoWriter, "max client packet size: %d, MTU: %d\n", maxPacketSizeClient, mtu) fmt.Fprintf(GinkgoWriter, "max datagram size: initial: %d, final: %d\n", initialMaxDatagramSize, finalMaxDatagramSize) fmt.Fprintf(GinkgoWriter, "max server packet size: %d, MTU: %d\n", maxPacketSizeServer, mtu) @@ -123,6 +134,16 @@ var _ = Describe("DPLPMTUD", func() { Expect(finalMaxDatagramSize).To(BeNumerically(">=", maxPacketSizeClient-maxDiff)) // MTU discovery was disabled on the server side Expect(maxPacketSizeServer).To(Equal(1234)) + + var numPacketsLargerThanDiscoveredMTU int + for _, s := range clientPacketSizes { + if s > maxPacketSizeClient { + numPacketsLargerThanDiscoveredMTU++ + } + } + // The client shouldn't have sent any packets larger than the MTU it discovered, + // except for at most one MTU probe packet. + Expect(numPacketsLargerThanDiscoveredMTU).To(BeNumerically("<=", 1)) }) It("uses the initial packet size", func() { diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index 8e8be8a5b..bd3a39c17 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -68,6 +68,9 @@ func NewMockConnectionTracer(ctrl *gomock.Controller) (*logging.ConnectionTracer LostPacket: func(encLevel logging.EncryptionLevel, pn logging.PacketNumber, reason logging.PacketLossReason) { t.LostPacket(encLevel, pn, reason) }, + UpdatedMTU: func(mtu logging.ByteCount, done bool) { + t.UpdatedMTU(mtu, done) + }, UpdatedCongestionState: func(state logging.CongestionState) { t.UpdatedCongestionState(state) }, diff --git a/internal/mocks/logging/internal/connection_tracer.go b/internal/mocks/logging/internal/connection_tracer.go index e02fab6bc..0bf082119 100644 --- a/internal/mocks/logging/internal/connection_tracer.go +++ b/internal/mocks/logging/internal/connection_tracer.go @@ -1052,6 +1052,42 @@ func (c *MockConnectionTracerUpdatedKeyFromTLSCall) DoAndReturn(f func(protocol. return c } +// UpdatedMTU mocks base method. +func (m *MockConnectionTracer) UpdatedMTU(arg0 protocol.ByteCount, arg1 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedMTU", arg0, arg1) +} + +// UpdatedMTU indicates an expected call of UpdatedMTU. +func (mr *MockConnectionTracerMockRecorder) UpdatedMTU(arg0, arg1 any) *MockConnectionTracerUpdatedMTUCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMTU", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMTU), arg0, arg1) + return &MockConnectionTracerUpdatedMTUCall{Call: call} +} + +// MockConnectionTracerUpdatedMTUCall wrap *gomock.Call +type MockConnectionTracerUpdatedMTUCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerUpdatedMTUCall) Return() *MockConnectionTracerUpdatedMTUCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerUpdatedMTUCall) Do(f func(protocol.ByteCount, bool)) *MockConnectionTracerUpdatedMTUCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerUpdatedMTUCall) DoAndReturn(f func(protocol.ByteCount, bool)) *MockConnectionTracerUpdatedMTUCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // UpdatedMetrics mocks base method. func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { m.ctrl.T.Helper() diff --git a/internal/mocks/logging/mockgen.go b/internal/mocks/logging/mockgen.go index a9850e642..3759aa1ad 100644 --- a/internal/mocks/logging/mockgen.go +++ b/internal/mocks/logging/mockgen.go @@ -34,6 +34,7 @@ type ConnectionTracer interface { ReceivedShortHeaderPacket(*logging.ShortHeader, logging.ByteCount, logging.ECN, []logging.Frame) BufferedPacket(logging.PacketType, logging.ByteCount) DroppedPacket(logging.PacketType, logging.PacketNumber, logging.ByteCount, logging.PacketDropReason) + UpdatedMTU(mtu logging.ByteCount, done bool) UpdatedMetrics(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) AcknowledgedPacket(logging.EncryptionLevel, logging.PacketNumber) LostPacket(logging.EncryptionLevel, logging.PacketNumber, logging.PacketLossReason) diff --git a/logging/connection_tracer.go b/logging/connection_tracer.go index 7f54d6cda..a354bbd94 100644 --- a/logging/connection_tracer.go +++ b/logging/connection_tracer.go @@ -24,6 +24,7 @@ type ConnectionTracer struct { UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) AcknowledgedPacket func(EncryptionLevel, PacketNumber) LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason) + UpdatedMTU func(mtu ByteCount, done bool) UpdatedCongestionState func(CongestionState) UpdatedPTOCount func(value uint32) UpdatedKeyFromTLS func(EncryptionLevel, Perspective) @@ -168,6 +169,13 @@ func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTra } } }, + UpdatedMTU: func(mtu ByteCount, done bool) { + for _, t := range tracers { + if t.UpdatedMTU != nil { + t.UpdatedMTU(mtu, done) + } + } + }, UpdatedCongestionState: func(state CongestionState) { for _, t := range tracers { if t.UpdatedCongestionState != nil { diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index 4a9fe6773..7c519f384 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -201,6 +201,12 @@ var _ = Describe("Tracing", func() { tracer.DroppedPacket(PacketTypeInitial, 42, 1337, PacketDropHeaderParseError) }) + It("traces the UpdatedMTU event", func() { + tr1.EXPECT().UpdatedMTU(ByteCount(1337), true) + tr2.EXPECT().UpdatedMTU(ByteCount(1337), true) + tracer.UpdatedMTU(1337, true) + }) + It("traces the UpdatedCongestionState event", func() { tr1.EXPECT().UpdatedCongestionState(CongestionStateRecovery) tr2.EXPECT().UpdatedCongestionState(CongestionStateRecovery) diff --git a/mtu_discoverer.go b/mtu_discoverer.go index cb3fbf06f..fb7c09a92 100644 --- a/mtu_discoverer.go +++ b/mtu_discoverer.go @@ -7,6 +7,7 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" ) type mtuDiscoverer interface { @@ -34,17 +35,25 @@ type mtuFinder struct { inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight current protocol.ByteCount max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer) + + tracer *logging.ConnectionTracer } var _ mtuDiscoverer = &mtuFinder{} -func newMTUDiscoverer(rttStats *utils.RTTStats, start, max protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder { +func newMTUDiscoverer( + rttStats *utils.RTTStats, + start, max protocol.ByteCount, + mtuIncreased func(protocol.ByteCount), + tracer *logging.ConnectionTracer, +) *mtuFinder { return &mtuFinder{ inFlight: protocol.InvalidByteCount, current: start, max: max, rttStats: rttStats, mtuIncreased: mtuIncreased, + tracer: tracer, } } @@ -79,7 +88,7 @@ func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { f.inFlight = size return ackhandler.Frame{ Frame: &wire.PingFrame{}, - Handler: (*mtuFinderAckHandler)(f), + Handler: &mtuFinderAckHandler{f}, }, size } @@ -87,7 +96,9 @@ func (f *mtuFinder) CurrentSize() protocol.ByteCount { return f.current } -type mtuFinderAckHandler mtuFinder +type mtuFinderAckHandler struct { + *mtuFinder +} var _ ackhandler.FrameHandler = &mtuFinderAckHandler{} @@ -98,6 +109,9 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) { } h.inFlight = protocol.InvalidByteCount h.current = size + if h.tracer != nil && h.tracer.UpdatedMTU != nil { + h.tracer.UpdatedMTU(size, h.done()) + } h.mtuIncreased(size) } diff --git a/mtu_discoverer_test.go b/mtu_discoverer_test.go index 497a2cc74..fc42f5396 100644 --- a/mtu_discoverer_test.go +++ b/mtu_discoverer_test.go @@ -6,6 +6,7 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -29,7 +30,13 @@ var _ = Describe("MTU Discoverer", func() { rttStats = &utils.RTTStats{} rttStats.SetInitialRTT(rtt) Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) - d = newMTUDiscoverer(rttStats, startMTU, maxMTU, func(s protocol.ByteCount) { discoveredMTU = s }) + d = newMTUDiscoverer( + rttStats, + startMTU, + maxMTU, + func(s protocol.ByteCount) { discoveredMTU = s }, + nil, + ) d.Start() now = time.Now() }) @@ -78,7 +85,7 @@ var _ = Describe("MTU Discoverer", func() { }) It("doesn't do discovery before being started", func() { - d := newMTUDiscoverer(rttStats, startMTU, protocol.MaxByteCount, func(s protocol.ByteCount) {}) + d := newMTUDiscoverer(rttStats, startMTU, protocol.MaxByteCount, func(s protocol.ByteCount) {}, nil) for i := 0; i < 5; i++ { Expect(d.ShouldSendProbe(time.Now())).To(BeFalse()) } @@ -90,7 +97,20 @@ var _ = Describe("MTU Discoverer", func() { for i := 0; i < rep; i++ { maxMTU := protocol.ByteCount(rand.Intn(int(3000-startMTU))) + startMTU + 1 currentMTU := startMTU - d := newMTUDiscoverer(rttStats, startMTU, maxMTU, func(s protocol.ByteCount) { currentMTU = s }) + var tracedMTU protocol.ByteCount + var tracerDone bool + d := newMTUDiscoverer( + rttStats, + startMTU, + maxMTU, + func(s protocol.ByteCount) { currentMTU = s }, + &logging.ConnectionTracer{ + UpdatedMTU: func(mtu logging.ByteCount, done bool) { + tracedMTU = mtu + tracerDone = done + }, + }, + ) d.Start() now := time.Now() realMTU := protocol.ByteCount(rand.Intn(int(maxMTU-startMTU))) + startMTU @@ -113,6 +133,10 @@ var _ = Describe("MTU Discoverer", func() { diff := realMTU - currentMTU Expect(diff).To(BeNumerically(">=", 0)) maxDiff = max(maxDiff, diff) + if maxMTU > currentMTU+maxMTU { + Expect(tracedMTU).To(Equal(currentMTU)) + Expect(tracerDone).To(BeTrue()) + } } Expect(maxDiff).To(BeEquivalentTo(maxMTUDiff)) }) diff --git a/qlog/connection_tracer.go b/qlog/connection_tracer.go index ce6c0db3f..3ea0ce11c 100644 --- a/qlog/connection_tracer.go +++ b/qlog/connection_tracer.go @@ -76,6 +76,9 @@ func NewConnectionTracer(w io.WriteCloser, p logging.Perspective, odcid protocol LostPacket: func(encLevel protocol.EncryptionLevel, pn protocol.PacketNumber, lossReason logging.PacketLossReason) { t.LostPacket(encLevel, pn, lossReason) }, + UpdatedMTU: func(mtu logging.ByteCount, done bool) { + t.UpdatedMTU(mtu, done) + }, UpdatedCongestionState: func(state logging.CongestionState) { t.UpdatedCongestionState(state) }, @@ -367,6 +370,10 @@ func (t *connectionTracer) LostPacket(encLevel protocol.EncryptionLevel, pn prot }) } +func (t *connectionTracer) UpdatedMTU(mtu protocol.ByteCount, done bool) { + t.recordEvent(time.Now(), &eventMTUUpdated{mtu: mtu, done: done}) +} + func (t *connectionTracer) UpdatedCongestionState(state logging.CongestionState) { t.recordEvent(time.Now(), &eventCongestionStateUpdated{state: congestionState(state)}) } diff --git a/qlog/connection_tracer_test.go b/qlog/connection_tracer_test.go index e2c27e077..a6207c777 100644 --- a/qlog/connection_tracer_test.go +++ b/qlog/connection_tracer_test.go @@ -94,7 +94,7 @@ var _ = Describe("Tracing", func() { Expect(m).To(HaveKey("title")) Expect(m).To(HaveKey("trace")) trace := m["trace"].(map[string]interface{}) - Expect(trace).To(HaveKey(("common_fields"))) + Expect(trace).To(HaveKey("common_fields")) commonFields := trace["common_fields"].(map[string]interface{}) Expect(commonFields).To(HaveKeyWithValue("ODCID", "deadbeef")) Expect(commonFields).To(HaveKeyWithValue("group_id", "deadbeef")) @@ -723,6 +723,17 @@ var _ = Describe("Tracing", func() { Expect(ev).To(HaveKeyWithValue("trigger", "reordering_threshold")) }) + It("records MTU discovery updates", func() { + tracer.UpdatedMTU(1337, true) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:mtu_updated")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("mtu", float64(1337))) + Expect(ev).To(HaveKeyWithValue("done", true)) + }) + It("records congestion state updates", func() { tracer.UpdatedCongestionState(logging.CongestionStateCongestionAvoidance) tracer.Close() diff --git a/qlog/event.go b/qlog/event.go index c70ee5486..7602b55ce 100644 --- a/qlog/event.go +++ b/qlog/event.go @@ -294,6 +294,20 @@ type metrics struct { PacketsInFlight int } +type eventMTUUpdated struct { + mtu protocol.ByteCount + done bool +} + +func (e eventMTUUpdated) Category() category { return categoryRecovery } +func (e eventMTUUpdated) Name() string { return "mtu_updated" } +func (e eventMTUUpdated) IsNil() bool { return false } + +func (e eventMTUUpdated) MarshalJSONObject(enc *gojay.Encoder) { + enc.Uint64Key("mtu", uint64(e.mtu)) + enc.BoolKey("done", e.done) +} + type eventMetricsUpdated struct { Last *metrics Current *metrics