From 04642c9e4d7aec3cd8c02c6088418ba7c195f90d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 27 Jan 2021 12:48:15 +0800 Subject: [PATCH] add a MTU discoverer that performs a binary search to determine the MTU --- mtu_discoverer.go | 76 +++++++++++++++++++++++++++ mtu_discoverer_test.go | 116 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+) create mode 100644 mtu_discoverer.go create mode 100644 mtu_discoverer_test.go diff --git a/mtu_discoverer.go b/mtu_discoverer.go new file mode 100644 index 000000000..d2c6c222a --- /dev/null +++ b/mtu_discoverer.go @@ -0,0 +1,76 @@ +package quic + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +const ( + // At some point, we have to stop searching for a higher MTU. + // We're happy to send a packet that's 10 bytes smaller than the actual MTU. + maxMTUDiff = 20 + // send a probe packet every mtuProbeDelay RTTs + mtuProbeDelay = 5 +) + +type mtuDiscoverer struct { + lastProbeTime time.Time + probeInFlight bool + mtuIncreased func(protocol.ByteCount) + + rttStats *utils.RTTStats + current protocol.ByteCount + max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer) +} + +func newMTUDiscoverer(rttStats *utils.RTTStats, start, max protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuDiscoverer { + return &mtuDiscoverer{ + current: start, + rttStats: rttStats, + lastProbeTime: time.Now(), // to make sure the first probe packet is not sent immediately + mtuIncreased: mtuIncreased, + max: max, + } +} + +func (d *mtuDiscoverer) done() bool { + return d.max-d.current <= maxMTUDiff+1 +} + +func (d *mtuDiscoverer) ShouldSendProbe(now time.Time) bool { + if d.probeInFlight || d.done() { + return false + } + return !now.Before(d.NextProbeTime()) +} + +// NextProbeTime returns the time when the next probe packet should be sent. +// It returns the zero value if no probe packet should be sent. +func (d *mtuDiscoverer) NextProbeTime() time.Time { + if d.probeInFlight || d.done() { + return time.Time{} + } + return d.lastProbeTime.Add(mtuProbeDelay * d.rttStats.SmoothedRTT()) +} + +func (d *mtuDiscoverer) GetPing() (ackhandler.Frame, protocol.ByteCount) { + size := (d.max + d.current) / 2 + d.lastProbeTime = time.Now() + d.probeInFlight = true + return ackhandler.Frame{ + Frame: &wire.PingFrame{}, + OnLost: func(wire.Frame) { + d.probeInFlight = false + d.max = size + }, + OnAcked: func(wire.Frame) { + d.probeInFlight = false + d.current = size + d.mtuIncreased(size) + }, + }, size +} diff --git a/mtu_discoverer_test.go b/mtu_discoverer_test.go new file mode 100644 index 000000000..60182e64e --- /dev/null +++ b/mtu_discoverer_test.go @@ -0,0 +1,116 @@ +package quic + +import ( + "math/rand" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + + "github.com/lucas-clemente/quic-go/internal/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("MTU Discoverer", func() { + const ( + rtt = 100 * time.Millisecond + startMTU protocol.ByteCount = 1000 + maxMTU protocol.ByteCount = 2000 + ) + + var ( + d *mtuDiscoverer + rttStats *utils.RTTStats + now time.Time + discoveredMTU protocol.ByteCount + ) + + BeforeEach(func() { + rttStats = &utils.RTTStats{} + rttStats.SetInitialRTT(rtt) + Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) + d = newMTUDiscoverer(rttStats, startMTU, maxMTU, func(s protocol.ByteCount) { discoveredMTU = s }) + now = time.Now() + _ = discoveredMTU + }) + + It("only allows a probe 5 RTTs after the handshake completes", func() { + Expect(d.ShouldSendProbe(now)).To(BeFalse()) + Expect(d.ShouldSendProbe(now.Add(rtt * 9 / 2))).To(BeFalse()) + Expect(d.NextProbeTime()).To(BeTemporally("~", now.Add(5*rtt), scaleDuration(20*time.Millisecond))) + Expect(d.ShouldSendProbe(now.Add(rtt * 5))).To(BeTrue()) + }) + + It("doesn't allow a probe if another probe is still in flight", func() { + ping, _ := d.GetPing() + Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeFalse()) + Expect(d.NextProbeTime()).To(BeZero()) + ping.OnLost(ping.Frame) + Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeTrue()) + Expect(d.NextProbeTime()).ToNot(BeZero()) + }) + + It("tries a lower size when a probe is lost", func() { + ping, size := d.GetPing() + Expect(size).To(Equal(protocol.ByteCount(1500))) + ping.OnLost(ping.Frame) + _, size = d.GetPing() + Expect(size).To(Equal(protocol.ByteCount(1250))) + }) + + It("tries a higher size and calls the callback when a probe is acknowledged", func() { + ping, size := d.GetPing() + Expect(size).To(Equal(protocol.ByteCount(1500))) + ping.OnAcked(ping.Frame) + Expect(discoveredMTU).To(Equal(protocol.ByteCount(1500))) + _, size = d.GetPing() + Expect(size).To(Equal(protocol.ByteCount(1750))) + }) + + It("stops discovery after getting close enough to the MTU", func() { + var sizes []protocol.ByteCount + t := now.Add(5 * rtt) + for d.ShouldSendProbe(t) { + ping, size := d.GetPing() + ping.OnAcked(ping.Frame) + sizes = append(sizes, size) + t = t.Add(5 * rtt) + } + Expect(sizes).To(Equal([]protocol.ByteCount{1500, 1750, 1875, 1937, 1968, 1984})) + Expect(d.ShouldSendProbe(t.Add(10 * rtt))).To(BeFalse()) + Expect(d.NextProbeTime()).To(BeZero()) + }) + + It("finds the MTU", func() { + const rep = 3000 + var maxDiff protocol.ByteCount + for i := 0; i < rep; i++ { + max := protocol.ByteCount(rand.Intn(int(3000-startMTU))) + startMTU + 1 + currentMTU := startMTU + d := newMTUDiscoverer(rttStats, startMTU, max, func(s protocol.ByteCount) { currentMTU = s }) + now := time.Now() + realMTU := protocol.ByteCount(rand.Intn(int(max-startMTU))) + startMTU + t := now.Add(mtuProbeDelay * rtt) + var count int + for d.ShouldSendProbe(t) { + if count > 25 { + Fail("too many iterations") + } + count++ + + ping, size := d.GetPing() + if size <= realMTU { + ping.OnAcked(ping.Frame) + } else { + ping.OnLost(ping.Frame) + } + t = t.Add(mtuProbeDelay * rtt) + } + diff := realMTU - currentMTU + Expect(diff).To(BeNumerically(">=", 0)) + maxDiff = utils.MaxByteCount(maxDiff, diff) + } + Expect(maxDiff).To(BeEquivalentTo(maxMTUDiff)) + }) +})