ackhandler: fix panic in probe packet tracking logic (#4998)

Under certain circumstances (loss and acknowledgment patterns), the
probe packet tracking logic could run into a nil-pointer dereference.
This commit is contained in:
Marten Seemann
2025-03-21 12:28:19 +07:00
committed by GitHub
parent da127a9a53
commit 017ed984f3
2 changed files with 98 additions and 4 deletions

View File

@@ -460,10 +460,10 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
}
if p.isPathProbePacket {
probePacket := pnSpace.history.RemovePathProbe(p.PacketNumber)
if probePacket == nil {
panic(fmt.Sprintf("path probe doesn't exist: %d", p.PacketNumber))
// the probe packet might already have been declared lost
if probePacket != nil {
h.ackedPackets = append(h.ackedPackets, probePacket)
}
h.ackedPackets = append(h.ackedPackets, probePacket)
continue
}
h.ackedPackets = append(h.ackedPackets, p)
@@ -658,7 +658,6 @@ func (h *sentPacketHandler) detectLostPathProbes(now time.Time) {
for _, f := range p.Frames {
f.Handler.OnLost(f.Frame)
}
h.appDataPackets.history.Remove(p.PacketNumber)
h.appDataPackets.history.RemovePathProbe(p.PacketNumber)
}
}

View File

@@ -1,7 +1,9 @@
package ackhandler
import (
"encoding/binary"
"fmt"
"math/rand/v2"
"slices"
"testing"
"time"
@@ -39,6 +41,11 @@ type packetTracker struct {
Lost []protocol.PacketNumber
}
func (t *packetTracker) Reset() {
t.Acked = nil
t.Lost = nil
}
func (t *packetTracker) NewPingFrame(pn protocol.PacketNumber) Frame {
return Frame{
Frame: &wire.PingFrame{},
@@ -1227,3 +1234,91 @@ func TestSentPacketHandlerPathProbeAckAndLoss(t *testing.T) {
require.Equal(t, t2.Add(pathProbePacketLossTimeout), sph.GetLossDetectionTimeout())
}
// The packet tracking logic is pretty complex.
// We test it with a randomized approach, to make sure that it doesn't panic under any circumstances.
func TestSentPacketHandlerRandomized(t *testing.T) {
seed := uint64(time.Now().UnixNano())
for i := range 5 {
t.Run(fmt.Sprintf("run %d (seed %d)", i+1, seed), func(t *testing.T) {
testSentPacketHandlerRandomized(t, seed)
})
seed++
}
}
func testSentPacketHandlerRandomized(t *testing.T, seed uint64) {
var b [32]byte
binary.BigEndian.PutUint64(b[:], seed)
r := rand.New(rand.NewChaCha8(b))
var rttStats utils.RTTStats
rtt := []time.Duration{10 * time.Millisecond, 100 * time.Millisecond, 1000 * time.Millisecond}[r.IntN(3)]
t.Logf("rtt: %dms", rtt.Milliseconds())
rttStats.UpdateRTT(rtt, 0) // RTT of the original path
randDuration := func(min, max time.Duration) time.Duration {
return time.Duration(rand.Int64N(int64(max-min))) + min
}
sph := newSentPacketHandler(
0,
1200,
&rttStats,
true,
false,
protocol.PerspectiveClient,
nil,
utils.DefaultLogger,
)
sph.DropPackets(protocol.EncryptionInitial, time.Now())
sph.DropPackets(protocol.EncryptionHandshake, time.Now())
var packets packetTracker
sendPacket := func(ti time.Time, isPathProbe bool) protocol.PacketNumber {
pn := sph.PopPacketNumber(protocol.Encryption1RTT)
sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.Encryption1RTT, protocol.ECNNon, 1200, false, isPathProbe)
return pn
}
now := time.Now()
start := now
var pns []protocol.PacketNumber
for range 4 {
isProbe := r.Int()%2 == 0
pn := sendPacket(now, isProbe)
t.Logf("t=%dms: sending packet %d (probe packet: %t)", now.Sub(start).Milliseconds(), pn, isProbe)
pns = append(pns, pn)
now = now.Add(randDuration(0, 500*time.Millisecond))
if r.Int()%3 == 0 {
sph.OnLossDetectionTimeout(now)
t.Logf("t=%dms: loss detection timeout (lost: %v)", now.Sub(start).Milliseconds(), packets.Lost)
packets.Reset()
now = now.Add(randDuration(0, 500*time.Millisecond))
}
if r.Int()%3 == 0 {
// acknowledge up to 2 random packet numbers from the pns slice
var ackPns []protocol.PacketNumber
if len(pns) > 0 {
numToAck := min(1+r.IntN(2), len(pns))
for range numToAck {
ackPns = append(ackPns, pns[r.IntN(len(pns))])
}
}
if len(ackPns) > 1 {
slices.Sort(ackPns)
ackPns = slices.Compact(ackPns)
}
sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(ackPns...)}, protocol.Encryption1RTT, now)
t.Logf("t=%dms: received ACK for packets %v (acked: %v, lost: %v)", now.Sub(start).Milliseconds(), ackPns, packets.Acked, packets.Lost)
packets.Reset()
now = now.Add(randDuration(0, 500*time.Millisecond))
}
if r.Int()%10 == 0 {
sph.MigratedPath(now, 1200)
now = now.Add(randDuration(0, 500*time.Millisecond))
}
}
t.Logf("t=%dms: loss detection timeout (lost: %v)", now.Sub(start).Milliseconds(), packets.Lost)
sph.OnLossDetectionTimeout(now)
}