move connection tracking into the Transport (#5113)

This commit is contained in:
Marten Seemann
2025-05-05 11:05:17 +08:00
committed by GitHub
parent 6ef55ca271
commit 2082d2c6b9
13 changed files with 333 additions and 842 deletions

View File

@@ -17,7 +17,7 @@ type connRunnerCallbacks struct {
}
// The memory address of the Transport is used as the key.
type connRunners map[*Transport]connRunnerCallbacks
type connRunners map[connRunner]connRunnerCallbacks
func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
@@ -57,11 +57,11 @@ type connIDGenerator struct {
}
func newConnIDGenerator(
tr *Transport,
runner connRunner,
initialConnectionID protocol.ConnectionID,
initialClientDestConnID *protocol.ConnectionID, // nil for the client
statelessResetter *statelessResetter,
connRunner connRunnerCallbacks,
callbacks connRunnerCallbacks,
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
) *connIDGenerator {
@@ -69,7 +69,7 @@ func newConnIDGenerator(
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
statelessResetter: statelessResetter,
connRunners: map[*Transport]connRunnerCallbacks{tr: connRunner},
connRunners: map[connRunner]connRunnerCallbacks{runner: callbacks},
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
@@ -202,13 +202,13 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte, expiry time.Durati
m.connRunners.ReplaceWithClosed(connIDs, connClose, expiry)
}
func (m *connIDGenerator) AddConnRunner(id *Transport, r connRunnerCallbacks) {
func (m *connIDGenerator) AddConnRunner(runner connRunner, r connRunnerCallbacks) {
// The transport might have already been added earlier.
// This happens if the application migrates back to and old path.
if _, ok := m.connRunners[id]; ok {
if _, ok := m.connRunners[runner]; ok {
return
}
m.connRunners[id] = r
m.connRunners[runner] = r
if m.initialClientDestConnID != nil {
r.AddConnectionID(*m.initialClientDestConnID)
}

View File

@@ -34,7 +34,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
initialClientDestConnID = &connID
}
g := newConnIDGenerator(
&Transport{},
&packetHandlerMap{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
sr,
@@ -107,7 +107,7 @@ func TestConnIDGeneratorRetiring(t *testing.T) {
initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
var added, removed []protocol.ConnectionID
g := newConnIDGenerator(
&Transport{},
&packetHandlerMap{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
&initialConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
@@ -178,7 +178,7 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool)
removed []protocol.ConnectionID
)
g := newConnIDGenerator(
&Transport{},
&packetHandlerMap{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
@@ -228,7 +228,7 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
replacedWith []byte
)
g := newConnIDGenerator(
&Transport{},
&packetHandlerMap{},
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
@@ -300,7 +300,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
var queuedFrames []wire.Frame
tr := &Transport{}
tr := &packetHandlerMap{}
g := newConnIDGenerator(
tr,
initialConnID,
@@ -314,7 +314,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
require.Len(t, tracker1.added, 2)
// add the second runner - it should get all existing connection IDs
g.AddConnRunner(&Transport{}, runner2)
g.AddConnRunner(&packetHandlerMap{}, runner2)
require.Len(t, tracker1.added, 2) // unchanged
require.Len(t, tracker2.added, 4)
require.Contains(t, tracker2.added, initialConnID)

View File

@@ -111,8 +111,6 @@ func nextConnTracingID() ConnectionTracingID { return ConnectionTracingID(connTr
// A Connection is a QUIC connection
type connection struct {
tr *Transport
// Destination connection ID used during the handshake.
// Used to check source connection ID on incoming packets.
handshakeDestConnID protocol.ConnectionID
@@ -225,7 +223,7 @@ var newConnection = func(
ctx context.Context,
ctxCancel context.CancelCauseFunc,
conn sendConn,
tr *Transport,
runner connRunner,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
@@ -244,7 +242,6 @@ var newConnection = func(
s := &connection{
ctx: ctx,
ctxCancel: ctxCancel,
tr: tr,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
@@ -261,7 +258,6 @@ var newConnection = func(
} else {
s.logID = destConnID.String()
}
runner := tr.connRunner()
s.connIDManager = newConnIDManager(
destConnID,
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
@@ -269,7 +265,7 @@ var newConnection = func(
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
tr,
runner,
srcConnID,
&clientDestConnID,
statelessResetter,
@@ -347,7 +343,7 @@ var newConnection = func(
var newClientConnection = func(
ctx context.Context,
conn sendConn,
tr *Transport,
runner connRunner,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
@@ -362,7 +358,6 @@ var newClientConnection = func(
v protocol.Version,
) quicConn {
s := &connection{
tr: tr,
conn: conn,
config: conf,
origDestConnID: destConnID,
@@ -375,7 +370,6 @@ var newClientConnection = func(
versionNegotiated: hasNegotiatedVersion,
version: v,
}
runner := tr.connRunner()
s.connIDManager = newConnIDManager(
destConnID,
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
@@ -383,7 +377,7 @@ var newClientConnection = func(
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
tr,
runner,
srcConnID,
nil,
statelessResetter,
@@ -2650,9 +2644,9 @@ func (s *connection) AddPath(t *Transport) (*Path, error) {
t,
200*time.Millisecond, // initial RTT estimate
func() {
runner := t.connRunner()
runner := (*packetHandlerMap)(t)
s.connIDGenerator.AddConnRunner(
t,
runner,
connRunnerCallbacks{
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
RemoveConnectionID: runner.Remove,

View File

@@ -81,7 +81,7 @@ func connectionOptRetrySrcConnID(rcid protocol.ConnectionID) testConnectionOpt {
type testConnection struct {
conn *connection
connRunner *MockPacketHandlerManager
connRunner *MockConnRunner
sendConn *MockSendConn
packer *MockPacker
destConnID protocol.ConnectionID
@@ -101,7 +101,7 @@ func newServerTestConnection(
}
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
phm := NewMockPacketHandlerManager(mockCtrl)
connRunner := NewMockConnRunner(mockCtrl)
sendConn := NewMockSendConn(mockCtrl)
sendConn.EXPECT().capabilities().Return(connCapabilities{GSO: gso}).AnyTimes()
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
@@ -119,7 +119,7 @@ func newServerTestConnection(
ctx,
cancel,
sendConn,
&Transport{handlerMap: phm},
connRunner,
origDestConnID,
nil,
protocol.ConnectionID{},
@@ -141,7 +141,7 @@ func newServerTestConnection(
}
return &testConnection{
conn: conn,
connRunner: phm,
connRunner: connRunner,
sendConn: sendConn,
packer: packer,
destConnID: origDestConnID,
@@ -162,7 +162,7 @@ func newClientTestConnection(
}
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
phm := NewMockPacketHandlerManager(mockCtrl)
connRunner := NewMockConnRunner(mockCtrl)
sendConn := NewMockSendConn(mockCtrl)
sendConn.EXPECT().capabilities().Return(connCapabilities{}).AnyTimes()
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
@@ -178,7 +178,7 @@ func newClientTestConnection(
conn := newClientConnection(
context.Background(),
sendConn,
&Transport{handlerMap: phm},
connRunner,
destConnID,
srcConnID,
&protocol.DefaultConnectionIDGenerator{},
@@ -198,7 +198,7 @@ func newClientTestConnection(
}
return &testConnection{
conn: conn,
connRunner: phm,
connRunner: connRunner,
sendConn: sendConn,
packer: packer,
destConnID: destConnID,

View File

@@ -1,376 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/quic-go/quic-go (interfaces: PacketHandlerManager)
//
// Generated by this command:
//
// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager
//
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
time "time"
protocol "github.com/quic-go/quic-go/internal/protocol"
gomock "go.uber.org/mock/gomock"
)
// MockPacketHandlerManager is a mock of PacketHandlerManager interface.
type MockPacketHandlerManager struct {
ctrl *gomock.Controller
recorder *MockPacketHandlerManagerMockRecorder
isgomock struct{}
}
// MockPacketHandlerManagerMockRecorder is the mock recorder for MockPacketHandlerManager.
type MockPacketHandlerManagerMockRecorder struct {
mock *MockPacketHandlerManager
}
// NewMockPacketHandlerManager creates a new mock instance.
func NewMockPacketHandlerManager(ctrl *gomock.Controller) *MockPacketHandlerManager {
mock := &MockPacketHandlerManager{ctrl: ctrl}
mock.recorder = &MockPacketHandlerManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorder {
return m.recorder
}
// Add mocks base method.
func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Add", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// Add indicates an expected call of Add.
func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 any) *MockPacketHandlerManagerAddCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1)
return &MockPacketHandlerManagerAddCall{Call: call}
}
// MockPacketHandlerManagerAddCall wrap *gomock.Call
type MockPacketHandlerManagerAddCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPacketHandlerManagerAddCall) Return(arg0 bool) *MockPacketHandlerManagerAddCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPacketHandlerManagerAddCall) Do(f func(protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPacketHandlerManagerAddCall) DoAndReturn(f func(protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// AddResetToken mocks base method.
func (m *MockPacketHandlerManager) AddResetToken(arg0 protocol.StatelessResetToken, arg1 packetHandler) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddResetToken", arg0, arg1)
}
// AddResetToken indicates an expected call of AddResetToken.
func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 any) *MockPacketHandlerManagerAddResetTokenCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1)
return &MockPacketHandlerManagerAddResetTokenCall{Call: call}
}
// MockPacketHandlerManagerAddResetTokenCall wrap *gomock.Call
type MockPacketHandlerManagerAddResetTokenCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPacketHandlerManagerAddResetTokenCall) Return() *MockPacketHandlerManagerAddResetTokenCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPacketHandlerManagerAddResetTokenCall) Do(f func(protocol.StatelessResetToken, packetHandler)) *MockPacketHandlerManagerAddResetTokenCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPacketHandlerManagerAddResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken, packetHandler)) *MockPacketHandlerManagerAddResetTokenCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// AddWithConnID mocks base method.
func (m *MockPacketHandlerManager) AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddWithConnID", destConnID, newConnID, h)
ret0, _ := ret[0].(bool)
return ret0
}
// AddWithConnID indicates an expected call of AddWithConnID.
func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(destConnID, newConnID, h any) *MockPacketHandlerManagerAddWithConnIDCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), destConnID, newConnID, h)
return &MockPacketHandlerManagerAddWithConnIDCall{Call: call}
}
// MockPacketHandlerManagerAddWithConnIDCall wrap *gomock.Call
type MockPacketHandlerManagerAddWithConnIDCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPacketHandlerManagerAddWithConnIDCall) Return(arg0 bool) *MockPacketHandlerManagerAddWithConnIDCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPacketHandlerManagerAddWithConnIDCall) Do(f func(protocol.ConnectionID, protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddWithConnIDCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPacketHandlerManagerAddWithConnIDCall) DoAndReturn(f func(protocol.ConnectionID, protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddWithConnIDCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Close mocks base method.
func (m *MockPacketHandlerManager) Close(arg0 error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Close", arg0)
}
// Close indicates an expected call of Close.
func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 any) *MockPacketHandlerManagerCloseCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0)
return &MockPacketHandlerManagerCloseCall{Call: call}
}
// MockPacketHandlerManagerCloseCall wrap *gomock.Call
type MockPacketHandlerManagerCloseCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPacketHandlerManagerCloseCall) Return() *MockPacketHandlerManagerCloseCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPacketHandlerManagerCloseCall) Do(f func(error)) *MockPacketHandlerManagerCloseCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPacketHandlerManagerCloseCall) DoAndReturn(f func(error)) *MockPacketHandlerManagerCloseCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Get mocks base method.
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0)
ret0, _ := ret[0].(packetHandler)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 any) *MockPacketHandlerManagerGetCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
return &MockPacketHandlerManagerGetCall{Call: call}
}
// MockPacketHandlerManagerGetCall wrap *gomock.Call
type MockPacketHandlerManagerGetCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPacketHandlerManagerGetCall) Return(arg0 packetHandler, arg1 bool) *MockPacketHandlerManagerGetCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPacketHandlerManagerGetCall) Do(f func(protocol.ConnectionID) (packetHandler, bool)) *MockPacketHandlerManagerGetCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPacketHandlerManagerGetCall) DoAndReturn(f func(protocol.ConnectionID) (packetHandler, bool)) *MockPacketHandlerManagerGetCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// GetByResetToken mocks base method.
func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetToken) (packetHandler, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetByResetToken", arg0)
ret0, _ := ret[0].(packetHandler)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// GetByResetToken indicates an expected call of GetByResetToken.
func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 any) *MockPacketHandlerManagerGetByResetTokenCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0)
return &MockPacketHandlerManagerGetByResetTokenCall{Call: call}
}
// MockPacketHandlerManagerGetByResetTokenCall wrap *gomock.Call
type MockPacketHandlerManagerGetByResetTokenCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPacketHandlerManagerGetByResetTokenCall) Return(arg0 packetHandler, arg1 bool) *MockPacketHandlerManagerGetByResetTokenCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPacketHandlerManagerGetByResetTokenCall) Do(f func(protocol.StatelessResetToken) (packetHandler, bool)) *MockPacketHandlerManagerGetByResetTokenCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPacketHandlerManagerGetByResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken) (packetHandler, bool)) *MockPacketHandlerManagerGetByResetTokenCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Remove mocks base method.
func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Remove", arg0)
}
// Remove indicates an expected call of Remove.
func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 any) *MockPacketHandlerManagerRemoveCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0)
return &MockPacketHandlerManagerRemoveCall{Call: call}
}
// MockPacketHandlerManagerRemoveCall wrap *gomock.Call
type MockPacketHandlerManagerRemoveCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPacketHandlerManagerRemoveCall) Return() *MockPacketHandlerManagerRemoveCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPacketHandlerManagerRemoveCall) Do(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRemoveCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPacketHandlerManagerRemoveCall) DoAndReturn(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRemoveCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// RemoveResetToken mocks base method.
func (m *MockPacketHandlerManager) RemoveResetToken(arg0 protocol.StatelessResetToken) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RemoveResetToken", arg0)
}
// RemoveResetToken indicates an expected call of RemoveResetToken.
func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 any) *MockPacketHandlerManagerRemoveResetTokenCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0)
return &MockPacketHandlerManagerRemoveResetTokenCall{Call: call}
}
// MockPacketHandlerManagerRemoveResetTokenCall wrap *gomock.Call
type MockPacketHandlerManagerRemoveResetTokenCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPacketHandlerManagerRemoveResetTokenCall) Return() *MockPacketHandlerManagerRemoveResetTokenCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPacketHandlerManagerRemoveResetTokenCall) Do(f func(protocol.StatelessResetToken)) *MockPacketHandlerManagerRemoveResetTokenCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPacketHandlerManagerRemoveResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken)) *MockPacketHandlerManagerRemoveResetTokenCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// ReplaceWithClosed mocks base method.
func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 []byte, arg2 time.Duration) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2)
}
// ReplaceWithClosed indicates an expected call of ReplaceWithClosed.
func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *MockPacketHandlerManagerReplaceWithClosedCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1, arg2)
return &MockPacketHandlerManagerReplaceWithClosedCall{Call: call}
}
// MockPacketHandlerManagerReplaceWithClosedCall wrap *gomock.Call
type MockPacketHandlerManagerReplaceWithClosedCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPacketHandlerManagerReplaceWithClosedCall) Return() *MockPacketHandlerManagerReplaceWithClosedCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPacketHandlerManagerReplaceWithClosedCall) Do(f func([]protocol.ConnectionID, []byte, time.Duration)) *MockPacketHandlerManagerReplaceWithClosedCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPacketHandlerManagerReplaceWithClosedCall) DoAndReturn(f func([]protocol.ConnectionID, []byte, time.Duration)) *MockPacketHandlerManagerReplaceWithClosedCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

