forked from quic-go/quic-go
refactor connection ID generator callback handling (#4961)
This commit is contained in:
@@ -8,41 +8,40 @@ import (
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type connRunnerCallbacks struct {
|
||||
AddConnectionID func(protocol.ConnectionID)
|
||||
RemoveConnectionID func(protocol.ConnectionID)
|
||||
RetireConnectionID func(protocol.ConnectionID)
|
||||
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
|
||||
}
|
||||
|
||||
type connIDGenerator struct {
|
||||
generator ConnectionIDGenerator
|
||||
connRunner connRunnerCallbacks
|
||||
highestSeq uint64
|
||||
|
||||
activeSrcConnIDs map[uint64]protocol.ConnectionID
|
||||
initialClientDestConnID *protocol.ConnectionID // nil for the client
|
||||
|
||||
addConnectionID func(protocol.ConnectionID)
|
||||
statelessResetter *statelessResetter
|
||||
removeConnectionID func(protocol.ConnectionID)
|
||||
retireConnectionID func(protocol.ConnectionID)
|
||||
replaceWithClosed func([]protocol.ConnectionID, []byte)
|
||||
queueControlFrame func(wire.Frame)
|
||||
statelessResetter *statelessResetter
|
||||
|
||||
queueControlFrame func(wire.Frame)
|
||||
}
|
||||
|
||||
func newConnIDGenerator(
|
||||
initialConnectionID protocol.ConnectionID,
|
||||
initialClientDestConnID *protocol.ConnectionID, // nil for the client
|
||||
addConnectionID func(protocol.ConnectionID),
|
||||
statelessResetter *statelessResetter,
|
||||
removeConnectionID func(protocol.ConnectionID),
|
||||
retireConnectionID func(protocol.ConnectionID),
|
||||
replaceWithClosed func([]protocol.ConnectionID, []byte),
|
||||
connRunner connRunnerCallbacks,
|
||||
queueControlFrame func(wire.Frame),
|
||||
generator ConnectionIDGenerator,
|
||||
) *connIDGenerator {
|
||||
m := &connIDGenerator{
|
||||
generator: generator,
|
||||
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
|
||||
addConnectionID: addConnectionID,
|
||||
statelessResetter: statelessResetter,
|
||||
removeConnectionID: removeConnectionID,
|
||||
retireConnectionID: retireConnectionID,
|
||||
replaceWithClosed: replaceWithClosed,
|
||||
queueControlFrame: queueControlFrame,
|
||||
generator: generator,
|
||||
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
|
||||
statelessResetter: statelessResetter,
|
||||
connRunner: connRunner,
|
||||
queueControlFrame: queueControlFrame,
|
||||
}
|
||||
m.activeSrcConnIDs[0] = initialConnectionID
|
||||
m.initialClientDestConnID = initialClientDestConnID
|
||||
@@ -85,7 +84,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
|
||||
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
|
||||
}
|
||||
}
|
||||
m.retireConnectionID(connID)
|
||||
m.connRunner.RetireConnectionID(connID)
|
||||
delete(m.activeSrcConnIDs, seq)
|
||||
// Don't issue a replacement for the initial connection ID.
|
||||
if seq == 0 {
|
||||
@@ -100,7 +99,7 @@ func (m *connIDGenerator) issueNewConnID() error {
|
||||
return err
|
||||
}
|
||||
m.activeSrcConnIDs[m.highestSeq+1] = connID
|
||||
m.addConnectionID(connID)
|
||||
m.connRunner.AddConnectionID(connID)
|
||||
m.queueControlFrame(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: m.highestSeq + 1,
|
||||
ConnectionID: connID,
|
||||
@@ -112,17 +111,17 @@ func (m *connIDGenerator) issueNewConnID() error {
|
||||
|
||||
func (m *connIDGenerator) SetHandshakeComplete() {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.retireConnectionID(*m.initialClientDestConnID)
|
||||
m.connRunner.RetireConnectionID(*m.initialClientDestConnID)
|
||||
m.initialClientDestConnID = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) RemoveAll() {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.removeConnectionID(*m.initialClientDestConnID)
|
||||
m.connRunner.RemoveConnectionID(*m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
m.removeConnectionID(connID)
|
||||
m.connRunner.RemoveConnectionID(connID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,5 +133,5 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
connIDs = append(connIDs, connID)
|
||||
}
|
||||
m.replaceWithClosed(connIDs, connClose)
|
||||
m.connRunner.ReplaceWithClosed(connIDs, connClose)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user