delete retired connection IDs after 3 PTOs (#5109)

This commit is contained in:
Marten Seemann
2025-05-03 20:24:40 +08:00
committed by GitHub
parent 55229d3f21
commit 97e7657df5
10 changed files with 167 additions and 164 deletions

View File

@@ -2,6 +2,8 @@ package quic
import (
"fmt"
"slices"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
@@ -11,7 +13,6 @@ import (
type connRunnerCallbacks struct {
AddConnectionID func(protocol.ConnectionID)
RemoveConnectionID func(protocol.ConnectionID)
RetireConnectionID func(protocol.ConnectionID)
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
}
@@ -29,24 +30,24 @@ func (cr connRunners) RemoveConnectionID(id protocol.ConnectionID) {
}
}
func (cr connRunners) RetireConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
c.RetireConnectionID(id)
}
}
func (cr connRunners) ReplaceWithClosed(ids []protocol.ConnectionID, b []byte) {
for _, c := range cr {
c.ReplaceWithClosed(ids, b)
}
}
type connIDToRetire struct {
t time.Time
connID protocol.ConnectionID
}
type connIDGenerator struct {
generator ConnectionIDGenerator
highestSeq uint64
connRunners connRunners
activeSrcConnIDs map[uint64]protocol.ConnectionID
connIDsToRetire []connIDToRetire // sorted by t
initialClientDestConnID *protocol.ConnectionID // nil for the client
statelessResetter *statelessResetter
@@ -93,7 +94,7 @@ func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
return nil
}
func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID, expiry time.Time) error {
if seq > m.highestSeq {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
@@ -111,7 +112,8 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
}
}
m.connRunners.RetireConnectionID(connID)
m.queueConnIDForRetiring(connID, expiry)
delete(m.activeSrcConnIDs, seq)
// Don't issue a replacement for the initial connection ID.
if seq == 0 {
@@ -120,6 +122,16 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
return m.issueNewConnID()
}
func (m *connIDGenerator) queueConnIDForRetiring(connID protocol.ConnectionID, expiry time.Time) {
idx := slices.IndexFunc(m.connIDsToRetire, func(c connIDToRetire) bool {
return c.t.After(expiry)
})
if idx == -1 {
idx = len(m.connIDsToRetire)
}
m.connIDsToRetire = slices.Insert(m.connIDsToRetire, idx, connIDToRetire{t: expiry, connID: connID})
}
func (m *connIDGenerator) issueNewConnID() error {
connID, err := m.generator.GenerateConnectionID()
if err != nil {
@@ -136,13 +148,33 @@ func (m *connIDGenerator) issueNewConnID() error {
return nil
}
func (m *connIDGenerator) SetHandshakeComplete() {
func (m *connIDGenerator) SetHandshakeComplete(connIDExpiry time.Time) {
if m.initialClientDestConnID != nil {
m.connRunners.RetireConnectionID(*m.initialClientDestConnID)
m.queueConnIDForRetiring(*m.initialClientDestConnID, connIDExpiry)
m.initialClientDestConnID = nil
}
}
func (m *connIDGenerator) NextRetireTime() time.Time {
if len(m.connIDsToRetire) == 0 {
return time.Time{}
}
return m.connIDsToRetire[0].t
}
func (m *connIDGenerator) RemoveRetiredConnIDs(now time.Time) {
if len(m.connIDsToRetire) == 0 {
return
}
for _, c := range m.connIDsToRetire {
if c.t.After(now) {
break
}
m.connRunners.RemoveConnectionID(c.connID)
m.connIDsToRetire = m.connIDsToRetire[1:]
}
}
func (m *connIDGenerator) RemoveAll() {
if m.initialClientDestConnID != nil {
m.connRunners.RemoveConnectionID(*m.initialClientDestConnID)
@@ -150,16 +182,22 @@ func (m *connIDGenerator) RemoveAll() {
for _, connID := range m.activeSrcConnIDs {
m.connRunners.RemoveConnectionID(connID)
}
for _, c := range m.connIDsToRetire {
m.connRunners.RemoveConnectionID(c.connID)
}
}
func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+len(m.connIDsToRetire)+1)
if m.initialClientDestConnID != nil {
connIDs = append(connIDs, *m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
for _, c := range m.connIDsToRetire {
connIDs = append(connIDs, c.connID)
}
m.connRunners.ReplaceWithClosed(connIDs, connClose)
}

View File

@@ -1,7 +1,9 @@
package quic
import (
"math/rand/v2"
"testing"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
@@ -22,7 +24,7 @@ func TestConnIDGeneratorIssueAndRetire(t *testing.T) {
func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID bool) {
var (
added []protocol.ConnectionID
retired []protocol.ConnectionID
removed []protocol.ConnectionID
)
var queuedFrames []wire.Frame
sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
@@ -38,8 +40,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
sr,
connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") },
RetireConnectionID: func(c protocol.ConnectionID) { retired = append(retired, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) },
ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {},
},
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
@@ -50,7 +51,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
require.NoError(t, g.SetMaxActiveConnIDs(4))
require.Len(t, added, 3)
require.Len(t, queuedFrames, 3)
require.Empty(t, retired)
require.Empty(t, removed)
connIDs := make(map[uint64]protocol.ConnectionID)
// connection IDs 1, 2 and 3 were issued
for i, f := range queuedFrames {
@@ -64,37 +65,97 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
// completing the handshake retires the initial client destination connection ID
added = added[:0]
queuedFrames = queuedFrames[:0]
g.SetHandshakeComplete()
now := time.Now()
g.SetHandshakeComplete(now)
require.Empty(t, added)
require.Empty(t, queuedFrames)
require.Empty(t, removed)
g.RemoveRetiredConnIDs(now)
if hasInitialClientDestConnID {
require.Equal(t, []protocol.ConnectionID{*initialClientDestConnID}, retired)
retired = retired[:0]
require.Equal(t, []protocol.ConnectionID{*initialClientDestConnID}, removed)
removed = removed[:0]
} else {
require.Empty(t, retired)
require.Empty(t, removed)
}
// it's invalid to retire a connection ID that hasn't been issued yet
err := g.Retire(4, protocol.ParseConnectionID([]byte{3, 3, 3, 3}))
err := g.Retire(4, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now())
require.ErrorIs(t, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}, err)
require.ErrorContains(t, err, "retired connection ID 4 (highest issued: 3)")
// it's invalid to retire a connection ID in a packet that uses that connection ID
err = g.Retire(3, connIDs[3])
err = g.Retire(3, connIDs[3], time.Now())
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
require.ErrorContains(t, err, "was used as the Destination Connection ID on this packet")
// retiring a connection ID makes us issue a new one
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3})))
require.Equal(t, []protocol.ConnectionID{connIDs[2]}, retired)
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now()))
g.RemoveRetiredConnIDs(time.Now())
require.Equal(t, []protocol.ConnectionID{connIDs[2]}, removed)
require.Len(t, queuedFrames, 1)
require.EqualValues(t, 4, queuedFrames[0].(*wire.NewConnectionIDFrame).SequenceNumber)
queuedFrames = queuedFrames[:0]
retired = retired[:0]
removed = removed[:0]
// duplicate retirements don't do anything
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3})))
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now()))
g.RemoveRetiredConnIDs(time.Now())
require.Empty(t, queuedFrames)
require.Empty(t, retired)
require.Empty(t, removed)
}
func TestConnIDGeneratorRetiring(t *testing.T) {
initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
var added, removed []protocol.ConnectionID
g := newConnIDGenerator(
1,
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
&initialConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) },
ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {},
},
func(f wire.Frame) {},
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
require.NoError(t, g.SetMaxActiveConnIDs(6))
require.Empty(t, removed)
require.Len(t, added, 5)
now := time.Now()
retirements := map[protocol.ConnectionID]time.Time{}
t1 := now.Add(time.Duration(rand.IntN(1000)) * time.Millisecond)
retirements[initialConnID] = t1
g.SetHandshakeComplete(t1)
for i := range 5 {
t2 := now.Add(time.Duration(rand.IntN(1000)) * time.Millisecond)
require.NoError(t, g.Retire(uint64(i+1), protocol.ParseConnectionID([]byte{9, 9, 9, 9}), t2))
retirements[added[i]] = t2
var nextRetirement time.Time
for _, r := range retirements {
if nextRetirement.IsZero() || r.Before(nextRetirement) {
nextRetirement = r
}
}
require.Equal(t, nextRetirement, g.NextRetireTime())
if rand.IntN(2) == 0 {
now = now.Add(time.Duration(rand.IntN(500)) * time.Millisecond)
g.RemoveRetiredConnIDs(now)
for _, r := range removed {
require.Contains(t, retirements, r)
require.LessOrEqual(t, retirements[r], now)
delete(retirements, r)
}
removed = removed[:0]
for _, r := range retirements {
require.Greater(t, r, now)
}
}
}
}
func TestConnIDGeneratorRemoveAll(t *testing.T) {
@@ -124,7 +185,6 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool)
connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) },
RetireConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") },
ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {},
},
func(f wire.Frame) {},
@@ -175,7 +235,6 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") },
RetireConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") },
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, b []byte) {
replaced = connIDs
replacedWith = b
@@ -187,13 +246,18 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
require.NoError(t, g.SetMaxActiveConnIDs(1000))
require.Len(t, added, protocol.MaxIssuedConnectionIDs-1)
// Retire two of these connection ID.
// This makes us issue two more connection IDs.
require.NoError(t, g.Retire(3, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), time.Now()))
require.NoError(t, g.Retire(4, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), time.Now()))
require.Len(t, added, protocol.MaxIssuedConnectionIDs+1)
g.ReplaceWithClosed([]byte("foobar"))
if hasInitialClientDestConnID {
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+1)
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+3)
require.Contains(t, replaced, *initialClientDestConnID)
} else {
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs)
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+2)
}
for _, id := range added {
require.Contains(t, replaced, id)
@@ -207,14 +271,13 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
clientDestConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
type connIDTracker struct {
added, removed, retired, replaced []protocol.ConnectionID
added, removed, replaced []protocol.ConnectionID
}
var tracker1, tracker2 connIDTracker
runner1 := connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) },
RetireConnectionID: func(c protocol.ConnectionID) { tracker1.retired = append(tracker1.retired, c) },
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) {
tracker1.replaced = append(tracker1.replaced, connIDs...)
},
@@ -222,7 +285,6 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
runner2 := connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { tracker2.added = append(tracker2.added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { tracker2.removed = append(tracker2.removed, c) },
RetireConnectionID: func(c protocol.ConnectionID) { tracker2.retired = append(tracker2.retired, c) },
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) {
tracker2.replaced = append(tracker2.replaced, connIDs...)
},
@@ -258,17 +320,17 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
connIDToRetire = ncid.ConnectionID
seqToRetire = ncid.SequenceNumber
tracker1.retired = nil
tracker2.retired = nil
require.NoError(t, g.Retire(seqToRetire, protocol.ParseConnectionID([]byte{3, 3, 3, 3})))
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker1.retired)
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker2.retired)
require.NoError(t, g.Retire(seqToRetire, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now()))
g.RemoveRetiredConnIDs(time.Now())
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker1.removed)
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker2.removed)
tracker1.retired = nil
tracker2.retired = nil
g.SetHandshakeComplete()
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker1.retired)
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker2.retired)
tracker1.removed = nil
tracker2.removed = nil
g.SetHandshakeComplete(time.Now())
g.RemoveRetiredConnIDs(time.Now())
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker1.removed)
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker2.removed)
g.ReplaceWithClosed([]byte("connection closed"))
require.True(t, len(tracker1.replaced) > 0)