View File

@@ -56,9 +56,4 @@ type QUICConn = quicConn
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler"
type PacketHandler = packetHandler
//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
type PacketHandlerManager = packetHandlerManager
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn"

View File

@@ -1,180 +0,0 @@
package quic
import (
"io"
"net"
"sync"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
type connCapabilities struct {
// This connection has the Don't Fragment (DF) bit set.
// This means it makes to run DPLPMTUD.
DF bool
// GSO (Generic Segmentation Offload) supported
GSO bool
// ECN (Explicit Congestion Notifications) supported
ECN bool
}
// rawConn is a connection that allow reading of a receivedPackeh.
type rawConn interface {
ReadPacket() (receivedPacket, error)
// WritePacket writes a packet on the wire.
// gsoSize is the size of a single packet, or 0 to disable GSO.
// It is invalid to set gsoSize if capabilities.GSO is not set.
WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error)
LocalAddr() net.Addr
SetReadDeadline(time.Time) error
io.Closer
capabilities() connCapabilities
}
type closePacket struct {
payload []byte
addr net.Addr
info packetInfo
}
type packetHandlerMap struct {
mutex sync.Mutex
handlers map[protocol.ConnectionID]packetHandler
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
closed bool
enqueueClosePacket func(closePacket)
logger utils.Logger
}
var _ packetHandlerManager = &packetHandlerMap{}
func newPacketHandlerMap(enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
return &packetHandlerMap{
handlers: make(map[protocol.ConnectionID]packetHandler),
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
enqueueClosePacket: enqueueClosePacket,
logger: logger,
}
}
func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
h.mutex.Lock()
defer h.mutex.Unlock()
handler, ok := h.handlers[id]
return handler, ok
}
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
h.mutex.Lock()
defer h.mutex.Unlock()
if _, ok := h.handlers[id]; ok {
h.logger.Debugf("Not adding connection ID %s, as it already exists.", id)
return false
}
h.handlers[id] = handler
h.logger.Debugf("Adding connection ID %s.", id)
return true
}
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool {
h.mutex.Lock()
defer h.mutex.Unlock()
if _, ok := h.handlers[clientDestConnID]; ok {
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false
}
h.handlers[clientDestConnID] = handler
h.handlers[newConnID] = handler
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
return true
}
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
h.mutex.Lock()
delete(h.handlers, id)
h.mutex.Unlock()
h.logger.Debugf("Removing connection ID %s.", id)
}
// ReplaceWithClosed is called when a connection is closed.
// Depending on which side closed the connection, we need to:
// * remote close: absorb delayed packets
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte, expiry time.Duration) {
var handler packetHandler
if connClosePacket != nil {
handler = newClosedLocalConn(
func(addr net.Addr, info packetInfo) {
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
},
h.logger,
)
} else {
handler = newClosedRemoteConn()
}
h.mutex.Lock()
for _, id := range ids {
h.handlers[id] = handler
}
h.mutex.Unlock()
h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids)
time.AfterFunc(expiry, func() {
h.mutex.Lock()
for _, id := range ids {
delete(h.handlers, id)
}
h.mutex.Unlock()
h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids)
})
}
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
h.mutex.Lock()
h.resetTokens[token] = handler
h.mutex.Unlock()
}
func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) {
h.mutex.Lock()
delete(h.resetTokens, token)
h.mutex.Unlock()
}
func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) {
h.mutex.Lock()
defer h.mutex.Unlock()
handler, ok := h.resetTokens[token]
return handler, ok
}
func (h *packetHandlerMap) Close(e error) {
h.mutex.Lock()
if h.closed {
h.mutex.Unlock()
return
}
var wg sync.WaitGroup
for _, handler := range h.handlers {
wg.Add(1)
go func(handler packetHandler) {
handler.destroy(e)
wg.Done()
}(handler)
}
h.closed = true
h.mutex.Unlock()
wg.Wait()
}

