extend connection ID generator to interact with multiple Transports (#4974)

This is needed when multiple paths are in use.
This commit is contained in:
Marten Seemann
2025-03-12 11:47:58 +07:00
committed by GitHub
parent 565b7f75ab
commit 0a2c2f0a82
2 changed files with 124 additions and 10 deletions

View File

@@ -15,10 +15,36 @@ type connRunnerCallbacks struct {
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
}
type connRunners []connRunnerCallbacks
func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
c.AddConnectionID(id)
}
}
func (cr connRunners) RemoveConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
c.RemoveConnectionID(id)
}
}
func (cr connRunners) RetireConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
c.RetireConnectionID(id)
}
}
func (cr connRunners) ReplaceWithClosed(ids []protocol.ConnectionID, b []byte) {
for _, c := range cr {
c.ReplaceWithClosed(ids, b)
}
}
type connIDGenerator struct {
generator ConnectionIDGenerator
connRunner connRunnerCallbacks
highestSeq uint64
generator ConnectionIDGenerator
highestSeq uint64
connRunners connRunners
activeSrcConnIDs map[uint64]protocol.ConnectionID
initialClientDestConnID *protocol.ConnectionID // nil for the client
@@ -40,7 +66,7 @@ func newConnIDGenerator(
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
statelessResetter: statelessResetter,
connRunner: connRunner,
connRunners: []connRunnerCallbacks{connRunner},
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
@@ -84,7 +110,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
}
}
m.connRunner.RetireConnectionID(connID)
m.connRunners.RetireConnectionID(connID)
delete(m.activeSrcConnIDs, seq)
// Don't issue a replacement for the initial connection ID.
if seq == 0 {
@@ -99,7 +125,7 @@ func (m *connIDGenerator) issueNewConnID() error {
return err
}
m.activeSrcConnIDs[m.highestSeq+1] = connID
m.connRunner.AddConnectionID(connID)
m.connRunners.AddConnectionID(connID)
m.queueControlFrame(&wire.NewConnectionIDFrame{
SequenceNumber: m.highestSeq + 1,
ConnectionID: connID,
@@ -111,17 +137,17 @@ func (m *connIDGenerator) issueNewConnID() error {
func (m *connIDGenerator) SetHandshakeComplete() {
if m.initialClientDestConnID != nil {
m.connRunner.RetireConnectionID(*m.initialClientDestConnID)
m.connRunners.RetireConnectionID(*m.initialClientDestConnID)
m.initialClientDestConnID = nil
}
}
func (m *connIDGenerator) RemoveAll() {
if m.initialClientDestConnID != nil {
m.connRunner.RemoveConnectionID(*m.initialClientDestConnID)
m.connRunners.RemoveConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.connRunner.RemoveConnectionID(connID)
m.connRunners.RemoveConnectionID(connID)
}
}
@@ -133,5 +159,15 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
m.connRunner.ReplaceWithClosed(connIDs, connClose)
m.connRunners.ReplaceWithClosed(connIDs, connClose)
}
func (m *connIDGenerator) AddConnRunner(r connRunnerCallbacks) {
if m.initialClientDestConnID != nil {
r.AddConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
r.AddConnectionID(connID)
}
m.connRunners = append(m.connRunners, r)
}

View File

@@ -198,3 +198,81 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
require.Contains(t, replaced, protocol.ParseConnectionID([]byte{1, 1, 1, 1}))
require.Equal(t, []byte("foobar"), replacedWith)
}
func TestConnIDGeneratorAddConnRunner(t *testing.T) {
initialConnID := protocol.ParseConnectionID([]byte{1, 1, 1, 1})
clientDestConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
type connIDTracker struct {
added, removed, retired, replaced []protocol.ConnectionID
}
var tracker1, tracker2 connIDTracker
runner1 := connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) },
RetireConnectionID: func(c protocol.ConnectionID) { tracker1.retired = append(tracker1.retired, c) },
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) {
tracker1.replaced = append(tracker1.replaced, connIDs...)
},
}
runner2 := connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { tracker2.added = append(tracker2.added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { tracker2.removed = append(tracker2.removed, c) },
RetireConnectionID: func(c protocol.ConnectionID) { tracker2.retired = append(tracker2.retired, c) },
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) {
tracker2.replaced = append(tracker2.replaced, connIDs...)
},
}
sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
var queuedFrames []wire.Frame
g := newConnIDGenerator(
initialConnID,
&clientDestConnID,
sr,
runner1,
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
require.NoError(t, g.SetMaxActiveConnIDs(3))
require.Len(t, tracker1.added, 2)
// add the second runner - it should get all existing connection IDs
g.AddConnRunner(runner2)
require.Len(t, tracker1.added, 2) // unchanged
require.Len(t, tracker2.added, 4)
require.Contains(t, tracker2.added, initialConnID)
require.Contains(t, tracker2.added, clientDestConnID)
require.Contains(t, tracker2.added, tracker1.added[0])
require.Contains(t, tracker2.added, tracker1.added[1])
var connIDToRetire protocol.ConnectionID
var seqToRetire uint64
ncid := queuedFrames[0].(*wire.NewConnectionIDFrame)
connIDToRetire = ncid.ConnectionID
seqToRetire = ncid.SequenceNumber
tracker1.retired = nil
tracker2.retired = nil
require.NoError(t, g.Retire(seqToRetire, protocol.ParseConnectionID([]byte{3, 3, 3, 3})))
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker1.retired)
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker2.retired)
tracker1.retired = nil
tracker2.retired = nil
g.SetHandshakeComplete()
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker1.retired)
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker2.retired)
g.ReplaceWithClosed([]byte("connection closed"))
require.True(t, len(tracker1.replaced) > 0)
require.Equal(t, tracker1.replaced, tracker2.replaced)
tracker1.removed = nil
tracker2.removed = nil
g.RemoveAll()
require.NotEmpty(t, tracker1.removed)
require.Equal(t, tracker1.removed, tracker2.removed)
}