forked from quic-go/quic-go
simplify tracking of Transports for connection migration (#5111)
No functional change expected.
This commit is contained in:
@@ -16,7 +16,8 @@ type connRunnerCallbacks struct {
|
||||
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
|
||||
}
|
||||
|
||||
type connRunners map[transportID]connRunnerCallbacks
|
||||
// The memory address of the Transport is used as the key.
|
||||
type connRunners map[*Transport]connRunnerCallbacks
|
||||
|
||||
func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
|
||||
for _, c := range cr {
|
||||
@@ -56,7 +57,7 @@ type connIDGenerator struct {
|
||||
}
|
||||
|
||||
func newConnIDGenerator(
|
||||
tID transportID,
|
||||
tr *Transport,
|
||||
initialConnectionID protocol.ConnectionID,
|
||||
initialClientDestConnID *protocol.ConnectionID, // nil for the client
|
||||
statelessResetter *statelessResetter,
|
||||
@@ -68,7 +69,7 @@ func newConnIDGenerator(
|
||||
generator: generator,
|
||||
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
|
||||
statelessResetter: statelessResetter,
|
||||
connRunners: map[transportID]connRunnerCallbacks{tID: connRunner},
|
||||
connRunners: map[*Transport]connRunnerCallbacks{tr: connRunner},
|
||||
queueControlFrame: queueControlFrame,
|
||||
}
|
||||
m.activeSrcConnIDs[0] = initialConnectionID
|
||||
@@ -201,7 +202,7 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
|
||||
m.connRunners.ReplaceWithClosed(connIDs, connClose)
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) AddConnRunner(id transportID, r connRunnerCallbacks) {
|
||||
func (m *connIDGenerator) AddConnRunner(id *Transport, r connRunnerCallbacks) {
|
||||
// The transport might have already been added earlier.
|
||||
// This happens if the application migrates back to and old path.
|
||||
if _, ok := m.connRunners[id]; ok {
|
||||
|
||||
@@ -34,7 +34,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
|
||||
initialClientDestConnID = &connID
|
||||
}
|
||||
g := newConnIDGenerator(
|
||||
1,
|
||||
&Transport{},
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
initialClientDestConnID,
|
||||
sr,
|
||||
@@ -107,7 +107,7 @@ func TestConnIDGeneratorRetiring(t *testing.T) {
|
||||
initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
|
||||
var added, removed []protocol.ConnectionID
|
||||
g := newConnIDGenerator(
|
||||
1,
|
||||
&Transport{},
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
&initialConnID,
|
||||
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
|
||||
@@ -178,7 +178,7 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool)
|
||||
removed []protocol.ConnectionID
|
||||
)
|
||||
g := newConnIDGenerator(
|
||||
0,
|
||||
&Transport{},
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
initialClientDestConnID,
|
||||
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
|
||||
@@ -228,7 +228,7 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
|
||||
replacedWith []byte
|
||||
)
|
||||
g := newConnIDGenerator(
|
||||
1,
|
||||
&Transport{},
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
initialClientDestConnID,
|
||||
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
|
||||
@@ -274,7 +274,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
added, removed, replaced []protocol.ConnectionID
|
||||
}
|
||||
|
||||
var tracker1, tracker2 connIDTracker
|
||||
var tracker1, tracker2, tracker3 connIDTracker
|
||||
runner1 := connRunnerCallbacks{
|
||||
AddConnectionID: func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) },
|
||||
RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) },
|
||||
@@ -289,12 +289,20 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
tracker2.replaced = append(tracker2.replaced, connIDs...)
|
||||
},
|
||||
}
|
||||
runner3 := connRunnerCallbacks{
|
||||
AddConnectionID: func(c protocol.ConnectionID) { tracker3.added = append(tracker3.added, c) },
|
||||
RemoveConnectionID: func(c protocol.ConnectionID) { tracker3.removed = append(tracker3.removed, c) },
|
||||
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) {
|
||||
tracker3.replaced = append(tracker3.replaced, connIDs...)
|
||||
},
|
||||
}
|
||||
|
||||
sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
|
||||
var queuedFrames []wire.Frame
|
||||
|
||||
tr := &Transport{}
|
||||
g := newConnIDGenerator(
|
||||
1,
|
||||
tr,
|
||||
initialConnID,
|
||||
&clientDestConnID,
|
||||
sr,
|
||||
@@ -306,7 +314,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
require.Len(t, tracker1.added, 2)
|
||||
|
||||
// add the second runner - it should get all existing connection IDs
|
||||
g.AddConnRunner(2, runner2)
|
||||
g.AddConnRunner(&Transport{}, runner2)
|
||||
require.Len(t, tracker1.added, 2) // unchanged
|
||||
require.Len(t, tracker2.added, 4)
|
||||
require.Contains(t, tracker2.added, initialConnID)
|
||||
@@ -314,6 +322,11 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
require.Contains(t, tracker2.added, tracker1.added[0])
|
||||
require.Contains(t, tracker2.added, tracker1.added[1])
|
||||
|
||||
// adding the same transport again doesn't do anything
|
||||
trCopy := tr
|
||||
g.AddConnRunner(trCopy, runner3)
|
||||
require.Empty(t, tracker3.added)
|
||||
|
||||
var connIDToRetire protocol.ConnectionID
|
||||
var seqToRetire uint64
|
||||
ncid := queuedFrames[0].(*wire.NewConnectionIDFrame)
|
||||
|
||||
@@ -269,7 +269,7 @@ var newConnection = func(
|
||||
s.queueControlFrame,
|
||||
)
|
||||
s.connIDGenerator = newConnIDGenerator(
|
||||
tr.id(),
|
||||
tr,
|
||||
srcConnID,
|
||||
&clientDestConnID,
|
||||
statelessResetter,
|
||||
@@ -383,7 +383,7 @@ var newClientConnection = func(
|
||||
s.queueControlFrame,
|
||||
)
|
||||
s.connIDGenerator = newConnIDGenerator(
|
||||
tr.id(),
|
||||
tr,
|
||||
srcConnID,
|
||||
nil,
|
||||
statelessResetter,
|
||||
@@ -2652,7 +2652,7 @@ func (s *connection) AddPath(t *Transport) (*Path, error) {
|
||||
func() {
|
||||
runner := t.connRunner()
|
||||
s.connIDGenerator.AddConnRunner(
|
||||
t.id(),
|
||||
t,
|
||||
connRunnerCallbacks{
|
||||
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
|
||||
RemoveConnectionID: runner.Remove,
|
||||
|
||||
@@ -38,10 +38,6 @@ func (e *errTransportClosed) Is(target error) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
type transportID uint64
|
||||
|
||||
var transportIDCounter atomic.Uint64
|
||||
|
||||
var errListenerAlreadySet = errors.New("listener already set")
|
||||
|
||||
// The Transport is the central point to manage incoming and outgoing QUIC connections.
|
||||
@@ -136,8 +132,6 @@ type Transport struct {
|
||||
initOnce sync.Once
|
||||
initErr error
|
||||
|
||||
// Set in init.
|
||||
transportID transportID
|
||||
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
|
||||
connIDLen int
|
||||
// Set in init.
|
||||
@@ -376,7 +370,6 @@ func (t *Transport) doDial(
|
||||
|
||||
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
||||
t.initOnce.Do(func() {
|
||||
t.transportID = transportID(transportIDCounter.Add(1))
|
||||
var conn rawConn
|
||||
if c, ok := t.Conn.(rawConn); ok {
|
||||
conn = c
|
||||
@@ -430,8 +423,6 @@ func (t *Transport) connRunner() packetHandlerManager {
|
||||
return t.handlerMap
|
||||
}
|
||||
|
||||
func (t *Transport) id() transportID { return t.transportID }
|
||||
|
||||
// WriteTo sends a packet on the underlying connection.
|
||||
func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if err := t.init(false); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user