View File

@@ -1,151 +0,0 @@
package quic
import (
"crypto/rand"
"errors"
"net"
"testing"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPacketHandlerMapAddAndRemove(t *testing.T) {
m := newPacketHandlerMap(nil, utils.DefaultLogger)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
h := &mockPacketHandler{}
require.True(t, m.Add(connID, h))
got, ok := m.Get(connID)
require.True(t, ok)
require.Equal(t, h, got)
// cannot add the same handler twice
require.False(t, m.Add(connID, h))
got, ok = m.Get(connID)
require.True(t, ok)
require.Equal(t, h, got)
// remove the handler
m.Remove(connID)
got, ok = m.Get(connID)
require.False(t, ok)
require.Nil(t, got)
}
func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) {
m := newPacketHandlerMap(nil, utils.DefaultLogger)
h := &mockPacketHandler{}
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
require.True(t, m.AddWithConnID(connID1, connID2, h))
// collision of the connection ID, this handler should not be added
require.False(t, m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), nil))
got, ok := m.Get(connID1)
require.True(t, ok)
require.Equal(t, h, got)
got, ok = m.Get(connID2)
require.True(t, ok)
require.Equal(t, h, got)
}
func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) {
m := newPacketHandlerMap(nil, utils.DefaultLogger)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
handler := &mockPacketHandler{}
m.AddResetToken(token, handler)
h, ok := m.GetByResetToken(token)
require.True(t, ok)
require.Equal(t, handler, h)
m.RemoveResetToken(token)
_, ok = m.GetByResetToken(token)
require.False(t, ok)
}
func TestPacketHandlerMapReplaceWithLocalClosed(t *testing.T) {
var closePackets []closePacket
m := newPacketHandlerMap(
func(p closePacket) { closePackets = append(closePackets, p) },
utils.DefaultLogger,
)
dur := scaleDuration(10 * time.Millisecond)
handler := &mockPacketHandler{}
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
require.True(t, m.Add(connID, handler))
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, []byte("foobar"), dur)
h, ok := m.Get(connID)
require.True(t, ok)
require.NotEqual(t, handler, h)
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
h.handlePacket(receivedPacket{remoteAddr: addr})
require.Len(t, closePackets, 1)
require.Equal(t, addr, closePackets[0].addr)
require.Equal(t, []byte("foobar"), closePackets[0].payload)
time.Sleep(dur)
require.Eventually(t, func() bool {
_, ok := m.Get(connID)
return !ok
}, time.Second, dur)
}
func TestPacketHandlerMapReplaceWithRemoteClosed(t *testing.T) {
var closePackets []closePacket
m := newPacketHandlerMap(
func(p closePacket) { closePackets = append(closePackets, p) },
utils.DefaultLogger,
)
dur := scaleDuration(25 * time.Millisecond)
handler := &mockPacketHandler{}
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
require.True(t, m.Add(connID, handler))
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, nil, dur)
h, ok := m.Get(connID)
require.True(t, ok)
require.NotEqual(t, handler, h)
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
h.handlePacket(receivedPacket{remoteAddr: addr})
require.Empty(t, closePackets)
time.Sleep(dur)
require.Eventually(t, func() bool {
_, ok := m.Get(connID)
return !ok
}, time.Second, 10*time.Millisecond)
}
func TestPacketHandlerMapClose(t *testing.T) {
m := newPacketHandlerMap(nil, utils.DefaultLogger)
const numConns = 10
destroyChan := make(chan error, 2*numConns)
for i := 0; i < numConns; i++ {
conn := &mockPacketHandler{destruction: destroyChan}
b := make([]byte, 12)
rand.Read(b)
m.Add(protocol.ParseConnectionID(b), conn)
}
m.Close(assert.AnError)
// check that Close can be called multiple times
m.Close(errors.New("close"))
for i := 0; i < numConns; i++ {
select {
case err := <-destroyChan:
require.ErrorIs(t, err, assert.AnError)
default:
t.Fatalf("connection not destroyed")
}
}
select {
case err := <-destroyChan:
t.Fatalf("connection destroyed more than once: %s", err)
default:
}
}

