diff --git a/conn_id_generator.go b/conn_id_generator.go index c309c2cd7..a5c203ff2 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -8,41 +8,40 @@ import ( "github.com/quic-go/quic-go/internal/wire" ) +type connRunnerCallbacks struct { + AddConnectionID func(protocol.ConnectionID) + RemoveConnectionID func(protocol.ConnectionID) + RetireConnectionID func(protocol.ConnectionID) + ReplaceWithClosed func([]protocol.ConnectionID, []byte) +} + type connIDGenerator struct { generator ConnectionIDGenerator + connRunner connRunnerCallbacks highestSeq uint64 activeSrcConnIDs map[uint64]protocol.ConnectionID initialClientDestConnID *protocol.ConnectionID // nil for the client - addConnectionID func(protocol.ConnectionID) - statelessResetter *statelessResetter - removeConnectionID func(protocol.ConnectionID) - retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func([]protocol.ConnectionID, []byte) - queueControlFrame func(wire.Frame) + statelessResetter *statelessResetter + + queueControlFrame func(wire.Frame) } func newConnIDGenerator( initialConnectionID protocol.ConnectionID, initialClientDestConnID *protocol.ConnectionID, // nil for the client - addConnectionID func(protocol.ConnectionID), statelessResetter *statelessResetter, - removeConnectionID func(protocol.ConnectionID), - retireConnectionID func(protocol.ConnectionID), - replaceWithClosed func([]protocol.ConnectionID, []byte), + connRunner connRunnerCallbacks, queueControlFrame func(wire.Frame), generator ConnectionIDGenerator, ) *connIDGenerator { m := &connIDGenerator{ - generator: generator, - activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), - addConnectionID: addConnectionID, - statelessResetter: statelessResetter, - removeConnectionID: removeConnectionID, - retireConnectionID: retireConnectionID, - replaceWithClosed: replaceWithClosed, - queueControlFrame: queueControlFrame, + generator: generator, + activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), + statelessResetter: statelessResetter, + connRunner: connRunner, + queueControlFrame: queueControlFrame, } m.activeSrcConnIDs[0] = initialConnectionID m.initialClientDestConnID = initialClientDestConnID @@ -85,7 +84,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), } } - m.retireConnectionID(connID) + m.connRunner.RetireConnectionID(connID) delete(m.activeSrcConnIDs, seq) // Don't issue a replacement for the initial connection ID. if seq == 0 { @@ -100,7 +99,7 @@ func (m *connIDGenerator) issueNewConnID() error { return err } m.activeSrcConnIDs[m.highestSeq+1] = connID - m.addConnectionID(connID) + m.connRunner.AddConnectionID(connID) m.queueControlFrame(&wire.NewConnectionIDFrame{ SequenceNumber: m.highestSeq + 1, ConnectionID: connID, @@ -112,17 +111,17 @@ func (m *connIDGenerator) issueNewConnID() error { func (m *connIDGenerator) SetHandshakeComplete() { if m.initialClientDestConnID != nil { - m.retireConnectionID(*m.initialClientDestConnID) + m.connRunner.RetireConnectionID(*m.initialClientDestConnID) m.initialClientDestConnID = nil } } func (m *connIDGenerator) RemoveAll() { if m.initialClientDestConnID != nil { - m.removeConnectionID(*m.initialClientDestConnID) + m.connRunner.RemoveConnectionID(*m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { - m.removeConnectionID(connID) + m.connRunner.RemoveConnectionID(connID) } } @@ -134,5 +133,5 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) { for _, connID := range m.activeSrcConnIDs { connIDs = append(connIDs, connID) } - m.replaceWithClosed(connIDs, connClose) + m.connRunner.ReplaceWithClosed(connIDs, connClose) } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 85cc8215d..4a9a09432 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -34,11 +34,13 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID g := newConnIDGenerator( protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, - func(c protocol.ConnectionID) { added = append(added, c) }, sr, - func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") }, - func(c protocol.ConnectionID) { retired = append(retired, c) }, - func([]protocol.ConnectionID, []byte) {}, + connRunnerCallbacks{ + AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) }, + RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") }, + RetireConnectionID: func(c protocol.ConnectionID) { retired = append(retired, c) }, + ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {}, + }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, &protocol.DefaultConnectionIDGenerator{ConnLen: 5}, ) @@ -116,11 +118,13 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool) g := newConnIDGenerator( protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, - func(c protocol.ConnectionID) { added = append(added, c) }, newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), - func(c protocol.ConnectionID) { removed = append(removed, c) }, - func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") }, - func([]protocol.ConnectionID, []byte) {}, + connRunnerCallbacks{ + AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) }, + RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) }, + RetireConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") }, + ReplaceWithClosed: func([]protocol.ConnectionID, []byte) {}, + }, func(f wire.Frame) {}, &protocol.DefaultConnectionIDGenerator{ConnLen: 5}, ) @@ -164,13 +168,15 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn g := newConnIDGenerator( protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, - func(c protocol.ConnectionID) { added = append(added, c) }, newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), - func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") }, - func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") }, - func(connIDs []protocol.ConnectionID, b []byte) { - replaced = connIDs - replacedWith = b + connRunnerCallbacks{ + AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) }, + RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") }, + RetireConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") }, + ReplaceWithClosed: func(connIDs []protocol.ConnectionID, b []byte) { + replaced = connIDs + replacedWith = b + }, }, func(f wire.Frame) {}, &protocol.DefaultConnectionIDGenerator{ConnLen: 5}, diff --git a/connection.go b/connection.go index 9415584de..5785c3751 100644 --- a/connection.go +++ b/connection.go @@ -267,11 +267,13 @@ var newConnection = func( s.connIDGenerator = newConnIDGenerator( srcConnID, &clientDestConnID, - func(connID protocol.ConnectionID) { runner.Add(connID, s) }, statelessResetter, - runner.Remove, - runner.Retire, - runner.ReplaceWithClosed, + connRunnerCallbacks{ + AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) }, + RemoveConnectionID: runner.Remove, + RetireConnectionID: runner.Retire, + ReplaceWithClosed: runner.ReplaceWithClosed, + }, s.queueControlFrame, connIDGenerator, ) @@ -378,11 +380,13 @@ var newClientConnection = func( s.connIDGenerator = newConnIDGenerator( srcConnID, nil, - func(connID protocol.ConnectionID) { runner.Add(connID, s) }, statelessResetter, - runner.Remove, - runner.Retire, - runner.ReplaceWithClosed, + connRunnerCallbacks{ + AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) }, + RemoveConnectionID: runner.Remove, + RetireConnectionID: runner.Retire, + ReplaceWithClosed: runner.ReplaceWithClosed, + }, s.queueControlFrame, connIDGenerator, )