From 017ed984f3c3ad9afe908e3fec2a8834872f696b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 21 Mar 2025 12:28:19 +0700 Subject: [PATCH] ackhandler: fix panic in probe packet tracking logic (#4998) Under certain circumstances (loss and acknowledgment patterns), the probe packet tracking logic could run into a nil-pointer dereference. --- internal/ackhandler/sent_packet_handler.go | 7 +- .../ackhandler/sent_packet_handler_test.go | 95 +++++++++++++++++++ 2 files changed, 98 insertions(+), 4 deletions(-) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 7c3cf8927..83d2736ac 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -460,10 +460,10 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL } if p.isPathProbePacket { probePacket := pnSpace.history.RemovePathProbe(p.PacketNumber) - if probePacket == nil { - panic(fmt.Sprintf("path probe doesn't exist: %d", p.PacketNumber)) + // the probe packet might already have been declared lost + if probePacket != nil { + h.ackedPackets = append(h.ackedPackets, probePacket) } - h.ackedPackets = append(h.ackedPackets, probePacket) continue } h.ackedPackets = append(h.ackedPackets, p) @@ -658,7 +658,6 @@ func (h *sentPacketHandler) detectLostPathProbes(now time.Time) { for _, f := range p.Frames { f.Handler.OnLost(f.Frame) } - h.appDataPackets.history.Remove(p.PacketNumber) h.appDataPackets.history.RemovePathProbe(p.PacketNumber) } } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index de1335ed8..54d5e6335 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -1,7 +1,9 @@ package ackhandler import ( + "encoding/binary" "fmt" + "math/rand/v2" "slices" "testing" "time" @@ -39,6 +41,11 @@ type packetTracker struct { Lost []protocol.PacketNumber } +func (t *packetTracker) Reset() { + t.Acked = nil + t.Lost = nil +} + func (t *packetTracker) NewPingFrame(pn protocol.PacketNumber) Frame { return Frame{ Frame: &wire.PingFrame{}, @@ -1227,3 +1234,91 @@ func TestSentPacketHandlerPathProbeAckAndLoss(t *testing.T) { require.Equal(t, t2.Add(pathProbePacketLossTimeout), sph.GetLossDetectionTimeout()) } + +// The packet tracking logic is pretty complex. +// We test it with a randomized approach, to make sure that it doesn't panic under any circumstances. +func TestSentPacketHandlerRandomized(t *testing.T) { + seed := uint64(time.Now().UnixNano()) + for i := range 5 { + t.Run(fmt.Sprintf("run %d (seed %d)", i+1, seed), func(t *testing.T) { + testSentPacketHandlerRandomized(t, seed) + }) + seed++ + } +} + +func testSentPacketHandlerRandomized(t *testing.T, seed uint64) { + var b [32]byte + binary.BigEndian.PutUint64(b[:], seed) + r := rand.New(rand.NewChaCha8(b)) + + var rttStats utils.RTTStats + rtt := []time.Duration{10 * time.Millisecond, 100 * time.Millisecond, 1000 * time.Millisecond}[r.IntN(3)] + t.Logf("rtt: %dms", rtt.Milliseconds()) + rttStats.UpdateRTT(rtt, 0) // RTT of the original path + + randDuration := func(min, max time.Duration) time.Duration { + return time.Duration(rand.Int64N(int64(max-min))) + min + } + + sph := newSentPacketHandler( + 0, + 1200, + &rttStats, + true, + false, + protocol.PerspectiveClient, + nil, + utils.DefaultLogger, + ) + sph.DropPackets(protocol.EncryptionInitial, time.Now()) + sph.DropPackets(protocol.EncryptionHandshake, time.Now()) + + var packets packetTracker + sendPacket := func(ti time.Time, isPathProbe bool) protocol.PacketNumber { + pn := sph.PopPacketNumber(protocol.Encryption1RTT) + sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.Encryption1RTT, protocol.ECNNon, 1200, false, isPathProbe) + return pn + } + + now := time.Now() + start := now + var pns []protocol.PacketNumber + for range 4 { + isProbe := r.Int()%2 == 0 + pn := sendPacket(now, isProbe) + t.Logf("t=%dms: sending packet %d (probe packet: %t)", now.Sub(start).Milliseconds(), pn, isProbe) + pns = append(pns, pn) + now = now.Add(randDuration(0, 500*time.Millisecond)) + if r.Int()%3 == 0 { + sph.OnLossDetectionTimeout(now) + t.Logf("t=%dms: loss detection timeout (lost: %v)", now.Sub(start).Milliseconds(), packets.Lost) + packets.Reset() + now = now.Add(randDuration(0, 500*time.Millisecond)) + } + if r.Int()%3 == 0 { + // acknowledge up to 2 random packet numbers from the pns slice + var ackPns []protocol.PacketNumber + if len(pns) > 0 { + numToAck := min(1+r.IntN(2), len(pns)) + for range numToAck { + ackPns = append(ackPns, pns[r.IntN(len(pns))]) + } + } + if len(ackPns) > 1 { + slices.Sort(ackPns) + ackPns = slices.Compact(ackPns) + } + sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(ackPns...)}, protocol.Encryption1RTT, now) + t.Logf("t=%dms: received ACK for packets %v (acked: %v, lost: %v)", now.Sub(start).Milliseconds(), ackPns, packets.Acked, packets.Lost) + packets.Reset() + now = now.Add(randDuration(0, 500*time.Millisecond)) + } + if r.Int()%10 == 0 { + sph.MigratedPath(now, 1200) + now = now.Add(randDuration(0, 500*time.Millisecond)) + } + } + t.Logf("t=%dms: loss detection timeout (lost: %v)", now.Sub(start).Milliseconds(), packets.Lost) + sph.OnLossDetectionTimeout(now) +}