forked from quic-go/quic-go
delete retired connection IDs after 3 PTOs (#5109)
This commit is contained in:
@@ -2,6 +2,8 @@ package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
@@ -11,7 +13,6 @@ import (
|
||||
type connRunnerCallbacks struct {
|
||||
AddConnectionID func(protocol.ConnectionID)
|
||||
RemoveConnectionID func(protocol.ConnectionID)
|
||||
RetireConnectionID func(protocol.ConnectionID)
|
||||
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
|
||||
}
|
||||
|
||||
@@ -29,24 +30,24 @@ func (cr connRunners) RemoveConnectionID(id protocol.ConnectionID) {
|
||||
}
|
||||
}
|
||||
|
||||
func (cr connRunners) RetireConnectionID(id protocol.ConnectionID) {
|
||||
for _, c := range cr {
|
||||
c.RetireConnectionID(id)
|
||||
}
|
||||
}
|
||||
|
||||
func (cr connRunners) ReplaceWithClosed(ids []protocol.ConnectionID, b []byte) {
|
||||
for _, c := range cr {
|
||||
c.ReplaceWithClosed(ids, b)
|
||||
}
|
||||
}
|
||||
|
||||
type connIDToRetire struct {
|
||||
t time.Time
|
||||
connID protocol.ConnectionID
|
||||
}
|
||||
|
||||
type connIDGenerator struct {
|
||||
generator ConnectionIDGenerator
|
||||
highestSeq uint64
|
||||
connRunners connRunners
|
||||
|
||||
activeSrcConnIDs map[uint64]protocol.ConnectionID
|
||||
connIDsToRetire []connIDToRetire // sorted by t
|
||||
initialClientDestConnID *protocol.ConnectionID // nil for the client
|
||||
|
||||
statelessResetter *statelessResetter
|
||||
@@ -93,7 +94,7 @@ func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
|
||||
func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID, expiry time.Time) error {
|
||||
if seq > m.highestSeq {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
@@ -111,7 +112,8 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
|
||||
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
|
||||
}
|
||||
}
|
||||
m.connRunners.RetireConnectionID(connID)
|
||||
m.queueConnIDForRetiring(connID, expiry)
|
||||
|
||||
delete(m.activeSrcConnIDs, seq)
|
||||
// Don't issue a replacement for the initial connection ID.
|
||||
if seq == 0 {
|
||||
@@ -120,6 +122,16 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
|
||||
return m.issueNewConnID()
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) queueConnIDForRetiring(connID protocol.ConnectionID, expiry time.Time) {
|
||||
idx := slices.IndexFunc(m.connIDsToRetire, func(c connIDToRetire) bool {
|
||||
return c.t.After(expiry)
|
||||
})
|
||||
if idx == -1 {
|
||||
idx = len(m.connIDsToRetire)
|
||||
}
|
||||
m.connIDsToRetire = slices.Insert(m.connIDsToRetire, idx, connIDToRetire{t: expiry, connID: connID})
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) issueNewConnID() error {
|
||||
connID, err := m.generator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
@@ -136,13 +148,33 @@ func (m *connIDGenerator) issueNewConnID() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) SetHandshakeComplete() {
|
||||
func (m *connIDGenerator) SetHandshakeComplete(connIDExpiry time.Time) {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.connRunners.RetireConnectionID(*m.initialClientDestConnID)
|
||||
m.queueConnIDForRetiring(*m.initialClientDestConnID, connIDExpiry)
|
||||
m.initialClientDestConnID = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) NextRetireTime() time.Time {
|
||||
if len(m.connIDsToRetire) == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
return m.connIDsToRetire[0].t
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) RemoveRetiredConnIDs(now time.Time) {
|
||||
if len(m.connIDsToRetire) == 0 {
|
||||
return
|
||||
}
|
||||
for _, c := range m.connIDsToRetire {
|
||||
if c.t.After(now) {
|
||||
break
|
||||
}
|
||||
m.connRunners.RemoveConnectionID(c.connID)
|
||||
m.connIDsToRetire = m.connIDsToRetire[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) RemoveAll() {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.connRunners.RemoveConnectionID(*m.initialClientDestConnID)
|
||||
@@ -150,16 +182,22 @@ func (m *connIDGenerator) RemoveAll() {
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
m.connRunners.RemoveConnectionID(connID)
|
||||
}
|
||||
for _, c := range m.connIDsToRetire {
|
||||
m.connRunners.RemoveConnectionID(c.connID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
|
||||
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
|
||||
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+len(m.connIDsToRetire)+1)
|
||||
if m.initialClientDestConnID != nil {
|
||||
connIDs = append(connIDs, *m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
connIDs = append(connIDs, connID)
|
||||
}
|
||||
for _, c := range m.connIDsToRetire {
|
||||
connIDs = append(connIDs, c.connID)
|
||||
}
|
||||
m.connRunners.ReplaceWithClosed(connIDs, connClose)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
@@ -22,7 +24,7 @@ func TestConnIDGeneratorIssueAndRetire(t *testing.T) {
|
||||
func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID bool) {
|
||||
var (
|
||||
added []protocol.ConnectionID
|
||||
retired []protocol.ConnectionID
|
||||
removed []protocol.ConnectionID
|
||||
)
|
||||
var queuedFrames []wire.Frame
|
||||
sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
|
||||
@@ -38,8 +40,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
|
||||
sr,
|
||||
connRunnerCallbacks{
|
||||
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
|
||||
RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") },
|
||||
RetireConnectionID: func(c protocol.ConnectionID) { retired = append(retired, c) },
|
||||
RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) },
|
||||
ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {},
|
||||
},
|
||||
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
|
||||
@@ -50,7 +51,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
|
||||
require.NoError(t, g.SetMaxActiveConnIDs(4))
|
||||
require.Len(t, added, 3)
|
||||
require.Len(t, queuedFrames, 3)
|
||||
require.Empty(t, retired)
|
||||
require.Empty(t, removed)
|
||||
connIDs := make(map[uint64]protocol.ConnectionID)
|
||||
// connection IDs 1, 2 and 3 were issued
|
||||
for i, f := range queuedFrames {
|
||||
@@ -64,37 +65,97 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
|
||||
// completing the handshake retires the initial client destination connection ID
|
||||
added = added[:0]
|
||||
queuedFrames = queuedFrames[:0]
|
||||
g.SetHandshakeComplete()
|
||||
now := time.Now()
|
||||
g.SetHandshakeComplete(now)
|
||||
require.Empty(t, added)
|
||||
require.Empty(t, queuedFrames)
|
||||
require.Empty(t, removed)
|
||||
g.RemoveRetiredConnIDs(now)
|
||||
if hasInitialClientDestConnID {
|
||||
require.Equal(t, []protocol.ConnectionID{*initialClientDestConnID}, retired)
|
||||
retired = retired[:0]
|
||||
require.Equal(t, []protocol.ConnectionID{*initialClientDestConnID}, removed)
|
||||
removed = removed[:0]
|
||||
} else {
|
||||
require.Empty(t, retired)
|
||||
require.Empty(t, removed)
|
||||
}
|
||||
|
||||
// it's invalid to retire a connection ID that hasn't been issued yet
|
||||
err := g.Retire(4, protocol.ParseConnectionID([]byte{3, 3, 3, 3}))
|
||||
err := g.Retire(4, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now())
|
||||
require.ErrorIs(t, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}, err)
|
||||
require.ErrorContains(t, err, "retired connection ID 4 (highest issued: 3)")
|
||||
// it's invalid to retire a connection ID in a packet that uses that connection ID
|
||||
err = g.Retire(3, connIDs[3])
|
||||
err = g.Retire(3, connIDs[3], time.Now())
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
|
||||
require.ErrorContains(t, err, "was used as the Destination Connection ID on this packet")
|
||||
|
||||
// retiring a connection ID makes us issue a new one
|
||||
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3})))
|
||||
require.Equal(t, []protocol.ConnectionID{connIDs[2]}, retired)
|
||||
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now()))
|
||||
g.RemoveRetiredConnIDs(time.Now())
|
||||
require.Equal(t, []protocol.ConnectionID{connIDs[2]}, removed)
|
||||
require.Len(t, queuedFrames, 1)
|
||||
require.EqualValues(t, 4, queuedFrames[0].(*wire.NewConnectionIDFrame).SequenceNumber)
|
||||
queuedFrames = queuedFrames[:0]
|
||||
retired = retired[:0]
|
||||
removed = removed[:0]
|
||||
|
||||
// duplicate retirements don't do anything
|
||||
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3})))
|
||||
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now()))
|
||||
g.RemoveRetiredConnIDs(time.Now())
|
||||
require.Empty(t, queuedFrames)
|
||||
require.Empty(t, retired)
|
||||
require.Empty(t, removed)
|
||||
}
|
||||
|
||||
func TestConnIDGeneratorRetiring(t *testing.T) {
|
||||
initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
|
||||
var added, removed []protocol.ConnectionID
|
||||
g := newConnIDGenerator(
|
||||
1,
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
&initialConnID,
|
||||
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
|
||||
connRunnerCallbacks{
|
||||
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
|
||||
RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) },
|
||||
ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {},
|
||||
},
|
||||
func(f wire.Frame) {},
|
||||
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
|
||||
)
|
||||
require.NoError(t, g.SetMaxActiveConnIDs(6))
|
||||
require.Empty(t, removed)
|
||||
require.Len(t, added, 5)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
retirements := map[protocol.ConnectionID]time.Time{}
|
||||
t1 := now.Add(time.Duration(rand.IntN(1000)) * time.Millisecond)
|
||||
retirements[initialConnID] = t1
|
||||
g.SetHandshakeComplete(t1)
|
||||
for i := range 5 {
|
||||
t2 := now.Add(time.Duration(rand.IntN(1000)) * time.Millisecond)
|
||||
require.NoError(t, g.Retire(uint64(i+1), protocol.ParseConnectionID([]byte{9, 9, 9, 9}), t2))
|
||||
retirements[added[i]] = t2
|
||||
|
||||
var nextRetirement time.Time
|
||||
for _, r := range retirements {
|
||||
if nextRetirement.IsZero() || r.Before(nextRetirement) {
|
||||
nextRetirement = r
|
||||
}
|
||||
}
|
||||
require.Equal(t, nextRetirement, g.NextRetireTime())
|
||||
|
||||
if rand.IntN(2) == 0 {
|
||||
now = now.Add(time.Duration(rand.IntN(500)) * time.Millisecond)
|
||||
g.RemoveRetiredConnIDs(now)
|
||||
for _, r := range removed {
|
||||
require.Contains(t, retirements, r)
|
||||
require.LessOrEqual(t, retirements[r], now)
|
||||
delete(retirements, r)
|
||||
}
|
||||
removed = removed[:0]
|
||||
for _, r := range retirements {
|
||||
require.Greater(t, r, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnIDGeneratorRemoveAll(t *testing.T) {
|
||||
@@ -124,7 +185,6 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool)
|
||||
connRunnerCallbacks{
|
||||
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
|
||||
RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) },
|
||||
RetireConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") },
|
||||
ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {},
|
||||
},
|
||||
func(f wire.Frame) {},
|
||||
@@ -175,7 +235,6 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
|
||||
connRunnerCallbacks{
|
||||
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
|
||||
RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") },
|
||||
RetireConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") },
|
||||
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, b []byte) {
|
||||
replaced = connIDs
|
||||
replacedWith = b
|
||||
@@ -187,13 +246,18 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
|
||||
|
||||
require.NoError(t, g.SetMaxActiveConnIDs(1000))
|
||||
require.Len(t, added, protocol.MaxIssuedConnectionIDs-1)
|
||||
// Retire two of these connection ID.
|
||||
// This makes us issue two more connection IDs.
|
||||
require.NoError(t, g.Retire(3, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), time.Now()))
|
||||
require.NoError(t, g.Retire(4, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), time.Now()))
|
||||
require.Len(t, added, protocol.MaxIssuedConnectionIDs+1)
|
||||
|
||||
g.ReplaceWithClosed([]byte("foobar"))
|
||||
if hasInitialClientDestConnID {
|
||||
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+1)
|
||||
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+3)
|
||||
require.Contains(t, replaced, *initialClientDestConnID)
|
||||
} else {
|
||||
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs)
|
||||
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+2)
|
||||
}
|
||||
for _, id := range added {
|
||||
require.Contains(t, replaced, id)
|
||||
@@ -207,14 +271,13 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
clientDestConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
|
||||
|
||||
type connIDTracker struct {
|
||||
added, removed, retired, replaced []protocol.ConnectionID
|
||||
added, removed, replaced []protocol.ConnectionID
|
||||
}
|
||||
|
||||
var tracker1, tracker2 connIDTracker
|
||||
runner1 := connRunnerCallbacks{
|
||||
AddConnectionID: func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) },
|
||||
RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) },
|
||||
RetireConnectionID: func(c protocol.ConnectionID) { tracker1.retired = append(tracker1.retired, c) },
|
||||
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) {
|
||||
tracker1.replaced = append(tracker1.replaced, connIDs...)
|
||||
},
|
||||
@@ -222,7 +285,6 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
runner2 := connRunnerCallbacks{
|
||||
AddConnectionID: func(c protocol.ConnectionID) { tracker2.added = append(tracker2.added, c) },
|
||||
RemoveConnectionID: func(c protocol.ConnectionID) { tracker2.removed = append(tracker2.removed, c) },
|
||||
RetireConnectionID: func(c protocol.ConnectionID) { tracker2.retired = append(tracker2.retired, c) },
|
||||
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) {
|
||||
tracker2.replaced = append(tracker2.replaced, connIDs...)
|
||||
},
|
||||
@@ -258,17 +320,17 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
connIDToRetire = ncid.ConnectionID
|
||||
seqToRetire = ncid.SequenceNumber
|
||||
|
||||
tracker1.retired = nil
|
||||
tracker2.retired = nil
|
||||
require.NoError(t, g.Retire(seqToRetire, protocol.ParseConnectionID([]byte{3, 3, 3, 3})))
|
||||
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker1.retired)
|
||||
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker2.retired)
|
||||
require.NoError(t, g.Retire(seqToRetire, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now()))
|
||||
g.RemoveRetiredConnIDs(time.Now())
|
||||
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker1.removed)
|
||||
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker2.removed)
|
||||
|
||||
tracker1.retired = nil
|
||||
tracker2.retired = nil
|
||||
g.SetHandshakeComplete()
|
||||
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker1.retired)
|
||||
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker2.retired)
|
||||
tracker1.removed = nil
|
||||
tracker2.removed = nil
|
||||
g.SetHandshakeComplete(time.Now())
|
||||
g.RemoveRetiredConnIDs(time.Now())
|
||||
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker1.removed)
|
||||
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker2.removed)
|
||||
|
||||
g.ReplaceWithClosed([]byte("connection closed"))
|
||||
require.True(t, len(tracker1.replaced) > 0)
|
||||
|
||||
@@ -86,7 +86,6 @@ func (p *receivedPacket) Clone() *receivedPacket {
|
||||
|
||||
type connRunner interface {
|
||||
Add(protocol.ConnectionID, packetHandler) bool
|
||||
Retire(protocol.ConnectionID)
|
||||
Remove(protocol.ConnectionID)
|
||||
ReplaceWithClosed([]protocol.ConnectionID, []byte)
|
||||
AddResetToken(protocol.StatelessResetToken, packetHandler)
|
||||
@@ -277,7 +276,6 @@ var newConnection = func(
|
||||
connRunnerCallbacks{
|
||||
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
|
||||
RemoveConnectionID: runner.Remove,
|
||||
RetireConnectionID: runner.Retire,
|
||||
ReplaceWithClosed: runner.ReplaceWithClosed,
|
||||
},
|
||||
s.queueControlFrame,
|
||||
@@ -392,7 +390,6 @@ var newClientConnection = func(
|
||||
connRunnerCallbacks{
|
||||
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
|
||||
RemoveConnectionID: runner.Remove,
|
||||
RetireConnectionID: runner.Retire,
|
||||
ReplaceWithClosed: runner.ReplaceWithClosed,
|
||||
},
|
||||
s.queueControlFrame,
|
||||
@@ -652,6 +649,8 @@ runLoop:
|
||||
}
|
||||
}
|
||||
|
||||
s.connIDGenerator.RemoveRetiredConnIDs(now)
|
||||
|
||||
if s.perspective == protocol.PerspectiveClient {
|
||||
pm := s.pathManagerOutgoing.Load()
|
||||
if pm != nil {
|
||||
@@ -762,6 +761,7 @@ func (s *connection) maybeResetTimer() {
|
||||
|
||||
s.timer.SetTimer(
|
||||
deadline,
|
||||
s.connIDGenerator.NextRetireTime(),
|
||||
s.receivedPacketHandler.GetAlarmTimeout(),
|
||||
s.sentPacketHandler.GetLossDetectionTimeout(),
|
||||
s.pacingDeadline,
|
||||
@@ -801,7 +801,7 @@ func (s *connection) handleHandshakeComplete(now time.Time) error {
|
||||
s.undecryptablePackets = nil
|
||||
|
||||
s.connIDManager.SetHandshakeComplete()
|
||||
s.connIDGenerator.SetHandshakeComplete()
|
||||
s.connIDGenerator.SetHandshakeComplete(now.Add(3 * s.rttStats.PTO(false)))
|
||||
|
||||
if s.tracer != nil && s.tracer.ChoseALPN != nil {
|
||||
s.tracer.ChoseALPN(s.cryptoStreamHandler.ConnectionState().NegotiatedProtocol)
|
||||
@@ -1532,7 +1532,7 @@ func (s *connection) handleFrame(
|
||||
case *wire.NewConnectionIDFrame:
|
||||
err = s.handleNewConnectionIDFrame(frame)
|
||||
case *wire.RetireConnectionIDFrame:
|
||||
err = s.handleRetireConnectionIDFrame(frame, destConnID)
|
||||
err = s.handleRetireConnectionIDFrame(rcvTime, frame, destConnID)
|
||||
case *wire.HandshakeDoneFrame:
|
||||
err = s.handleHandshakeDoneFrame(rcvTime)
|
||||
case *wire.DatagramFrame:
|
||||
@@ -1751,8 +1751,8 @@ func (s *connection) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) er
|
||||
return s.connIDManager.Add(f)
|
||||
}
|
||||
|
||||
func (s *connection) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error {
|
||||
return s.connIDGenerator.Retire(f.SequenceNumber, destConnID)
|
||||
func (s *connection) handleRetireConnectionIDFrame(now time.Time, f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error {
|
||||
return s.connIDGenerator.Retire(f.SequenceNumber, destConnID, now.Add(3*s.rttStats.PTO(false)))
|
||||
}
|
||||
|
||||
func (s *connection) handleHandshakeDoneFrame(rcvTime time.Time) error {
|
||||
@@ -2656,7 +2656,6 @@ func (s *connection) AddPath(t *Transport) (*Path, error) {
|
||||
connRunnerCallbacks{
|
||||
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
|
||||
RemoveConnectionID: runner.Remove,
|
||||
RetireConnectionID: runner.Retire,
|
||||
ReplaceWithClosed: runner.ReplaceWithClosed,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1163,7 +1163,6 @@ func TestConnectionHandshakeServer(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
cs.EXPECT().DiscardInitialKeys()
|
||||
tc.connRunner.EXPECT().Retire(gomock.Any())
|
||||
gomock.InOrder(
|
||||
cs.EXPECT().StartHandshake(gomock.Any()),
|
||||
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
||||
|
||||
@@ -32,8 +32,11 @@ func (t *connectionTimer) Chan() <-chan time.Time {
|
||||
// It makes sure that the deadline is strictly increasing.
|
||||
// This prevents busy-looping in cases where the timer fires, but we can't actually send out a packet.
|
||||
// This doesn't apply to the pacing deadline, which can be set multiple times to deadlineSendImmediately.
|
||||
func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, ackAlarm, lossTime, pacing time.Time) {
|
||||
func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, connIDRetirement, ackAlarm, lossTime, pacing time.Time) {
|
||||
deadline := idleTimeoutOrKeepAlive
|
||||
if !connIDRetirement.IsZero() && connIDRetirement.Before(deadline) && connIDRetirement.After(t.last) {
|
||||
deadline = connIDRetirement
|
||||
}
|
||||
if !ackAlarm.IsZero() && ackAlarm.Before(deadline) && ackAlarm.After(t.last) {
|
||||
deadline = ackAlarm
|
||||
}
|
||||
|
||||
@@ -14,25 +14,31 @@ func TestConnectionTimerModes(t *testing.T) {
|
||||
|
||||
t.Run("idle timeout", func(t *testing.T) {
|
||||
timer := newTimer()
|
||||
timer.SetTimer(now.Add(time.Hour), time.Time{}, time.Time{}, time.Time{})
|
||||
timer.SetTimer(now.Add(time.Hour), time.Time{}, time.Time{}, time.Time{}, time.Time{})
|
||||
require.Equal(t, now.Add(time.Hour), timer.Deadline())
|
||||
})
|
||||
|
||||
t.Run("connection ID expiry", func(t *testing.T) {
|
||||
timer := newTimer()
|
||||
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{}, time.Time{})
|
||||
require.Equal(t, now.Add(time.Minute), timer.Deadline())
|
||||
})
|
||||
|
||||
t.Run("ACK timer", func(t *testing.T) {
|
||||
timer := newTimer()
|
||||
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{})
|
||||
timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), time.Time{}, time.Time{})
|
||||
require.Equal(t, now.Add(time.Minute), timer.Deadline())
|
||||
})
|
||||
|
||||
t.Run("loss timer", func(t *testing.T) {
|
||||
timer := newTimer()
|
||||
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), now.Add(time.Second), time.Time{})
|
||||
timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), now.Add(time.Second), time.Time{})
|
||||
require.Equal(t, now.Add(time.Second), timer.Deadline())
|
||||
})
|
||||
|
||||
t.Run("pacing timer", func(t *testing.T) {
|
||||
timer := newTimer()
|
||||
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), now.Add(time.Second), now.Add(time.Millisecond))
|
||||
timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), now.Add(time.Second), now.Add(time.Millisecond))
|
||||
require.Equal(t, now.Add(time.Millisecond), timer.Deadline())
|
||||
})
|
||||
}
|
||||
@@ -40,10 +46,10 @@ func TestConnectionTimerModes(t *testing.T) {
|
||||
func TestConnectionTimerReset(t *testing.T) {
|
||||
now := time.Now()
|
||||
timer := newTimer()
|
||||
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{})
|
||||
timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), time.Time{}, time.Time{})
|
||||
require.Equal(t, now.Add(time.Minute), timer.Deadline())
|
||||
timer.SetRead()
|
||||
|
||||
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{})
|
||||
timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), time.Time{}, time.Time{})
|
||||
require.Equal(t, now.Add(time.Hour), timer.Deadline())
|
||||
}
|
||||
|
||||
@@ -221,39 +221,3 @@ func (c *MockConnRunnerReplaceWithClosedCall) DoAndReturn(f func([]protocol.Conn
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Retire mocks base method.
|
||||
func (m *MockConnRunner) Retire(arg0 protocol.ConnectionID) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Retire", arg0)
|
||||
}
|
||||
|
||||
// Retire indicates an expected call of Retire.
|
||||
func (mr *MockConnRunnerMockRecorder) Retire(arg0 any) *MockConnRunnerRetireCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockConnRunner)(nil).Retire), arg0)
|
||||
return &MockConnRunnerRetireCall{Call: call}
|
||||
}
|
||||
|
||||
// MockConnRunnerRetireCall wrap *gomock.Call
|
||||
type MockConnRunnerRetireCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockConnRunnerRetireCall) Return() *MockConnRunnerRetireCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockConnRunnerRetireCall) Do(f func(protocol.ConnectionID)) *MockConnRunnerRetireCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockConnRunnerRetireCall) DoAndReturn(f func(protocol.ConnectionID)) *MockConnRunnerRetireCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -373,39 +373,3 @@ func (c *MockPacketHandlerManagerReplaceWithClosedCall) DoAndReturn(f func([]pro
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Retire mocks base method.
|
||||
func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Retire", arg0)
|
||||
}
|
||||
|
||||
// Retire indicates an expected call of Retire.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 any) *MockPacketHandlerManagerRetireCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0)
|
||||
return &MockPacketHandlerManagerRetireCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerRetireCall wrap *gomock.Call
|
||||
type MockPacketHandlerManagerRetireCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPacketHandlerManagerRetireCall) Return() *MockPacketHandlerManagerRetireCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPacketHandlerManagerRetireCall) Do(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRetireCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPacketHandlerManagerRetireCall) DoAndReturn(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRetireCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -108,16 +108,6 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
||||
h.logger.Debugf("Removing connection ID %s.", id)
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
|
||||
h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter)
|
||||
time.AfterFunc(h.deleteRetiredConnsAfter, func() {
|
||||
h.mutex.Lock()
|
||||
delete(h.handlers, id)
|
||||
h.mutex.Unlock()
|
||||
h.logger.Debugf("Removing connection ID %s after it has been retired.", id)
|
||||
})
|
||||
}
|
||||
|
||||
// ReplaceWithClosed is called when a connection is closed.
|
||||
// Depending on which side closed the connection, we need to:
|
||||
// * remote close: absorb delayed packets
|
||||
|
||||
@@ -54,28 +54,6 @@ func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) {
|
||||
require.Equal(t, h, got)
|
||||
}
|
||||
|
||||
func TestPacketHandlerMapRetire(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, utils.DefaultLogger)
|
||||
dur := scaleDuration(10 * time.Millisecond)
|
||||
m.deleteRetiredConnsAfter = dur
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
h := &mockPacketHandler{}
|
||||
require.True(t, m.Add(connID, h))
|
||||
m.Retire(connID)
|
||||
|
||||
// immediately after retiring, the handler should still be there
|
||||
got, ok := m.Get(connID)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, h, got)
|
||||
|
||||
// after the timeout, the handler should be removed
|
||||
time.Sleep(dur)
|
||||
require.Eventually(t, func() bool {
|
||||
_, ok := m.Get(connID)
|
||||
return !ok
|
||||
}, dur, dur/10)
|
||||
}
|
||||
|
||||
func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, utils.DefaultLogger)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
|
||||
|
||||
Reference in New Issue
Block a user