From ac87292e879e52f2dfa5a3b774e3251fa3df4c21 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 31 Jan 2021 14:27:25 +0800 Subject: [PATCH] start path MTU discovery when the handshake completes --- integrationtests/self/packetization_test.go | 7 +- integrationtests/self/timeout_test.go | 39 +++++----- mock_mtu_discoverer_test.go | 80 +++++++++++++++++++++ mock_packer_test.go | 28 ++++++++ mockgen.go | 1 + packet_packer.go | 3 + session.go | 48 ++++++++++--- session_test.go | 30 +++++++- 8 files changed, 208 insertions(+), 28 deletions(-) create mode 100644 mock_mtu_discoverer_test.go diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index 42084155..3d8f8a2e 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -29,7 +29,10 @@ var _ = Describe("Packetization", func() { server, err = quic.ListenAddr( "localhost:0", getTLSConfig(), - getQuicConfig(&quic.Config{AcceptToken: func(net.Addr, *quic.Token) bool { return true }}), + getQuicConfig(&quic.Config{ + AcceptToken: func(net.Addr, *quic.Token) bool { return true }, + DisablePathMTUDiscovery: true, + }), ) Expect(err).ToNot(HaveOccurred()) serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) @@ -64,7 +67,7 @@ var _ = Describe("Packetization", func() { sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), - getQuicConfig(nil), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index df67edd5..e312b9ba 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -130,7 +130,7 @@ var _ = Describe("Timeout tests", func() { server, err := quic.ListenAddr( "localhost:0", getTLSConfig(), - getQuicConfig(nil), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) Expect(err).ToNot(HaveOccurred()) defer server.Close() @@ -159,7 +159,7 @@ var _ = Describe("Timeout tests", func() { sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), - getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout}), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}), ) Expect(err).ToNot(HaveOccurred()) strIn, err := sess.AcceptStream(context.Background()) @@ -200,7 +200,7 @@ var _ = Describe("Timeout tests", func() { server, err := quic.ListenAddr( "localhost:0", getTLSConfig(), - getQuicConfig(nil), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) Expect(err).ToNot(HaveOccurred()) defer server.Close() @@ -218,7 +218,11 @@ var _ = Describe("Timeout tests", func() { sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, Tracer: newTracer(func() logging.ConnectionTracer { return tr })}), + getQuicConfig(&quic.Config{ + MaxIdleTimeout: idleTimeout, + Tracer: newTracer(func() logging.ConnectionTracer { return tr }), + DisablePathMTUDiscovery: true, + }), ) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) @@ -246,7 +250,7 @@ var _ = Describe("Timeout tests", func() { server, err := quic.ListenAddr( "localhost:0", getTLSConfig(), - getQuicConfig(nil), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) Expect(err).ToNot(HaveOccurred()) defer server.Close() @@ -276,7 +280,7 @@ var _ = Describe("Timeout tests", func() { sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), - getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout}), + getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}), ) Expect(err).ToNot(HaveOccurred()) @@ -317,7 +321,7 @@ var _ = Describe("Timeout tests", func() { server, err := quic.ListenAddr( "localhost:0", getTLSConfig(), - getQuicConfig(nil), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) Expect(err).ToNot(HaveOccurred()) defer server.Close() @@ -345,8 +349,9 @@ var _ = Describe("Timeout tests", func() { fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ - MaxIdleTimeout: idleTimeout, - KeepAlive: true, + MaxIdleTimeout: idleTimeout, + KeepAlive: true, + DisablePathMTUDiscovery: true, }), ) Expect(err).ToNot(HaveOccurred()) @@ -417,7 +422,7 @@ var _ = Describe("Timeout tests", func() { ln, err := quic.Listen( &faultyConn{PacketConn: conn, MaxPackets: maxPackets}, getTLSConfig(), - getQuicConfig(nil), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) Expect(err).ToNot(HaveOccurred()) @@ -434,8 +439,9 @@ var _ = Describe("Timeout tests", func() { fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{ - HandshakeIdleTimeout: handshakeTimeout, - MaxIdleTimeout: handshakeTimeout, + HandshakeIdleTimeout: handshakeTimeout, + MaxIdleTimeout: handshakeTimeout, + DisablePathMTUDiscovery: true, }), ) if err != nil { @@ -467,9 +473,10 @@ var _ = Describe("Timeout tests", func() { "localhost:0", getTLSConfig(), getQuicConfig(&quic.Config{ - HandshakeIdleTimeout: handshakeTimeout, - MaxIdleTimeout: handshakeTimeout, - KeepAlive: true, + HandshakeIdleTimeout: handshakeTimeout, + MaxIdleTimeout: handshakeTimeout, + KeepAlive: true, + DisablePathMTUDiscovery: true, }), ) Expect(err).ToNot(HaveOccurred()) @@ -494,7 +501,7 @@ var _ = Describe("Timeout tests", func() { ln.Addr(), "localhost", getTLSClientConfig(), - getQuicConfig(nil), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) if err != nil { clientErrChan <- err diff --git a/mock_mtu_discoverer_test.go b/mock_mtu_discoverer_test.go new file mode 100644 index 00000000..530180a4 --- /dev/null +++ b/mock_mtu_discoverer_test.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: mtu_discoverer.go + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/lucas-clemente/quic-go/internal/ackhandler" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockMtuDiscoverer is a mock of MtuDiscoverer interface. +type MockMtuDiscoverer struct { + ctrl *gomock.Controller + recorder *MockMtuDiscovererMockRecorder +} + +// MockMtuDiscovererMockRecorder is the mock recorder for MockMtuDiscoverer. +type MockMtuDiscovererMockRecorder struct { + mock *MockMtuDiscoverer +} + +// NewMockMtuDiscoverer creates a new mock instance. +func NewMockMtuDiscoverer(ctrl *gomock.Controller) *MockMtuDiscoverer { + mock := &MockMtuDiscoverer{ctrl: ctrl} + mock.recorder = &MockMtuDiscovererMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMtuDiscoverer) EXPECT() *MockMtuDiscovererMockRecorder { + return m.recorder +} + +// GetPing mocks base method. +func (m *MockMtuDiscoverer) GetPing() (ackhandler.Frame, protocol.ByteCount) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPing") + ret0, _ := ret[0].(ackhandler.Frame) + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 +} + +// GetPing indicates an expected call of GetPing. +func (mr *MockMtuDiscovererMockRecorder) GetPing() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPing", reflect.TypeOf((*MockMtuDiscoverer)(nil).GetPing)) +} + +// NextProbeTime mocks base method. +func (m *MockMtuDiscoverer) NextProbeTime() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextProbeTime") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// NextProbeTime indicates an expected call of NextProbeTime. +func (mr *MockMtuDiscovererMockRecorder) NextProbeTime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextProbeTime", reflect.TypeOf((*MockMtuDiscoverer)(nil).NextProbeTime)) +} + +// ShouldSendProbe mocks base method. +func (m *MockMtuDiscoverer) ShouldSendProbe(now time.Time) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ShouldSendProbe", now) + ret0, _ := ret[0].(bool) + return ret0 +} + +// ShouldSendProbe indicates an expected call of ShouldSendProbe. +func (mr *MockMtuDiscovererMockRecorder) ShouldSendProbe(now interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendProbe", reflect.TypeOf((*MockMtuDiscoverer)(nil).ShouldSendProbe), now) +} diff --git a/mock_packer_test.go b/mock_packer_test.go index fc633e97..b63f0277 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -8,6 +8,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + ackhandler "github.com/lucas-clemente/quic-go/internal/ackhandler" protocol "github.com/lucas-clemente/quic-go/internal/protocol" qerr "github.com/lucas-clemente/quic-go/internal/qerr" wire "github.com/lucas-clemente/quic-go/internal/wire" @@ -108,6 +109,21 @@ func (mr *MockPackerMockRecorder) PackConnectionClose(arg0 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0) } +// PackMTUProbePacket mocks base method. +func (m *MockPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PackMTUProbePacket", ping, size) + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackMTUProbePacket indicates an expected call of PackMTUProbePacket. +func (mr *MockPackerMockRecorder) PackMTUProbePacket(ping, size interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), ping, size) +} + // PackPacket mocks base method. func (m *MockPacker) PackPacket() (*packedPacket, error) { m.ctrl.T.Helper() @@ -123,6 +139,18 @@ func (mr *MockPackerMockRecorder) PackPacket() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket)) } +// SetMaxPacketSize mocks base method. +func (m *MockPacker) SetMaxPacketSize(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetMaxPacketSize", arg0) +} + +// SetMaxPacketSize indicates an expected call of SetMaxPacketSize. +func (mr *MockPackerMockRecorder) SetMaxPacketSize(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxPacketSize", reflect.TypeOf((*MockPacker)(nil).SetMaxPacketSize), arg0) +} + // SetToken mocks base method. func (m *MockPacker) SetToken(arg0 []byte) { m.ctrl.T.Helper() diff --git a/mockgen.go b/mockgen.go index 8f6b0578..d5e9ae10 100644 --- a/mockgen.go +++ b/mockgen.go @@ -15,6 +15,7 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/lucas-clemente/quic-go sealingManager" //go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker" //go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer" +//go:generate sh -c "./mockgen_private.sh quic mock_mtu_discoverer_test.go github.com/lucas-clemente/quic-go mtuDiscoverer" //go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner" //go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession" //go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler" diff --git a/packet_packer.go b/packet_packer.go index 4eef7497..c9b94d4f 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -22,6 +22,9 @@ type packer interface { MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) PackConnectionClose(*qerr.QuicError) (*coalescedPacket, error) + SetMaxPacketSize(protocol.ByteCount) + PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) + HandleTransportParameters(*wire.TransportParameters) SetToken([]byte) } diff --git a/session.go b/session.go index eb104299..43ca8526 100644 --- a/session.go +++ b/session.go @@ -155,9 +155,10 @@ type session struct { tokenStoreKey string // only set for the client tokenGenerator *handshake.TokenGenerator // only set for the server - unpacker unpacker - frameParser wire.FrameParser - packer packer + unpacker unpacker + frameParser wire.FrameParser + packer packer + mtuDiscoverer mtuDiscoverer // initialized when the handshake completes oneRTTStream cryptoStream // only set for the server cryptoStreamHandler cryptoStreamHandler @@ -731,6 +732,11 @@ func (s *session) maybeResetTimer() { } else { deadline = s.idleTimeoutStartTime().Add(s.idleTimeout) } + if !s.config.DisablePathMTUDiscovery { + if probeTime := s.mtuDiscoverer.NextProbeTime(); !probeTime.IsZero() { + deadline = utils.MinTime(deadline, probeTime) + } + } } if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() { @@ -761,6 +767,23 @@ func (s *session) handleHandshakeComplete() { s.connIDManager.SetHandshakeComplete() s.connIDGenerator.SetHandshakeComplete() + if !s.config.DisablePathMTUDiscovery { + maxPacketSize := s.peerParams.MaxUDPPayloadSize + if maxPacketSize == 0 { + maxPacketSize = protocol.MaxByteCount + } + maxPacketSize = utils.MinByteCount(maxPacketSize, protocol.MaxPacketBufferSize) + s.mtuDiscoverer = newMTUDiscoverer( + s.rttStats, + getMaxPacketSize(s.conn.RemoteAddr()), + maxPacketSize, + func(size protocol.ByteCount) { + s.sentPacketHandler.SetMaxDatagramSize(size) + s.packer.SetMaxPacketSize(size) + }, + ) + } + if s.perspective == protocol.PerspectiveServer { s.handshakeConfirmed = true s.sentPacketHandler.SetHandshakeConfirmed() @@ -1584,7 +1607,7 @@ func (s *session) maybeSendAckOnlyPacket() error { if packet == nil { return nil } - s.sendPackedPacket(packet) + s.sendPackedPacket(packet, time.Now()) return nil } @@ -1626,7 +1649,7 @@ func (s *session) sendProbePacket(encLevel protocol.EncryptionLevel) error { if packet == nil || packet.packetContents == nil { return fmt.Errorf("session BUG: couldn't pack %s probe packet", encLevel) } - s.sendPackedPacket(packet) + s.sendPackedPacket(packet, time.Now()) return nil } @@ -1636,8 +1659,8 @@ func (s *session) sendPacket() (bool, error) { } s.windowUpdateQueue.QueueAll() + now := time.Now() if !s.handshakeConfirmed { - now := time.Now() packet, err := s.packer.PackCoalescedPacket() if err != nil || packet == nil { return false, err @@ -1653,16 +1676,23 @@ func (s *session) sendPacket() (bool, error) { s.sendQueue.Send(packet.buffer) return true, nil } + if !s.config.DisablePathMTUDiscovery && s.handshakeComplete && s.mtuDiscoverer.ShouldSendProbe(now) { + packet, err := s.packer.PackMTUProbePacket(s.mtuDiscoverer.GetPing()) + if err != nil { + return false, err + } + s.sendPackedPacket(packet, now) + return true, nil + } packet, err := s.packer.PackPacket() if err != nil || packet == nil { return false, err } - s.sendPackedPacket(packet) + s.sendPackedPacket(packet, now) return true, nil } -func (s *session) sendPackedPacket(packet *packedPacket) { - now := time.Now() +func (s *session) sendPackedPacket(packet *packedPacket, now time.Time) { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && packet.IsAckEliciting() { s.firstAckElicitingPacketAfterIdleSentTime = now } diff --git a/session_test.go b/session_test.go index abde2f8a..71edea2d 100644 --- a/session_test.go +++ b/session_test.go @@ -102,7 +102,7 @@ var _ = Describe("Session", func() { destConnID, srcConnID, protocol.StatelessResetToken{}, - populateServerConfig(&Config{}), + populateServerConfig(&Config{DisablePathMTUDiscovery: true}), nil, // tls.Config tokenGenerator, false, @@ -1692,6 +1692,34 @@ var _ = Describe("Session", func() { sess.scheduleSending() // no packet will get sent time.Sleep(50 * time.Millisecond) }) + + It("sends a Path MTU probe packet", func() { + mtuDiscoverer := NewMockMtuDiscoverer(mockCtrl) + sess.mtuDiscoverer = mtuDiscoverer + sess.config.DisablePathMTUDiscovery = false + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny) + sph.EXPECT().SendMode().Return(ackhandler.SendNone) + written := make(chan struct{}, 1) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) + gomock.InOrder( + mtuDiscoverer.EXPECT().NextProbeTime(), + mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true), + mtuDiscoverer.EXPECT().NextProbeTime(), + ) + ping := ackhandler.Frame{Frame: &wire.PingFrame{}} + mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) + packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234)).Return(getPacket(1), nil) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + sess.run() + }() + sess.scheduleSending() + Eventually(written).Should(Receive()) + }) }) Context("scheduling sending", func() {