From 2082d2c6b92bd5aa51620730ac44a67477ef04ac Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 5 May 2025 11:05:17 +0800 Subject: [PATCH] move connection tracking into the Transport (#5113) --- conn_id_generator.go | 14 +- conn_id_generator_test.go | 12 +- connection.go | 18 +- connection_test.go | 14 +- mock_packet_handler_manager_test.go | 376 ---------------------------- mockgen.go | 5 - packet_handler_map.go | 180 ------------- packet_handler_map_test.go | 151 ----------- server.go | 20 +- server_test.go | 43 +++- sys_conn.go | 25 ++ transport.go | 156 ++++++++++-- transport_test.go | 161 ++++++++---- 13 files changed, 333 insertions(+), 842 deletions(-) delete mode 100644 mock_packet_handler_manager_test.go delete mode 100644 packet_handler_map.go delete mode 100644 packet_handler_map_test.go diff --git a/conn_id_generator.go b/conn_id_generator.go index 394d5b8a8..f1a0c92c0 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -17,7 +17,7 @@ type connRunnerCallbacks struct { } // The memory address of the Transport is used as the key. -type connRunners map[*Transport]connRunnerCallbacks +type connRunners map[connRunner]connRunnerCallbacks func (cr connRunners) AddConnectionID(id protocol.ConnectionID) { for _, c := range cr { @@ -57,11 +57,11 @@ type connIDGenerator struct { } func newConnIDGenerator( - tr *Transport, + runner connRunner, initialConnectionID protocol.ConnectionID, initialClientDestConnID *protocol.ConnectionID, // nil for the client statelessResetter *statelessResetter, - connRunner connRunnerCallbacks, + callbacks connRunnerCallbacks, queueControlFrame func(wire.Frame), generator ConnectionIDGenerator, ) *connIDGenerator { @@ -69,7 +69,7 @@ func newConnIDGenerator( generator: generator, activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), statelessResetter: statelessResetter, - connRunners: map[*Transport]connRunnerCallbacks{tr: connRunner}, + connRunners: map[connRunner]connRunnerCallbacks{runner: callbacks}, queueControlFrame: queueControlFrame, } m.activeSrcConnIDs[0] = initialConnectionID @@ -202,13 +202,13 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte, expiry time.Durati m.connRunners.ReplaceWithClosed(connIDs, connClose, expiry) } -func (m *connIDGenerator) AddConnRunner(id *Transport, r connRunnerCallbacks) { +func (m *connIDGenerator) AddConnRunner(runner connRunner, r connRunnerCallbacks) { // The transport might have already been added earlier. // This happens if the application migrates back to and old path. - if _, ok := m.connRunners[id]; ok { + if _, ok := m.connRunners[runner]; ok { return } - m.connRunners[id] = r + m.connRunners[runner] = r if m.initialClientDestConnID != nil { r.AddConnectionID(*m.initialClientDestConnID) } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 2c482cdbc..6d1716522 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -34,7 +34,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID initialClientDestConnID = &connID } g := newConnIDGenerator( - &Transport{}, + &packetHandlerMap{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, sr, @@ -107,7 +107,7 @@ func TestConnIDGeneratorRetiring(t *testing.T) { initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2}) var added, removed []protocol.ConnectionID g := newConnIDGenerator( - &Transport{}, + &packetHandlerMap{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), &initialConnID, newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), @@ -178,7 +178,7 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool) removed []protocol.ConnectionID ) g := newConnIDGenerator( - &Transport{}, + &packetHandlerMap{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), @@ -228,7 +228,7 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn replacedWith []byte ) g := newConnIDGenerator( - &Transport{}, + &packetHandlerMap{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), @@ -300,7 +300,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) { sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}) var queuedFrames []wire.Frame - tr := &Transport{} + tr := &packetHandlerMap{} g := newConnIDGenerator( tr, initialConnID, @@ -314,7 +314,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) { require.Len(t, tracker1.added, 2) // add the second runner - it should get all existing connection IDs - g.AddConnRunner(&Transport{}, runner2) + g.AddConnRunner(&packetHandlerMap{}, runner2) require.Len(t, tracker1.added, 2) // unchanged require.Len(t, tracker2.added, 4) require.Contains(t, tracker2.added, initialConnID) diff --git a/connection.go b/connection.go index e03e8c945..0e71aaa88 100644 --- a/connection.go +++ b/connection.go @@ -111,8 +111,6 @@ func nextConnTracingID() ConnectionTracingID { return ConnectionTracingID(connTr // A Connection is a QUIC connection type connection struct { - tr *Transport - // Destination connection ID used during the handshake. // Used to check source connection ID on incoming packets. handshakeDestConnID protocol.ConnectionID @@ -225,7 +223,7 @@ var newConnection = func( ctx context.Context, ctxCancel context.CancelCauseFunc, conn sendConn, - tr *Transport, + runner connRunner, origDestConnID protocol.ConnectionID, retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, @@ -244,7 +242,6 @@ var newConnection = func( s := &connection{ ctx: ctx, ctxCancel: ctxCancel, - tr: tr, conn: conn, config: conf, handshakeDestConnID: destConnID, @@ -261,7 +258,6 @@ var newConnection = func( } else { s.logID = destConnID.String() } - runner := tr.connRunner() s.connIDManager = newConnIDManager( destConnID, func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, @@ -269,7 +265,7 @@ var newConnection = func( s.queueControlFrame, ) s.connIDGenerator = newConnIDGenerator( - tr, + runner, srcConnID, &clientDestConnID, statelessResetter, @@ -347,7 +343,7 @@ var newConnection = func( var newClientConnection = func( ctx context.Context, conn sendConn, - tr *Transport, + runner connRunner, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, connIDGenerator ConnectionIDGenerator, @@ -362,7 +358,6 @@ var newClientConnection = func( v protocol.Version, ) quicConn { s := &connection{ - tr: tr, conn: conn, config: conf, origDestConnID: destConnID, @@ -375,7 +370,6 @@ var newClientConnection = func( versionNegotiated: hasNegotiatedVersion, version: v, } - runner := tr.connRunner() s.connIDManager = newConnIDManager( destConnID, func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, @@ -383,7 +377,7 @@ var newClientConnection = func( s.queueControlFrame, ) s.connIDGenerator = newConnIDGenerator( - tr, + runner, srcConnID, nil, statelessResetter, @@ -2650,9 +2644,9 @@ func (s *connection) AddPath(t *Transport) (*Path, error) { t, 200*time.Millisecond, // initial RTT estimate func() { - runner := t.connRunner() + runner := (*packetHandlerMap)(t) s.connIDGenerator.AddConnRunner( - t, + runner, connRunnerCallbacks{ AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) }, RemoveConnectionID: runner.Remove, diff --git a/connection_test.go b/connection_test.go index 691494f4f..9e4bf763b 100644 --- a/connection_test.go +++ b/connection_test.go @@ -81,7 +81,7 @@ func connectionOptRetrySrcConnID(rcid protocol.ConnectionID) testConnectionOpt { type testConnection struct { conn *connection - connRunner *MockPacketHandlerManager + connRunner *MockConnRunner sendConn *MockSendConn packer *MockPacker destConnID protocol.ConnectionID @@ -101,7 +101,7 @@ func newServerTestConnection( } remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321} localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} - phm := NewMockPacketHandlerManager(mockCtrl) + connRunner := NewMockConnRunner(mockCtrl) sendConn := NewMockSendConn(mockCtrl) sendConn.EXPECT().capabilities().Return(connCapabilities{GSO: gso}).AnyTimes() sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() @@ -119,7 +119,7 @@ func newServerTestConnection( ctx, cancel, sendConn, - &Transport{handlerMap: phm}, + connRunner, origDestConnID, nil, protocol.ConnectionID{}, @@ -141,7 +141,7 @@ func newServerTestConnection( } return &testConnection{ conn: conn, - connRunner: phm, + connRunner: connRunner, sendConn: sendConn, packer: packer, destConnID: origDestConnID, @@ -162,7 +162,7 @@ func newClientTestConnection( } remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321} localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} - phm := NewMockPacketHandlerManager(mockCtrl) + connRunner := NewMockConnRunner(mockCtrl) sendConn := NewMockSendConn(mockCtrl) sendConn.EXPECT().capabilities().Return(connCapabilities{}).AnyTimes() sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() @@ -178,7 +178,7 @@ func newClientTestConnection( conn := newClientConnection( context.Background(), sendConn, - &Transport{handlerMap: phm}, + connRunner, destConnID, srcConnID, &protocol.DefaultConnectionIDGenerator{}, @@ -198,7 +198,7 @@ func newClientTestConnection( } return &testConnection{ conn: conn, - connRunner: phm, + connRunner: connRunner, sendConn: sendConn, packer: packer, destConnID: destConnID, diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go deleted file mode 100644 index 98b2ea172..000000000 --- a/mock_packet_handler_manager_test.go +++ /dev/null @@ -1,376 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go (interfaces: PacketHandlerManager) -// -// Generated by this command: -// -// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager -// - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - time "time" - - protocol "github.com/quic-go/quic-go/internal/protocol" - gomock "go.uber.org/mock/gomock" -) - -// MockPacketHandlerManager is a mock of PacketHandlerManager interface. -type MockPacketHandlerManager struct { - ctrl *gomock.Controller - recorder *MockPacketHandlerManagerMockRecorder - isgomock struct{} -} - -// MockPacketHandlerManagerMockRecorder is the mock recorder for MockPacketHandlerManager. -type MockPacketHandlerManagerMockRecorder struct { - mock *MockPacketHandlerManager -} - -// NewMockPacketHandlerManager creates a new mock instance. -func NewMockPacketHandlerManager(ctrl *gomock.Controller) *MockPacketHandlerManager { - mock := &MockPacketHandlerManager{ctrl: ctrl} - mock.recorder = &MockPacketHandlerManagerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorder { - return m.recorder -} - -// Add mocks base method. -func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Add", arg0, arg1) - ret0, _ := ret[0].(bool) - return ret0 -} - -// Add indicates an expected call of Add. -func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 any) *MockPacketHandlerManagerAddCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) - return &MockPacketHandlerManagerAddCall{Call: call} -} - -// MockPacketHandlerManagerAddCall wrap *gomock.Call -type MockPacketHandlerManagerAddCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPacketHandlerManagerAddCall) Return(arg0 bool) *MockPacketHandlerManagerAddCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPacketHandlerManagerAddCall) Do(f func(protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPacketHandlerManagerAddCall) DoAndReturn(f func(protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// AddResetToken mocks base method. -func (m *MockPacketHandlerManager) AddResetToken(arg0 protocol.StatelessResetToken, arg1 packetHandler) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddResetToken", arg0, arg1) -} - -// AddResetToken indicates an expected call of AddResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 any) *MockPacketHandlerManagerAddResetTokenCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1) - return &MockPacketHandlerManagerAddResetTokenCall{Call: call} -} - -// MockPacketHandlerManagerAddResetTokenCall wrap *gomock.Call -type MockPacketHandlerManagerAddResetTokenCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPacketHandlerManagerAddResetTokenCall) Return() *MockPacketHandlerManagerAddResetTokenCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPacketHandlerManagerAddResetTokenCall) Do(f func(protocol.StatelessResetToken, packetHandler)) *MockPacketHandlerManagerAddResetTokenCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPacketHandlerManagerAddResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken, packetHandler)) *MockPacketHandlerManagerAddResetTokenCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// AddWithConnID mocks base method. -func (m *MockPacketHandlerManager) AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddWithConnID", destConnID, newConnID, h) - ret0, _ := ret[0].(bool) - return ret0 -} - -// AddWithConnID indicates an expected call of AddWithConnID. -func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(destConnID, newConnID, h any) *MockPacketHandlerManagerAddWithConnIDCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), destConnID, newConnID, h) - return &MockPacketHandlerManagerAddWithConnIDCall{Call: call} -} - -// MockPacketHandlerManagerAddWithConnIDCall wrap *gomock.Call -type MockPacketHandlerManagerAddWithConnIDCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPacketHandlerManagerAddWithConnIDCall) Return(arg0 bool) *MockPacketHandlerManagerAddWithConnIDCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPacketHandlerManagerAddWithConnIDCall) Do(f func(protocol.ConnectionID, protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddWithConnIDCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPacketHandlerManagerAddWithConnIDCall) DoAndReturn(f func(protocol.ConnectionID, protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddWithConnIDCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Close mocks base method. -func (m *MockPacketHandlerManager) Close(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Close", arg0) -} - -// Close indicates an expected call of Close. -func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 any) *MockPacketHandlerManagerCloseCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0) - return &MockPacketHandlerManagerCloseCall{Call: call} -} - -// MockPacketHandlerManagerCloseCall wrap *gomock.Call -type MockPacketHandlerManagerCloseCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPacketHandlerManagerCloseCall) Return() *MockPacketHandlerManagerCloseCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPacketHandlerManagerCloseCall) Do(f func(error)) *MockPacketHandlerManagerCloseCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPacketHandlerManagerCloseCall) DoAndReturn(f func(error)) *MockPacketHandlerManagerCloseCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Get mocks base method. -func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0) - ret0, _ := ret[0].(packetHandler) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// Get indicates an expected call of Get. -func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 any) *MockPacketHandlerManagerGetCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0) - return &MockPacketHandlerManagerGetCall{Call: call} -} - -// MockPacketHandlerManagerGetCall wrap *gomock.Call -type MockPacketHandlerManagerGetCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPacketHandlerManagerGetCall) Return(arg0 packetHandler, arg1 bool) *MockPacketHandlerManagerGetCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPacketHandlerManagerGetCall) Do(f func(protocol.ConnectionID) (packetHandler, bool)) *MockPacketHandlerManagerGetCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPacketHandlerManagerGetCall) DoAndReturn(f func(protocol.ConnectionID) (packetHandler, bool)) *MockPacketHandlerManagerGetCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// GetByResetToken mocks base method. -func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetToken) (packetHandler, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetByResetToken", arg0) - ret0, _ := ret[0].(packetHandler) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// GetByResetToken indicates an expected call of GetByResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 any) *MockPacketHandlerManagerGetByResetTokenCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0) - return &MockPacketHandlerManagerGetByResetTokenCall{Call: call} -} - -// MockPacketHandlerManagerGetByResetTokenCall wrap *gomock.Call -type MockPacketHandlerManagerGetByResetTokenCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPacketHandlerManagerGetByResetTokenCall) Return(arg0 packetHandler, arg1 bool) *MockPacketHandlerManagerGetByResetTokenCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPacketHandlerManagerGetByResetTokenCall) Do(f func(protocol.StatelessResetToken) (packetHandler, bool)) *MockPacketHandlerManagerGetByResetTokenCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPacketHandlerManagerGetByResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken) (packetHandler, bool)) *MockPacketHandlerManagerGetByResetTokenCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Remove mocks base method. -func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Remove", arg0) -} - -// Remove indicates an expected call of Remove. -func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 any) *MockPacketHandlerManagerRemoveCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) - return &MockPacketHandlerManagerRemoveCall{Call: call} -} - -// MockPacketHandlerManagerRemoveCall wrap *gomock.Call -type MockPacketHandlerManagerRemoveCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPacketHandlerManagerRemoveCall) Return() *MockPacketHandlerManagerRemoveCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPacketHandlerManagerRemoveCall) Do(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRemoveCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPacketHandlerManagerRemoveCall) DoAndReturn(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRemoveCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// RemoveResetToken mocks base method. -func (m *MockPacketHandlerManager) RemoveResetToken(arg0 protocol.StatelessResetToken) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RemoveResetToken", arg0) -} - -// RemoveResetToken indicates an expected call of RemoveResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 any) *MockPacketHandlerManagerRemoveResetTokenCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0) - return &MockPacketHandlerManagerRemoveResetTokenCall{Call: call} -} - -// MockPacketHandlerManagerRemoveResetTokenCall wrap *gomock.Call -type MockPacketHandlerManagerRemoveResetTokenCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPacketHandlerManagerRemoveResetTokenCall) Return() *MockPacketHandlerManagerRemoveResetTokenCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPacketHandlerManagerRemoveResetTokenCall) Do(f func(protocol.StatelessResetToken)) *MockPacketHandlerManagerRemoveResetTokenCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPacketHandlerManagerRemoveResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken)) *MockPacketHandlerManagerRemoveResetTokenCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// ReplaceWithClosed mocks base method. -func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 []byte, arg2 time.Duration) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2) -} - -// ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *MockPacketHandlerManagerReplaceWithClosedCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1, arg2) - return &MockPacketHandlerManagerReplaceWithClosedCall{Call: call} -} - -// MockPacketHandlerManagerReplaceWithClosedCall wrap *gomock.Call -type MockPacketHandlerManagerReplaceWithClosedCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPacketHandlerManagerReplaceWithClosedCall) Return() *MockPacketHandlerManagerReplaceWithClosedCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPacketHandlerManagerReplaceWithClosedCall) Do(f func([]protocol.ConnectionID, []byte, time.Duration)) *MockPacketHandlerManagerReplaceWithClosedCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPacketHandlerManagerReplaceWithClosedCall) DoAndReturn(f func([]protocol.ConnectionID, []byte, time.Duration)) *MockPacketHandlerManagerReplaceWithClosedCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/mockgen.go b/mockgen.go index 1a8b28db8..f46c20498 100644 --- a/mockgen.go +++ b/mockgen.go @@ -56,9 +56,4 @@ type QUICConn = quicConn //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler" type PacketHandler = packetHandler -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" - -//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" -type PacketHandlerManager = packetHandlerManager - //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn" diff --git a/packet_handler_map.go b/packet_handler_map.go deleted file mode 100644 index ec65ae0b1..000000000 --- a/packet_handler_map.go +++ /dev/null @@ -1,180 +0,0 @@ -package quic - -import ( - "io" - "net" - "sync" - "time" - - "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" -) - -type connCapabilities struct { - // This connection has the Don't Fragment (DF) bit set. - // This means it makes to run DPLPMTUD. - DF bool - // GSO (Generic Segmentation Offload) supported - GSO bool - // ECN (Explicit Congestion Notifications) supported - ECN bool -} - -// rawConn is a connection that allow reading of a receivedPackeh. -type rawConn interface { - ReadPacket() (receivedPacket, error) - // WritePacket writes a packet on the wire. - // gsoSize is the size of a single packet, or 0 to disable GSO. - // It is invalid to set gsoSize if capabilities.GSO is not set. - WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) - LocalAddr() net.Addr - SetReadDeadline(time.Time) error - io.Closer - - capabilities() connCapabilities -} - -type closePacket struct { - payload []byte - addr net.Addr - info packetInfo -} - -type packetHandlerMap struct { - mutex sync.Mutex - handlers map[protocol.ConnectionID]packetHandler - resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler - - closed bool - - enqueueClosePacket func(closePacket) - - logger utils.Logger -} - -var _ packetHandlerManager = &packetHandlerMap{} - -func newPacketHandlerMap(enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap { - return &packetHandlerMap{ - handlers: make(map[protocol.ConnectionID]packetHandler), - resetTokens: make(map[protocol.StatelessResetToken]packetHandler), - enqueueClosePacket: enqueueClosePacket, - logger: logger, - } -} - -func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) { - h.mutex.Lock() - defer h.mutex.Unlock() - - handler, ok := h.handlers[id] - return handler, ok -} - -func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { - h.mutex.Lock() - defer h.mutex.Unlock() - - if _, ok := h.handlers[id]; ok { - h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) - return false - } - h.handlers[id] = handler - h.logger.Debugf("Adding connection ID %s.", id) - return true -} - -func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool { - h.mutex.Lock() - defer h.mutex.Unlock() - - if _, ok := h.handlers[clientDestConnID]; ok { - h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) - return false - } - h.handlers[clientDestConnID] = handler - h.handlers[newConnID] = handler - h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) - return true -} - -func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { - h.mutex.Lock() - delete(h.handlers, id) - h.mutex.Unlock() - h.logger.Debugf("Removing connection ID %s.", id) -} - -// ReplaceWithClosed is called when a connection is closed. -// Depending on which side closed the connection, we need to: -// * remote close: absorb delayed packets -// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost -func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte, expiry time.Duration) { - var handler packetHandler - if connClosePacket != nil { - handler = newClosedLocalConn( - func(addr net.Addr, info packetInfo) { - h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info}) - }, - h.logger, - ) - } else { - handler = newClosedRemoteConn() - } - - h.mutex.Lock() - for _, id := range ids { - h.handlers[id] = handler - } - h.mutex.Unlock() - h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids) - - time.AfterFunc(expiry, func() { - h.mutex.Lock() - for _, id := range ids { - delete(h.handlers, id) - } - h.mutex.Unlock() - h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids) - }) -} - -func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { - h.mutex.Lock() - h.resetTokens[token] = handler - h.mutex.Unlock() -} - -func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) { - h.mutex.Lock() - delete(h.resetTokens, token) - h.mutex.Unlock() -} - -func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) { - h.mutex.Lock() - defer h.mutex.Unlock() - - handler, ok := h.resetTokens[token] - return handler, ok -} - -func (h *packetHandlerMap) Close(e error) { - h.mutex.Lock() - if h.closed { - h.mutex.Unlock() - return - } - - var wg sync.WaitGroup - for _, handler := range h.handlers { - wg.Add(1) - go func(handler packetHandler) { - handler.destroy(e) - wg.Done() - }(handler) - } - h.closed = true - h.mutex.Unlock() - wg.Wait() -} diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go deleted file mode 100644 index a285ba80e..000000000 --- a/packet_handler_map_test.go +++ /dev/null @@ -1,151 +0,0 @@ -package quic - -import ( - "crypto/rand" - "errors" - "net" - "testing" - "time" - - "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestPacketHandlerMapAddAndRemove(t *testing.T) { - m := newPacketHandlerMap(nil, utils.DefaultLogger) - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) - h := &mockPacketHandler{} - require.True(t, m.Add(connID, h)) - got, ok := m.Get(connID) - require.True(t, ok) - require.Equal(t, h, got) - - // cannot add the same handler twice - require.False(t, m.Add(connID, h)) - got, ok = m.Get(connID) - require.True(t, ok) - require.Equal(t, h, got) - - // remove the handler - m.Remove(connID) - got, ok = m.Get(connID) - require.False(t, ok) - require.Nil(t, got) -} - -func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) { - m := newPacketHandlerMap(nil, utils.DefaultLogger) - h := &mockPacketHandler{} - - connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) - connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) - require.True(t, m.AddWithConnID(connID1, connID2, h)) - // collision of the connection ID, this handler should not be added - require.False(t, m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), nil)) - - got, ok := m.Get(connID1) - require.True(t, ok) - require.Equal(t, h, got) - got, ok = m.Get(connID2) - require.True(t, ok) - require.Equal(t, h, got) -} - -func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) { - m := newPacketHandlerMap(nil, utils.DefaultLogger) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} - handler := &mockPacketHandler{} - m.AddResetToken(token, handler) - h, ok := m.GetByResetToken(token) - require.True(t, ok) - require.Equal(t, handler, h) - m.RemoveResetToken(token) - _, ok = m.GetByResetToken(token) - require.False(t, ok) -} - -func TestPacketHandlerMapReplaceWithLocalClosed(t *testing.T) { - var closePackets []closePacket - m := newPacketHandlerMap( - func(p closePacket) { closePackets = append(closePackets, p) }, - utils.DefaultLogger, - ) - dur := scaleDuration(10 * time.Millisecond) - - handler := &mockPacketHandler{} - connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) - require.True(t, m.Add(connID, handler)) - m.ReplaceWithClosed([]protocol.ConnectionID{connID}, []byte("foobar"), dur) - h, ok := m.Get(connID) - require.True(t, ok) - require.NotEqual(t, handler, h) - addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} - h.handlePacket(receivedPacket{remoteAddr: addr}) - require.Len(t, closePackets, 1) - require.Equal(t, addr, closePackets[0].addr) - require.Equal(t, []byte("foobar"), closePackets[0].payload) - - time.Sleep(dur) - require.Eventually(t, func() bool { - _, ok := m.Get(connID) - return !ok - }, time.Second, dur) -} - -func TestPacketHandlerMapReplaceWithRemoteClosed(t *testing.T) { - var closePackets []closePacket - m := newPacketHandlerMap( - func(p closePacket) { closePackets = append(closePackets, p) }, - utils.DefaultLogger, - ) - dur := scaleDuration(25 * time.Millisecond) - - handler := &mockPacketHandler{} - connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) - require.True(t, m.Add(connID, handler)) - m.ReplaceWithClosed([]protocol.ConnectionID{connID}, nil, dur) - h, ok := m.Get(connID) - require.True(t, ok) - require.NotEqual(t, handler, h) - addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} - h.handlePacket(receivedPacket{remoteAddr: addr}) - require.Empty(t, closePackets) - - time.Sleep(dur) - require.Eventually(t, func() bool { - _, ok := m.Get(connID) - return !ok - }, time.Second, 10*time.Millisecond) -} - -func TestPacketHandlerMapClose(t *testing.T) { - m := newPacketHandlerMap(nil, utils.DefaultLogger) - const numConns = 10 - destroyChan := make(chan error, 2*numConns) - for i := 0; i < numConns; i++ { - conn := &mockPacketHandler{destruction: destroyChan} - b := make([]byte, 12) - rand.Read(b) - m.Add(protocol.ParseConnectionID(b), conn) - } - m.Close(assert.AnError) - // check that Close can be called multiple times - m.Close(errors.New("close")) - - for i := 0; i < numConns; i++ { - select { - case err := <-destroyChan: - require.ErrorIs(t, err, assert.AnError) - default: - t.Fatalf("connection not destroyed") - } - } - select { - case err := <-destroyChan: - t.Fatalf("connection destroyed more than once: %s", err) - default: - } -} diff --git a/server.go b/server.go index 3eb8e1d43..fae34be23 100644 --- a/server.go +++ b/server.go @@ -32,14 +32,6 @@ type packetHandler interface { closeWithTransportError(qerr.TransportErrorCode) } -type packetHandlerManager interface { - Get(protocol.ConnectionID) (packetHandler, bool) - GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool) - AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool - Close(error) - connRunner -} - type quicConn interface { EarlyConnection earlyConnReady() <-chan struct{} @@ -61,7 +53,7 @@ type rejectedPacket struct { // A Listener of QUIC type baseServer struct { - tr *Transport + tr *packetHandlerMap disableVersionNegotiation bool acceptEarlyConns bool @@ -89,7 +81,7 @@ type baseServer struct { context.Context, context.CancelCauseFunc, sendConn, - *Transport, + connRunner, protocol.ConnectionID, /* original dest connection ID */ *protocol.ConnectionID, /* retry src connection ID */ protocol.ConnectionID, /* client dest connection ID */ @@ -247,7 +239,7 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear func newServer( conn rawConn, - tr *Transport, + tr *packetHandlerMap, connIDGenerator ConnectionIDGenerator, statelessResetter *statelessResetter, connContext func(context.Context) context.Context, @@ -501,7 +493,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { } // check again if we might have a connection now - if handler, ok := s.tr.connRunner().Get(connID); ok { + if handler, ok := s.tr.Get(connID); ok { handler.handlePacket(p) return true } @@ -591,7 +583,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error // The server queues packets for a while, and we might already have established a connection by now. // This results in a second check in the connection map. // That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets). - if handler, ok := s.tr.connRunner().Get(hdr.DestConnectionID); ok { + if handler, ok := s.tr.Get(hdr.DestConnectionID); ok { handler.handlePacket(p) return nil } @@ -727,7 +719,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error // This is very unlikely: Even if an attacker chooses a connection ID that's already in use, // under normal circumstances the packet would just be routed to that connection. // The only time this collision will occur if we receive the two Initial packets at the same time. - if added := s.tr.connRunner().AddWithConnID(hdr.DestConnectionID, connID, conn); !added { + if added := s.tr.AddWithConnID(hdr.DestConnectionID, connID, conn); !added { delete(s.zeroRTTQueues, hdr.DestConnectionID) conn.closeWithTransportError(qerr.ConnectionRefused) return nil diff --git a/server_test.go b/server_test.go index 5ed4a498d..38cdfe891 100644 --- a/server_test.go +++ b/server_test.go @@ -37,7 +37,7 @@ type serverOpts struct { context.Context, context.CancelCauseFunc, sendConn, - *Transport, + connRunner, protocol.ConnectionID, // original dest connection ID *protocol.ConnectionID, // retry src connection ID protocol.ConnectionID, // client dest connection ID @@ -61,9 +61,11 @@ func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer { require.NoError(t, err) verifySourceAddress := func(net.Addr) bool { return serverOpts.useRetry } config := populateConfig(serverOpts.config) + tr := &Transport{Conn: newUDPConnLocalhost(t)} + tr.init(true) s := newServer( c, - &Transport{handlerMap: newPacketHandlerMap(nil, utils.DefaultLogger)}, + (*packetHandlerMap)(tr), &protocol.DefaultConnectionIDGenerator{}, &statelessResetter{}, func(ctx context.Context) context.Context { return ctx }, @@ -596,6 +598,7 @@ func testServerTokenValidation( type connConstructorArgs struct { ctx context.Context + connRunner connRunner config *Config origDestConnID protocol.ConnectionID retrySrcConnID *protocol.ConnectionID @@ -607,10 +610,10 @@ type connConstructorArgs struct { type connConstructorRecorder struct { ch chan connConstructorArgs - conns []quicConn + conns []*MockQUICConn } -func newConnConstructorRecorder(conns ...quicConn) *connConstructorRecorder { +func newConnConstructorRecorder(conns ...*MockQUICConn) *connConstructorRecorder { return &connConstructorRecorder{ ch: make(chan connConstructorArgs, len(conns)), conns: conns, @@ -623,7 +626,7 @@ func (r *connConstructorRecorder) NewConn( ctx context.Context, _ context.CancelCauseFunc, _ sendConn, - _ *Transport, + connRunner connRunner, origDestConnID protocol.ConnectionID, retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, @@ -641,6 +644,7 @@ func (r *connConstructorRecorder) NewConn( ) quicConn { r.ch <- connConstructorArgs{ ctx: ctx, + connRunner: connRunner, config: config, origDestConnID: origDestConnID, retrySrcConnID: retrySrcConnID, @@ -747,7 +751,7 @@ func testServerCreateConnection(t *testing.T, useRetry bool) { func TestServerClose(t *testing.T) { mockCtrl := gomock.NewController(t) - var conns []quicConn + var conns []*MockQUICConn const numConns = 3 done := make(chan struct{}, numConns) for range numConns { @@ -779,7 +783,7 @@ func TestServerClose(t *testing.T) { } server.Close() - // closing closes all handshakeing connections with CONNECTION_REFUSED + // closing closes all handshaking connections with CONNECTION_REFUSED for range numConns { select { case <-done: @@ -796,6 +800,11 @@ func TestServerClose(t *testing.T) { require.ErrorIs(t, err, ErrServerClosed) require.ErrorIs(t, err, net.ErrClosed) } + + // test shutdown + for _, conn := range conns { + conn.EXPECT().destroy(gomock.Any()).AnyTimes() + } } func TestServerGetConfigForClientAccept(t *testing.T) { @@ -878,7 +887,7 @@ func TestServerPacketHandling(t *testing.T) { conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { handledPacket <- p }) - server.tr.handlerMap.Add(destConnID, conn) + server.tr.Add(destConnID, conn) server.handlePacket( getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, srcConnID, destConnID), @@ -904,7 +913,7 @@ func TestServerReceiveQueue(t *testing.T) { _ context.Context, _ context.CancelCauseFunc, _ sendConn, - _ *Transport, + _ connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, @@ -927,6 +936,8 @@ func TestServerReceiveQueue(t *testing.T) { conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1) conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) + // during test shutdown + conn.EXPECT().destroy(gomock.Any()).AnyTimes() return conn }, }) @@ -1013,6 +1024,9 @@ func testServerAccept(t *testing.T, acceptEarly bool) { case <-time.After(time.Second): t.Fatal("timeout") } + + // test shutdown + c.EXPECT().destroy(gomock.Any()).AnyTimes() } func TestServerAcceptHandshakeFailure(t *testing.T) { @@ -1044,11 +1058,14 @@ func TestServerAcceptHandshakeFailure(t *testing.T) { t.Fatal("server should not have accepted the connection") case <-time.After(scaleDuration(5 * time.Millisecond)): } + + // test shutdown + c.EXPECT().destroy(gomock.Any()).AnyTimes() } func TestServerAcceptQueue(t *testing.T) { mockCtrl := gomock.NewController(t) - var conns []quicConn + var conns []*MockQUICConn var rejectedConn *MockQUICConn for i := range protocol.MaxAcceptQueueSize + 2 { conn := NewMockQUICConn(mockCtrl) @@ -1125,6 +1142,11 @@ func TestServerAcceptQueue(t *testing.T) { case <-time.After(time.Second): t.Fatal("timeout") } + + // test shutdown + for _, conn := range conns { + conn.EXPECT().destroy(gomock.Any()).AnyTimes() + } } func TestServer0RTTReordering(t *testing.T) { @@ -1210,6 +1232,7 @@ func TestServer0RTTReordering(t *testing.T) { // shutdown conn.EXPECT().closeWithTransportError(gomock.Any()).AnyTimes() + conn.EXPECT().destroy(gomock.Any()).AnyTimes() } func TestServer0RTTQueueing(t *testing.T) { diff --git a/sys_conn.go b/sys_conn.go index 8159a1469..6e768dcf2 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -1,6 +1,7 @@ package quic import ( + "io" "log" "net" "os" @@ -13,6 +14,30 @@ import ( "github.com/quic-go/quic-go/internal/utils" ) +type connCapabilities struct { + // This connection has the Don't Fragment (DF) bit set. + // This means it makes to run DPLPMTUD. + DF bool + // GSO (Generic Segmentation Offload) supported + GSO bool + // ECN (Explicit Congestion Notifications) supported + ECN bool +} + +// rawConn is a connection that allow reading of a receivedPackeh. +type rawConn interface { + ReadPacket() (receivedPacket, error) + // WritePacket writes a packet on the wire. + // gsoSize is the size of a single packet, or 0 to disable GSO. + // It is invalid to set gsoSize if capabilities.GSO is not set. + WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) + LocalAddr() net.Addr + SetReadDeadline(time.Time) error + io.Closer + + capabilities() connCapabilities +} + // OOBCapablePacketConn is a connection that allows the reading of ECN bits from the IP header. // If the PacketConn passed to the [Transport] satisfies this interface, quic-go will use it. // In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets. diff --git a/transport.go b/transport.go index 915c1ce08..527bdf00d 100644 --- a/transport.go +++ b/transport.go @@ -40,6 +40,12 @@ func (e *errTransportClosed) Is(target error) bool { var errListenerAlreadySet = errors.New("listener already set") +type closePacket struct { + payload []byte + addr net.Addr + info packetInfo +} + // The Transport is the central point to manage incoming and outgoing QUIC connections. // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple. // This means that a single UDP socket can be used for listening for incoming connections, as well as @@ -126,7 +132,9 @@ type Transport struct { // Tracer.Close is called when the transport is closed. Tracer *logging.Tracer - handlerMap packetHandlerManager + connMx sync.Mutex + handlers map[protocol.ConnectionID]packetHandler + resetTokens map[protocol.StatelessResetToken]packetHandler mutex sync.Mutex initOnce sync.Once @@ -206,7 +214,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo } s := newServer( t.conn, - t, + (*packetHandlerMap)(t), t.connIDGenerator, t.statelessResetter, t.ConnContext, @@ -297,7 +305,7 @@ func (t *Transport) doDial( conn := newClientConnection( context.WithoutCancel(ctx), sendConn, - t, + (*packetHandlerMap)(t), destConnID, srcConnID, t.connIDGenerator, @@ -311,7 +319,9 @@ func (t *Transport) doDial( logger, version, ) - t.handlerMap.Add(srcConnID, conn) + t.connMx.Lock() + t.handlers[srcConnID] = conn + t.connMx.Unlock() t.mutex.Unlock() // The error channel needs to be buffered, as the run loop will continue running @@ -384,9 +394,8 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn - if t.handlerMap == nil { // allows mocking the handlerMap in tests - t.handlerMap = newPacketHandlerMap(t.enqueueClosePacket, t.logger) - } + t.handlers = make(map[protocol.ConnectionID]packetHandler) + t.resetTokens = make(map[protocol.StatelessResetToken]packetHandler) t.listening = make(chan struct{}) t.closeQueue = make(chan closePacket, 4) @@ -419,10 +428,6 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { return t.initErr } -func (t *Transport) connRunner() packetHandlerManager { - return t.handlerMap -} - // WriteTo sends a packet on the underlying connection. func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { @@ -431,15 +436,6 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported) } -func (t *Transport) enqueueClosePacket(p closePacket) { - select { - case t.closeQueue <- p: - default: - // Oops, we're backlogged. - // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. - } -} - func (t *Transport) runSendQueue() { for { select { @@ -501,9 +497,19 @@ func (t *Transport) close(e error) { } e = &errTransportClosed{err: e} - if t.handlerMap != nil { - t.handlerMap.Close(e) + + var wg sync.WaitGroup + t.connMx.Lock() + for _, handler := range t.handlers { + wg.Add(1) + go func(handler packetHandler) { + handler.destroy(e) + wg.Done() + }(handler) } + t.connMx.Unlock() + wg.Wait() + if t.server != nil { t.server.close(e, false) } @@ -566,7 +572,7 @@ func (t *Transport) handlePacket(p receivedPacket) { } // If there's a connection associated with the connection ID, pass the packet there. - if handler, ok := t.handlerMap.Get(connID); ok { + if handler, ok := (*packetHandlerMap)(t).Get(connID); ok { handler.handlePacket(p) return } @@ -651,8 +657,12 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool { return false } - token := *(*protocol.StatelessResetToken)(data[len(data)-16:]) - if conn, ok := t.handlerMap.GetByResetToken(token); ok { + token := protocol.StatelessResetToken(data[len(data)-16:]) + t.connMx.Lock() + conn, ok := t.resetTokens[token] + t.connMx.Unlock() + + if ok { t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) go conn.destroy(&StatelessResetError{}) return true @@ -717,3 +727,99 @@ func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) { } tlsConf.ServerName = h } + +type packetHandlerMap Transport + +var _ connRunner = &packetHandlerMap{} + +func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { + h.connMx.Lock() + defer h.connMx.Unlock() + + if _, ok := h.handlers[id]; ok { + h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) + return false + } + h.handlers[id] = handler + h.logger.Debugf("Adding connection ID %s.", id) + return true +} + +func (h *packetHandlerMap) Get(connID protocol.ConnectionID) (packetHandler, bool) { + h.connMx.Lock() + defer h.connMx.Unlock() + handler, ok := h.handlers[connID] + return handler, ok +} + +func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { + h.connMx.Lock() + h.resetTokens[token] = handler + h.connMx.Unlock() +} + +func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) { + h.connMx.Lock() + delete(h.resetTokens, token) + h.connMx.Unlock() +} + +func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool { + h.connMx.Lock() + defer h.connMx.Unlock() + + if _, ok := h.handlers[clientDestConnID]; ok { + h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) + return false + } + h.handlers[clientDestConnID] = handler + h.handlers[newConnID] = handler + h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) + return true +} + +func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { + h.connMx.Lock() + delete(h.handlers, id) + h.connMx.Unlock() + h.logger.Debugf("Removing connection ID %s.", id) +} + +// ReplaceWithClosed is called when a connection is closed. +// Depending on which side closed the connection, we need to: +// * remote close: absorb delayed packets +// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost +func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte, expiry time.Duration) { + var handler packetHandler + if connClosePacket != nil { + handler = newClosedLocalConn( + func(addr net.Addr, info packetInfo) { + select { + case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}: + default: + // We're backlogged. + // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. + } + }, + h.logger, + ) + } else { + handler = newClosedRemoteConn() + } + + h.connMx.Lock() + for _, id := range ids { + h.handlers[id] = handler + } + h.connMx.Unlock() + h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids) + + time.AfterFunc(expiry, func() { + h.connMx.Lock() + for _, id := range ids { + delete(h.handlers, id) + } + h.connMx.Unlock() + h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids) + }) +} diff --git a/transport_test.go b/transport_test.go index 98c8bcc49..926624da3 100644 --- a/transport_test.go +++ b/transport_test.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "errors" "net" + "os" "syscall" "testing" "time" @@ -47,8 +48,16 @@ type mockPacketHandler struct { destruction chan<- error } -func (h *mockPacketHandler) handlePacket(p receivedPacket) { h.packets <- p } -func (h *mockPacketHandler) destroy(err error) { h.destruction <- err } +func (h *mockPacketHandler) handlePacket(p receivedPacket) { + h.packets <- p +} + +func (h *mockPacketHandler) destroy(err error) { + if h.destruction != nil { + h.destruction <- err + } +} + func (h *mockPacketHandler) closeWithTransportError(code qerr.TransportErrorCode) {} func getPacket(t *testing.T, connID protocol.ConnectionID) []byte { @@ -71,28 +80,19 @@ func getPacketWithPacketType(t *testing.T, connID protocol.ConnectionID, typ pro } func TestTransportPacketHandling(t *testing.T) { - mockCtrl := gomock.NewController(t) - phm := NewMockPacketHandlerManager(mockCtrl) - - tr := &Transport{ - Conn: newUDPConnLocalhost(t), - handlerMap: phm, - } + tr := &Transport{Conn: newUDPConnLocalhost(t)} tr.init(true) - defer func() { - phm.EXPECT().Close(gomock.Any()) - tr.Close() - }() + defer tr.Close() connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) connChan1 := make(chan receivedPacket, 1) conn1 := &mockPacketHandler{packets: connChan1} - phm.EXPECT().Get(connID1).Return(conn1, true) + (*packetHandlerMap)(tr).Add(connID1, conn1) connChan2 := make(chan receivedPacket, 1) conn2 := &mockPacketHandler{packets: connChan2} - phm.EXPECT().Get(connID2).Return(conn2, true) + (*packetHandlerMap)(tr).Add(connID2, conn2) conn := newUDPConnLocalhost(t) _, err := conn.WriteTo(getPacket(t, connID1), tr.Conn.LocalAddr()) @@ -158,51 +158,49 @@ func TestTransportAndDialConcurrentClose(t *testing.T) { } func TestTransportErrFromConn(t *testing.T) { - mockCtrl := gomock.NewController(t) - phm := NewMockPacketHandlerManager(mockCtrl) - readErrChan := make(chan error, 2) - conn := &mockPacketConn{readErrs: readErrChan, localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}} - tr := Transport{Conn: conn, handlerMap: phm} - defer tr.Close() + t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") + readErrChan := make(chan error, 2) + tr := Transport{ + Conn: &mockPacketConn{ + readErrs: readErrChan, + localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}, + }, + } + defer tr.Close() tr.init(true) - tr.handlerMap = phm + + errChan := make(chan error, 1) + ph := &mockPacketHandler{destruction: errChan} + (*packetHandlerMap)(&tr).Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), ph) // temporary errors don't lead to a shutdown... var tempErr deadlineError require.True(t, tempErr.Temporary()) readErrChan <- tempErr // don't expect any calls to phm.Close - time.Sleep(scaleDuration(20 * time.Millisecond)) + time.Sleep(scaleDuration(10 * time.Millisecond)) // ...but non-temporary errors do - done := make(chan struct{}) - phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) }) readErrChan <- errors.New("read failed") select { - case <-done: + case err := <-errChan: + require.ErrorIs(t, err, ErrTransportClosed) case <-time.After(time.Second): t.Fatal("timeout") } _, err := tr.Listen(&tls.Config{}, nil) - require.Error(t, err) require.ErrorIs(t, err, ErrTransportClosed) } func TestTransportStatelessResetReceiving(t *testing.T) { - mockCtrl := gomock.NewController(t) - phm := NewMockPacketHandlerManager(mockCtrl) tr := &Transport{ Conn: newUDPConnLocalhost(t), ConnectionIDLength: 4, - handlerMap: phm, } tr.init(true) - defer func() { - phm.EXPECT().Close(gomock.Any()) - tr.Close() - }() + defer tr.Close() connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12}) // now send a packet with a connection ID that doesn't exist @@ -213,10 +211,7 @@ func TestTransportStatelessResetReceiving(t *testing.T) { destroyChan := make(chan error, 1) conn1 := &mockPacketHandler{destruction: destroyChan} - gomock.InOrder( - phm.EXPECT().Get(connID), // no handler for this connection ID - phm.EXPECT().GetByResetToken(token).Return(conn1, true), - ) + (*packetHandlerMap)(tr).AddResetToken(token, conn1) conn := newUDPConnLocalhost(t) _, err = conn.WriteTo(b, tr.Conn.LocalAddr()) @@ -224,7 +219,7 @@ func TestTransportStatelessResetReceiving(t *testing.T) { select { case err := <-destroyChan: - require.Error(t, err) + require.ErrorIs(t, err, &qerr.StatelessResetError{}) case <-time.After(time.Second): t.Fatal("timeout") } @@ -232,25 +227,20 @@ func TestTransportStatelessResetReceiving(t *testing.T) { func TestTransportStatelessResetSending(t *testing.T) { mockCtrl := gomock.NewController(t) - phm := NewMockPacketHandlerManager(mockCtrl) tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ Conn: newUDPConnLocalhost(t), ConnectionIDLength: 4, StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, - handlerMap: phm, Tracer: tracer, } tr.init(true) defer func() { mockTracer.EXPECT().Close() - phm.EXPECT().Close(gomock.Any()) tr.Close() }() connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12}) - phm.EXPECT().Get(connID) // no handler for this connection ID - phm.EXPECT().GetByResetToken(gomock.Any()) // now send a packet with a connection ID that doesn't exist b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne) @@ -274,8 +264,6 @@ func TestTransportStatelessResetSending(t *testing.T) { require.True(t, mockCtrl.Satisfied()) // but a stateless reset is sent for packets larger than MinStatelessResetSize - phm.EXPECT().Get(connID) // no handler for this connection ID - phm.EXPECT().GetByResetToken(gomock.Any()) _, err = conn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr()) require.NoError(t, err) conn.SetReadDeadline(time.Now().Add(time.Second)) @@ -287,7 +275,7 @@ func TestTransportStatelessResetSending(t *testing.T) { require.Contains(t, string(p[:n]), string(srt[:])) } -func TestTransportDropsUnparseableQUICPackets(t *testing.T) { +func TestTransportUnparseableQUICPackets(t *testing.T) { mockCtrl := gomock.NewController(t) tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ @@ -511,7 +499,7 @@ func testTransportDial(t *testing.T, early bool) { newClientConnection = func( _ context.Context, _ sendConn, - _ *Transport, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, @@ -586,7 +574,7 @@ func TestTransportDialingVersionNegotiation(t *testing.T) { newClientConnection = func( _ context.Context, _ sendConn, - _ *Transport, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, @@ -636,3 +624,78 @@ func TestTransportDialingVersionNegotiation(t *testing.T) { conn.EXPECT().destroy(gomock.Any()).AnyTimes() conn2.EXPECT().destroy(gomock.Any()).AnyTimes() } + +func TestTransportReplaceWithClosed(t *testing.T) { + t.Run("local", func(t *testing.T) { + testTransportReplaceWithClosed(t, true) + }) + t.Run("remote", func(t *testing.T) { + testTransportReplaceWithClosed(t, false) + }) +} + +func testTransportReplaceWithClosed(t *testing.T, local bool) { + srk := StatelessResetKey{1, 2, 3, 4} + tr := &Transport{ + Conn: newUDPConnLocalhost(t), + ConnectionIDLength: 4, + StatelessResetKey: &srk, + } + tr.init(true) + defer tr.Close() + + dur := scaleDuration(10 * time.Millisecond) + + var closePacket []byte + if local { + closePacket = []byte("foobar") + } + + handler := &mockPacketHandler{} + connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + m := (*packetHandlerMap)(tr) + require.True(t, m.Add(connID, handler)) + start := time.Now() + m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, dur) + + p := make([]byte, 100) + p[0] = 0x40 // QUIC bit + copy(p[1:], connID.Bytes()) + + conn := newUDPConnLocalhost(t) + var sent int + for now := range time.NewTicker(dur / 20).C { + _, err := conn.WriteTo(p, tr.Conn.LocalAddr()) + require.NoError(t, err) + sent++ + if now.After(start.Add(dur)) { + break + } + } + // For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff + for i := 0; i*i < sent; i++ { + conn.SetReadDeadline(time.Now().Add(time.Second)) + b := make([]byte, 100) + if local { + n, _, err := conn.ReadFrom(b) + require.NoError(t, err) + require.Equal(t, []byte("foobar"), b[:n]) + } + } + // Afterwards, we receive a stateless reset, not a copy of the CONNECTION_CLOSE packet. + // Retry up to 3 times, the connection is deleted from the map on a timer. + for i := 0; i < 3; i++ { + _, err := conn.WriteTo(p, tr.Conn.LocalAddr()) + require.NoError(t, err) + conn.SetReadDeadline(time.Now().Add(dur / 4)) + b := make([]byte, 100) + n, _, err := conn.ReadFrom(b) + if errors.Is(err, os.ErrDeadlineExceeded) { + continue + } + require.NoError(t, err) + require.NotEqual(t, []byte("foobar"), b[:n]) + require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize) + break + } +}