View File

@@ -32,14 +32,6 @@ type packetHandler interface {
closeWithTransportError(qerr.TransportErrorCode)
}
type packetHandlerManager interface {
Get(protocol.ConnectionID) (packetHandler, bool)
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool
Close(error)
connRunner
}
type quicConn interface {
EarlyConnection
earlyConnReady() <-chan struct{}
@@ -61,7 +53,7 @@ type rejectedPacket struct {
// A Listener of QUIC
type baseServer struct {
tr *Transport
tr *packetHandlerMap
disableVersionNegotiation bool
acceptEarlyConns bool
@@ -89,7 +81,7 @@ type baseServer struct {
context.Context,
context.CancelCauseFunc,
sendConn,
*Transport,
connRunner,
protocol.ConnectionID, /* original dest connection ID */
*protocol.ConnectionID, /* retry src connection ID */
protocol.ConnectionID, /* client dest connection ID */
@@ -247,7 +239,7 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear
func newServer(
conn rawConn,
tr *Transport,
tr *packetHandlerMap,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
connContext func(context.Context) context.Context,
@@ -501,7 +493,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
}
// check again if we might have a connection now
if handler, ok := s.tr.connRunner().Get(connID); ok {
if handler, ok := s.tr.Get(connID); ok {
handler.handlePacket(p)
return true
}
@@ -591,7 +583,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
// The server queues packets for a while, and we might already have established a connection by now.
// This results in a second check in the connection map.
// That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets).
if handler, ok := s.tr.connRunner().Get(hdr.DestConnectionID); ok {
if handler, ok := s.tr.Get(hdr.DestConnectionID); ok {
handler.handlePacket(p)
return nil
}
@@ -727,7 +719,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
// This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
// under normal circumstances the packet would just be routed to that connection.
// The only time this collision will occur if we receive the two Initial packets at the same time.
if added := s.tr.connRunner().AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
if added := s.tr.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
delete(s.zeroRTTQueues, hdr.DestConnectionID)
conn.closeWithTransportError(qerr.ConnectionRefused)
return nil

View File

@@ -37,7 +37,7 @@ type serverOpts struct {
context.Context,
context.CancelCauseFunc,
sendConn,
*Transport,
connRunner,
protocol.ConnectionID, // original dest connection ID
*protocol.ConnectionID, // retry src connection ID
protocol.ConnectionID, // client dest connection ID
@@ -61,9 +61,11 @@ func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer {
require.NoError(t, err)
verifySourceAddress := func(net.Addr) bool { return serverOpts.useRetry }
config := populateConfig(serverOpts.config)
tr := &Transport{Conn: newUDPConnLocalhost(t)}
tr.init(true)
s := newServer(
c,
&Transport{handlerMap: newPacketHandlerMap(nil, utils.DefaultLogger)},
(*packetHandlerMap)(tr),
&protocol.DefaultConnectionIDGenerator{},
&statelessResetter{},
func(ctx context.Context) context.Context { return ctx },
@@ -596,6 +598,7 @@ func testServerTokenValidation(
type connConstructorArgs struct {
ctx context.Context
connRunner connRunner
config *Config
origDestConnID protocol.ConnectionID
retrySrcConnID *protocol.ConnectionID
@@ -607,10 +610,10 @@ type connConstructorArgs struct {
type connConstructorRecorder struct {
ch chan connConstructorArgs
conns []quicConn
conns []*MockQUICConn
}
func newConnConstructorRecorder(conns ...quicConn) *connConstructorRecorder {
func newConnConstructorRecorder(conns ...*MockQUICConn) *connConstructorRecorder {
return &connConstructorRecorder{
ch: make(chan connConstructorArgs, len(conns)),
conns: conns,
@@ -623,7 +626,7 @@ func (r *connConstructorRecorder) NewConn(
ctx context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ *Transport,
connRunner connRunner,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
@@ -641,6 +644,7 @@ func (r *connConstructorRecorder) NewConn(
) quicConn {
r.ch <- connConstructorArgs{
ctx: ctx,
connRunner: connRunner,
config: config,
origDestConnID: origDestConnID,
retrySrcConnID: retrySrcConnID,
@@ -747,7 +751,7 @@ func testServerCreateConnection(t *testing.T, useRetry bool) {
func TestServerClose(t *testing.T) {
mockCtrl := gomock.NewController(t)
var conns []quicConn
var conns []*MockQUICConn
const numConns = 3
done := make(chan struct{}, numConns)
for range numConns {
@@ -779,7 +783,7 @@ func TestServerClose(t *testing.T) {
}
server.Close()
// closing closes all handshakeing connections with CONNECTION_REFUSED
// closing closes all handshaking connections with CONNECTION_REFUSED
for range numConns {
select {
case <-done:
@@ -796,6 +800,11 @@ func TestServerClose(t *testing.T) {
require.ErrorIs(t, err, ErrServerClosed)
require.ErrorIs(t, err, net.ErrClosed)
}
// test shutdown
for _, conn := range conns {
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
}
}
func TestServerGetConfigForClientAccept(t *testing.T) {
@@ -878,7 +887,7 @@ func TestServerPacketHandling(t *testing.T) {
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
handledPacket <- p
})
server.tr.handlerMap.Add(destConnID, conn)
server.tr.Add(destConnID, conn)
server.handlePacket(
getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, srcConnID, destConnID),
@@ -904,7 +913,7 @@ func TestServerReceiveQueue(t *testing.T) {
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ *Transport,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
@@ -927,6 +936,8 @@ func TestServerReceiveQueue(t *testing.T) {
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
// during test shutdown
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
return conn
},
})
@@ -1013,6 +1024,9 @@ func testServerAccept(t *testing.T, acceptEarly bool) {
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test shutdown
c.EXPECT().destroy(gomock.Any()).AnyTimes()
}
func TestServerAcceptHandshakeFailure(t *testing.T) {
@@ -1044,11 +1058,14 @@ func TestServerAcceptHandshakeFailure(t *testing.T) {
t.Fatal("server should not have accepted the connection")
case <-time.After(scaleDuration(5 * time.Millisecond)):
}
// test shutdown
c.EXPECT().destroy(gomock.Any()).AnyTimes()
}
func TestServerAcceptQueue(t *testing.T) {
mockCtrl := gomock.NewController(t)
var conns []quicConn
var conns []*MockQUICConn
var rejectedConn *MockQUICConn
for i := range protocol.MaxAcceptQueueSize + 2 {
conn := NewMockQUICConn(mockCtrl)
@@ -1125,6 +1142,11 @@ func TestServerAcceptQueue(t *testing.T) {
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test shutdown
for _, conn := range conns {
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
}
}
func TestServer0RTTReordering(t *testing.T) {
@@ -1210,6 +1232,7 @@ func TestServer0RTTReordering(t *testing.T) {
// shutdown
conn.EXPECT().closeWithTransportError(gomock.Any()).AnyTimes()
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
}
func TestServer0RTTQueueing(t *testing.T) {

View File

@@ -1,6 +1,7 @@
package quic
import (
"io"
"log"
"net"
"os"
@@ -13,6 +14,30 @@ import (
"github.com/quic-go/quic-go/internal/utils"
)
type connCapabilities struct {
// This connection has the Don't Fragment (DF) bit set.
// This means it makes to run DPLPMTUD.
DF bool
// GSO (Generic Segmentation Offload) supported
GSO bool
// ECN (Explicit Congestion Notifications) supported
ECN bool
}
// rawConn is a connection that allow reading of a receivedPackeh.
type rawConn interface {
ReadPacket() (receivedPacket, error)
// WritePacket writes a packet on the wire.
// gsoSize is the size of a single packet, or 0 to disable GSO.
// It is invalid to set gsoSize if capabilities.GSO is not set.
WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error)
LocalAddr() net.Addr
SetReadDeadline(time.Time) error
io.Closer
capabilities() connCapabilities
}
// OOBCapablePacketConn is a connection that allows the reading of ECN bits from the IP header.
// If the PacketConn passed to the [Transport] satisfies this interface, quic-go will use it.
// In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets.

View File

@@ -40,6 +40,12 @@ func (e *errTransportClosed) Is(target error) bool {
var errListenerAlreadySet = errors.New("listener already set")
type closePacket struct {
payload []byte
addr net.Addr
info packetInfo
}
// The Transport is the central point to manage incoming and outgoing QUIC connections.
// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
// This means that a single UDP socket can be used for listening for incoming connections, as well as
@@ -126,7 +132,9 @@ type Transport struct {
// Tracer.Close is called when the transport is closed.
Tracer *logging.Tracer
handlerMap packetHandlerManager
connMx sync.Mutex
handlers map[protocol.ConnectionID]packetHandler
resetTokens map[protocol.StatelessResetToken]packetHandler
mutex sync.Mutex
initOnce sync.Once
@@ -206,7 +214,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
}
s := newServer(
t.conn,
t,
(*packetHandlerMap)(t),
t.connIDGenerator,
t.statelessResetter,
t.ConnContext,
@@ -297,7 +305,7 @@ func (t *Transport) doDial(
conn := newClientConnection(
context.WithoutCancel(ctx),
sendConn,
t,
(*packetHandlerMap)(t),
destConnID,
srcConnID,
t.connIDGenerator,
@@ -311,7 +319,9 @@ func (t *Transport) doDial(
logger,
version,
)
t.handlerMap.Add(srcConnID, conn)
t.connMx.Lock()
t.handlers[srcConnID] = conn
t.connMx.Unlock()
t.mutex.Unlock()
// The error channel needs to be buffered, as the run loop will continue running
@@ -384,9 +394,8 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.logger = utils.DefaultLogger // TODO: make this configurable
t.conn = conn
if t.handlerMap == nil { // allows mocking the handlerMap in tests
t.handlerMap = newPacketHandlerMap(t.enqueueClosePacket, t.logger)
}
t.handlers = make(map[protocol.ConnectionID]packetHandler)
t.resetTokens = make(map[protocol.StatelessResetToken]packetHandler)
t.listening = make(chan struct{})
t.closeQueue = make(chan closePacket, 4)
@@ -419,10 +428,6 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
return t.initErr
}
func (t *Transport) connRunner() packetHandlerManager {
return t.handlerMap
}
// WriteTo sends a packet on the underlying connection.
func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
if err := t.init(false); err != nil {
@@ -431,15 +436,6 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported)
}
func (t *Transport) enqueueClosePacket(p closePacket) {
select {
case t.closeQueue <- p:
default:
// Oops, we're backlogged.
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
}
}
func (t *Transport) runSendQueue() {
for {
select {
@@ -501,9 +497,19 @@ func (t *Transport) close(e error) {
}
e = &errTransportClosed{err: e}
if t.handlerMap != nil {
t.handlerMap.Close(e)
var wg sync.WaitGroup
t.connMx.Lock()
for _, handler := range t.handlers {
wg.Add(1)
go func(handler packetHandler) {
handler.destroy(e)
wg.Done()
}(handler)
}
t.connMx.Unlock()
wg.Wait()
if t.server != nil {
t.server.close(e, false)
}
@@ -566,7 +572,7 @@ func (t *Transport) handlePacket(p receivedPacket) {
}
// If there's a connection associated with the connection ID, pass the packet there.
if handler, ok := t.handlerMap.Get(connID); ok {
if handler, ok := (*packetHandlerMap)(t).Get(connID); ok {
handler.handlePacket(p)
return
}
@@ -651,8 +657,12 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
return false
}
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
if conn, ok := t.handlerMap.GetByResetToken(token); ok {
token := protocol.StatelessResetToken(data[len(data)-16:])
t.connMx.Lock()
conn, ok := t.resetTokens[token]
t.connMx.Unlock()
if ok {
t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
go conn.destroy(&StatelessResetError{})
return true
@@ -717,3 +727,99 @@ func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) {
}
tlsConf.ServerName = h
}
type packetHandlerMap Transport
var _ connRunner = &packetHandlerMap{}
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
h.connMx.Lock()
defer h.connMx.Unlock()
if _, ok := h.handlers[id]; ok {
h.logger.Debugf("Not adding connection ID %s, as it already exists.", id)
return false
}
h.handlers[id] = handler
h.logger.Debugf("Adding connection ID %s.", id)
return true
}
func (h *packetHandlerMap) Get(connID protocol.ConnectionID) (packetHandler, bool) {
h.connMx.Lock()
defer h.connMx.Unlock()
handler, ok := h.handlers[connID]
return handler, ok
}
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
h.connMx.Lock()
h.resetTokens[token] = handler
h.connMx.Unlock()
}
func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) {
h.connMx.Lock()
delete(h.resetTokens, token)
h.connMx.Unlock()
}
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool {
h.connMx.Lock()
defer h.connMx.Unlock()
if _, ok := h.handlers[clientDestConnID]; ok {
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false
}
h.handlers[clientDestConnID] = handler
h.handlers[newConnID] = handler
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
return true
}
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
h.connMx.Lock()
delete(h.handlers, id)
h.connMx.Unlock()
h.logger.Debugf("Removing connection ID %s.", id)
}
// ReplaceWithClosed is called when a connection is closed.
// Depending on which side closed the connection, we need to:
// * remote close: absorb delayed packets
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte, expiry time.Duration) {
var handler packetHandler
if connClosePacket != nil {
handler = newClosedLocalConn(
func(addr net.Addr, info packetInfo) {
select {
case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}:
default:
// We're backlogged.
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
}
},
h.logger,
)
} else {
handler = newClosedRemoteConn()
}
h.connMx.Lock()
for _, id := range ids {
h.handlers[id] = handler
}
h.connMx.Unlock()
h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids)
time.AfterFunc(expiry, func() {
h.connMx.Lock()
for _, id := range ids {
delete(h.handlers, id)
}
h.connMx.Unlock()
h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids)
})
}

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls"
"errors"
"net"
"os"
"syscall"
"testing"
"time"
@@ -47,8 +48,16 @@ type mockPacketHandler struct {
destruction chan<- error
}
func (h *mockPacketHandler) handlePacket(p receivedPacket) { h.packets <- p }
func (h *mockPacketHandler) destroy(err error) { h.destruction <- err }
func (h *mockPacketHandler) handlePacket(p receivedPacket) {
h.packets <- p
}
func (h *mockPacketHandler) destroy(err error) {
if h.destruction != nil {
h.destruction <- err
}
}
func (h *mockPacketHandler) closeWithTransportError(code qerr.TransportErrorCode) {}
func getPacket(t *testing.T, connID protocol.ConnectionID) []byte {
@@ -71,28 +80,19 @@ func getPacketWithPacketType(t *testing.T, connID protocol.ConnectionID, typ pro
}
func TestTransportPacketHandling(t *testing.T) {
mockCtrl := gomock.NewController(t)
phm := NewMockPacketHandlerManager(mockCtrl)
tr := &Transport{
Conn: newUDPConnLocalhost(t),
handlerMap: phm,
}
tr := &Transport{Conn: newUDPConnLocalhost(t)}
tr.init(true)
defer func() {
phm.EXPECT().Close(gomock.Any())
tr.Close()
}()
defer tr.Close()
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
connChan1 := make(chan receivedPacket, 1)
conn1 := &mockPacketHandler{packets: connChan1}
phm.EXPECT().Get(connID1).Return(conn1, true)
(*packetHandlerMap)(tr).Add(connID1, conn1)
connChan2 := make(chan receivedPacket, 1)
conn2 := &mockPacketHandler{packets: connChan2}
phm.EXPECT().Get(connID2).Return(conn2, true)
(*packetHandlerMap)(tr).Add(connID2, conn2)
conn := newUDPConnLocalhost(t)
_, err := conn.WriteTo(getPacket(t, connID1), tr.Conn.LocalAddr())
@@ -158,51 +158,49 @@ func TestTransportAndDialConcurrentClose(t *testing.T) {
}
func TestTransportErrFromConn(t *testing.T) {
mockCtrl := gomock.NewController(t)
phm := NewMockPacketHandlerManager(mockCtrl)
readErrChan := make(chan error, 2)
conn := &mockPacketConn{readErrs: readErrChan, localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}}
tr := Transport{Conn: conn, handlerMap: phm}
defer tr.Close()
t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true")
readErrChan := make(chan error, 2)
tr := Transport{
Conn: &mockPacketConn{
readErrs: readErrChan,
localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234},
},
}
defer tr.Close()
tr.init(true)
tr.handlerMap = phm
errChan := make(chan error, 1)
ph := &mockPacketHandler{destruction: errChan}
(*packetHandlerMap)(&tr).Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), ph)
// temporary errors don't lead to a shutdown...
var tempErr deadlineError
require.True(t, tempErr.Temporary())
readErrChan <- tempErr
// don't expect any calls to phm.Close
time.Sleep(scaleDuration(20 * time.Millisecond))
time.Sleep(scaleDuration(10 * time.Millisecond))
// ...but non-temporary errors do
done := make(chan struct{})
phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
readErrChan <- errors.New("read failed")
select {
case <-done:
case err := <-errChan:
require.ErrorIs(t, err, ErrTransportClosed)
case <-time.After(time.Second):
t.Fatal("timeout")
}
_, err := tr.Listen(&tls.Config{}, nil)
require.Error(t, err)
require.ErrorIs(t, err, ErrTransportClosed)
}
func TestTransportStatelessResetReceiving(t *testing.T) {
mockCtrl := gomock.NewController(t)
phm := NewMockPacketHandlerManager(mockCtrl)
tr := &Transport{
Conn: newUDPConnLocalhost(t),
ConnectionIDLength: 4,
handlerMap: phm,
}
tr.init(true)
defer func() {
phm.EXPECT().Close(gomock.Any())
tr.Close()
}()
defer tr.Close()
connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12})
// now send a packet with a connection ID that doesn't exist
@@ -213,10 +211,7 @@ func TestTransportStatelessResetReceiving(t *testing.T) {
destroyChan := make(chan error, 1)
conn1 := &mockPacketHandler{destruction: destroyChan}
gomock.InOrder(
phm.EXPECT().Get(connID), // no handler for this connection ID
phm.EXPECT().GetByResetToken(token).Return(conn1, true),
)
(*packetHandlerMap)(tr).AddResetToken(token, conn1)
conn := newUDPConnLocalhost(t)
_, err = conn.WriteTo(b, tr.Conn.LocalAddr())
@@ -224,7 +219,7 @@ func TestTransportStatelessResetReceiving(t *testing.T) {
select {
case err := <-destroyChan:
require.Error(t, err)
require.ErrorIs(t, err, &qerr.StatelessResetError{})
case <-time.After(time.Second):
t.Fatal("timeout")
}
@@ -232,25 +227,20 @@ func TestTransportStatelessResetReceiving(t *testing.T) {
func TestTransportStatelessResetSending(t *testing.T) {
mockCtrl := gomock.NewController(t)
phm := NewMockPacketHandlerManager(mockCtrl)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
tr := &Transport{
Conn: newUDPConnLocalhost(t),
ConnectionIDLength: 4,
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
handlerMap: phm,
Tracer: tracer,
}
tr.init(true)
defer func() {
mockTracer.EXPECT().Close()
phm.EXPECT().Close(gomock.Any())
tr.Close()
}()
connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12})
phm.EXPECT().Get(connID) // no handler for this connection ID
phm.EXPECT().GetByResetToken(gomock.Any())
// now send a packet with a connection ID that doesn't exist
b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne)
@@ -274,8 +264,6 @@ func TestTransportStatelessResetSending(t *testing.T) {
require.True(t, mockCtrl.Satisfied())
// but a stateless reset is sent for packets larger than MinStatelessResetSize
phm.EXPECT().Get(connID) // no handler for this connection ID
phm.EXPECT().GetByResetToken(gomock.Any())
_, err = conn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr())
require.NoError(t, err)
conn.SetReadDeadline(time.Now().Add(time.Second))
@@ -287,7 +275,7 @@ func TestTransportStatelessResetSending(t *testing.T) {
require.Contains(t, string(p[:n]), string(srt[:]))
}
func TestTransportDropsUnparseableQUICPackets(t *testing.T) {
func TestTransportUnparseableQUICPackets(t *testing.T) {
mockCtrl := gomock.NewController(t)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
tr := &Transport{
@@ -511,7 +499,7 @@ func testTransportDial(t *testing.T, early bool) {
newClientConnection = func(
_ context.Context,
_ sendConn,
_ *Transport,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
@@ -586,7 +574,7 @@ func TestTransportDialingVersionNegotiation(t *testing.T) {
newClientConnection = func(
_ context.Context,
_ sendConn,
_ *Transport,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
@@ -636,3 +624,78 @@ func TestTransportDialingVersionNegotiation(t *testing.T) {
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
conn2.EXPECT().destroy(gomock.Any()).AnyTimes()
}
func TestTransportReplaceWithClosed(t *testing.T) {
t.Run("local", func(t *testing.T) {
testTransportReplaceWithClosed(t, true)
})
t.Run("remote", func(t *testing.T) {
testTransportReplaceWithClosed(t, false)
})
}
func testTransportReplaceWithClosed(t *testing.T, local bool) {
srk := StatelessResetKey{1, 2, 3, 4}
tr := &Transport{
Conn: newUDPConnLocalhost(t),
ConnectionIDLength: 4,
StatelessResetKey: &srk,
}
tr.init(true)
defer tr.Close()
dur := scaleDuration(10 * time.Millisecond)
var closePacket []byte
if local {
closePacket = []byte("foobar")
}
handler := &mockPacketHandler{}
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
m := (*packetHandlerMap)(tr)
require.True(t, m.Add(connID, handler))
start := time.Now()
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, dur)
p := make([]byte, 100)
p[0] = 0x40 // QUIC bit
copy(p[1:], connID.Bytes())
conn := newUDPConnLocalhost(t)
var sent int
for now := range time.NewTicker(dur / 20).C {
_, err := conn.WriteTo(p, tr.Conn.LocalAddr())
require.NoError(t, err)
sent++
if now.After(start.Add(dur)) {
break
}
}
// For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff
for i := 0; i*i < sent; i++ {
conn.SetReadDeadline(time.Now().Add(time.Second))
b := make([]byte, 100)
if local {
n, _, err := conn.ReadFrom(b)
require.NoError(t, err)
require.Equal(t, []byte("foobar"), b[:n])
}
}
// Afterwards, we receive a stateless reset, not a copy of the CONNECTION_CLOSE packet.
// Retry up to 3 times, the connection is deleted from the map on a timer.
for i := 0; i < 3; i++ {
_, err := conn.WriteTo(p, tr.Conn.LocalAddr())
require.NoError(t, err)
conn.SetReadDeadline(time.Now().Add(dur / 4))
b := make([]byte, 100)
n, _, err := conn.ReadFrom(b)
if errors.Is(err, os.ErrDeadlineExceeded) {
continue
}
require.NoError(t, err)
require.NotEqual(t, []byte("foobar"), b[:n])
require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize)
break
}
}