View File

@@ -86,7 +86,6 @@ func (p *receivedPacket) Clone() *receivedPacket {
type connRunner interface {
Add(protocol.ConnectionID, packetHandler) bool
Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID)
ReplaceWithClosed([]protocol.ConnectionID, []byte)
AddResetToken(protocol.StatelessResetToken, packetHandler)
@@ -277,7 +276,6 @@ var newConnection = func(
connRunnerCallbacks{
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
RemoveConnectionID: runner.Remove,
RetireConnectionID: runner.Retire,
ReplaceWithClosed: runner.ReplaceWithClosed,
},
s.queueControlFrame,
@@ -392,7 +390,6 @@ var newClientConnection = func(
connRunnerCallbacks{
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
RemoveConnectionID: runner.Remove,
RetireConnectionID: runner.Retire,
ReplaceWithClosed: runner.ReplaceWithClosed,
},
s.queueControlFrame,
@@ -652,6 +649,8 @@ runLoop:
}
}
s.connIDGenerator.RemoveRetiredConnIDs(now)
if s.perspective == protocol.PerspectiveClient {
pm := s.pathManagerOutgoing.Load()
if pm != nil {
@@ -762,6 +761,7 @@ func (s *connection) maybeResetTimer() {
s.timer.SetTimer(
deadline,
s.connIDGenerator.NextRetireTime(),
s.receivedPacketHandler.GetAlarmTimeout(),
s.sentPacketHandler.GetLossDetectionTimeout(),
s.pacingDeadline,
@@ -801,7 +801,7 @@ func (s *connection) handleHandshakeComplete(now time.Time) error {
s.undecryptablePackets = nil
s.connIDManager.SetHandshakeComplete()
s.connIDGenerator.SetHandshakeComplete()
s.connIDGenerator.SetHandshakeComplete(now.Add(3 * s.rttStats.PTO(false)))
if s.tracer != nil && s.tracer.ChoseALPN != nil {
s.tracer.ChoseALPN(s.cryptoStreamHandler.ConnectionState().NegotiatedProtocol)
@@ -1532,7 +1532,7 @@ func (s *connection) handleFrame(
case *wire.NewConnectionIDFrame:
err = s.handleNewConnectionIDFrame(frame)
case *wire.RetireConnectionIDFrame:
err = s.handleRetireConnectionIDFrame(frame, destConnID)
err = s.handleRetireConnectionIDFrame(rcvTime, frame, destConnID)
case *wire.HandshakeDoneFrame:
err = s.handleHandshakeDoneFrame(rcvTime)
case *wire.DatagramFrame:
@@ -1751,8 +1751,8 @@ func (s *connection) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) er
return s.connIDManager.Add(f)
}
func (s *connection) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error {
return s.connIDGenerator.Retire(f.SequenceNumber, destConnID)
func (s *connection) handleRetireConnectionIDFrame(now time.Time, f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error {
return s.connIDGenerator.Retire(f.SequenceNumber, destConnID, now.Add(3*s.rttStats.PTO(false)))
}
func (s *connection) handleHandshakeDoneFrame(rcvTime time.Time) error {
@@ -2656,7 +2656,6 @@ func (s *connection) AddPath(t *Transport) (*Path, error) {
connRunnerCallbacks{
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
RemoveConnectionID: runner.Remove,
RetireConnectionID: runner.Retire,
ReplaceWithClosed: runner.ReplaceWithClosed,
},
)

View File

@@ -1163,7 +1163,6 @@ func TestConnectionHandshakeServer(t *testing.T) {
require.NoError(t, err)
cs.EXPECT().DiscardInitialKeys()
tc.connRunner.EXPECT().Retire(gomock.Any())
gomock.InOrder(
cs.EXPECT().StartHandshake(gomock.Any()),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),

View File

@@ -32,8 +32,11 @@ func (t *connectionTimer) Chan() <-chan time.Time {
// It makes sure that the deadline is strictly increasing.
// This prevents busy-looping in cases where the timer fires, but we can't actually send out a packet.
// This doesn't apply to the pacing deadline, which can be set multiple times to deadlineSendImmediately.
func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, ackAlarm, lossTime, pacing time.Time) {
func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, connIDRetirement, ackAlarm, lossTime, pacing time.Time) {
deadline := idleTimeoutOrKeepAlive
if !connIDRetirement.IsZero() && connIDRetirement.Before(deadline) && connIDRetirement.After(t.last) {
deadline = connIDRetirement
}
if !ackAlarm.IsZero() && ackAlarm.Before(deadline) && ackAlarm.After(t.last) {
deadline = ackAlarm
}

View File

@@ -14,25 +14,31 @@ func TestConnectionTimerModes(t *testing.T) {
t.Run("idle timeout", func(t *testing.T) {
timer := newTimer()
timer.SetTimer(now.Add(time.Hour), time.Time{}, time.Time{}, time.Time{})
timer.SetTimer(now.Add(time.Hour), time.Time{}, time.Time{}, time.Time{}, time.Time{})
require.Equal(t, now.Add(time.Hour), timer.Deadline())
})
t.Run("connection ID expiry", func(t *testing.T) {
timer := newTimer()
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{}, time.Time{})
require.Equal(t, now.Add(time.Minute), timer.Deadline())
})
t.Run("ACK timer", func(t *testing.T) {
timer := newTimer()
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{})
timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), time.Time{}, time.Time{})
require.Equal(t, now.Add(time.Minute), timer.Deadline())
})
t.Run("loss timer", func(t *testing.T) {
timer := newTimer()
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), now.Add(time.Second), time.Time{})
timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), now.Add(time.Second), time.Time{})
require.Equal(t, now.Add(time.Second), timer.Deadline())
})
t.Run("pacing timer", func(t *testing.T) {
timer := newTimer()
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), now.Add(time.Second), now.Add(time.Millisecond))
timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), now.Add(time.Second), now.Add(time.Millisecond))
require.Equal(t, now.Add(time.Millisecond), timer.Deadline())
})
}
@@ -40,10 +46,10 @@ func TestConnectionTimerModes(t *testing.T) {
func TestConnectionTimerReset(t *testing.T) {
now := time.Now()
timer := newTimer()
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{})
timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), time.Time{}, time.Time{})
require.Equal(t, now.Add(time.Minute), timer.Deadline())
timer.SetRead()
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{})
timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), time.Time{}, time.Time{})
require.Equal(t, now.Add(time.Hour), timer.Deadline())
}

