refactor connection ID generator callback handling (#4961)

This commit is contained in:
Marten Seemann
2025-02-22 12:06:01 +01:00
committed by GitHub
parent d726a795db
commit eb2f986a06
3 changed files with 55 additions and 46 deletions

View File

@@ -8,41 +8,40 @@ import (
"github.com/quic-go/quic-go/internal/wire"
)
type connRunnerCallbacks struct {
AddConnectionID func(protocol.ConnectionID)
RemoveConnectionID func(protocol.ConnectionID)
RetireConnectionID func(protocol.ConnectionID)
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
}
type connIDGenerator struct {
generator ConnectionIDGenerator
connRunner connRunnerCallbacks
highestSeq uint64
activeSrcConnIDs map[uint64]protocol.ConnectionID
initialClientDestConnID *protocol.ConnectionID // nil for the client
addConnectionID func(protocol.ConnectionID)
statelessResetter *statelessResetter
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func([]protocol.ConnectionID, []byte)
queueControlFrame func(wire.Frame)
statelessResetter *statelessResetter
queueControlFrame func(wire.Frame)
}
func newConnIDGenerator(
initialConnectionID protocol.ConnectionID,
initialClientDestConnID *protocol.ConnectionID, // nil for the client
addConnectionID func(protocol.ConnectionID),
statelessResetter *statelessResetter,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, []byte),
connRunner connRunnerCallbacks,
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
) *connIDGenerator {
m := &connIDGenerator{
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
addConnectionID: addConnectionID,
statelessResetter: statelessResetter,
removeConnectionID: removeConnectionID,
retireConnectionID: retireConnectionID,
replaceWithClosed: replaceWithClosed,
queueControlFrame: queueControlFrame,
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
statelessResetter: statelessResetter,
connRunner: connRunner,
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
m.initialClientDestConnID = initialClientDestConnID
@@ -85,7 +84,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
}
}
m.retireConnectionID(connID)
m.connRunner.RetireConnectionID(connID)
delete(m.activeSrcConnIDs, seq)
// Don't issue a replacement for the initial connection ID.
if seq == 0 {
@@ -100,7 +99,7 @@ func (m *connIDGenerator) issueNewConnID() error {
return err
}
m.activeSrcConnIDs[m.highestSeq+1] = connID
m.addConnectionID(connID)
m.connRunner.AddConnectionID(connID)
m.queueControlFrame(&wire.NewConnectionIDFrame{
SequenceNumber: m.highestSeq + 1,
ConnectionID: connID,
@@ -112,17 +111,17 @@ func (m *connIDGenerator) issueNewConnID() error {
func (m *connIDGenerator) SetHandshakeComplete() {
if m.initialClientDestConnID != nil {
m.retireConnectionID(*m.initialClientDestConnID)
m.connRunner.RetireConnectionID(*m.initialClientDestConnID)
m.initialClientDestConnID = nil
}
}
func (m *connIDGenerator) RemoveAll() {
if m.initialClientDestConnID != nil {
m.removeConnectionID(*m.initialClientDestConnID)
m.connRunner.RemoveConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.removeConnectionID(connID)
m.connRunner.RemoveConnectionID(connID)
}
}
@@ -134,5 +133,5 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
m.replaceWithClosed(connIDs, connClose)
m.connRunner.ReplaceWithClosed(connIDs, connClose)
}

View File

@@ -34,11 +34,13 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
g := newConnIDGenerator(
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
func(c protocol.ConnectionID) { added = append(added, c) },
sr,
func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") },
func(c protocol.ConnectionID) { retired = append(retired, c) },
func([]protocol.ConnectionID, []byte) {},
connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") },
RetireConnectionID: func(c protocol.ConnectionID) { retired = append(retired, c) },
ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {},
},
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
@@ -116,11 +118,13 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool)
g := newConnIDGenerator(
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
func(c protocol.ConnectionID) { added = append(added, c) },
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
func(c protocol.ConnectionID) { removed = append(removed, c) },
func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") },
func([]protocol.ConnectionID, []byte) {},
connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) },
RetireConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") },
ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {},
},
func(f wire.Frame) {},
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
@@ -164,13 +168,15 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
g := newConnIDGenerator(
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
func(c protocol.ConnectionID) { added = append(added, c) },
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") },
func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") },
func(connIDs []protocol.ConnectionID, b []byte) {
replaced = connIDs
replacedWith = b
connRunnerCallbacks{
AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) },
RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") },
RetireConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") },
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, b []byte) {
replaced = connIDs
replacedWith = b
},
},
func(f wire.Frame) {},
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},

View File

@@ -267,11 +267,13 @@ var newConnection = func(
s.connIDGenerator = newConnIDGenerator(
srcConnID,
&clientDestConnID,
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
statelessResetter,
runner.Remove,
runner.Retire,
runner.ReplaceWithClosed,
connRunnerCallbacks{
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
RemoveConnectionID: runner.Remove,
RetireConnectionID: runner.Retire,
ReplaceWithClosed: runner.ReplaceWithClosed,
},
s.queueControlFrame,
connIDGenerator,
)
@@ -378,11 +380,13 @@ var newClientConnection = func(
s.connIDGenerator = newConnIDGenerator(
srcConnID,
nil,
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
statelessResetter,
runner.Remove,
runner.Retire,
runner.ReplaceWithClosed,
connRunnerCallbacks{
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
RemoveConnectionID: runner.Remove,
RetireConnectionID: runner.Retire,
ReplaceWithClosed: runner.ReplaceWithClosed,
},
s.queueControlFrame,
connIDGenerator,
)