forked from quic-go/quic-go
move connection tracking into the Transport (#5113)
This commit is contained in:
@@ -17,7 +17,7 @@ type connRunnerCallbacks struct {
|
||||
}
|
||||
|
||||
// The memory address of the Transport is used as the key.
|
||||
type connRunners map[*Transport]connRunnerCallbacks
|
||||
type connRunners map[connRunner]connRunnerCallbacks
|
||||
|
||||
func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
|
||||
for _, c := range cr {
|
||||
@@ -57,11 +57,11 @@ type connIDGenerator struct {
|
||||
}
|
||||
|
||||
func newConnIDGenerator(
|
||||
tr *Transport,
|
||||
runner connRunner,
|
||||
initialConnectionID protocol.ConnectionID,
|
||||
initialClientDestConnID *protocol.ConnectionID, // nil for the client
|
||||
statelessResetter *statelessResetter,
|
||||
connRunner connRunnerCallbacks,
|
||||
callbacks connRunnerCallbacks,
|
||||
queueControlFrame func(wire.Frame),
|
||||
generator ConnectionIDGenerator,
|
||||
) *connIDGenerator {
|
||||
@@ -69,7 +69,7 @@ func newConnIDGenerator(
|
||||
generator: generator,
|
||||
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
|
||||
statelessResetter: statelessResetter,
|
||||
connRunners: map[*Transport]connRunnerCallbacks{tr: connRunner},
|
||||
connRunners: map[connRunner]connRunnerCallbacks{runner: callbacks},
|
||||
queueControlFrame: queueControlFrame,
|
||||
}
|
||||
m.activeSrcConnIDs[0] = initialConnectionID
|
||||
@@ -202,13 +202,13 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte, expiry time.Durati
|
||||
m.connRunners.ReplaceWithClosed(connIDs, connClose, expiry)
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) AddConnRunner(id *Transport, r connRunnerCallbacks) {
|
||||
func (m *connIDGenerator) AddConnRunner(runner connRunner, r connRunnerCallbacks) {
|
||||
// The transport might have already been added earlier.
|
||||
// This happens if the application migrates back to and old path.
|
||||
if _, ok := m.connRunners[id]; ok {
|
||||
if _, ok := m.connRunners[runner]; ok {
|
||||
return
|
||||
}
|
||||
m.connRunners[id] = r
|
||||
m.connRunners[runner] = r
|
||||
if m.initialClientDestConnID != nil {
|
||||
r.AddConnectionID(*m.initialClientDestConnID)
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
|
||||
initialClientDestConnID = &connID
|
||||
}
|
||||
g := newConnIDGenerator(
|
||||
&Transport{},
|
||||
&packetHandlerMap{},
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
initialClientDestConnID,
|
||||
sr,
|
||||
@@ -107,7 +107,7 @@ func TestConnIDGeneratorRetiring(t *testing.T) {
|
||||
initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
|
||||
var added, removed []protocol.ConnectionID
|
||||
g := newConnIDGenerator(
|
||||
&Transport{},
|
||||
&packetHandlerMap{},
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
&initialConnID,
|
||||
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
|
||||
@@ -178,7 +178,7 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool)
|
||||
removed []protocol.ConnectionID
|
||||
)
|
||||
g := newConnIDGenerator(
|
||||
&Transport{},
|
||||
&packetHandlerMap{},
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
initialClientDestConnID,
|
||||
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
|
||||
@@ -228,7 +228,7 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
|
||||
replacedWith []byte
|
||||
)
|
||||
g := newConnIDGenerator(
|
||||
&Transport{},
|
||||
&packetHandlerMap{},
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
initialClientDestConnID,
|
||||
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
|
||||
@@ -300,7 +300,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
|
||||
var queuedFrames []wire.Frame
|
||||
|
||||
tr := &Transport{}
|
||||
tr := &packetHandlerMap{}
|
||||
g := newConnIDGenerator(
|
||||
tr,
|
||||
initialConnID,
|
||||
@@ -314,7 +314,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
require.Len(t, tracker1.added, 2)
|
||||
|
||||
// add the second runner - it should get all existing connection IDs
|
||||
g.AddConnRunner(&Transport{}, runner2)
|
||||
g.AddConnRunner(&packetHandlerMap{}, runner2)
|
||||
require.Len(t, tracker1.added, 2) // unchanged
|
||||
require.Len(t, tracker2.added, 4)
|
||||
require.Contains(t, tracker2.added, initialConnID)
|
||||
|
||||
@@ -111,8 +111,6 @@ func nextConnTracingID() ConnectionTracingID { return ConnectionTracingID(connTr
|
||||
|
||||
// A Connection is a QUIC connection
|
||||
type connection struct {
|
||||
tr *Transport
|
||||
|
||||
// Destination connection ID used during the handshake.
|
||||
// Used to check source connection ID on incoming packets.
|
||||
handshakeDestConnID protocol.ConnectionID
|
||||
@@ -225,7 +223,7 @@ var newConnection = func(
|
||||
ctx context.Context,
|
||||
ctxCancel context.CancelCauseFunc,
|
||||
conn sendConn,
|
||||
tr *Transport,
|
||||
runner connRunner,
|
||||
origDestConnID protocol.ConnectionID,
|
||||
retrySrcConnID *protocol.ConnectionID,
|
||||
clientDestConnID protocol.ConnectionID,
|
||||
@@ -244,7 +242,6 @@ var newConnection = func(
|
||||
s := &connection{
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
tr: tr,
|
||||
conn: conn,
|
||||
config: conf,
|
||||
handshakeDestConnID: destConnID,
|
||||
@@ -261,7 +258,6 @@ var newConnection = func(
|
||||
} else {
|
||||
s.logID = destConnID.String()
|
||||
}
|
||||
runner := tr.connRunner()
|
||||
s.connIDManager = newConnIDManager(
|
||||
destConnID,
|
||||
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
|
||||
@@ -269,7 +265,7 @@ var newConnection = func(
|
||||
s.queueControlFrame,
|
||||
)
|
||||
s.connIDGenerator = newConnIDGenerator(
|
||||
tr,
|
||||
runner,
|
||||
srcConnID,
|
||||
&clientDestConnID,
|
||||
statelessResetter,
|
||||
@@ -347,7 +343,7 @@ var newConnection = func(
|
||||
var newClientConnection = func(
|
||||
ctx context.Context,
|
||||
conn sendConn,
|
||||
tr *Transport,
|
||||
runner connRunner,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
@@ -362,7 +358,6 @@ var newClientConnection = func(
|
||||
v protocol.Version,
|
||||
) quicConn {
|
||||
s := &connection{
|
||||
tr: tr,
|
||||
conn: conn,
|
||||
config: conf,
|
||||
origDestConnID: destConnID,
|
||||
@@ -375,7 +370,6 @@ var newClientConnection = func(
|
||||
versionNegotiated: hasNegotiatedVersion,
|
||||
version: v,
|
||||
}
|
||||
runner := tr.connRunner()
|
||||
s.connIDManager = newConnIDManager(
|
||||
destConnID,
|
||||
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
|
||||
@@ -383,7 +377,7 @@ var newClientConnection = func(
|
||||
s.queueControlFrame,
|
||||
)
|
||||
s.connIDGenerator = newConnIDGenerator(
|
||||
tr,
|
||||
runner,
|
||||
srcConnID,
|
||||
nil,
|
||||
statelessResetter,
|
||||
@@ -2650,9 +2644,9 @@ func (s *connection) AddPath(t *Transport) (*Path, error) {
|
||||
t,
|
||||
200*time.Millisecond, // initial RTT estimate
|
||||
func() {
|
||||
runner := t.connRunner()
|
||||
runner := (*packetHandlerMap)(t)
|
||||
s.connIDGenerator.AddConnRunner(
|
||||
t,
|
||||
runner,
|
||||
connRunnerCallbacks{
|
||||
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
|
||||
RemoveConnectionID: runner.Remove,
|
||||
|
||||
@@ -81,7 +81,7 @@ func connectionOptRetrySrcConnID(rcid protocol.ConnectionID) testConnectionOpt {
|
||||
|
||||
type testConnection struct {
|
||||
conn *connection
|
||||
connRunner *MockPacketHandlerManager
|
||||
connRunner *MockConnRunner
|
||||
sendConn *MockSendConn
|
||||
packer *MockPacker
|
||||
destConnID protocol.ConnectionID
|
||||
@@ -101,7 +101,7 @@ func newServerTestConnection(
|
||||
}
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
|
||||
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
connRunner := NewMockConnRunner(mockCtrl)
|
||||
sendConn := NewMockSendConn(mockCtrl)
|
||||
sendConn.EXPECT().capabilities().Return(connCapabilities{GSO: gso}).AnyTimes()
|
||||
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
|
||||
@@ -119,7 +119,7 @@ func newServerTestConnection(
|
||||
ctx,
|
||||
cancel,
|
||||
sendConn,
|
||||
&Transport{handlerMap: phm},
|
||||
connRunner,
|
||||
origDestConnID,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
@@ -141,7 +141,7 @@ func newServerTestConnection(
|
||||
}
|
||||
return &testConnection{
|
||||
conn: conn,
|
||||
connRunner: phm,
|
||||
connRunner: connRunner,
|
||||
sendConn: sendConn,
|
||||
packer: packer,
|
||||
destConnID: origDestConnID,
|
||||
@@ -162,7 +162,7 @@ func newClientTestConnection(
|
||||
}
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
|
||||
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
connRunner := NewMockConnRunner(mockCtrl)
|
||||
sendConn := NewMockSendConn(mockCtrl)
|
||||
sendConn.EXPECT().capabilities().Return(connCapabilities{}).AnyTimes()
|
||||
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
|
||||
@@ -178,7 +178,7 @@ func newClientTestConnection(
|
||||
conn := newClientConnection(
|
||||
context.Background(),
|
||||
sendConn,
|
||||
&Transport{handlerMap: phm},
|
||||
connRunner,
|
||||
destConnID,
|
||||
srcConnID,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
@@ -198,7 +198,7 @@ func newClientTestConnection(
|
||||
}
|
||||
return &testConnection{
|
||||
conn: conn,
|
||||
connRunner: phm,
|
||||
connRunner: connRunner,
|
||||
sendConn: sendConn,
|
||||
packer: packer,
|
||||
destConnID: destConnID,
|
||||
|
||||
@@ -1,376 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/quic-go/quic-go (interfaces: PacketHandlerManager)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager
|
||||
//
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
protocol "github.com/quic-go/quic-go/internal/protocol"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockPacketHandlerManager is a mock of PacketHandlerManager interface.
|
||||
type MockPacketHandlerManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPacketHandlerManagerMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerMockRecorder is the mock recorder for MockPacketHandlerManager.
|
||||
type MockPacketHandlerManagerMockRecorder struct {
|
||||
mock *MockPacketHandlerManager
|
||||
}
|
||||
|
||||
// NewMockPacketHandlerManager creates a new mock instance.
|
||||
func NewMockPacketHandlerManager(ctrl *gomock.Controller) *MockPacketHandlerManager {
|
||||
mock := &MockPacketHandlerManager{ctrl: ctrl}
|
||||
mock.recorder = &MockPacketHandlerManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Add mocks base method.
|
||||
func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Add", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 any) *MockPacketHandlerManagerAddCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1)
|
||||
return &MockPacketHandlerManagerAddCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerAddCall wrap *gomock.Call
|
||||
type MockPacketHandlerManagerAddCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPacketHandlerManagerAddCall) Return(arg0 bool) *MockPacketHandlerManagerAddCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPacketHandlerManagerAddCall) Do(f func(protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPacketHandlerManagerAddCall) DoAndReturn(f func(protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// AddResetToken mocks base method.
|
||||
func (m *MockPacketHandlerManager) AddResetToken(arg0 protocol.StatelessResetToken, arg1 packetHandler) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AddResetToken", arg0, arg1)
|
||||
}
|
||||
|
||||
// AddResetToken indicates an expected call of AddResetToken.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 any) *MockPacketHandlerManagerAddResetTokenCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1)
|
||||
return &MockPacketHandlerManagerAddResetTokenCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerAddResetTokenCall wrap *gomock.Call
|
||||
type MockPacketHandlerManagerAddResetTokenCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPacketHandlerManagerAddResetTokenCall) Return() *MockPacketHandlerManagerAddResetTokenCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPacketHandlerManagerAddResetTokenCall) Do(f func(protocol.StatelessResetToken, packetHandler)) *MockPacketHandlerManagerAddResetTokenCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPacketHandlerManagerAddResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken, packetHandler)) *MockPacketHandlerManagerAddResetTokenCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// AddWithConnID mocks base method.
|
||||
func (m *MockPacketHandlerManager) AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddWithConnID", destConnID, newConnID, h)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddWithConnID indicates an expected call of AddWithConnID.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(destConnID, newConnID, h any) *MockPacketHandlerManagerAddWithConnIDCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), destConnID, newConnID, h)
|
||||
return &MockPacketHandlerManagerAddWithConnIDCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerAddWithConnIDCall wrap *gomock.Call
|
||||
type MockPacketHandlerManagerAddWithConnIDCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPacketHandlerManagerAddWithConnIDCall) Return(arg0 bool) *MockPacketHandlerManagerAddWithConnIDCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPacketHandlerManagerAddWithConnIDCall) Do(f func(protocol.ConnectionID, protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddWithConnIDCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPacketHandlerManagerAddWithConnIDCall) DoAndReturn(f func(protocol.ConnectionID, protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddWithConnIDCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockPacketHandlerManager) Close(arg0 error) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Close", arg0)
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 any) *MockPacketHandlerManagerCloseCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0)
|
||||
return &MockPacketHandlerManagerCloseCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerCloseCall wrap *gomock.Call
|
||||
type MockPacketHandlerManagerCloseCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPacketHandlerManagerCloseCall) Return() *MockPacketHandlerManagerCloseCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPacketHandlerManagerCloseCall) Do(f func(error)) *MockPacketHandlerManagerCloseCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPacketHandlerManagerCloseCall) DoAndReturn(f func(error)) *MockPacketHandlerManagerCloseCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", arg0)
|
||||
ret0, _ := ret[0].(packetHandler)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 any) *MockPacketHandlerManagerGetCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
|
||||
return &MockPacketHandlerManagerGetCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerGetCall wrap *gomock.Call
|
||||
type MockPacketHandlerManagerGetCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPacketHandlerManagerGetCall) Return(arg0 packetHandler, arg1 bool) *MockPacketHandlerManagerGetCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPacketHandlerManagerGetCall) Do(f func(protocol.ConnectionID) (packetHandler, bool)) *MockPacketHandlerManagerGetCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPacketHandlerManagerGetCall) DoAndReturn(f func(protocol.ConnectionID) (packetHandler, bool)) *MockPacketHandlerManagerGetCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// GetByResetToken mocks base method.
|
||||
func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetToken) (packetHandler, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetByResetToken", arg0)
|
||||
ret0, _ := ret[0].(packetHandler)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetByResetToken indicates an expected call of GetByResetToken.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 any) *MockPacketHandlerManagerGetByResetTokenCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0)
|
||||
return &MockPacketHandlerManagerGetByResetTokenCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerGetByResetTokenCall wrap *gomock.Call
|
||||
type MockPacketHandlerManagerGetByResetTokenCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPacketHandlerManagerGetByResetTokenCall) Return(arg0 packetHandler, arg1 bool) *MockPacketHandlerManagerGetByResetTokenCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPacketHandlerManagerGetByResetTokenCall) Do(f func(protocol.StatelessResetToken) (packetHandler, bool)) *MockPacketHandlerManagerGetByResetTokenCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPacketHandlerManagerGetByResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken) (packetHandler, bool)) *MockPacketHandlerManagerGetByResetTokenCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Remove mocks base method.
|
||||
func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Remove", arg0)
|
||||
}
|
||||
|
||||
// Remove indicates an expected call of Remove.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 any) *MockPacketHandlerManagerRemoveCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0)
|
||||
return &MockPacketHandlerManagerRemoveCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerRemoveCall wrap *gomock.Call
|
||||
type MockPacketHandlerManagerRemoveCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPacketHandlerManagerRemoveCall) Return() *MockPacketHandlerManagerRemoveCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPacketHandlerManagerRemoveCall) Do(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRemoveCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPacketHandlerManagerRemoveCall) DoAndReturn(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRemoveCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// RemoveResetToken mocks base method.
|
||||
func (m *MockPacketHandlerManager) RemoveResetToken(arg0 protocol.StatelessResetToken) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "RemoveResetToken", arg0)
|
||||
}
|
||||
|
||||
// RemoveResetToken indicates an expected call of RemoveResetToken.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 any) *MockPacketHandlerManagerRemoveResetTokenCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0)
|
||||
return &MockPacketHandlerManagerRemoveResetTokenCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerRemoveResetTokenCall wrap *gomock.Call
|
||||
type MockPacketHandlerManagerRemoveResetTokenCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPacketHandlerManagerRemoveResetTokenCall) Return() *MockPacketHandlerManagerRemoveResetTokenCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPacketHandlerManagerRemoveResetTokenCall) Do(f func(protocol.StatelessResetToken)) *MockPacketHandlerManagerRemoveResetTokenCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPacketHandlerManagerRemoveResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken)) *MockPacketHandlerManagerRemoveResetTokenCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// ReplaceWithClosed mocks base method.
|
||||
func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 []byte, arg2 time.Duration) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// ReplaceWithClosed indicates an expected call of ReplaceWithClosed.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *MockPacketHandlerManagerReplaceWithClosedCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1, arg2)
|
||||
return &MockPacketHandlerManagerReplaceWithClosedCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerReplaceWithClosedCall wrap *gomock.Call
|
||||
type MockPacketHandlerManagerReplaceWithClosedCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPacketHandlerManagerReplaceWithClosedCall) Return() *MockPacketHandlerManagerReplaceWithClosedCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPacketHandlerManagerReplaceWithClosedCall) Do(f func([]protocol.ConnectionID, []byte, time.Duration)) *MockPacketHandlerManagerReplaceWithClosedCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPacketHandlerManagerReplaceWithClosedCall) DoAndReturn(f func([]protocol.ConnectionID, []byte, time.Duration)) *MockPacketHandlerManagerReplaceWithClosedCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
@@ -56,9 +56,4 @@ type QUICConn = quicConn
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler"
|
||||
type PacketHandler = packetHandler
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
|
||||
type PacketHandlerManager = packetHandlerManager
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn"
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type connCapabilities struct {
|
||||
// This connection has the Don't Fragment (DF) bit set.
|
||||
// This means it makes to run DPLPMTUD.
|
||||
DF bool
|
||||
// GSO (Generic Segmentation Offload) supported
|
||||
GSO bool
|
||||
// ECN (Explicit Congestion Notifications) supported
|
||||
ECN bool
|
||||
}
|
||||
|
||||
// rawConn is a connection that allow reading of a receivedPackeh.
|
||||
type rawConn interface {
|
||||
ReadPacket() (receivedPacket, error)
|
||||
// WritePacket writes a packet on the wire.
|
||||
// gsoSize is the size of a single packet, or 0 to disable GSO.
|
||||
// It is invalid to set gsoSize if capabilities.GSO is not set.
|
||||
WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error)
|
||||
LocalAddr() net.Addr
|
||||
SetReadDeadline(time.Time) error
|
||||
io.Closer
|
||||
|
||||
capabilities() connCapabilities
|
||||
}
|
||||
|
||||
type closePacket struct {
|
||||
payload []byte
|
||||
addr net.Addr
|
||||
info packetInfo
|
||||
}
|
||||
|
||||
type packetHandlerMap struct {
|
||||
mutex sync.Mutex
|
||||
handlers map[protocol.ConnectionID]packetHandler
|
||||
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
|
||||
|
||||
closed bool
|
||||
|
||||
enqueueClosePacket func(closePacket)
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ packetHandlerManager = &packetHandlerMap{}
|
||||
|
||||
func newPacketHandlerMap(enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
|
||||
return &packetHandlerMap{
|
||||
handlers: make(map[protocol.ConnectionID]packetHandler),
|
||||
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
|
||||
enqueueClosePacket: enqueueClosePacket,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
handler, ok := h.handlers[id]
|
||||
return handler, ok
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if _, ok := h.handlers[id]; ok {
|
||||
h.logger.Debugf("Not adding connection ID %s, as it already exists.", id)
|
||||
return false
|
||||
}
|
||||
h.handlers[id] = handler
|
||||
h.logger.Debugf("Adding connection ID %s.", id)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if _, ok := h.handlers[clientDestConnID]; ok {
|
||||
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
|
||||
return false
|
||||
}
|
||||
h.handlers[clientDestConnID] = handler
|
||||
h.handlers[newConnID] = handler
|
||||
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
||||
h.mutex.Lock()
|
||||
delete(h.handlers, id)
|
||||
h.mutex.Unlock()
|
||||
h.logger.Debugf("Removing connection ID %s.", id)
|
||||
}
|
||||
|
||||
// ReplaceWithClosed is called when a connection is closed.
|
||||
// Depending on which side closed the connection, we need to:
|
||||
// * remote close: absorb delayed packets
|
||||
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
|
||||
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte, expiry time.Duration) {
|
||||
var handler packetHandler
|
||||
if connClosePacket != nil {
|
||||
handler = newClosedLocalConn(
|
||||
func(addr net.Addr, info packetInfo) {
|
||||
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
|
||||
},
|
||||
h.logger,
|
||||
)
|
||||
} else {
|
||||
handler = newClosedRemoteConn()
|
||||
}
|
||||
|
||||
h.mutex.Lock()
|
||||
for _, id := range ids {
|
||||
h.handlers[id] = handler
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids)
|
||||
|
||||
time.AfterFunc(expiry, func() {
|
||||
h.mutex.Lock()
|
||||
for _, id := range ids {
|
||||
delete(h.handlers, id)
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
|
||||
h.mutex.Lock()
|
||||
h.resetTokens[token] = handler
|
||||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) {
|
||||
h.mutex.Lock()
|
||||
delete(h.resetTokens, token)
|
||||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
handler, ok := h.resetTokens[token]
|
||||
return handler, ok
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Close(e error) {
|
||||
h.mutex.Lock()
|
||||
if h.closed {
|
||||
h.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, handler := range h.handlers {
|
||||
wg.Add(1)
|
||||
go func(handler packetHandler) {
|
||||
handler.destroy(e)
|
||||
wg.Done()
|
||||
}(handler)
|
||||
}
|
||||
h.closed = true
|
||||
h.mutex.Unlock()
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPacketHandlerMapAddAndRemove(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, utils.DefaultLogger)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
h := &mockPacketHandler{}
|
||||
require.True(t, m.Add(connID, h))
|
||||
got, ok := m.Get(connID)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, h, got)
|
||||
|
||||
// cannot add the same handler twice
|
||||
require.False(t, m.Add(connID, h))
|
||||
got, ok = m.Get(connID)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, h, got)
|
||||
|
||||
// remove the handler
|
||||
m.Remove(connID)
|
||||
got, ok = m.Get(connID)
|
||||
require.False(t, ok)
|
||||
require.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, utils.DefaultLogger)
|
||||
h := &mockPacketHandler{}
|
||||
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
require.True(t, m.AddWithConnID(connID1, connID2, h))
|
||||
// collision of the connection ID, this handler should not be added
|
||||
require.False(t, m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), nil))
|
||||
|
||||
got, ok := m.Get(connID1)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, h, got)
|
||||
got, ok = m.Get(connID2)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, h, got)
|
||||
}
|
||||
|
||||
func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, utils.DefaultLogger)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
|
||||
handler := &mockPacketHandler{}
|
||||
m.AddResetToken(token, handler)
|
||||
h, ok := m.GetByResetToken(token)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, handler, h)
|
||||
m.RemoveResetToken(token)
|
||||
_, ok = m.GetByResetToken(token)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPacketHandlerMapReplaceWithLocalClosed(t *testing.T) {
|
||||
var closePackets []closePacket
|
||||
m := newPacketHandlerMap(
|
||||
func(p closePacket) { closePackets = append(closePackets, p) },
|
||||
utils.DefaultLogger,
|
||||
)
|
||||
dur := scaleDuration(10 * time.Millisecond)
|
||||
|
||||
handler := &mockPacketHandler{}
|
||||
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
require.True(t, m.Add(connID, handler))
|
||||
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, []byte("foobar"), dur)
|
||||
h, ok := m.Get(connID)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, handler, h)
|
||||
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
||||
h.handlePacket(receivedPacket{remoteAddr: addr})
|
||||
require.Len(t, closePackets, 1)
|
||||
require.Equal(t, addr, closePackets[0].addr)
|
||||
require.Equal(t, []byte("foobar"), closePackets[0].payload)
|
||||
|
||||
time.Sleep(dur)
|
||||
require.Eventually(t, func() bool {
|
||||
_, ok := m.Get(connID)
|
||||
return !ok
|
||||
}, time.Second, dur)
|
||||
}
|
||||
|
||||
func TestPacketHandlerMapReplaceWithRemoteClosed(t *testing.T) {
|
||||
var closePackets []closePacket
|
||||
m := newPacketHandlerMap(
|
||||
func(p closePacket) { closePackets = append(closePackets, p) },
|
||||
utils.DefaultLogger,
|
||||
)
|
||||
dur := scaleDuration(25 * time.Millisecond)
|
||||
|
||||
handler := &mockPacketHandler{}
|
||||
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
require.True(t, m.Add(connID, handler))
|
||||
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, nil, dur)
|
||||
h, ok := m.Get(connID)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, handler, h)
|
||||
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
||||
h.handlePacket(receivedPacket{remoteAddr: addr})
|
||||
require.Empty(t, closePackets)
|
||||
|
||||
time.Sleep(dur)
|
||||
require.Eventually(t, func() bool {
|
||||
_, ok := m.Get(connID)
|
||||
return !ok
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestPacketHandlerMapClose(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, utils.DefaultLogger)
|
||||
const numConns = 10
|
||||
destroyChan := make(chan error, 2*numConns)
|
||||
for i := 0; i < numConns; i++ {
|
||||
conn := &mockPacketHandler{destruction: destroyChan}
|
||||
b := make([]byte, 12)
|
||||
rand.Read(b)
|
||||
m.Add(protocol.ParseConnectionID(b), conn)
|
||||
}
|
||||
m.Close(assert.AnError)
|
||||
// check that Close can be called multiple times
|
||||
m.Close(errors.New("close"))
|
||||
|
||||
for i := 0; i < numConns; i++ {
|
||||
select {
|
||||
case err := <-destroyChan:
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
default:
|
||||
t.Fatalf("connection not destroyed")
|
||||
}
|
||||
}
|
||||
select {
|
||||
case err := <-destroyChan:
|
||||
t.Fatalf("connection destroyed more than once: %s", err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
20
server.go
20
server.go
@@ -32,14 +32,6 @@ type packetHandler interface {
|
||||
closeWithTransportError(qerr.TransportErrorCode)
|
||||
}
|
||||
|
||||
type packetHandlerManager interface {
|
||||
Get(protocol.ConnectionID) (packetHandler, bool)
|
||||
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
|
||||
AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool
|
||||
Close(error)
|
||||
connRunner
|
||||
}
|
||||
|
||||
type quicConn interface {
|
||||
EarlyConnection
|
||||
earlyConnReady() <-chan struct{}
|
||||
@@ -61,7 +53,7 @@ type rejectedPacket struct {
|
||||
|
||||
// A Listener of QUIC
|
||||
type baseServer struct {
|
||||
tr *Transport
|
||||
tr *packetHandlerMap
|
||||
disableVersionNegotiation bool
|
||||
acceptEarlyConns bool
|
||||
|
||||
@@ -89,7 +81,7 @@ type baseServer struct {
|
||||
context.Context,
|
||||
context.CancelCauseFunc,
|
||||
sendConn,
|
||||
*Transport,
|
||||
connRunner,
|
||||
protocol.ConnectionID, /* original dest connection ID */
|
||||
*protocol.ConnectionID, /* retry src connection ID */
|
||||
protocol.ConnectionID, /* client dest connection ID */
|
||||
@@ -247,7 +239,7 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear
|
||||
|
||||
func newServer(
|
||||
conn rawConn,
|
||||
tr *Transport,
|
||||
tr *packetHandlerMap,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetter *statelessResetter,
|
||||
connContext func(context.Context) context.Context,
|
||||
@@ -501,7 +493,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
|
||||
}
|
||||
|
||||
// check again if we might have a connection now
|
||||
if handler, ok := s.tr.connRunner().Get(connID); ok {
|
||||
if handler, ok := s.tr.Get(connID); ok {
|
||||
handler.handlePacket(p)
|
||||
return true
|
||||
}
|
||||
@@ -591,7 +583,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
// The server queues packets for a while, and we might already have established a connection by now.
|
||||
// This results in a second check in the connection map.
|
||||
// That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets).
|
||||
if handler, ok := s.tr.connRunner().Get(hdr.DestConnectionID); ok {
|
||||
if handler, ok := s.tr.Get(hdr.DestConnectionID); ok {
|
||||
handler.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
@@ -727,7 +719,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
// This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
|
||||
// under normal circumstances the packet would just be routed to that connection.
|
||||
// The only time this collision will occur if we receive the two Initial packets at the same time.
|
||||
if added := s.tr.connRunner().AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
|
||||
if added := s.tr.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
conn.closeWithTransportError(qerr.ConnectionRefused)
|
||||
return nil
|
||||
|
||||
@@ -37,7 +37,7 @@ type serverOpts struct {
|
||||
context.Context,
|
||||
context.CancelCauseFunc,
|
||||
sendConn,
|
||||
*Transport,
|
||||
connRunner,
|
||||
protocol.ConnectionID, // original dest connection ID
|
||||
*protocol.ConnectionID, // retry src connection ID
|
||||
protocol.ConnectionID, // client dest connection ID
|
||||
@@ -61,9 +61,11 @@ func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer {
|
||||
require.NoError(t, err)
|
||||
verifySourceAddress := func(net.Addr) bool { return serverOpts.useRetry }
|
||||
config := populateConfig(serverOpts.config)
|
||||
tr := &Transport{Conn: newUDPConnLocalhost(t)}
|
||||
tr.init(true)
|
||||
s := newServer(
|
||||
c,
|
||||
&Transport{handlerMap: newPacketHandlerMap(nil, utils.DefaultLogger)},
|
||||
(*packetHandlerMap)(tr),
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
&statelessResetter{},
|
||||
func(ctx context.Context) context.Context { return ctx },
|
||||
@@ -596,6 +598,7 @@ func testServerTokenValidation(
|
||||
|
||||
type connConstructorArgs struct {
|
||||
ctx context.Context
|
||||
connRunner connRunner
|
||||
config *Config
|
||||
origDestConnID protocol.ConnectionID
|
||||
retrySrcConnID *protocol.ConnectionID
|
||||
@@ -607,10 +610,10 @@ type connConstructorArgs struct {
|
||||
type connConstructorRecorder struct {
|
||||
ch chan connConstructorArgs
|
||||
|
||||
conns []quicConn
|
||||
conns []*MockQUICConn
|
||||
}
|
||||
|
||||
func newConnConstructorRecorder(conns ...quicConn) *connConstructorRecorder {
|
||||
func newConnConstructorRecorder(conns ...*MockQUICConn) *connConstructorRecorder {
|
||||
return &connConstructorRecorder{
|
||||
ch: make(chan connConstructorArgs, len(conns)),
|
||||
conns: conns,
|
||||
@@ -623,7 +626,7 @@ func (r *connConstructorRecorder) NewConn(
|
||||
ctx context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
_ *Transport,
|
||||
connRunner connRunner,
|
||||
origDestConnID protocol.ConnectionID,
|
||||
retrySrcConnID *protocol.ConnectionID,
|
||||
clientDestConnID protocol.ConnectionID,
|
||||
@@ -641,6 +644,7 @@ func (r *connConstructorRecorder) NewConn(
|
||||
) quicConn {
|
||||
r.ch <- connConstructorArgs{
|
||||
ctx: ctx,
|
||||
connRunner: connRunner,
|
||||
config: config,
|
||||
origDestConnID: origDestConnID,
|
||||
retrySrcConnID: retrySrcConnID,
|
||||
@@ -747,7 +751,7 @@ func testServerCreateConnection(t *testing.T, useRetry bool) {
|
||||
|
||||
func TestServerClose(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
var conns []quicConn
|
||||
var conns []*MockQUICConn
|
||||
const numConns = 3
|
||||
done := make(chan struct{}, numConns)
|
||||
for range numConns {
|
||||
@@ -779,7 +783,7 @@ func TestServerClose(t *testing.T) {
|
||||
}
|
||||
|
||||
server.Close()
|
||||
// closing closes all handshakeing connections with CONNECTION_REFUSED
|
||||
// closing closes all handshaking connections with CONNECTION_REFUSED
|
||||
for range numConns {
|
||||
select {
|
||||
case <-done:
|
||||
@@ -796,6 +800,11 @@ func TestServerClose(t *testing.T) {
|
||||
require.ErrorIs(t, err, ErrServerClosed)
|
||||
require.ErrorIs(t, err, net.ErrClosed)
|
||||
}
|
||||
|
||||
// test shutdown
|
||||
for _, conn := range conns {
|
||||
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerGetConfigForClientAccept(t *testing.T) {
|
||||
@@ -878,7 +887,7 @@ func TestServerPacketHandling(t *testing.T) {
|
||||
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
|
||||
handledPacket <- p
|
||||
})
|
||||
server.tr.handlerMap.Add(destConnID, conn)
|
||||
server.tr.Add(destConnID, conn)
|
||||
|
||||
server.handlePacket(
|
||||
getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, srcConnID, destConnID),
|
||||
@@ -904,7 +913,7 @@ func TestServerReceiveQueue(t *testing.T) {
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
_ *Transport,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ *protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -927,6 +936,8 @@ func TestServerReceiveQueue(t *testing.T) {
|
||||
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
|
||||
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
|
||||
// during test shutdown
|
||||
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
return conn
|
||||
},
|
||||
})
|
||||
@@ -1013,6 +1024,9 @@ func testServerAccept(t *testing.T, acceptEarly bool) {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// test shutdown
|
||||
c.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
}
|
||||
|
||||
func TestServerAcceptHandshakeFailure(t *testing.T) {
|
||||
@@ -1044,11 +1058,14 @@ func TestServerAcceptHandshakeFailure(t *testing.T) {
|
||||
t.Fatal("server should not have accepted the connection")
|
||||
case <-time.After(scaleDuration(5 * time.Millisecond)):
|
||||
}
|
||||
|
||||
// test shutdown
|
||||
c.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
}
|
||||
|
||||
func TestServerAcceptQueue(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
var conns []quicConn
|
||||
var conns []*MockQUICConn
|
||||
var rejectedConn *MockQUICConn
|
||||
for i := range protocol.MaxAcceptQueueSize + 2 {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
@@ -1125,6 +1142,11 @@ func TestServerAcceptQueue(t *testing.T) {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// test shutdown
|
||||
for _, conn := range conns {
|
||||
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer0RTTReordering(t *testing.T) {
|
||||
@@ -1210,6 +1232,7 @@ func TestServer0RTTReordering(t *testing.T) {
|
||||
|
||||
// shutdown
|
||||
conn.EXPECT().closeWithTransportError(gomock.Any()).AnyTimes()
|
||||
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
}
|
||||
|
||||
func TestServer0RTTQueueing(t *testing.T) {
|
||||
|
||||
25
sys_conn.go
25
sys_conn.go
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
@@ -13,6 +14,30 @@ import (
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type connCapabilities struct {
|
||||
// This connection has the Don't Fragment (DF) bit set.
|
||||
// This means it makes to run DPLPMTUD.
|
||||
DF bool
|
||||
// GSO (Generic Segmentation Offload) supported
|
||||
GSO bool
|
||||
// ECN (Explicit Congestion Notifications) supported
|
||||
ECN bool
|
||||
}
|
||||
|
||||
// rawConn is a connection that allow reading of a receivedPackeh.
|
||||
type rawConn interface {
|
||||
ReadPacket() (receivedPacket, error)
|
||||
// WritePacket writes a packet on the wire.
|
||||
// gsoSize is the size of a single packet, or 0 to disable GSO.
|
||||
// It is invalid to set gsoSize if capabilities.GSO is not set.
|
||||
WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error)
|
||||
LocalAddr() net.Addr
|
||||
SetReadDeadline(time.Time) error
|
||||
io.Closer
|
||||
|
||||
capabilities() connCapabilities
|
||||
}
|
||||
|
||||
// OOBCapablePacketConn is a connection that allows the reading of ECN bits from the IP header.
|
||||
// If the PacketConn passed to the [Transport] satisfies this interface, quic-go will use it.
|
||||
// In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets.
|
||||
|
||||
156
transport.go
156
transport.go
@@ -40,6 +40,12 @@ func (e *errTransportClosed) Is(target error) bool {
|
||||
|
||||
var errListenerAlreadySet = errors.New("listener already set")
|
||||
|
||||
type closePacket struct {
|
||||
payload []byte
|
||||
addr net.Addr
|
||||
info packetInfo
|
||||
}
|
||||
|
||||
// The Transport is the central point to manage incoming and outgoing QUIC connections.
|
||||
// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
|
||||
// This means that a single UDP socket can be used for listening for incoming connections, as well as
|
||||
@@ -126,7 +132,9 @@ type Transport struct {
|
||||
// Tracer.Close is called when the transport is closed.
|
||||
Tracer *logging.Tracer
|
||||
|
||||
handlerMap packetHandlerManager
|
||||
connMx sync.Mutex
|
||||
handlers map[protocol.ConnectionID]packetHandler
|
||||
resetTokens map[protocol.StatelessResetToken]packetHandler
|
||||
|
||||
mutex sync.Mutex
|
||||
initOnce sync.Once
|
||||
@@ -206,7 +214,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
||||
}
|
||||
s := newServer(
|
||||
t.conn,
|
||||
t,
|
||||
(*packetHandlerMap)(t),
|
||||
t.connIDGenerator,
|
||||
t.statelessResetter,
|
||||
t.ConnContext,
|
||||
@@ -297,7 +305,7 @@ func (t *Transport) doDial(
|
||||
conn := newClientConnection(
|
||||
context.WithoutCancel(ctx),
|
||||
sendConn,
|
||||
t,
|
||||
(*packetHandlerMap)(t),
|
||||
destConnID,
|
||||
srcConnID,
|
||||
t.connIDGenerator,
|
||||
@@ -311,7 +319,9 @@ func (t *Transport) doDial(
|
||||
logger,
|
||||
version,
|
||||
)
|
||||
t.handlerMap.Add(srcConnID, conn)
|
||||
t.connMx.Lock()
|
||||
t.handlers[srcConnID] = conn
|
||||
t.connMx.Unlock()
|
||||
t.mutex.Unlock()
|
||||
|
||||
// The error channel needs to be buffered, as the run loop will continue running
|
||||
@@ -384,9 +394,8 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
||||
|
||||
t.logger = utils.DefaultLogger // TODO: make this configurable
|
||||
t.conn = conn
|
||||
if t.handlerMap == nil { // allows mocking the handlerMap in tests
|
||||
t.handlerMap = newPacketHandlerMap(t.enqueueClosePacket, t.logger)
|
||||
}
|
||||
t.handlers = make(map[protocol.ConnectionID]packetHandler)
|
||||
t.resetTokens = make(map[protocol.StatelessResetToken]packetHandler)
|
||||
t.listening = make(chan struct{})
|
||||
|
||||
t.closeQueue = make(chan closePacket, 4)
|
||||
@@ -419,10 +428,6 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
||||
return t.initErr
|
||||
}
|
||||
|
||||
func (t *Transport) connRunner() packetHandlerManager {
|
||||
return t.handlerMap
|
||||
}
|
||||
|
||||
// WriteTo sends a packet on the underlying connection.
|
||||
func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if err := t.init(false); err != nil {
|
||||
@@ -431,15 +436,6 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported)
|
||||
}
|
||||
|
||||
func (t *Transport) enqueueClosePacket(p closePacket) {
|
||||
select {
|
||||
case t.closeQueue <- p:
|
||||
default:
|
||||
// Oops, we're backlogged.
|
||||
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) runSendQueue() {
|
||||
for {
|
||||
select {
|
||||
@@ -501,9 +497,19 @@ func (t *Transport) close(e error) {
|
||||
}
|
||||
|
||||
e = &errTransportClosed{err: e}
|
||||
if t.handlerMap != nil {
|
||||
t.handlerMap.Close(e)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
t.connMx.Lock()
|
||||
for _, handler := range t.handlers {
|
||||
wg.Add(1)
|
||||
go func(handler packetHandler) {
|
||||
handler.destroy(e)
|
||||
wg.Done()
|
||||
}(handler)
|
||||
}
|
||||
t.connMx.Unlock()
|
||||
wg.Wait()
|
||||
|
||||
if t.server != nil {
|
||||
t.server.close(e, false)
|
||||
}
|
||||
@@ -566,7 +572,7 @@ func (t *Transport) handlePacket(p receivedPacket) {
|
||||
}
|
||||
|
||||
// If there's a connection associated with the connection ID, pass the packet there.
|
||||
if handler, ok := t.handlerMap.Get(connID); ok {
|
||||
if handler, ok := (*packetHandlerMap)(t).Get(connID); ok {
|
||||
handler.handlePacket(p)
|
||||
return
|
||||
}
|
||||
@@ -651,8 +657,12 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
|
||||
if conn, ok := t.handlerMap.GetByResetToken(token); ok {
|
||||
token := protocol.StatelessResetToken(data[len(data)-16:])
|
||||
t.connMx.Lock()
|
||||
conn, ok := t.resetTokens[token]
|
||||
t.connMx.Unlock()
|
||||
|
||||
if ok {
|
||||
t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
|
||||
go conn.destroy(&StatelessResetError{})
|
||||
return true
|
||||
@@ -717,3 +727,99 @@ func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) {
|
||||
}
|
||||
tlsConf.ServerName = h
|
||||
}
|
||||
|
||||
type packetHandlerMap Transport
|
||||
|
||||
var _ connRunner = &packetHandlerMap{}
|
||||
|
||||
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
|
||||
h.connMx.Lock()
|
||||
defer h.connMx.Unlock()
|
||||
|
||||
if _, ok := h.handlers[id]; ok {
|
||||
h.logger.Debugf("Not adding connection ID %s, as it already exists.", id)
|
||||
return false
|
||||
}
|
||||
h.handlers[id] = handler
|
||||
h.logger.Debugf("Adding connection ID %s.", id)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Get(connID protocol.ConnectionID) (packetHandler, bool) {
|
||||
h.connMx.Lock()
|
||||
defer h.connMx.Unlock()
|
||||
handler, ok := h.handlers[connID]
|
||||
return handler, ok
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
|
||||
h.connMx.Lock()
|
||||
h.resetTokens[token] = handler
|
||||
h.connMx.Unlock()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) {
|
||||
h.connMx.Lock()
|
||||
delete(h.resetTokens, token)
|
||||
h.connMx.Unlock()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool {
|
||||
h.connMx.Lock()
|
||||
defer h.connMx.Unlock()
|
||||
|
||||
if _, ok := h.handlers[clientDestConnID]; ok {
|
||||
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
|
||||
return false
|
||||
}
|
||||
h.handlers[clientDestConnID] = handler
|
||||
h.handlers[newConnID] = handler
|
||||
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
||||
h.connMx.Lock()
|
||||
delete(h.handlers, id)
|
||||
h.connMx.Unlock()
|
||||
h.logger.Debugf("Removing connection ID %s.", id)
|
||||
}
|
||||
|
||||
// ReplaceWithClosed is called when a connection is closed.
|
||||
// Depending on which side closed the connection, we need to:
|
||||
// * remote close: absorb delayed packets
|
||||
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
|
||||
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte, expiry time.Duration) {
|
||||
var handler packetHandler
|
||||
if connClosePacket != nil {
|
||||
handler = newClosedLocalConn(
|
||||
func(addr net.Addr, info packetInfo) {
|
||||
select {
|
||||
case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}:
|
||||
default:
|
||||
// We're backlogged.
|
||||
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
|
||||
}
|
||||
},
|
||||
h.logger,
|
||||
)
|
||||
} else {
|
||||
handler = newClosedRemoteConn()
|
||||
}
|
||||
|
||||
h.connMx.Lock()
|
||||
for _, id := range ids {
|
||||
h.handlers[id] = handler
|
||||
}
|
||||
h.connMx.Unlock()
|
||||
h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids)
|
||||
|
||||
time.AfterFunc(expiry, func() {
|
||||
h.connMx.Lock()
|
||||
for _, id := range ids {
|
||||
delete(h.handlers, id)
|
||||
}
|
||||
h.connMx.Unlock()
|
||||
h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -47,8 +48,16 @@ type mockPacketHandler struct {
|
||||
destruction chan<- error
|
||||
}
|
||||
|
||||
func (h *mockPacketHandler) handlePacket(p receivedPacket) { h.packets <- p }
|
||||
func (h *mockPacketHandler) destroy(err error) { h.destruction <- err }
|
||||
func (h *mockPacketHandler) handlePacket(p receivedPacket) {
|
||||
h.packets <- p
|
||||
}
|
||||
|
||||
func (h *mockPacketHandler) destroy(err error) {
|
||||
if h.destruction != nil {
|
||||
h.destruction <- err
|
||||
}
|
||||
}
|
||||
|
||||
func (h *mockPacketHandler) closeWithTransportError(code qerr.TransportErrorCode) {}
|
||||
|
||||
func getPacket(t *testing.T, connID protocol.ConnectionID) []byte {
|
||||
@@ -71,28 +80,19 @@ func getPacketWithPacketType(t *testing.T, connID protocol.ConnectionID, typ pro
|
||||
}
|
||||
|
||||
func TestTransportPacketHandling(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
|
||||
tr := &Transport{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
handlerMap: phm,
|
||||
}
|
||||
tr := &Transport{Conn: newUDPConnLocalhost(t)}
|
||||
tr.init(true)
|
||||
defer func() {
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
tr.Close()
|
||||
}()
|
||||
defer tr.Close()
|
||||
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
|
||||
|
||||
connChan1 := make(chan receivedPacket, 1)
|
||||
conn1 := &mockPacketHandler{packets: connChan1}
|
||||
phm.EXPECT().Get(connID1).Return(conn1, true)
|
||||
(*packetHandlerMap)(tr).Add(connID1, conn1)
|
||||
connChan2 := make(chan receivedPacket, 1)
|
||||
conn2 := &mockPacketHandler{packets: connChan2}
|
||||
phm.EXPECT().Get(connID2).Return(conn2, true)
|
||||
(*packetHandlerMap)(tr).Add(connID2, conn2)
|
||||
|
||||
conn := newUDPConnLocalhost(t)
|
||||
_, err := conn.WriteTo(getPacket(t, connID1), tr.Conn.LocalAddr())
|
||||
@@ -158,51 +158,49 @@ func TestTransportAndDialConcurrentClose(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTransportErrFromConn(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
readErrChan := make(chan error, 2)
|
||||
conn := &mockPacketConn{readErrs: readErrChan, localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}}
|
||||
tr := Transport{Conn: conn, handlerMap: phm}
|
||||
defer tr.Close()
|
||||
t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true")
|
||||
|
||||
readErrChan := make(chan error, 2)
|
||||
tr := Transport{
|
||||
Conn: &mockPacketConn{
|
||||
readErrs: readErrChan,
|
||||
localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234},
|
||||
},
|
||||
}
|
||||
defer tr.Close()
|
||||
tr.init(true)
|
||||
tr.handlerMap = phm
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
ph := &mockPacketHandler{destruction: errChan}
|
||||
(*packetHandlerMap)(&tr).Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), ph)
|
||||
|
||||
// temporary errors don't lead to a shutdown...
|
||||
var tempErr deadlineError
|
||||
require.True(t, tempErr.Temporary())
|
||||
readErrChan <- tempErr
|
||||
// don't expect any calls to phm.Close
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond))
|
||||
time.Sleep(scaleDuration(10 * time.Millisecond))
|
||||
|
||||
// ...but non-temporary errors do
|
||||
done := make(chan struct{})
|
||||
phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
|
||||
readErrChan <- errors.New("read failed")
|
||||
select {
|
||||
case <-done:
|
||||
case err := <-errChan:
|
||||
require.ErrorIs(t, err, ErrTransportClosed)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
_, err := tr.Listen(&tls.Config{}, nil)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrTransportClosed)
|
||||
}
|
||||
|
||||
func TestTransportStatelessResetReceiving(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr := &Transport{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ConnectionIDLength: 4,
|
||||
handlerMap: phm,
|
||||
}
|
||||
tr.init(true)
|
||||
defer func() {
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
tr.Close()
|
||||
}()
|
||||
defer tr.Close()
|
||||
|
||||
connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12})
|
||||
// now send a packet with a connection ID that doesn't exist
|
||||
@@ -213,10 +211,7 @@ func TestTransportStatelessResetReceiving(t *testing.T) {
|
||||
|
||||
destroyChan := make(chan error, 1)
|
||||
conn1 := &mockPacketHandler{destruction: destroyChan}
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().Get(connID), // no handler for this connection ID
|
||||
phm.EXPECT().GetByResetToken(token).Return(conn1, true),
|
||||
)
|
||||
(*packetHandlerMap)(tr).AddResetToken(token, conn1)
|
||||
|
||||
conn := newUDPConnLocalhost(t)
|
||||
_, err = conn.WriteTo(b, tr.Conn.LocalAddr())
|
||||
@@ -224,7 +219,7 @@ func TestTransportStatelessResetReceiving(t *testing.T) {
|
||||
|
||||
select {
|
||||
case err := <-destroyChan:
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &qerr.StatelessResetError{})
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
@@ -232,25 +227,20 @@ func TestTransportStatelessResetReceiving(t *testing.T) {
|
||||
|
||||
func TestTransportStatelessResetSending(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
||||
tr := &Transport{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ConnectionIDLength: 4,
|
||||
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
|
||||
handlerMap: phm,
|
||||
Tracer: tracer,
|
||||
}
|
||||
tr.init(true)
|
||||
defer func() {
|
||||
mockTracer.EXPECT().Close()
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
tr.Close()
|
||||
}()
|
||||
|
||||
connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12})
|
||||
phm.EXPECT().Get(connID) // no handler for this connection ID
|
||||
phm.EXPECT().GetByResetToken(gomock.Any())
|
||||
|
||||
// now send a packet with a connection ID that doesn't exist
|
||||
b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne)
|
||||
@@ -274,8 +264,6 @@ func TestTransportStatelessResetSending(t *testing.T) {
|
||||
require.True(t, mockCtrl.Satisfied())
|
||||
|
||||
// but a stateless reset is sent for packets larger than MinStatelessResetSize
|
||||
phm.EXPECT().Get(connID) // no handler for this connection ID
|
||||
phm.EXPECT().GetByResetToken(gomock.Any())
|
||||
_, err = conn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
@@ -287,7 +275,7 @@ func TestTransportStatelessResetSending(t *testing.T) {
|
||||
require.Contains(t, string(p[:n]), string(srt[:]))
|
||||
}
|
||||
|
||||
func TestTransportDropsUnparseableQUICPackets(t *testing.T) {
|
||||
func TestTransportUnparseableQUICPackets(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
||||
tr := &Transport{
|
||||
@@ -511,7 +499,7 @@ func testTransportDial(t *testing.T, early bool) {
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
_ *Transport,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
@@ -586,7 +574,7 @@ func TestTransportDialingVersionNegotiation(t *testing.T) {
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
_ *Transport,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
@@ -636,3 +624,78 @@ func TestTransportDialingVersionNegotiation(t *testing.T) {
|
||||
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
conn2.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
}
|
||||
|
||||
func TestTransportReplaceWithClosed(t *testing.T) {
|
||||
t.Run("local", func(t *testing.T) {
|
||||
testTransportReplaceWithClosed(t, true)
|
||||
})
|
||||
t.Run("remote", func(t *testing.T) {
|
||||
testTransportReplaceWithClosed(t, false)
|
||||
})
|
||||
}
|
||||
|
||||
func testTransportReplaceWithClosed(t *testing.T, local bool) {
|
||||
srk := StatelessResetKey{1, 2, 3, 4}
|
||||
tr := &Transport{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ConnectionIDLength: 4,
|
||||
StatelessResetKey: &srk,
|
||||
}
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
|
||||
dur := scaleDuration(10 * time.Millisecond)
|
||||
|
||||
var closePacket []byte
|
||||
if local {
|
||||
closePacket = []byte("foobar")
|
||||
}
|
||||
|
||||
handler := &mockPacketHandler{}
|
||||
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
m := (*packetHandlerMap)(tr)
|
||||
require.True(t, m.Add(connID, handler))
|
||||
start := time.Now()
|
||||
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, dur)
|
||||
|
||||
p := make([]byte, 100)
|
||||
p[0] = 0x40 // QUIC bit
|
||||
copy(p[1:], connID.Bytes())
|
||||
|
||||
conn := newUDPConnLocalhost(t)
|
||||
var sent int
|
||||
for now := range time.NewTicker(dur / 20).C {
|
||||
_, err := conn.WriteTo(p, tr.Conn.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
sent++
|
||||
if now.After(start.Add(dur)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
// For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff
|
||||
for i := 0; i*i < sent; i++ {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
b := make([]byte, 100)
|
||||
if local {
|
||||
n, _, err := conn.ReadFrom(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foobar"), b[:n])
|
||||
}
|
||||
}
|
||||
// Afterwards, we receive a stateless reset, not a copy of the CONNECTION_CLOSE packet.
|
||||
// Retry up to 3 times, the connection is deleted from the map on a timer.
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := conn.WriteTo(p, tr.Conn.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
conn.SetReadDeadline(time.Now().Add(dur / 4))
|
||||
b := make([]byte, 100)
|
||||
n, _, err := conn.ReadFrom(b)
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
continue
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, []byte("foobar"), b[:n])
|
||||
require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user