From 0a2c2f0a829d200e36192d46888a4621432dd36d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 12 Mar 2025 11:47:58 +0700 Subject: [PATCH] extend connection ID generator to interact with multiple Transports (#4974) This is needed when multiple paths are in use. --- conn_id_generator.go | 56 +++++++++++++++++++++++----- conn_id_generator_test.go | 78 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 10 deletions(-) diff --git a/conn_id_generator.go b/conn_id_generator.go index a5c203ff2..675d93835 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -15,10 +15,36 @@ type connRunnerCallbacks struct { ReplaceWithClosed func([]protocol.ConnectionID, []byte) } +type connRunners []connRunnerCallbacks + +func (cr connRunners) AddConnectionID(id protocol.ConnectionID) { + for _, c := range cr { + c.AddConnectionID(id) + } +} + +func (cr connRunners) RemoveConnectionID(id protocol.ConnectionID) { + for _, c := range cr { + c.RemoveConnectionID(id) + } +} + +func (cr connRunners) RetireConnectionID(id protocol.ConnectionID) { + for _, c := range cr { + c.RetireConnectionID(id) + } +} + +func (cr connRunners) ReplaceWithClosed(ids []protocol.ConnectionID, b []byte) { + for _, c := range cr { + c.ReplaceWithClosed(ids, b) + } +} + type connIDGenerator struct { - generator ConnectionIDGenerator - connRunner connRunnerCallbacks - highestSeq uint64 + generator ConnectionIDGenerator + highestSeq uint64 + connRunners connRunners activeSrcConnIDs map[uint64]protocol.ConnectionID initialClientDestConnID *protocol.ConnectionID // nil for the client @@ -40,7 +66,7 @@ func newConnIDGenerator( generator: generator, activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), statelessResetter: statelessResetter, - connRunner: connRunner, + connRunners: []connRunnerCallbacks{connRunner}, queueControlFrame: queueControlFrame, } m.activeSrcConnIDs[0] = initialConnectionID @@ -84,7 +110,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), } } - m.connRunner.RetireConnectionID(connID) + m.connRunners.RetireConnectionID(connID) delete(m.activeSrcConnIDs, seq) // Don't issue a replacement for the initial connection ID. if seq == 0 { @@ -99,7 +125,7 @@ func (m *connIDGenerator) issueNewConnID() error { return err } m.activeSrcConnIDs[m.highestSeq+1] = connID - m.connRunner.AddConnectionID(connID) + m.connRunners.AddConnectionID(connID) m.queueControlFrame(&wire.NewConnectionIDFrame{ SequenceNumber: m.highestSeq + 1, ConnectionID: connID, @@ -111,17 +137,17 @@ func (m *connIDGenerator) issueNewConnID() error { func (m *connIDGenerator) SetHandshakeComplete() { if m.initialClientDestConnID != nil { - m.connRunner.RetireConnectionID(*m.initialClientDestConnID) + m.connRunners.RetireConnectionID(*m.initialClientDestConnID) m.initialClientDestConnID = nil } } func (m *connIDGenerator) RemoveAll() { if m.initialClientDestConnID != nil { - m.connRunner.RemoveConnectionID(*m.initialClientDestConnID) + m.connRunners.RemoveConnectionID(*m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { - m.connRunner.RemoveConnectionID(connID) + m.connRunners.RemoveConnectionID(connID) } } @@ -133,5 +159,15 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) { for _, connID := range m.activeSrcConnIDs { connIDs = append(connIDs, connID) } - m.connRunner.ReplaceWithClosed(connIDs, connClose) + m.connRunners.ReplaceWithClosed(connIDs, connClose) +} + +func (m *connIDGenerator) AddConnRunner(r connRunnerCallbacks) { + if m.initialClientDestConnID != nil { + r.AddConnectionID(*m.initialClientDestConnID) + } + for _, connID := range m.activeSrcConnIDs { + r.AddConnectionID(connID) + } + m.connRunners = append(m.connRunners, r) } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 4a9a09432..1595e90e2 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -198,3 +198,81 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn require.Contains(t, replaced, protocol.ParseConnectionID([]byte{1, 1, 1, 1})) require.Equal(t, []byte("foobar"), replacedWith) } + +func TestConnIDGeneratorAddConnRunner(t *testing.T) { + initialConnID := protocol.ParseConnectionID([]byte{1, 1, 1, 1}) + clientDestConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2}) + + type connIDTracker struct { + added, removed, retired, replaced []protocol.ConnectionID + } + + var tracker1, tracker2 connIDTracker + runner1 := connRunnerCallbacks{ + AddConnectionID: func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) }, + RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) }, + RetireConnectionID: func(c protocol.ConnectionID) { tracker1.retired = append(tracker1.retired, c) }, + ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) { + tracker1.replaced = append(tracker1.replaced, connIDs...) + }, + } + runner2 := connRunnerCallbacks{ + AddConnectionID: func(c protocol.ConnectionID) { tracker2.added = append(tracker2.added, c) }, + RemoveConnectionID: func(c protocol.ConnectionID) { tracker2.removed = append(tracker2.removed, c) }, + RetireConnectionID: func(c protocol.ConnectionID) { tracker2.retired = append(tracker2.retired, c) }, + ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) { + tracker2.replaced = append(tracker2.replaced, connIDs...) + }, + } + + sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}) + var queuedFrames []wire.Frame + + g := newConnIDGenerator( + initialConnID, + &clientDestConnID, + sr, + runner1, + func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, + &protocol.DefaultConnectionIDGenerator{ConnLen: 5}, + ) + require.NoError(t, g.SetMaxActiveConnIDs(3)) + require.Len(t, tracker1.added, 2) + + // add the second runner - it should get all existing connection IDs + g.AddConnRunner(runner2) + require.Len(t, tracker1.added, 2) // unchanged + require.Len(t, tracker2.added, 4) + require.Contains(t, tracker2.added, initialConnID) + require.Contains(t, tracker2.added, clientDestConnID) + require.Contains(t, tracker2.added, tracker1.added[0]) + require.Contains(t, tracker2.added, tracker1.added[1]) + + var connIDToRetire protocol.ConnectionID + var seqToRetire uint64 + ncid := queuedFrames[0].(*wire.NewConnectionIDFrame) + connIDToRetire = ncid.ConnectionID + seqToRetire = ncid.SequenceNumber + + tracker1.retired = nil + tracker2.retired = nil + require.NoError(t, g.Retire(seqToRetire, protocol.ParseConnectionID([]byte{3, 3, 3, 3}))) + require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker1.retired) + require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker2.retired) + + tracker1.retired = nil + tracker2.retired = nil + g.SetHandshakeComplete() + require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker1.retired) + require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker2.retired) + + g.ReplaceWithClosed([]byte("connection closed")) + require.True(t, len(tracker1.replaced) > 0) + require.Equal(t, tracker1.replaced, tracker2.replaced) + + tracker1.removed = nil + tracker2.removed = nil + g.RemoveAll() + require.NotEmpty(t, tracker1.removed) + require.Equal(t, tracker1.removed, tracker2.removed) +}