View File

@@ -221,39 +221,3 @@ func (c *MockConnRunnerReplaceWithClosedCall) DoAndReturn(f func([]protocol.Conn
c.Call = c.Call.DoAndReturn(f)
return c
}
// Retire mocks base method.
func (m *MockConnRunner) Retire(arg0 protocol.ConnectionID) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Retire", arg0)
}
// Retire indicates an expected call of Retire.
func (mr *MockConnRunnerMockRecorder) Retire(arg0 any) *MockConnRunnerRetireCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockConnRunner)(nil).Retire), arg0)
return &MockConnRunnerRetireCall{Call: call}
}
// MockConnRunnerRetireCall wrap *gomock.Call
type MockConnRunnerRetireCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockConnRunnerRetireCall) Return() *MockConnRunnerRetireCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockConnRunnerRetireCall) Do(f func(protocol.ConnectionID)) *MockConnRunnerRetireCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockConnRunnerRetireCall) DoAndReturn(f func(protocol.ConnectionID)) *MockConnRunnerRetireCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

View File

@@ -373,39 +373,3 @@ func (c *MockPacketHandlerManagerReplaceWithClosedCall) DoAndReturn(f func([]pro
c.Call = c.Call.DoAndReturn(f)
return c
}
// Retire mocks base method.
func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Retire", arg0)
}
// Retire indicates an expected call of Retire.
func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 any) *MockPacketHandlerManagerRetireCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0)
return &MockPacketHandlerManagerRetireCall{Call: call}
}
// MockPacketHandlerManagerRetireCall wrap *gomock.Call
type MockPacketHandlerManagerRetireCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPacketHandlerManagerRetireCall) Return() *MockPacketHandlerManagerRetireCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPacketHandlerManagerRetireCall) Do(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRetireCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPacketHandlerManagerRetireCall) DoAndReturn(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRetireCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

View File

@@ -108,16 +108,6 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
h.logger.Debugf("Removing connection ID %s.", id)
}
func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter)
time.AfterFunc(h.deleteRetiredConnsAfter, func() {
h.mutex.Lock()
delete(h.handlers, id)
h.mutex.Unlock()
h.logger.Debugf("Removing connection ID %s after it has been retired.", id)
})
}
// ReplaceWithClosed is called when a connection is closed.
// Depending on which side closed the connection, we need to:
// * remote close: absorb delayed packets

View File

@@ -54,28 +54,6 @@ func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) {
require.Equal(t, h, got)
}
func TestPacketHandlerMapRetire(t *testing.T) {
m := newPacketHandlerMap(nil, utils.DefaultLogger)
dur := scaleDuration(10 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
h := &mockPacketHandler{}
require.True(t, m.Add(connID, h))
m.Retire(connID)
// immediately after retiring, the handler should still be there
got, ok := m.Get(connID)
require.True(t, ok)
require.Equal(t, h, got)
// after the timeout, the handler should be removed
time.Sleep(dur)
require.Eventually(t, func() bool {
_, ok := m.Get(connID)
return !ok
}, dur, dur/10)
}
func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) {
m := newPacketHandlerMap(nil, utils.DefaultLogger)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}