From 4d39fc31c5a8ed73ccd0de688a7145d85f3724e8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 3 May 2025 22:06:23 +0800 Subject: [PATCH] simplify tracking of Transports for connection migration (#5111) No functional change expected. --- conn_id_generator.go | 9 +++++---- conn_id_generator_test.go | 27 ++++++++++++++++++++------- connection.go | 6 +++--- transport.go | 9 --------- 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/conn_id_generator.go b/conn_id_generator.go index 74f524142..0b0b90fd8 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -16,7 +16,8 @@ type connRunnerCallbacks struct { ReplaceWithClosed func([]protocol.ConnectionID, []byte) } -type connRunners map[transportID]connRunnerCallbacks +// The memory address of the Transport is used as the key. +type connRunners map[*Transport]connRunnerCallbacks func (cr connRunners) AddConnectionID(id protocol.ConnectionID) { for _, c := range cr { @@ -56,7 +57,7 @@ type connIDGenerator struct { } func newConnIDGenerator( - tID transportID, + tr *Transport, initialConnectionID protocol.ConnectionID, initialClientDestConnID *protocol.ConnectionID, // nil for the client statelessResetter *statelessResetter, @@ -68,7 +69,7 @@ func newConnIDGenerator( generator: generator, activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), statelessResetter: statelessResetter, - connRunners: map[transportID]connRunnerCallbacks{tID: connRunner}, + connRunners: map[*Transport]connRunnerCallbacks{tr: connRunner}, queueControlFrame: queueControlFrame, } m.activeSrcConnIDs[0] = initialConnectionID @@ -201,7 +202,7 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) { m.connRunners.ReplaceWithClosed(connIDs, connClose) } -func (m *connIDGenerator) AddConnRunner(id transportID, r connRunnerCallbacks) { +func (m *connIDGenerator) AddConnRunner(id *Transport, r connRunnerCallbacks) { // The transport might have already been added earlier. // This happens if the application migrates back to and old path. if _, ok := m.connRunners[id]; ok { diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index cd6d24c9a..bdd31bdb3 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -34,7 +34,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID initialClientDestConnID = &connID } g := newConnIDGenerator( - 1, + &Transport{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, sr, @@ -107,7 +107,7 @@ func TestConnIDGeneratorRetiring(t *testing.T) { initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2}) var added, removed []protocol.ConnectionID g := newConnIDGenerator( - 1, + &Transport{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), &initialConnID, newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), @@ -178,7 +178,7 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool) removed []protocol.ConnectionID ) g := newConnIDGenerator( - 0, + &Transport{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), @@ -228,7 +228,7 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn replacedWith []byte ) g := newConnIDGenerator( - 1, + &Transport{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), @@ -274,7 +274,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) { added, removed, replaced []protocol.ConnectionID } - var tracker1, tracker2 connIDTracker + var tracker1, tracker2, tracker3 connIDTracker runner1 := connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) }, @@ -289,12 +289,20 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) { tracker2.replaced = append(tracker2.replaced, connIDs...) }, } + runner3 := connRunnerCallbacks{ + AddConnectionID: func(c protocol.ConnectionID) { tracker3.added = append(tracker3.added, c) }, + RemoveConnectionID: func(c protocol.ConnectionID) { tracker3.removed = append(tracker3.removed, c) }, + ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) { + tracker3.replaced = append(tracker3.replaced, connIDs...) + }, + } sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}) var queuedFrames []wire.Frame + tr := &Transport{} g := newConnIDGenerator( - 1, + tr, initialConnID, &clientDestConnID, sr, @@ -306,7 +314,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) { require.Len(t, tracker1.added, 2) // add the second runner - it should get all existing connection IDs - g.AddConnRunner(2, runner2) + g.AddConnRunner(&Transport{}, runner2) require.Len(t, tracker1.added, 2) // unchanged require.Len(t, tracker2.added, 4) require.Contains(t, tracker2.added, initialConnID) @@ -314,6 +322,11 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) { require.Contains(t, tracker2.added, tracker1.added[0]) require.Contains(t, tracker2.added, tracker1.added[1]) + // adding the same transport again doesn't do anything + trCopy := tr + g.AddConnRunner(trCopy, runner3) + require.Empty(t, tracker3.added) + var connIDToRetire protocol.ConnectionID var seqToRetire uint64 ncid := queuedFrames[0].(*wire.NewConnectionIDFrame) diff --git a/connection.go b/connection.go index 682e356b5..1165078e7 100644 --- a/connection.go +++ b/connection.go @@ -269,7 +269,7 @@ var newConnection = func( s.queueControlFrame, ) s.connIDGenerator = newConnIDGenerator( - tr.id(), + tr, srcConnID, &clientDestConnID, statelessResetter, @@ -383,7 +383,7 @@ var newClientConnection = func( s.queueControlFrame, ) s.connIDGenerator = newConnIDGenerator( - tr.id(), + tr, srcConnID, nil, statelessResetter, @@ -2652,7 +2652,7 @@ func (s *connection) AddPath(t *Transport) (*Path, error) { func() { runner := t.connRunner() s.connIDGenerator.AddConnRunner( - t.id(), + t, connRunnerCallbacks{ AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) }, RemoveConnectionID: runner.Remove, diff --git a/transport.go b/transport.go index a7775fb14..915c1ce08 100644 --- a/transport.go +++ b/transport.go @@ -38,10 +38,6 @@ func (e *errTransportClosed) Is(target error) bool { return ok } -type transportID uint64 - -var transportIDCounter atomic.Uint64 - var errListenerAlreadySet = errors.New("listener already set") // The Transport is the central point to manage incoming and outgoing QUIC connections. @@ -136,8 +132,6 @@ type Transport struct { initOnce sync.Once initErr error - // Set in init. - transportID transportID // If no ConnectionIDGenerator is set, this is the ConnectionIDLength. connIDLen int // Set in init. @@ -376,7 +370,6 @@ func (t *Transport) doDial( func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.initOnce.Do(func() { - t.transportID = transportID(transportIDCounter.Add(1)) var conn rawConn if c, ok := t.Conn.(rawConn); ok { conn = c @@ -430,8 +423,6 @@ func (t *Transport) connRunner() packetHandlerManager { return t.handlerMap } -func (t *Transport) id() transportID { return t.transportID } - // WriteTo sends a packet on the underlying connection. func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil {