diff --git a/conn_id_generator.go b/conn_id_generator.go index e05fbd7fb..74f524142 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -2,6 +2,8 @@ package quic import ( "fmt" + "slices" + "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" @@ -11,7 +13,6 @@ import ( type connRunnerCallbacks struct { AddConnectionID func(protocol.ConnectionID) RemoveConnectionID func(protocol.ConnectionID) - RetireConnectionID func(protocol.ConnectionID) ReplaceWithClosed func([]protocol.ConnectionID, []byte) } @@ -29,24 +30,24 @@ func (cr connRunners) RemoveConnectionID(id protocol.ConnectionID) { } } -func (cr connRunners) RetireConnectionID(id protocol.ConnectionID) { - for _, c := range cr { - c.RetireConnectionID(id) - } -} - func (cr connRunners) ReplaceWithClosed(ids []protocol.ConnectionID, b []byte) { for _, c := range cr { c.ReplaceWithClosed(ids, b) } } +type connIDToRetire struct { + t time.Time + connID protocol.ConnectionID +} + type connIDGenerator struct { generator ConnectionIDGenerator highestSeq uint64 connRunners connRunners activeSrcConnIDs map[uint64]protocol.ConnectionID + connIDsToRetire []connIDToRetire // sorted by t initialClientDestConnID *protocol.ConnectionID // nil for the client statelessResetter *statelessResetter @@ -93,7 +94,7 @@ func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { return nil } -func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error { +func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID, expiry time.Time) error { if seq > m.highestSeq { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, @@ -111,7 +112,8 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), } } - m.connRunners.RetireConnectionID(connID) + m.queueConnIDForRetiring(connID, expiry) + delete(m.activeSrcConnIDs, seq) // Don't issue a replacement for the initial connection ID. if seq == 0 { @@ -120,6 +122,16 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect return m.issueNewConnID() } +func (m *connIDGenerator) queueConnIDForRetiring(connID protocol.ConnectionID, expiry time.Time) { + idx := slices.IndexFunc(m.connIDsToRetire, func(c connIDToRetire) bool { + return c.t.After(expiry) + }) + if idx == -1 { + idx = len(m.connIDsToRetire) + } + m.connIDsToRetire = slices.Insert(m.connIDsToRetire, idx, connIDToRetire{t: expiry, connID: connID}) +} + func (m *connIDGenerator) issueNewConnID() error { connID, err := m.generator.GenerateConnectionID() if err != nil { @@ -136,13 +148,33 @@ func (m *connIDGenerator) issueNewConnID() error { return nil } -func (m *connIDGenerator) SetHandshakeComplete() { +func (m *connIDGenerator) SetHandshakeComplete(connIDExpiry time.Time) { if m.initialClientDestConnID != nil { - m.connRunners.RetireConnectionID(*m.initialClientDestConnID) + m.queueConnIDForRetiring(*m.initialClientDestConnID, connIDExpiry) m.initialClientDestConnID = nil } } +func (m *connIDGenerator) NextRetireTime() time.Time { + if len(m.connIDsToRetire) == 0 { + return time.Time{} + } + return m.connIDsToRetire[0].t +} + +func (m *connIDGenerator) RemoveRetiredConnIDs(now time.Time) { + if len(m.connIDsToRetire) == 0 { + return + } + for _, c := range m.connIDsToRetire { + if c.t.After(now) { + break + } + m.connRunners.RemoveConnectionID(c.connID) + m.connIDsToRetire = m.connIDsToRetire[1:] + } +} + func (m *connIDGenerator) RemoveAll() { if m.initialClientDestConnID != nil { m.connRunners.RemoveConnectionID(*m.initialClientDestConnID) @@ -150,16 +182,22 @@ func (m *connIDGenerator) RemoveAll() { for _, connID := range m.activeSrcConnIDs { m.connRunners.RemoveConnectionID(connID) } + for _, c := range m.connIDsToRetire { + m.connRunners.RemoveConnectionID(c.connID) + } } func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) { - connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) + connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+len(m.connIDsToRetire)+1) if m.initialClientDestConnID != nil { connIDs = append(connIDs, *m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { connIDs = append(connIDs, connID) } + for _, c := range m.connIDsToRetire { + connIDs = append(connIDs, c.connID) + } m.connRunners.ReplaceWithClosed(connIDs, connClose) } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index b4ca69519..cd6d24c9a 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -1,7 +1,9 @@ package quic import ( + "math/rand/v2" "testing" + "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" @@ -22,7 +24,7 @@ func TestConnIDGeneratorIssueAndRetire(t *testing.T) { func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID bool) { var ( added []protocol.ConnectionID - retired []protocol.ConnectionID + removed []protocol.ConnectionID ) var queuedFrames []wire.Frame sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}) @@ -38,8 +40,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID sr, connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) }, - RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") }, - RetireConnectionID: func(c protocol.ConnectionID) { retired = append(retired, c) }, + RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) }, ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {}, }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, @@ -50,7 +51,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID require.NoError(t, g.SetMaxActiveConnIDs(4)) require.Len(t, added, 3) require.Len(t, queuedFrames, 3) - require.Empty(t, retired) + require.Empty(t, removed) connIDs := make(map[uint64]protocol.ConnectionID) // connection IDs 1, 2 and 3 were issued for i, f := range queuedFrames { @@ -64,37 +65,97 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID // completing the handshake retires the initial client destination connection ID added = added[:0] queuedFrames = queuedFrames[:0] - g.SetHandshakeComplete() + now := time.Now() + g.SetHandshakeComplete(now) require.Empty(t, added) require.Empty(t, queuedFrames) + require.Empty(t, removed) + g.RemoveRetiredConnIDs(now) if hasInitialClientDestConnID { - require.Equal(t, []protocol.ConnectionID{*initialClientDestConnID}, retired) - retired = retired[:0] + require.Equal(t, []protocol.ConnectionID{*initialClientDestConnID}, removed) + removed = removed[:0] } else { - require.Empty(t, retired) + require.Empty(t, removed) } // it's invalid to retire a connection ID that hasn't been issued yet - err := g.Retire(4, protocol.ParseConnectionID([]byte{3, 3, 3, 3})) + err := g.Retire(4, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now()) require.ErrorIs(t, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}, err) require.ErrorContains(t, err, "retired connection ID 4 (highest issued: 3)") // it's invalid to retire a connection ID in a packet that uses that connection ID - err = g.Retire(3, connIDs[3]) + err = g.Retire(3, connIDs[3], time.Now()) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}) require.ErrorContains(t, err, "was used as the Destination Connection ID on this packet") // retiring a connection ID makes us issue a new one - require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}))) - require.Equal(t, []protocol.ConnectionID{connIDs[2]}, retired) + require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now())) + g.RemoveRetiredConnIDs(time.Now()) + require.Equal(t, []protocol.ConnectionID{connIDs[2]}, removed) require.Len(t, queuedFrames, 1) require.EqualValues(t, 4, queuedFrames[0].(*wire.NewConnectionIDFrame).SequenceNumber) queuedFrames = queuedFrames[:0] - retired = retired[:0] + removed = removed[:0] // duplicate retirements don't do anything - require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}))) + require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now())) + g.RemoveRetiredConnIDs(time.Now()) require.Empty(t, queuedFrames) - require.Empty(t, retired) + require.Empty(t, removed) +} + +func TestConnIDGeneratorRetiring(t *testing.T) { + initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2}) + var added, removed []protocol.ConnectionID + g := newConnIDGenerator( + 1, + protocol.ParseConnectionID([]byte{1, 1, 1, 1}), + &initialConnID, + newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), + connRunnerCallbacks{ + AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) }, + RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) }, + ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {}, + }, + func(f wire.Frame) {}, + &protocol.DefaultConnectionIDGenerator{ConnLen: 5}, + ) + require.NoError(t, g.SetMaxActiveConnIDs(6)) + require.Empty(t, removed) + require.Len(t, added, 5) + + now := time.Now() + + retirements := map[protocol.ConnectionID]time.Time{} + t1 := now.Add(time.Duration(rand.IntN(1000)) * time.Millisecond) + retirements[initialConnID] = t1 + g.SetHandshakeComplete(t1) + for i := range 5 { + t2 := now.Add(time.Duration(rand.IntN(1000)) * time.Millisecond) + require.NoError(t, g.Retire(uint64(i+1), protocol.ParseConnectionID([]byte{9, 9, 9, 9}), t2)) + retirements[added[i]] = t2 + + var nextRetirement time.Time + for _, r := range retirements { + if nextRetirement.IsZero() || r.Before(nextRetirement) { + nextRetirement = r + } + } + require.Equal(t, nextRetirement, g.NextRetireTime()) + + if rand.IntN(2) == 0 { + now = now.Add(time.Duration(rand.IntN(500)) * time.Millisecond) + g.RemoveRetiredConnIDs(now) + for _, r := range removed { + require.Contains(t, retirements, r) + require.LessOrEqual(t, retirements[r], now) + delete(retirements, r) + } + removed = removed[:0] + for _, r := range retirements { + require.Greater(t, r, now) + } + } + } } func TestConnIDGeneratorRemoveAll(t *testing.T) { @@ -124,7 +185,6 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool) connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) }, - RetireConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") }, ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {}, }, func(f wire.Frame) {}, @@ -175,7 +235,6 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") }, - RetireConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") }, ReplaceWithClosed: func(connIDs []protocol.ConnectionID, b []byte) { replaced = connIDs replacedWith = b @@ -187,13 +246,18 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn require.NoError(t, g.SetMaxActiveConnIDs(1000)) require.Len(t, added, protocol.MaxIssuedConnectionIDs-1) + // Retire two of these connection ID. + // This makes us issue two more connection IDs. + require.NoError(t, g.Retire(3, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), time.Now())) + require.NoError(t, g.Retire(4, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), time.Now())) + require.Len(t, added, protocol.MaxIssuedConnectionIDs+1) g.ReplaceWithClosed([]byte("foobar")) if hasInitialClientDestConnID { - require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+1) + require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+3) require.Contains(t, replaced, *initialClientDestConnID) } else { - require.Len(t, replaced, protocol.MaxIssuedConnectionIDs) + require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+2) } for _, id := range added { require.Contains(t, replaced, id) @@ -207,14 +271,13 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) { clientDestConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2}) type connIDTracker struct { - added, removed, retired, replaced []protocol.ConnectionID + added, removed, replaced []protocol.ConnectionID } var tracker1, tracker2 connIDTracker runner1 := connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) }, - RetireConnectionID: func(c protocol.ConnectionID) { tracker1.retired = append(tracker1.retired, c) }, ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) { tracker1.replaced = append(tracker1.replaced, connIDs...) }, @@ -222,7 +285,6 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) { runner2 := connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { tracker2.added = append(tracker2.added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { tracker2.removed = append(tracker2.removed, c) }, - RetireConnectionID: func(c protocol.ConnectionID) { tracker2.retired = append(tracker2.retired, c) }, ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) { tracker2.replaced = append(tracker2.replaced, connIDs...) }, @@ -258,17 +320,17 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) { connIDToRetire = ncid.ConnectionID seqToRetire = ncid.SequenceNumber - tracker1.retired = nil - tracker2.retired = nil - require.NoError(t, g.Retire(seqToRetire, protocol.ParseConnectionID([]byte{3, 3, 3, 3}))) - require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker1.retired) - require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker2.retired) + require.NoError(t, g.Retire(seqToRetire, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), time.Now())) + g.RemoveRetiredConnIDs(time.Now()) + require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker1.removed) + require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker2.removed) - tracker1.retired = nil - tracker2.retired = nil - g.SetHandshakeComplete() - require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker1.retired) - require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker2.retired) + tracker1.removed = nil + tracker2.removed = nil + g.SetHandshakeComplete(time.Now()) + g.RemoveRetiredConnIDs(time.Now()) + require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker1.removed) + require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker2.removed) g.ReplaceWithClosed([]byte("connection closed")) require.True(t, len(tracker1.replaced) > 0) diff --git a/connection.go b/connection.go index 79506fedb..682e356b5 100644 --- a/connection.go +++ b/connection.go @@ -86,7 +86,6 @@ func (p *receivedPacket) Clone() *receivedPacket { type connRunner interface { Add(protocol.ConnectionID, packetHandler) bool - Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) ReplaceWithClosed([]protocol.ConnectionID, []byte) AddResetToken(protocol.StatelessResetToken, packetHandler) @@ -277,7 +276,6 @@ var newConnection = func( connRunnerCallbacks{ AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) }, RemoveConnectionID: runner.Remove, - RetireConnectionID: runner.Retire, ReplaceWithClosed: runner.ReplaceWithClosed, }, s.queueControlFrame, @@ -392,7 +390,6 @@ var newClientConnection = func( connRunnerCallbacks{ AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) }, RemoveConnectionID: runner.Remove, - RetireConnectionID: runner.Retire, ReplaceWithClosed: runner.ReplaceWithClosed, }, s.queueControlFrame, @@ -652,6 +649,8 @@ runLoop: } } + s.connIDGenerator.RemoveRetiredConnIDs(now) + if s.perspective == protocol.PerspectiveClient { pm := s.pathManagerOutgoing.Load() if pm != nil { @@ -762,6 +761,7 @@ func (s *connection) maybeResetTimer() { s.timer.SetTimer( deadline, + s.connIDGenerator.NextRetireTime(), s.receivedPacketHandler.GetAlarmTimeout(), s.sentPacketHandler.GetLossDetectionTimeout(), s.pacingDeadline, @@ -801,7 +801,7 @@ func (s *connection) handleHandshakeComplete(now time.Time) error { s.undecryptablePackets = nil s.connIDManager.SetHandshakeComplete() - s.connIDGenerator.SetHandshakeComplete() + s.connIDGenerator.SetHandshakeComplete(now.Add(3 * s.rttStats.PTO(false))) if s.tracer != nil && s.tracer.ChoseALPN != nil { s.tracer.ChoseALPN(s.cryptoStreamHandler.ConnectionState().NegotiatedProtocol) @@ -1532,7 +1532,7 @@ func (s *connection) handleFrame( case *wire.NewConnectionIDFrame: err = s.handleNewConnectionIDFrame(frame) case *wire.RetireConnectionIDFrame: - err = s.handleRetireConnectionIDFrame(frame, destConnID) + err = s.handleRetireConnectionIDFrame(rcvTime, frame, destConnID) case *wire.HandshakeDoneFrame: err = s.handleHandshakeDoneFrame(rcvTime) case *wire.DatagramFrame: @@ -1751,8 +1751,8 @@ func (s *connection) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) er return s.connIDManager.Add(f) } -func (s *connection) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error { - return s.connIDGenerator.Retire(f.SequenceNumber, destConnID) +func (s *connection) handleRetireConnectionIDFrame(now time.Time, f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error { + return s.connIDGenerator.Retire(f.SequenceNumber, destConnID, now.Add(3*s.rttStats.PTO(false))) } func (s *connection) handleHandshakeDoneFrame(rcvTime time.Time) error { @@ -2656,7 +2656,6 @@ func (s *connection) AddPath(t *Transport) (*Path, error) { connRunnerCallbacks{ AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) }, RemoveConnectionID: runner.Remove, - RetireConnectionID: runner.Retire, ReplaceWithClosed: runner.ReplaceWithClosed, }, ) diff --git a/connection_test.go b/connection_test.go index 36990e78a..75e331663 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1163,7 +1163,6 @@ func TestConnectionHandshakeServer(t *testing.T) { require.NoError(t, err) cs.EXPECT().DiscardInitialKeys() - tc.connRunner.EXPECT().Retire(gomock.Any()) gomock.InOrder( cs.EXPECT().StartHandshake(gomock.Any()), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}), diff --git a/connection_timer.go b/connection_timer.go index 171fdd013..0fdc8cb8b 100644 --- a/connection_timer.go +++ b/connection_timer.go @@ -32,8 +32,11 @@ func (t *connectionTimer) Chan() <-chan time.Time { // It makes sure that the deadline is strictly increasing. // This prevents busy-looping in cases where the timer fires, but we can't actually send out a packet. // This doesn't apply to the pacing deadline, which can be set multiple times to deadlineSendImmediately. -func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, ackAlarm, lossTime, pacing time.Time) { +func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, connIDRetirement, ackAlarm, lossTime, pacing time.Time) { deadline := idleTimeoutOrKeepAlive + if !connIDRetirement.IsZero() && connIDRetirement.Before(deadline) && connIDRetirement.After(t.last) { + deadline = connIDRetirement + } if !ackAlarm.IsZero() && ackAlarm.Before(deadline) && ackAlarm.After(t.last) { deadline = ackAlarm } diff --git a/connection_timer_test.go b/connection_timer_test.go index 189f7b5b1..5f01b1bc2 100644 --- a/connection_timer_test.go +++ b/connection_timer_test.go @@ -14,25 +14,31 @@ func TestConnectionTimerModes(t *testing.T) { t.Run("idle timeout", func(t *testing.T) { timer := newTimer() - timer.SetTimer(now.Add(time.Hour), time.Time{}, time.Time{}, time.Time{}) + timer.SetTimer(now.Add(time.Hour), time.Time{}, time.Time{}, time.Time{}, time.Time{}) require.Equal(t, now.Add(time.Hour), timer.Deadline()) }) + t.Run("connection ID expiry", func(t *testing.T) { + timer := newTimer() + timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{}, time.Time{}) + require.Equal(t, now.Add(time.Minute), timer.Deadline()) + }) + t.Run("ACK timer", func(t *testing.T) { timer := newTimer() - timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{}) + timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), time.Time{}, time.Time{}) require.Equal(t, now.Add(time.Minute), timer.Deadline()) }) t.Run("loss timer", func(t *testing.T) { timer := newTimer() - timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), now.Add(time.Second), time.Time{}) + timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), now.Add(time.Second), time.Time{}) require.Equal(t, now.Add(time.Second), timer.Deadline()) }) t.Run("pacing timer", func(t *testing.T) { timer := newTimer() - timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), now.Add(time.Second), now.Add(time.Millisecond)) + timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), now.Add(time.Second), now.Add(time.Millisecond)) require.Equal(t, now.Add(time.Millisecond), timer.Deadline()) }) } @@ -40,10 +46,10 @@ func TestConnectionTimerModes(t *testing.T) { func TestConnectionTimerReset(t *testing.T) { now := time.Now() timer := newTimer() - timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{}) + timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), time.Time{}, time.Time{}) require.Equal(t, now.Add(time.Minute), timer.Deadline()) timer.SetRead() - timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{}) + timer.SetTimer(now.Add(time.Hour), time.Time{}, now.Add(time.Minute), time.Time{}, time.Time{}) require.Equal(t, now.Add(time.Hour), timer.Deadline()) } diff --git a/mock_conn_runner_test.go b/mock_conn_runner_test.go index 3db6a9b90..b6868fc31 100644 --- a/mock_conn_runner_test.go +++ b/mock_conn_runner_test.go @@ -221,39 +221,3 @@ func (c *MockConnRunnerReplaceWithClosedCall) DoAndReturn(f func([]protocol.Conn c.Call = c.Call.DoAndReturn(f) return c } - -// Retire mocks base method. -func (m *MockConnRunner) Retire(arg0 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Retire", arg0) -} - -// Retire indicates an expected call of Retire. -func (mr *MockConnRunnerMockRecorder) Retire(arg0 any) *MockConnRunnerRetireCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockConnRunner)(nil).Retire), arg0) - return &MockConnRunnerRetireCall{Call: call} -} - -// MockConnRunnerRetireCall wrap *gomock.Call -type MockConnRunnerRetireCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockConnRunnerRetireCall) Return() *MockConnRunnerRetireCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockConnRunnerRetireCall) Do(f func(protocol.ConnectionID)) *MockConnRunnerRetireCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockConnRunnerRetireCall) DoAndReturn(f func(protocol.ConnectionID)) *MockConnRunnerRetireCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index a51e590b2..27c8f3b87 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -373,39 +373,3 @@ func (c *MockPacketHandlerManagerReplaceWithClosedCall) DoAndReturn(f func([]pro c.Call = c.Call.DoAndReturn(f) return c } - -// Retire mocks base method. -func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Retire", arg0) -} - -// Retire indicates an expected call of Retire. -func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 any) *MockPacketHandlerManagerRetireCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0) - return &MockPacketHandlerManagerRetireCall{Call: call} -} - -// MockPacketHandlerManagerRetireCall wrap *gomock.Call -type MockPacketHandlerManagerRetireCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPacketHandlerManagerRetireCall) Return() *MockPacketHandlerManagerRetireCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPacketHandlerManagerRetireCall) Do(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRetireCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPacketHandlerManagerRetireCall) DoAndReturn(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRetireCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/packet_handler_map.go b/packet_handler_map.go index cc690b1f8..83afdbfd4 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -108,16 +108,6 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { h.logger.Debugf("Removing connection ID %s.", id) } -func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { - h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter) - time.AfterFunc(h.deleteRetiredConnsAfter, func() { - h.mutex.Lock() - delete(h.handlers, id) - h.mutex.Unlock() - h.logger.Debugf("Removing connection ID %s after it has been retired.", id) - }) -} - // ReplaceWithClosed is called when a connection is closed. // Depending on which side closed the connection, we need to: // * remote close: absorb delayed packets diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 6d8266704..e4442d4f4 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -54,28 +54,6 @@ func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) { require.Equal(t, h, got) } -func TestPacketHandlerMapRetire(t *testing.T) { - m := newPacketHandlerMap(nil, utils.DefaultLogger) - dur := scaleDuration(10 * time.Millisecond) - m.deleteRetiredConnsAfter = dur - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) - h := &mockPacketHandler{} - require.True(t, m.Add(connID, h)) - m.Retire(connID) - - // immediately after retiring, the handler should still be there - got, ok := m.Get(connID) - require.True(t, ok) - require.Equal(t, h, got) - - // after the timeout, the handler should be removed - time.Sleep(dur) - require.Eventually(t, func() bool { - _, ok := m.Get(connID) - return !ok - }, dur, dur/10) -} - func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) { m := newPacketHandlerMap(nil, utils.DefaultLogger) token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}