refactor connection ID generator callback handling (#4961)

This commit is contained in:
Marten Seemann
2025-02-22 12:06:01 +01:00
committed by GitHub
parent d726a795db
commit eb2f986a06
3 changed files with 55 additions and 46 deletions

View File

@@ -8,41 +8,40 @@ import (
"github.com/quic-go/quic-go/internal/wire"
)
type connRunnerCallbacks struct {
AddConnectionID func(protocol.ConnectionID)
RemoveConnectionID func(protocol.ConnectionID)
RetireConnectionID func(protocol.ConnectionID)
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
}
type connIDGenerator struct {
generator ConnectionIDGenerator
connRunner connRunnerCallbacks
highestSeq uint64
activeSrcConnIDs map[uint64]protocol.ConnectionID
initialClientDestConnID *protocol.ConnectionID // nil for the client
addConnectionID func(protocol.ConnectionID)
statelessResetter *statelessResetter
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func([]protocol.ConnectionID, []byte)
queueControlFrame func(wire.Frame)
statelessResetter *statelessResetter
queueControlFrame func(wire.Frame)
}
func newConnIDGenerator(
initialConnectionID protocol.ConnectionID,
initialClientDestConnID *protocol.ConnectionID, // nil for the client
addConnectionID func(protocol.ConnectionID),
statelessResetter *statelessResetter,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, []byte),
connRunner connRunnerCallbacks,
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
) *connIDGenerator {
m := &connIDGenerator{
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
addConnectionID: addConnectionID,
statelessResetter: statelessResetter,
removeConnectionID: removeConnectionID,
retireConnectionID: retireConnectionID,
replaceWithClosed: replaceWithClosed,
queueControlFrame: queueControlFrame,
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
statelessResetter: statelessResetter,
connRunner: connRunner,
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
m.initialClientDestConnID = initialClientDestConnID
@@ -85,7 +84,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
}
}
m.retireConnectionID(connID)
m.connRunner.RetireConnectionID(connID)
delete(m.activeSrcConnIDs, seq)
// Don't issue a replacement for the initial connection ID.
if seq == 0 {
@@ -100,7 +99,7 @@ func (m *connIDGenerator) issueNewConnID() error {
return err
}
m.activeSrcConnIDs[m.highestSeq+1] = connID
m.addConnectionID(connID)
m.connRunner.AddConnectionID(connID)
m.queueControlFrame(&wire.NewConnectionIDFrame{
SequenceNumber: m.highestSeq + 1,
ConnectionID: connID,
@@ -112,17 +111,17 @@ func (m *connIDGenerator) issueNewConnID() error {
func (m *connIDGenerator) SetHandshakeComplete() {
if m.initialClientDestConnID != nil {
m.retireConnectionID(*m.initialClientDestConnID)
m.connRunner.RetireConnectionID(*m.initialClientDestConnID)
m.initialClientDestConnID = nil
}
}
func (m *connIDGenerator) RemoveAll() {
if m.initialClientDestConnID != nil {
m.removeConnectionID(*m.initialClientDestConnID)
m.connRunner.RemoveConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.removeConnectionID(connID)
m.connRunner.RemoveConnectionID(connID)
}
}
@@ -134,5 +133,5 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
m.replaceWithClosed(connIDs, connClose)
m.connRunner.ReplaceWithClosed(connIDs, connClose)
}