forked from quic-go/quic-go
ackhandler: fix panic in probe packet tracking logic (#4998)
Under certain circumstances (loss and acknowledgment patterns), the probe packet tracking logic could run into a nil-pointer dereference.
This commit is contained in:
@@ -460,10 +460,10 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
|
||||
}
|
||||
if p.isPathProbePacket {
|
||||
probePacket := pnSpace.history.RemovePathProbe(p.PacketNumber)
|
||||
if probePacket == nil {
|
||||
panic(fmt.Sprintf("path probe doesn't exist: %d", p.PacketNumber))
|
||||
// the probe packet might already have been declared lost
|
||||
if probePacket != nil {
|
||||
h.ackedPackets = append(h.ackedPackets, probePacket)
|
||||
}
|
||||
h.ackedPackets = append(h.ackedPackets, probePacket)
|
||||
continue
|
||||
}
|
||||
h.ackedPackets = append(h.ackedPackets, p)
|
||||
@@ -658,7 +658,6 @@ func (h *sentPacketHandler) detectLostPathProbes(now time.Time) {
|
||||
for _, f := range p.Frames {
|
||||
f.Handler.OnLost(f.Frame)
|
||||
}
|
||||
h.appDataPackets.history.Remove(p.PacketNumber)
|
||||
h.appDataPackets.history.RemovePathProbe(p.PacketNumber)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -39,6 +41,11 @@ type packetTracker struct {
|
||||
Lost []protocol.PacketNumber
|
||||
}
|
||||
|
||||
func (t *packetTracker) Reset() {
|
||||
t.Acked = nil
|
||||
t.Lost = nil
|
||||
}
|
||||
|
||||
func (t *packetTracker) NewPingFrame(pn protocol.PacketNumber) Frame {
|
||||
return Frame{
|
||||
Frame: &wire.PingFrame{},
|
||||
@@ -1227,3 +1234,91 @@ func TestSentPacketHandlerPathProbeAckAndLoss(t *testing.T) {
|
||||
|
||||
require.Equal(t, t2.Add(pathProbePacketLossTimeout), sph.GetLossDetectionTimeout())
|
||||
}
|
||||
|
||||
// The packet tracking logic is pretty complex.
|
||||
// We test it with a randomized approach, to make sure that it doesn't panic under any circumstances.
|
||||
func TestSentPacketHandlerRandomized(t *testing.T) {
|
||||
seed := uint64(time.Now().UnixNano())
|
||||
for i := range 5 {
|
||||
t.Run(fmt.Sprintf("run %d (seed %d)", i+1, seed), func(t *testing.T) {
|
||||
testSentPacketHandlerRandomized(t, seed)
|
||||
})
|
||||
seed++
|
||||
}
|
||||
}
|
||||
|
||||
func testSentPacketHandlerRandomized(t *testing.T, seed uint64) {
|
||||
var b [32]byte
|
||||
binary.BigEndian.PutUint64(b[:], seed)
|
||||
r := rand.New(rand.NewChaCha8(b))
|
||||
|
||||
var rttStats utils.RTTStats
|
||||
rtt := []time.Duration{10 * time.Millisecond, 100 * time.Millisecond, 1000 * time.Millisecond}[r.IntN(3)]
|
||||
t.Logf("rtt: %dms", rtt.Milliseconds())
|
||||
rttStats.UpdateRTT(rtt, 0) // RTT of the original path
|
||||
|
||||
randDuration := func(min, max time.Duration) time.Duration {
|
||||
return time.Duration(rand.Int64N(int64(max-min))) + min
|
||||
}
|
||||
|
||||
sph := newSentPacketHandler(
|
||||
0,
|
||||
1200,
|
||||
&rttStats,
|
||||
true,
|
||||
false,
|
||||
protocol.PerspectiveClient,
|
||||
nil,
|
||||
utils.DefaultLogger,
|
||||
)
|
||||
sph.DropPackets(protocol.EncryptionInitial, time.Now())
|
||||
sph.DropPackets(protocol.EncryptionHandshake, time.Now())
|
||||
|
||||
var packets packetTracker
|
||||
sendPacket := func(ti time.Time, isPathProbe bool) protocol.PacketNumber {
|
||||
pn := sph.PopPacketNumber(protocol.Encryption1RTT)
|
||||
sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.Encryption1RTT, protocol.ECNNon, 1200, false, isPathProbe)
|
||||
return pn
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
start := now
|
||||
var pns []protocol.PacketNumber
|
||||
for range 4 {
|
||||
isProbe := r.Int()%2 == 0
|
||||
pn := sendPacket(now, isProbe)
|
||||
t.Logf("t=%dms: sending packet %d (probe packet: %t)", now.Sub(start).Milliseconds(), pn, isProbe)
|
||||
pns = append(pns, pn)
|
||||
now = now.Add(randDuration(0, 500*time.Millisecond))
|
||||
if r.Int()%3 == 0 {
|
||||
sph.OnLossDetectionTimeout(now)
|
||||
t.Logf("t=%dms: loss detection timeout (lost: %v)", now.Sub(start).Milliseconds(), packets.Lost)
|
||||
packets.Reset()
|
||||
now = now.Add(randDuration(0, 500*time.Millisecond))
|
||||
}
|
||||
if r.Int()%3 == 0 {
|
||||
// acknowledge up to 2 random packet numbers from the pns slice
|
||||
var ackPns []protocol.PacketNumber
|
||||
if len(pns) > 0 {
|
||||
numToAck := min(1+r.IntN(2), len(pns))
|
||||
for range numToAck {
|
||||
ackPns = append(ackPns, pns[r.IntN(len(pns))])
|
||||
}
|
||||
}
|
||||
if len(ackPns) > 1 {
|
||||
slices.Sort(ackPns)
|
||||
ackPns = slices.Compact(ackPns)
|
||||
}
|
||||
sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(ackPns...)}, protocol.Encryption1RTT, now)
|
||||
t.Logf("t=%dms: received ACK for packets %v (acked: %v, lost: %v)", now.Sub(start).Milliseconds(), ackPns, packets.Acked, packets.Lost)
|
||||
packets.Reset()
|
||||
now = now.Add(randDuration(0, 500*time.Millisecond))
|
||||
}
|
||||
if r.Int()%10 == 0 {
|
||||
sph.MigratedPath(now, 1200)
|
||||
now = now.Add(randDuration(0, 500*time.Millisecond))
|
||||
}
|
||||
}
|
||||
t.Logf("t=%dms: loss detection timeout (lost: %v)", now.Sub(start).Milliseconds(), packets.Lost)
|
||||
sph.OnLossDetectionTimeout(now)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user