simplify tracking of Transports for connection migration (#5111)

No functional change expected.
This commit is contained in:
Marten Seemann
2025-05-03 22:06:23 +08:00
committed by GitHub
parent 97e7657df5
commit 4d39fc31c5
4 changed files with 28 additions and 23 deletions

View File

@@ -16,7 +16,8 @@ type connRunnerCallbacks struct {
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
}
type connRunners map[transportID]connRunnerCallbacks
// The memory address of the Transport is used as the key.
type connRunners map[*Transport]connRunnerCallbacks
func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
@@ -56,7 +57,7 @@ type connIDGenerator struct {
}
func newConnIDGenerator(
tID transportID,
tr *Transport,
initialConnectionID protocol.ConnectionID,
initialClientDestConnID *protocol.ConnectionID, // nil for the client
statelessResetter *statelessResetter,
@@ -68,7 +69,7 @@ func newConnIDGenerator(
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
statelessResetter: statelessResetter,
connRunners: map[transportID]connRunnerCallbacks{tID: connRunner},
connRunners: map[*Transport]connRunnerCallbacks{tr: connRunner},
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
@@ -201,7 +202,7 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
m.connRunners.ReplaceWithClosed(connIDs, connClose)
}
func (m *connIDGenerator) AddConnRunner(id transportID, r connRunnerCallbacks) {
func (m *connIDGenerator) AddConnRunner(id *Transport, r connRunnerCallbacks) {
// The transport might have already been added earlier.
// This happens if the application migrates back to and old path.
if _, ok := m.connRunners[id]; ok {

View File

@@ -34,7 +34,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
initialClientDestConnID = &connID
}
g := newConnIDGenerator(
1,
&Transport{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
sr,
@@ -107,7 +107,7 @@ func TestConnIDGeneratorRetiring(t *testing.T) {
initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
var added, removed []protocol.ConnectionID
g := newConnIDGenerator(
1,
&Transport{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
&initialConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
@@ -178,7 +178,7 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool)
removed []protocol.ConnectionID
)
g := newConnIDGenerator(
0,
&Transport{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
@@ -228,7 +228,7 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
replacedWith []byte
)
g := newConnIDGenerator(
1,
&Transport{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
@@ -274,7 +274,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
added, removed, replaced []protocol.ConnectionID
}
var tracker1, tracker2 connIDTracker
var tracker1, tracker2, tracker3 connIDTracker
runner1 := connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) },
@@ -289,12 +289,20 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
tracker2.replaced = append(tracker2.replaced, connIDs...)
},
}
runner3 := connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { tracker3.added = append(tracker3.added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { tracker3.removed = append(tracker3.removed, c) },
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) {
tracker3.replaced = append(tracker3.replaced, connIDs...)
},
}
sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
var queuedFrames []wire.Frame
tr := &Transport{}
g := newConnIDGenerator(
1,
tr,
initialConnID,
&clientDestConnID,
sr,
@@ -306,7 +314,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
require.Len(t, tracker1.added, 2)
// add the second runner - it should get all existing connection IDs
g.AddConnRunner(2, runner2)
g.AddConnRunner(&Transport{}, runner2)
require.Len(t, tracker1.added, 2) // unchanged
require.Len(t, tracker2.added, 4)
require.Contains(t, tracker2.added, initialConnID)
@@ -314,6 +322,11 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
require.Contains(t, tracker2.added, tracker1.added[0])
require.Contains(t, tracker2.added, tracker1.added[1])
// adding the same transport again doesn't do anything
trCopy := tr
g.AddConnRunner(trCopy, runner3)
require.Empty(t, tracker3.added)
var connIDToRetire protocol.ConnectionID
var seqToRetire uint64
ncid := queuedFrames[0].(*wire.NewConnectionIDFrame)

View File

@@ -269,7 +269,7 @@ var newConnection = func(
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
tr.id(),
tr,
srcConnID,
&clientDestConnID,
statelessResetter,
@@ -383,7 +383,7 @@ var newClientConnection = func(
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
tr.id(),
tr,
srcConnID,
nil,
statelessResetter,
@@ -2652,7 +2652,7 @@ func (s *connection) AddPath(t *Transport) (*Path, error) {
func() {
runner := t.connRunner()
s.connIDGenerator.AddConnRunner(
t.id(),
t,
connRunnerCallbacks{
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
RemoveConnectionID: runner.Remove,

View File

@@ -38,10 +38,6 @@ func (e *errTransportClosed) Is(target error) bool {
return ok
}
type transportID uint64
var transportIDCounter atomic.Uint64
var errListenerAlreadySet = errors.New("listener already set")
// The Transport is the central point to manage incoming and outgoing QUIC connections.
@@ -136,8 +132,6 @@ type Transport struct {
initOnce sync.Once
initErr error
// Set in init.
transportID transportID
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
connIDLen int
// Set in init.
@@ -376,7 +370,6 @@ func (t *Transport) doDial(
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.initOnce.Do(func() {
t.transportID = transportID(transportIDCounter.Add(1))
var conn rawConn
if c, ok := t.Conn.(rawConn); ok {
conn = c
@@ -430,8 +423,6 @@ func (t *Transport) connRunner() packetHandlerManager {
return t.handlerMap
}
func (t *Transport) id() transportID { return t.transportID }
// WriteTo sends a packet on the underlying connection.
func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
if err := t.init(false); err != nil {