forked from quic-go/quic-go
extend connection ID generator to interact with multiple Transports (#4974)
This is needed when multiple paths are in use.
This commit is contained in:
@@ -15,10 +15,36 @@ type connRunnerCallbacks struct {
|
||||
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
|
||||
}
|
||||
|
||||
type connRunners []connRunnerCallbacks
|
||||
|
||||
func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
|
||||
for _, c := range cr {
|
||||
c.AddConnectionID(id)
|
||||
}
|
||||
}
|
||||
|
||||
func (cr connRunners) RemoveConnectionID(id protocol.ConnectionID) {
|
||||
for _, c := range cr {
|
||||
c.RemoveConnectionID(id)
|
||||
}
|
||||
}
|
||||
|
||||
func (cr connRunners) RetireConnectionID(id protocol.ConnectionID) {
|
||||
for _, c := range cr {
|
||||
c.RetireConnectionID(id)
|
||||
}
|
||||
}
|
||||
|
||||
func (cr connRunners) ReplaceWithClosed(ids []protocol.ConnectionID, b []byte) {
|
||||
for _, c := range cr {
|
||||
c.ReplaceWithClosed(ids, b)
|
||||
}
|
||||
}
|
||||
|
||||
type connIDGenerator struct {
|
||||
generator ConnectionIDGenerator
|
||||
connRunner connRunnerCallbacks
|
||||
highestSeq uint64
|
||||
generator ConnectionIDGenerator
|
||||
highestSeq uint64
|
||||
connRunners connRunners
|
||||
|
||||
activeSrcConnIDs map[uint64]protocol.ConnectionID
|
||||
initialClientDestConnID *protocol.ConnectionID // nil for the client
|
||||
@@ -40,7 +66,7 @@ func newConnIDGenerator(
|
||||
generator: generator,
|
||||
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
|
||||
statelessResetter: statelessResetter,
|
||||
connRunner: connRunner,
|
||||
connRunners: []connRunnerCallbacks{connRunner},
|
||||
queueControlFrame: queueControlFrame,
|
||||
}
|
||||
m.activeSrcConnIDs[0] = initialConnectionID
|
||||
@@ -84,7 +110,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
|
||||
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
|
||||
}
|
||||
}
|
||||
m.connRunner.RetireConnectionID(connID)
|
||||
m.connRunners.RetireConnectionID(connID)
|
||||
delete(m.activeSrcConnIDs, seq)
|
||||
// Don't issue a replacement for the initial connection ID.
|
||||
if seq == 0 {
|
||||
@@ -99,7 +125,7 @@ func (m *connIDGenerator) issueNewConnID() error {
|
||||
return err
|
||||
}
|
||||
m.activeSrcConnIDs[m.highestSeq+1] = connID
|
||||
m.connRunner.AddConnectionID(connID)
|
||||
m.connRunners.AddConnectionID(connID)
|
||||
m.queueControlFrame(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: m.highestSeq + 1,
|
||||
ConnectionID: connID,
|
||||
@@ -111,17 +137,17 @@ func (m *connIDGenerator) issueNewConnID() error {
|
||||
|
||||
func (m *connIDGenerator) SetHandshakeComplete() {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.connRunner.RetireConnectionID(*m.initialClientDestConnID)
|
||||
m.connRunners.RetireConnectionID(*m.initialClientDestConnID)
|
||||
m.initialClientDestConnID = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) RemoveAll() {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.connRunner.RemoveConnectionID(*m.initialClientDestConnID)
|
||||
m.connRunners.RemoveConnectionID(*m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
m.connRunner.RemoveConnectionID(connID)
|
||||
m.connRunners.RemoveConnectionID(connID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,5 +159,15 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
connIDs = append(connIDs, connID)
|
||||
}
|
||||
m.connRunner.ReplaceWithClosed(connIDs, connClose)
|
||||
m.connRunners.ReplaceWithClosed(connIDs, connClose)
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) AddConnRunner(r connRunnerCallbacks) {
|
||||
if m.initialClientDestConnID != nil {
|
||||
r.AddConnectionID(*m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
r.AddConnectionID(connID)
|
||||
}
|
||||
m.connRunners = append(m.connRunners, r)
|
||||
}
|
||||
|
||||
@@ -198,3 +198,81 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
|
||||
require.Contains(t, replaced, protocol.ParseConnectionID([]byte{1, 1, 1, 1}))
|
||||
require.Equal(t, []byte("foobar"), replacedWith)
|
||||
}
|
||||
|
||||
func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
initialConnID := protocol.ParseConnectionID([]byte{1, 1, 1, 1})
|
||||
clientDestConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
|
||||
|
||||
type connIDTracker struct {
|
||||
added, removed, retired, replaced []protocol.ConnectionID
|
||||
}
|
||||
|
||||
var tracker1, tracker2 connIDTracker
|
||||
runner1 := connRunnerCallbacks{
|
||||
AddConnectionID: func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) },
|
||||
RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) },
|
||||
RetireConnectionID: func(c protocol.ConnectionID) { tracker1.retired = append(tracker1.retired, c) },
|
||||
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) {
|
||||
tracker1.replaced = append(tracker1.replaced, connIDs...)
|
||||
},
|
||||
}
|
||||
runner2 := connRunnerCallbacks{
|
||||
AddConnectionID: func(c protocol.ConnectionID) { tracker2.added = append(tracker2.added, c) },
|
||||
RemoveConnectionID: func(c protocol.ConnectionID) { tracker2.removed = append(tracker2.removed, c) },
|
||||
RetireConnectionID: func(c protocol.ConnectionID) { tracker2.retired = append(tracker2.retired, c) },
|
||||
ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) {
|
||||
tracker2.replaced = append(tracker2.replaced, connIDs...)
|
||||
},
|
||||
}
|
||||
|
||||
sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
|
||||
var queuedFrames []wire.Frame
|
||||
|
||||
g := newConnIDGenerator(
|
||||
initialConnID,
|
||||
&clientDestConnID,
|
||||
sr,
|
||||
runner1,
|
||||
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
|
||||
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
|
||||
)
|
||||
require.NoError(t, g.SetMaxActiveConnIDs(3))
|
||||
require.Len(t, tracker1.added, 2)
|
||||
|
||||
// add the second runner - it should get all existing connection IDs
|
||||
g.AddConnRunner(runner2)
|
||||
require.Len(t, tracker1.added, 2) // unchanged
|
||||
require.Len(t, tracker2.added, 4)
|
||||
require.Contains(t, tracker2.added, initialConnID)
|
||||
require.Contains(t, tracker2.added, clientDestConnID)
|
||||
require.Contains(t, tracker2.added, tracker1.added[0])
|
||||
require.Contains(t, tracker2.added, tracker1.added[1])
|
||||
|
||||
var connIDToRetire protocol.ConnectionID
|
||||
var seqToRetire uint64
|
||||
ncid := queuedFrames[0].(*wire.NewConnectionIDFrame)
|
||||
connIDToRetire = ncid.ConnectionID
|
||||
seqToRetire = ncid.SequenceNumber
|
||||
|
||||
tracker1.retired = nil
|
||||
tracker2.retired = nil
|
||||
require.NoError(t, g.Retire(seqToRetire, protocol.ParseConnectionID([]byte{3, 3, 3, 3})))
|
||||
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker1.retired)
|
||||
require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker2.retired)
|
||||
|
||||
tracker1.retired = nil
|
||||
tracker2.retired = nil
|
||||
g.SetHandshakeComplete()
|
||||
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker1.retired)
|
||||
require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker2.retired)
|
||||
|
||||
g.ReplaceWithClosed([]byte("connection closed"))
|
||||
require.True(t, len(tracker1.replaced) > 0)
|
||||
require.Equal(t, tracker1.replaced, tracker2.replaced)
|
||||
|
||||
tracker1.removed = nil
|
||||
tracker2.removed = nil
|
||||
g.RemoveAll()
|
||||
require.NotEmpty(t, tracker1.removed)
|
||||
require.Equal(t, tracker1.removed, tracker2.removed)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user