Files
quic-go/conn_id_generator_test.go
Marten Seemann adc13be540 implement a memory-optimized time.Time replacement (#5334)
* implement a memory-optimized time.Time replacement

* monotime: properly handle systems with bad timer resolution (Windows)

* monotime: simplify Since
2025-09-14 08:12:10 +02:00

359 lines
13 KiB
Go

package quic
import (
"math/rand/v2"
"testing"
"time"
"github.com/quic-go/quic-go/internal/monotime"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/require"
)
func TestConnIDGeneratorIssueAndRetire(t *testing.T) {
t.Run("with initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorIssueAndRetire(t, true)
})
t.Run("without initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorIssueAndRetire(t, false)
})
}
func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID bool) {
var (
added []protocol.ConnectionID
removed []protocol.ConnectionID
)
var queuedFrames []wire.Frame
sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
var initialClientDestConnID *protocol.ConnectionID
if hasInitialClientDestConnID {
connID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
initialClientDestConnID = &connID
}
g := newConnIDGenerator(
&packetHandlerMap{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
sr,
connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) },
ReplaceWithClosed: func([]protocol.ConnectionID, []byte, time.Duration) {},
},
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
require.Empty(t, added)
require.NoError(t, g.SetMaxActiveConnIDs(4))
require.Len(t, added, 3)
require.Len(t, queuedFrames, 3)
require.Empty(t, removed)
connIDs := make(map[uint64]protocol.ConnectionID)
// connection IDs 1, 2 and 3 were issued
for i, f := range queuedFrames {
ncid := f.(*wire.NewConnectionIDFrame)
require.EqualValues(t, i+1, ncid.SequenceNumber)
require.Equal(t, ncid.ConnectionID, added[i])
require.Equal(t, ncid.StatelessResetToken, sr.GetStatelessResetToken(ncid.ConnectionID))
connIDs[ncid.SequenceNumber] = ncid.ConnectionID
}
// completing the handshake retires the initial client destination connection ID
added = added[:0]
queuedFrames = queuedFrames[:0]
now := monotime.Now()
g.SetHandshakeComplete(now)
require.Empty(t, added)
require.Empty(t, queuedFrames)
require.Empty(t, removed)
g.RemoveRetiredConnIDs(now)
if hasInitialClientDestConnID {
require.Equal(t, []protocol.ConnectionID{*initialClientDestConnID}, removed)
removed = removed[:0]
} else {
require.Empty(t, removed)
}
// it's invalid to retire a connection ID that hasn't been issued yet
err := g.Retire(4, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), monotime.Now())
require.ErrorIs(t, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}, err)
require.ErrorContains(t, err, "retired connection ID 4 (highest issued: 3)")
// it's invalid to retire a connection ID in a packet that uses that connection ID
err = g.Retire(3, connIDs[3], monotime.Now())
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
require.ErrorContains(t, err, "was used as the Destination Connection ID on this packet")
// retiring a connection ID makes us issue a new one
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), monotime.Now()))
g.RemoveRetiredConnIDs(monotime.Now())
require.Equal(t, []protocol.ConnectionID{connIDs[2]}, removed)
require.Len(t, queuedFrames, 1)
require.EqualValues(t, 4, queuedFrames[0].(*wire.NewConnectionIDFrame).SequenceNumber)
queuedFrames = queuedFrames[:0]
removed = removed[:0]
// duplicate retirements don't do anything
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), monotime.Now()))
g.RemoveRetiredConnIDs(monotime.Now())
require.Empty(t, queuedFrames)
require.Empty(t, removed)
}
func TestConnIDGeneratorRetiring(t *testing.T) {
initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
var added, removed []protocol.ConnectionID
g := newConnIDGenerator(
&packetHandlerMap{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
&initialConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) },
ReplaceWithClosed: func([]protocol.ConnectionID, []byte, time.Duration) {},
},
func(f wire.Frame) {},
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
require.NoError(t, g.SetMaxActiveConnIDs(6))
require.Empty(t, removed)
require.Len(t, added, 5)
now := monotime.Now()
retirements := map[protocol.ConnectionID]monotime.Time{}
t1 := now.Add(time.Duration(rand.IntN(1000)) * time.Millisecond)
retirements[initialConnID] = t1
g.SetHandshakeComplete(t1)
for i := range 5 {
t2 := now.Add(time.Duration(rand.IntN(1000)) * time.Millisecond)
require.NoError(t, g.Retire(uint64(i+1), protocol.ParseConnectionID([]byte{9, 9, 9, 9}), t2))
retirements[added[i]] = t2
var nextRetirement monotime.Time
for _, r := range retirements {
if nextRetirement.IsZero() || r.Before(nextRetirement) {
nextRetirement = r
}
}
require.Equal(t, nextRetirement, g.NextRetireTime())
if rand.IntN(2) == 0 {
now = now.Add(time.Duration(rand.IntN(500)) * time.Millisecond)
g.RemoveRetiredConnIDs(now)
for _, r := range removed {
require.Contains(t, retirements, r)
require.LessOrEqual(t, retirements[r], now)
delete(retirements, r)
}
removed = removed[:0]
for _, r := range retirements {
require.Greater(t, r, now)
}
}
}
}
func TestConnIDGeneratorRemoveAll(t *testing.T) {
t.Run("with initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorRemoveAll(t, true)
})
t.Run("without initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorRemoveAll(t, false)
})
}
func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool) {
var initialClientDestConnID *protocol.ConnectionID
if hasInitialClientDestConnID {
connID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
initialClientDestConnID = &connID
}
var (
added []protocol.ConnectionID
removed []protocol.ConnectionID
)
g := newConnIDGenerator(
&packetHandlerMap{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) },
ReplaceWithClosed: func([]protocol.ConnectionID, []byte, time.Duration) {},
},
func(f wire.Frame) {},
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
require.NoError(t, g.SetMaxActiveConnIDs(1000))
require.Len(t, added, protocol.MaxIssuedConnectionIDs-1)
g.RemoveAll()
if hasInitialClientDestConnID {
require.Len(t, removed, protocol.MaxIssuedConnectionIDs+1)
require.Contains(t, removed, *initialClientDestConnID)
} else {
require.Len(t, removed, protocol.MaxIssuedConnectionIDs)
}
for _, id := range added {
require.Contains(t, removed, id)
}
require.Contains(t, removed, protocol.ParseConnectionID([]byte{1, 1, 1, 1}))
}
func TestConnIDGeneratorReplaceWithClosed(t *testing.T) {
t.Run("with initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorReplaceWithClosed(t, true)
})
t.Run("without initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorReplaceWithClosed(t, false)
})
}
func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConnID bool) {
var initialClientDestConnID *protocol.ConnectionID
if hasInitialClientDestConnID {
connID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
initialClientDestConnID = &connID
}
var (
added []protocol.ConnectionID
replaced []protocol.ConnectionID
replacedWith []byte
)
g := newConnIDGenerator(
&packetHandlerMap{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") },
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, b []byte, _ time.Duration) {
replaced = connIDs
replacedWith = b
},
},
func(f wire.Frame) {},
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
require.NoError(t, g.SetMaxActiveConnIDs(1000))
require.Len(t, added, protocol.MaxIssuedConnectionIDs-1)
// Retire two of these connection ID.
// This makes us issue two more connection IDs.
require.NoError(t, g.Retire(3, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), monotime.Now()))
require.NoError(t, g.Retire(4, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), monotime.Now()))
require.Len(t, added, protocol.MaxIssuedConnectionIDs+1)
g.ReplaceWithClosed([]byte("foobar"), time.Second)
if hasInitialClientDestConnID {
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+3)
require.Contains(t, replaced, *initialClientDestConnID)
} else {
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+2)
}
for _, id := range added {
require.Contains(t, replaced, id)
}
require.Contains(t, replaced, protocol.ParseConnectionID([]byte{1, 1, 1, 1}))
require.Equal(t, []byte("foobar"), replacedWith)
}
func TestConnIDGeneratorAddConnRunner(t *testing.T) {
initialConnID := protocol.ParseConnectionID([]byte{1, 1, 1, 1})
clientDestConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
type connIDTracker struct {
added, removed, replaced []protocol.ConnectionID
}
var tracker1, tracker2, tracker3 connIDTracker
runner1 := connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) },
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte, _ time.Duration) {
tracker1.replaced = append(tracker1.replaced, connIDs...)
},
}
runner2 := connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { tracker2.added = append(tracker2.added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { tracker2.removed = append(tracker2.removed, c) },
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte, _ time.Duration) {
tracker2.replaced = append(tracker2.replaced, connIDs...)
},
}
runner3 := connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { tracker3.added = append(tracker3.added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { tracker3.removed = append(tracker3.removed, c) },
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte, _ time.Duration) {
tracker3.replaced = append(tracker3.replaced, connIDs...)
},
}
sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
var queuedFrames []wire.Frame
tr := &packetHandlerMap{}
g := newConnIDGenerator(
tr,
initialConnID,
&clientDestConnID,
sr,
runner1,
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
require.NoError(t, g.SetMaxActiveConnIDs(3))
require.Len(t, tracker1.added, 2)
// add the second runner - it should get all existing connection IDs
g.AddConnRunner(&packetHandlerMap{}, runner2)
require.Len(t, tracker1.added, 2) // unchanged
require.Len(t, tracker2.added, 4)
require.Contains(t, tracker2.added, initialConnID)
require.Contains(t, tracker2.added, clientDestConnID)
require.Contains(t, tracker2.added, tracker1.added[0])
require.Contains(t, tracker2.added, tracker1.added[1])
// adding the same transport again doesn't do anything
trCopy := tr
g.AddConnRunner(trCopy, runner3)
require.Empty(t, tracker3.added)
var connIDToRetire protocol.ConnectionID
var seqToRetire uint64
ncid := queuedFrames[0].(*wire.NewConnectionIDFrame)
connIDToRetire = ncid.ConnectionID
seqToRetire = ncid.SequenceNumber
require.NoError(t, g.Retire(seqToRetire, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), monotime.Now()))
g.RemoveRetiredConnIDs(monotime.Now())
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker1.removed)
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker2.removed)
tracker1.removed = nil
tracker2.removed = nil
g.SetHandshakeComplete(monotime.Now())
g.RemoveRetiredConnIDs(monotime.Now())
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker1.removed)
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker2.removed)
g.ReplaceWithClosed([]byte("connection closed"), time.Second)
require.True(t, len(tracker1.replaced) > 0)
require.Equal(t, tracker1.replaced, tracker2.replaced)
tracker1.removed = nil
tracker2.removed = nil
g.RemoveAll()
require.NotEmpty(t, tracker1.removed)
require.Equal(t, tracker1.removed, tracker2.removed)
}