add support for connection migration (#4960)

This commit is contained in:
Marten Seemann
2025-03-12 12:11:11 +07:00
committed by GitHub
parent 0a2c2f0a82
commit 24acc54ef1
14 changed files with 782 additions and 36 deletions

View File

@@ -15,7 +15,7 @@ type connRunnerCallbacks struct {
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
}
type connRunners []connRunnerCallbacks
type connRunners map[transportID]connRunnerCallbacks
func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
@@ -55,6 +55,7 @@ type connIDGenerator struct {
}
func newConnIDGenerator(
tID transportID,
initialConnectionID protocol.ConnectionID,
initialClientDestConnID *protocol.ConnectionID, // nil for the client
statelessResetter *statelessResetter,
@@ -66,7 +67,7 @@ func newConnIDGenerator(
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
statelessResetter: statelessResetter,
connRunners: []connRunnerCallbacks{connRunner},
connRunners: map[transportID]connRunnerCallbacks{tID: connRunner},
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
@@ -162,12 +163,17 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
m.connRunners.ReplaceWithClosed(connIDs, connClose)
}
func (m *connIDGenerator) AddConnRunner(r connRunnerCallbacks) {
func (m *connIDGenerator) AddConnRunner(id transportID, r connRunnerCallbacks) {
// The transport might have already been added earlier.
// This happens if the application migrates back to and old path.
if _, ok := m.connRunners[id]; ok {
return
}
m.connRunners[id] = r
if m.initialClientDestConnID != nil {
r.AddConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
r.AddConnectionID(connID)
}
m.connRunners = append(m.connRunners, r)
}

View File

@@ -32,6 +32,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
initialClientDestConnID = &connID
}
g := newConnIDGenerator(
1,
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
sr,
@@ -116,6 +117,7 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool)
removed []protocol.ConnectionID
)
g := newConnIDGenerator(
0,
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
@@ -166,6 +168,7 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
replacedWith []byte
)
g := newConnIDGenerator(
1,
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
@@ -229,6 +232,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
var queuedFrames []wire.Frame
g := newConnIDGenerator(
1,
initialConnID,
&clientDestConnID,
sr,
@@ -240,7 +244,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
require.Len(t, tracker1.added, 2)
// add the second runner - it should get all existing connection IDs
g.AddConnRunner(runner2)
g.AddConnRunner(2, runner2)
require.Len(t, tracker1.added, 2) // unchanged
require.Len(t, tracker2.added, 4)
require.Contains(t, tracker2.added, initialConnID)

View File

@@ -112,6 +112,8 @@ func nextConnTracingID() ConnectionTracingID { return ConnectionTracingID(connTr
// A Connection is a QUIC connection
type connection struct {
tr *Transport
// Destination connection ID used during the handshake.
// Used to check source connection ID on incoming packets.
handshakeDestConnID protocol.ConnectionID
@@ -129,8 +131,9 @@ type connection struct {
sendQueue sender
// lazily initialzed: most connections never migrate
pathManager *pathManager
largestRcvdAppData protocol.PacketNumber
pathManager *pathManager
largestRcvdAppData protocol.PacketNumber
pathManagerOutgoing atomic.Pointer[pathManagerOutgoing]
streamsMap streamManager
connIDManager *connIDManager
@@ -223,7 +226,7 @@ var newConnection = func(
ctx context.Context,
ctxCancel context.CancelCauseFunc,
conn sendConn,
runner connRunner,
tr *Transport,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
@@ -242,6 +245,7 @@ var newConnection = func(
s := &connection{
ctx: ctx,
ctxCancel: ctxCancel,
tr: tr,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
@@ -258,6 +262,7 @@ var newConnection = func(
} else {
s.logID = destConnID.String()
}
runner := tr.connRunner()
s.connIDManager = newConnIDManager(
destConnID,
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
@@ -265,6 +270,7 @@ var newConnection = func(
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
tr.id(),
srcConnID,
&clientDestConnID,
statelessResetter,
@@ -301,7 +307,6 @@ var newConnection = func(
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent,
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
DisableActiveMigration: true,
StatelessResetToken: &statelessResetToken,
OriginalDestinationConnectionID: origDestConnID,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
@@ -344,7 +349,7 @@ var newConnection = func(
var newClientConnection = func(
ctx context.Context,
conn sendConn,
runner connRunner,
tr *Transport,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
@@ -359,6 +364,7 @@ var newClientConnection = func(
v protocol.Version,
) quicConn {
s := &connection{
tr: tr,
conn: conn,
config: conf,
origDestConnID: destConnID,
@@ -371,6 +377,7 @@ var newClientConnection = func(
versionNegotiated: hasNegotiatedVersion,
version: v,
}
runner := tr.connRunner()
s.connIDManager = newConnIDManager(
destConnID,
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
@@ -378,6 +385,7 @@ var newClientConnection = func(
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
tr.id(),
srcConnID,
nil,
statelessResetter,
@@ -415,7 +423,6 @@ var newClientConnection = func(
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
AckDelayExponent: protocol.AckDelayExponent,
DisableActiveMigration: true,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
// different from protocol.DefaultActiveConnectionIDLimit.
// If set to the default value, it will be omitted from the transport parameters, which will make
@@ -645,6 +652,16 @@ runLoop:
}
}
if s.perspective == protocol.PerspectiveClient {
pm := s.pathManagerOutgoing.Load()
if pm != nil {
tr, ok := pm.ShouldSwitchPath()
if ok {
s.switchToNewPath(tr, now)
}
}
}
if s.sendQueue.WouldBlock() {
// The send queue is still busy sending out packets. Wait until there's space to enqueue new packets.
sendQueueAvailable = s.sendQueue.Available()
@@ -759,6 +776,24 @@ func (s *connection) idleTimeoutStartTime() time.Time {
return startTime
}
func (s *connection) switchToNewPath(tr *Transport, now time.Time) {
initialPacketSize := protocol.ByteCount(s.config.InitialPacketSize)
s.sentPacketHandler.MigratedPath(now, initialPacketSize)
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if s.peerParams.MaxUDPPayloadSize > 0 && s.peerParams.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = s.peerParams.MaxUDPPayloadSize
}
s.mtuDiscoverer.Reset(now, initialPacketSize, maxPacketSize)
s.conn = newSendConn(tr.conn, s.conn.RemoteAddr(), packetInfo{}, utils.DefaultLogger) // TODO: find a better way
s.sendQueue.Close()
s.sendQueue = newSendQueue(s.conn)
go func() {
if err := s.sendQueue.Run(); err != nil {
s.destroyImpl(err)
}
}()
}
func (s *connection) handleHandshakeComplete(now time.Time) error {
defer close(s.handshakeCompleteChan)
// Once the handshake completes, we have derived 1-RTT keys.
@@ -1642,7 +1677,29 @@ func (s *connection) handlePathChallengeFrame(f *wire.PathChallengeFrame) {
}
func (s *connection) handlePathResponseFrame(f *wire.PathResponseFrame) error {
s.logger.Debugf("received PATH_RESPONSE frame: %v", f.Data)
switch s.perspective {
case protocol.PerspectiveClient:
return s.handlePathResponseFrameClient(f)
case protocol.PerspectiveServer:
return s.handlePathResponseFrameServer(f)
default:
panic("unreachable")
}
}
func (s *connection) handlePathResponseFrameClient(f *wire.PathResponseFrame) error {
pm := s.pathManagerOutgoing.Load()
if pm == nil {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "unexpected PATH_RESPONSE frame",
}
}
pm.HandlePathResponseFrame(f)
return nil
}
func (s *connection) handlePathResponseFrameServer(f *wire.PathResponseFrame) error {
if s.pathManager == nil {
// since we didn't send PATH_CHALLENGEs yet, we don't expect PATH_RESPONSEs
return &qerr.TransportError{
@@ -2020,6 +2077,25 @@ func (s *connection) triggerSending(now time.Time) error {
}
func (s *connection) sendPackets(now time.Time) error {
if s.perspective == protocol.PerspectiveClient && s.handshakeConfirmed {
if pm := s.pathManagerOutgoing.Load(); pm != nil {
connID, frame, tr, ok := pm.NextPathToProbe()
if ok {
probe, buf, err := s.packer.PackPathProbePacket(connID, frame, s.version)
if err != nil {
return err
}
s.logger.Debugf("sending path probe packet from %s", s.LocalAddr())
s.logShortHeaderPacket(probe.DestConnID, probe.Ack, probe.Frames, probe.StreamFrames, probe.PacketNumber, probe.PacketNumberLen, probe.KeyPhase, protocol.ECNNon, buf.Len(), false)
s.registerPackedShortHeaderPacket(probe, protocol.ECNNon, now)
tr.WriteTo(buf.Data, s.conn.RemoteAddr())
// There's (likely) more data to send. Loop around again.
s.scheduleSending()
return nil
}
}
}
// Path MTU Discovery
// Can't use GSO, since we need to send a single packet that's larger than our current maximum size.
// Performance-wise, this doesn't matter, since we only send a very small (<10) number of
@@ -2527,6 +2603,43 @@ func (s *connection) ReceiveDatagram(ctx context.Context) ([]byte, error) {
func (s *connection) LocalAddr() net.Addr { return s.conn.LocalAddr() }
func (s *connection) RemoteAddr() net.Addr { return s.conn.RemoteAddr() }
func (s *connection) getPathManager() *pathManagerOutgoing {
s.pathManagerOutgoing.CompareAndSwap(nil,
func() *pathManagerOutgoing { // this function is only called if a swap is performed
return newPathManagerOutgoing(
s.connIDManager.GetConnIDForPath,
s.connIDManager.RetireConnIDForPath,
s.scheduleSending,
)
}(),
)
return s.pathManagerOutgoing.Load()
}
func (s *connection) AddPath(t *Transport) (*Path, error) {
if s.perspective == protocol.PerspectiveServer {
return nil, errors.New("server cannot initiate connection migration")
}
if s.peerParams.DisableActiveMigration {
return nil, errors.New("server disabled connection migration")
}
if err := t.init(false); err != nil {
return nil, err
}
return s.getPathManager().NewPath(t, func() {
runner := t.connRunner()
s.connIDGenerator.AddConnRunner(
t.id(),
connRunnerCallbacks{
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
RemoveConnectionID: runner.Remove,
RetireConnectionID: runner.Retire,
ReplaceWithClosed: runner.ReplaceWithClosed,
},
)
}), nil
}
func (s *connection) NextConnection(ctx context.Context) (Connection, error) {
// The handshake might fail after the server rejected 0-RTT.
// This could happen if the Finished message is malformed or never received.

View File

@@ -81,7 +81,7 @@ func connectionOptRetrySrcConnID(rcid protocol.ConnectionID) testConnectionOpt {
type testConnection struct {
conn *connection
connRunner *MockConnRunner
connRunner *MockPacketHandlerManager
sendConn *MockSendConn
packer *MockPacker
destConnID protocol.ConnectionID
@@ -101,7 +101,7 @@ func newServerTestConnection(
}
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
connRunner := NewMockConnRunner(mockCtrl)
phm := NewMockPacketHandlerManager(mockCtrl)
sendConn := NewMockSendConn(mockCtrl)
sendConn.EXPECT().capabilities().Return(connCapabilities{GSO: gso}).AnyTimes()
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
@@ -119,7 +119,7 @@ func newServerTestConnection(
ctx,
cancel,
sendConn,
connRunner,
&Transport{handlerMap: phm},
origDestConnID,
nil,
protocol.ConnectionID{},
@@ -141,7 +141,7 @@ func newServerTestConnection(
}
return &testConnection{
conn: conn,
connRunner: connRunner,
connRunner: phm,
sendConn: sendConn,
packer: packer,
destConnID: origDestConnID,
@@ -162,7 +162,7 @@ func newClientTestConnection(
}
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
connRunner := NewMockConnRunner(mockCtrl)
phm := NewMockPacketHandlerManager(mockCtrl)
sendConn := NewMockSendConn(mockCtrl)
sendConn.EXPECT().capabilities().Return(connCapabilities{}).AnyTimes()
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
@@ -178,7 +178,7 @@ func newClientTestConnection(
conn := newClientConnection(
context.Background(),
sendConn,
connRunner,
&Transport{handlerMap: phm},
destConnID,
srcConnID,
&protocol.DefaultConnectionIDGenerator{},
@@ -198,7 +198,7 @@ func newClientTestConnection(
}
return &testConnection{
conn: conn,
connRunner: connRunner,
connRunner: phm,
sendConn: sendConn,
packer: packer,
destConnID: destConnID,
@@ -2994,3 +2994,86 @@ func testConnectionPathValidation(t *testing.T, isNATRebinding bool) {
t.Fatal("timeout")
}
}
func TestConnectionMigrationServer(t *testing.T) {
tc := newServerTestConnection(t, nil, nil, false)
_, err := tc.conn.AddPath(&Transport{})
require.Error(t, err)
require.ErrorContains(t, err, "server cannot initiate connection migration")
}
func TestConnectionMigration(t *testing.T) {
t.Run("disabled", func(t *testing.T) {
testConnectionMigration(t, false)
})
t.Run("enabled", func(t *testing.T) {
testConnectionMigration(t, true)
})
}
func testConnectionMigration(t *testing.T, enabled bool) {
tc := newClientTestConnection(t, nil, nil, false, connectionOptHandshakeConfirmed())
require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
InitialSourceConnectionID: tc.destConnID,
OriginalDestinationConnectionID: tc.destConnID,
DisableActiveMigration: !enabled,
}))
tr := &Transport{
Conn: newUDPConnLocalhost(t),
StatelessResetKey: &StatelessResetKey{},
}
defer tr.Close()
path, err := tc.conn.AddPath(tr)
if !enabled {
require.Error(t, err)
require.ErrorContains(t, err, "server disabled connection migration")
return
}
require.NoError(t, err)
require.NotNil(t, path)
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
shortHeaderPacket{}, errNothingToPack,
).AnyTimes()
packedProbe := make(chan struct{})
tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(_ protocol.ConnectionID, f ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
defer close(packedProbe)
return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil
},
).AnyTimes()
// add a new connection ID, so the path can be probed
require.NoError(t, tc.conn.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{
SequenceNumber: 1,
ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
}))
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
// Adding the path initialized the transport.
// We can test this by triggering a stateless reset.
conn := newUDPConnLocalhost(t)
_, err = conn.WriteTo(append([]byte{0x40}, make([]byte, 100)...), tr.Conn.LocalAddr())
require.NoError(t, err)
conn.SetReadDeadline(time.Now().Add(time.Second))
_, _, err = conn.ReadFrom(make([]byte, 100))
require.NoError(t, err)
go func() { path.Probe(context.Background()) }()
select {
case <-packedProbe:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case <-errChan:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}

View File

@@ -0,0 +1,144 @@
package self_test
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/stretchr/testify/require"
)
func TestConnectionMigration(t *testing.T) {
ln, err := quic.ListenAddr("localhost:0", tlsConfig, getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
tr1 := &quic.Transport{Conn: newUDPConnLocalhost(t)}
defer tr1.Close()
tr2 := &quic.Transport{Conn: newUDPConnLocalhost(t)}
defer tr2.Close()
var packetsPath1, packetsPath2 atomic.Int64
const rtt = 5 * time.Millisecond
proxy := quicproxy.Proxy{
Conn: newUDPConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(dir quicproxy.Direction, from, to net.Addr, _ []byte) time.Duration {
var port int
switch dir {
case quicproxy.DirectionIncoming:
port = from.(*net.UDPAddr).Port
case quicproxy.DirectionOutgoing:
port = to.(*net.UDPAddr).Port
}
switch port {
case tr1.Conn.LocalAddr().(*net.UDPAddr).Port:
packetsPath1.Add(1)
case tr2.Conn.LocalAddr().(*net.UDPAddr).Port:
packetsPath2.Add(1)
default:
fmt.Println("address not found", from)
}
return rtt / 2
},
}
require.NoError(t, proxy.Start())
defer proxy.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := tr1.Dial(ctx, proxy.LocalAddr(), getTLSClientConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer conn.CloseWithError(0, "")
sconn, err := ln.Accept(ctx)
require.NoError(t, err)
defer sconn.CloseWithError(0, "")
sendAndReceiveFile := func(t *testing.T) {
t.Helper()
str, err := conn.OpenUniStream()
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
defer close(errChan)
sstr, err := sconn.AcceptUniStream(ctx)
if err != nil {
errChan <- fmt.Errorf("accepting stream: %w", err)
return
}
data, err := io.ReadAll(sstr)
if err != nil {
errChan <- fmt.Errorf("reading stream data: %w", err)
return
}
if !bytes.Equal(data, PRData) {
errChan <- errors.New("unexpected data")
}
}()
_, err = str.Write(PRData)
require.NoError(t, err)
require.NoError(t, str.Close())
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timed out waiting for data")
}
}
sendAndReceiveFile(t) // stream 2
require.NotZero(t, packetsPath1.Load())
require.Zero(t, packetsPath2.Load())
// probing the path causes a few packets to be sent on path 2
path, err := conn.AddPath(tr2)
require.NoError(t, err)
require.ErrorIs(t, path.Switch(), quic.ErrPathNotValidated)
require.NoError(t, path.Probe(ctx))
require.Less(t, int(packetsPath2.Load()), 5)
// make sure that no more packets are sent on path 2 before switching to the path
c2 := packetsPath2.Load()
sendAndReceiveFile(t) // stream 6
require.Equal(t, packetsPath2.Load(), c2)
time.Sleep(3 * rtt) // wait for ACKs
// now switch and make sure that no packets are sent on path 1
require.NoError(t, path.Switch())
sendAndReceiveFile(t) // stream 10
c1 := packetsPath1.Load()
require.Equal(t, c1, packetsPath1.Load())
require.Greater(t, packetsPath2.Load(), c2)
require.Equal(t, tr2.Conn.LocalAddr(), conn.LocalAddr())
// switch back to the handshake path
time.Sleep(3 * rtt) // wait for ACKs
c1BeforeSwitch := packetsPath1.Load()
c2BeforeSwitch := packetsPath2.Load()
path2, err := conn.AddPath(tr1)
require.NoError(t, err)
require.NoError(t, path2.Probe(ctx))
time.Sleep(3 * rtt) // wait for ACKs
require.NoError(t, path2.Switch())
sendAndReceiveFile(t) // stream 14
require.Greater(t, packetsPath1.Load(), c1BeforeSwitch)
// some path probing might have happened
require.Less(t, int(packetsPath2.Load()-c2BeforeSwitch), 20)
require.Equal(t, tr1.Conn.LocalAddr(), conn.LocalAddr())
}

View File

@@ -205,6 +205,8 @@ type Connection interface {
SendDatagram(payload []byte) error
// ReceiveDatagram gets a message received in a datagram, as specified in RFC 9221.
ReceiveDatagram(context.Context) ([]byte, error)
AddPath(*Transport) (*Path, error)
}
// An EarlyConnection is a connection that is handshaking.

View File

@@ -120,6 +120,45 @@ func (c *MockEarlyConnectionAcceptUniStreamCall) DoAndReturn(f func(context.Cont
return c
}
// AddPath mocks base method.
func (m *MockEarlyConnection) AddPath(arg0 *quic.Transport) (*quic.Path, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddPath", arg0)
ret0, _ := ret[0].(*quic.Path)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AddPath indicates an expected call of AddPath.
func (mr *MockEarlyConnectionMockRecorder) AddPath(arg0 any) *MockEarlyConnectionAddPathCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPath", reflect.TypeOf((*MockEarlyConnection)(nil).AddPath), arg0)
return &MockEarlyConnectionAddPathCall{Call: call}
}
// MockEarlyConnectionAddPathCall wrap *gomock.Call
type MockEarlyConnectionAddPathCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockEarlyConnectionAddPathCall) Return(arg0 *quic.Path, arg1 error) *MockEarlyConnectionAddPathCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockEarlyConnectionAddPathCall) Do(f func(*quic.Transport) (*quic.Path, error)) *MockEarlyConnectionAddPathCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockEarlyConnectionAddPathCall) DoAndReturn(f func(*quic.Transport) (*quic.Path, error)) *MockEarlyConnectionAddPathCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// CloseWithError mocks base method.
func (m *MockEarlyConnection) CloseWithError(arg0 quic.ApplicationErrorCode, arg1 string) error {
m.ctrl.T.Helper()

View File

@@ -119,6 +119,45 @@ func (c *MockQUICConnAcceptUniStreamCall) DoAndReturn(f func(context.Context) (R
return c
}
// AddPath mocks base method.
func (m *MockQUICConn) AddPath(arg0 *Transport) (*Path, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddPath", arg0)
ret0, _ := ret[0].(*Path)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AddPath indicates an expected call of AddPath.
func (mr *MockQUICConnMockRecorder) AddPath(arg0 any) *MockQUICConnAddPathCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPath", reflect.TypeOf((*MockQUICConn)(nil).AddPath), arg0)
return &MockQUICConnAddPathCall{Call: call}
}
// MockQUICConnAddPathCall wrap *gomock.Call
type MockQUICConnAddPathCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnAddPathCall) Return(arg0 *Path, arg1 error) *MockQUICConnAddPathCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnAddPathCall) Do(f func(*Transport) (*Path, error)) *MockQUICConnAddPathCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnAddPathCall) DoAndReturn(f func(*Transport) (*Path, error)) *MockQUICConnAddPathCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// CloseWithError mocks base method.
func (m *MockQUICConn) CloseWithError(arg0 ApplicationErrorCode, arg1 string) error {
m.ctrl.T.Helper()

205
path_manager_outgoing.go Normal file
View File

@@ -0,0 +1,205 @@
package quic
import (
"context"
"crypto/rand"
"errors"
"sync"
"sync/atomic"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
// ErrPathNotValidated is returned when trying to use a path before path probing has completed.
var ErrPathNotValidated = errors.New("path not yet validated")
var errPathDoesNotExist = errors.New("path does not exist")
// Path is a network path.
type Path struct {
id pathID
pathManager *pathManagerOutgoing
tr *Transport
enablePath func()
startedProbing atomic.Bool
}
func (p *Path) Probe(ctx context.Context) error {
done := make(chan struct{})
p.pathManager.addPath(p, p.enablePath, done)
p.startedProbing.Store(true)
select {
case <-ctx.Done():
return context.Cause(ctx)
case <-done:
return nil
}
}
// Switch switches the QUIC connection to this path.
// It immediately stops sending on the old path, and sends on this new path.
func (p *Path) Switch() error {
if err := p.pathManager.switchToPath(p.id); err != nil {
if errors.Is(err, errPathDoesNotExist) && !p.startedProbing.Load() {
return ErrPathNotValidated
}
return err
}
return nil
}
type pathOutgoing struct {
pathChallenge [8]byte
tr *Transport
isValidated bool
validated chan<- struct{} // closed when the path the corresponding PATH_RESPONSE is received
enablePath func()
}
type pathManagerOutgoing struct {
getConnID func(pathID) (_ protocol.ConnectionID, ok bool)
retireConnID func(pathID)
scheduleSending func()
mx sync.Mutex
pathsToProbe []pathID
paths map[pathID]*pathOutgoing
nextPathID pathID
pathToSwitchTo *pathOutgoing
}
func newPathManagerOutgoing(
getConnID func(pathID) (_ protocol.ConnectionID, ok bool),
retireConnID func(pathID),
scheduleSending func(),
) *pathManagerOutgoing {
return &pathManagerOutgoing{
getConnID: getConnID,
retireConnID: retireConnID,
scheduleSending: scheduleSending,
paths: make(map[pathID]*pathOutgoing, 4),
}
}
func (pm *pathManagerOutgoing) addPath(p *Path, enablePath func(), done chan<- struct{}) {
var b [8]byte
_, _ = rand.Read(b[:])
pm.mx.Lock()
pm.paths[p.id] = &pathOutgoing{
pathChallenge: b,
tr: p.tr,
validated: done,
enablePath: enablePath,
}
pm.pathsToProbe = append(pm.pathsToProbe, p.id)
pm.mx.Unlock()
pm.scheduleSending()
}
func (pm *pathManagerOutgoing) switchToPath(id pathID) error {
pm.mx.Lock()
defer pm.mx.Unlock()
p, ok := pm.paths[id]
if !ok {
return errPathDoesNotExist
}
if !p.isValidated {
return ErrPathNotValidated
}
pm.pathToSwitchTo = p
return nil
}
func (pm *pathManagerOutgoing) NewPath(t *Transport, enablePath func()) *Path {
pm.mx.Lock()
defer pm.mx.Unlock()
id := pm.nextPathID
pm.nextPathID++
return &Path{
pathManager: pm,
id: id,
tr: t,
enablePath: enablePath,
}
}
func (pm *pathManagerOutgoing) NextPathToProbe() (_ protocol.ConnectionID, _ ackhandler.Frame, _ *Transport, hasPath bool) {
pm.mx.Lock()
defer pm.mx.Unlock()
var p *pathOutgoing
var id pathID
for {
if len(pm.pathsToProbe) == 0 {
return protocol.ConnectionID{}, ackhandler.Frame{}, nil, false
}
id = pm.pathsToProbe[0]
pm.pathsToProbe = pm.pathsToProbe[1:]
var ok bool
// if the path doesn't exist in the map, it might have been abandoned
p, ok = pm.paths[id]
if ok {
break
}
}
connID, ok := pm.getConnID(id)
if !ok {
return protocol.ConnectionID{}, ackhandler.Frame{}, nil, false
}
p.enablePath()
frame := ackhandler.Frame{
Frame: &wire.PathChallengeFrame{Data: p.pathChallenge},
Handler: (*pathManagerOutgoingAckHandler)(pm),
}
return connID, frame, p.tr, true
}
func (pm *pathManagerOutgoing) HandlePathResponseFrame(f *wire.PathResponseFrame) {
pm.mx.Lock()
defer pm.mx.Unlock()
for _, p := range pm.paths {
if f.Data == p.pathChallenge {
// path validated
if !p.isValidated {
p.isValidated = true
close(p.validated)
}
// makes sure that duplicate PATH_RESPONSE frames are ignored
p.validated = nil
break
}
}
}
func (pm *pathManagerOutgoing) ShouldSwitchPath() (*Transport, bool) {
pm.mx.Lock()
defer pm.mx.Unlock()
if pm.pathToSwitchTo == nil {
return nil, false
}
p := pm.pathToSwitchTo
pm.pathToSwitchTo = nil
return p.tr, true
}
type pathManagerOutgoingAckHandler pathManagerOutgoing
var _ ackhandler.FrameHandler = &pathManagerOutgoingAckHandler{}
// OnAcked is called when the PATH_CHALLENGE is acked.
// This doesn't validate the path, only receiving the PATH_RESPONSE does.
func (pm *pathManagerOutgoingAckHandler) OnAcked(wire.Frame) {}
func (pm *pathManagerOutgoingAckHandler) OnLost(wire.Frame) {}

View File

@@ -0,0 +1,99 @@
package quic
import (
"context"
"testing"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/require"
)
func TestPathManagerOutgoing(t *testing.T) {
connIDs := []protocol.ConnectionID{
protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
protocol.ParseConnectionID([]byte{2, 3, 4, 5, 6, 7, 8, 9}),
}
var retiredConnIDs []protocol.ConnectionID
pm := newPathManagerOutgoing(
func(id pathID) (protocol.ConnectionID, bool) { return connIDs[id], true },
func(id pathID) { retiredConnIDs = append(retiredConnIDs, connIDs[id]) },
func() {},
)
_, _, _, ok := pm.NextPathToProbe()
require.False(t, ok)
tr1 := &Transport{}
var enabled bool
p := pm.NewPath(tr1, func() { enabled = true })
require.ErrorIs(t, p.Switch(), ErrPathNotValidated)
errChan := make(chan error, 1)
go func() { errChan <- p.Probe(context.Background()) }()
// wait for the path to be queued for probing
time.Sleep(scaleDuration(5 * time.Millisecond))
require.False(t, enabled)
connID, f, tr, ok := pm.NextPathToProbe()
require.True(t, ok)
require.Equal(t, tr1, tr)
require.Equal(t, connIDs[0], connID)
require.IsType(t, &wire.PathChallengeFrame{}, f.Frame)
pc := f.Frame.(*wire.PathChallengeFrame)
require.True(t, enabled)
_, _, _, ok = pm.NextPathToProbe()
require.False(t, ok)
select {
case <-errChan:
t.Fatal("should still be probing")
default:
}
// acking the frame doesn't complete path validation...
f.Handler.OnAcked(f.Frame)
select {
case <-errChan:
t.Fatal("should still be probing")
default:
}
require.ErrorIs(t, p.Switch(), ErrPathNotValidated)
_, ok = pm.ShouldSwitchPath()
require.False(t, ok)
// ... neither does receiving a random PATH_RESPONSE...
pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: [8]byte{'f', 'o', 'o', 'f', 'o', 'o'}})
f.Handler.OnAcked(f.Frame) // doesn't do anything
f.Handler.OnLost(f.Frame) // doesn't do anything
select {
case <-errChan:
t.Fatal("should still be probing")
default:
}
// ... only receiving the corresponding PATH_RESPONSE does
pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc.Data})
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// receiving it multiple times is ok
pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc.Data})
// now switch to the other path
_, ok = pm.ShouldSwitchPath()
require.False(t, ok)
require.NoError(t, p.Switch())
switchToTransport, ok := pm.ShouldSwitchPath()
require.True(t, ok)
require.Equal(t, tr1, switchToTransport)
}

View File

@@ -61,6 +61,7 @@ type rejectedPacket struct {
// A Listener of QUIC
type baseServer struct {
tr *Transport
disableVersionNegotiation bool
acceptEarlyConns bool
@@ -74,7 +75,6 @@ type baseServer struct {
connIDGenerator ConnectionIDGenerator
statelessResetter *statelessResetter
connHandler packetHandlerManager
onClose func()
receivedPackets chan receivedPacket
@@ -89,7 +89,7 @@ type baseServer struct {
context.Context,
context.CancelCauseFunc,
sendConn,
connRunner,
*Transport,
protocol.ConnectionID, /* original dest connection ID */
*protocol.ConnectionID, /* retry src connection ID */
protocol.ConnectionID, /* client dest connection ID */
@@ -247,7 +247,7 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear
func newServer(
conn rawConn,
connHandler packetHandlerManager,
tr *Transport,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
connContext func(context.Context) context.Context,
@@ -264,6 +264,7 @@ func newServer(
s := &baseServer{
conn: conn,
connContext: connContext,
tr: tr,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
@@ -271,7 +272,6 @@ func newServer(
verifySourceAddress: verifySourceAddress,
connIDGenerator: connIDGenerator,
statelessResetter: statelessResetter,
connHandler: connHandler,
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
stopAccepting: make(chan struct{}),
@@ -501,7 +501,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
}
// check again if we might have a connection now
if handler, ok := s.connHandler.Get(connID); ok {
if handler, ok := s.tr.connRunner().Get(connID); ok {
handler.handlePacket(p)
return true
}
@@ -591,7 +591,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
// The server queues packets for a while, and we might already have established a connection by now.
// This results in a second check in the connection map.
// That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets).
if handler, ok := s.connHandler.Get(hdr.DestConnectionID); ok {
if handler, ok := s.tr.connRunner().Get(hdr.DestConnectionID); ok {
handler.handlePacket(p)
return nil
}
@@ -706,7 +706,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
ctx,
cancel,
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
s.tr,
origDestConnID,
retrySrcConnID,
hdr.DestConnectionID,
@@ -727,7 +727,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
// This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
// under normal circumstances the packet would just be routed to that connection.
// The only time this collision will occur if we receive the two Initial packets at the same time.
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
if added := s.tr.connRunner().AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
delete(s.zeroRTTQueues, hdr.DestConnectionID)
conn.closeWithTransportError(qerr.ConnectionRefused)
return nil

View File

@@ -37,7 +37,7 @@ type serverOpts struct {
context.Context,
context.CancelCauseFunc,
sendConn,
connRunner,
*Transport,
protocol.ConnectionID, // original dest connection ID
*protocol.ConnectionID, // retry src connection ID
protocol.ConnectionID, // client dest connection ID
@@ -63,7 +63,7 @@ func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer {
config := populateConfig(serverOpts.config)
s := newServer(
c,
newPacketHandlerMap(nil, utils.DefaultLogger),
&Transport{handlerMap: newPacketHandlerMap(nil, utils.DefaultLogger)},
&protocol.DefaultConnectionIDGenerator{},
&statelessResetter{},
func(ctx context.Context) context.Context { return ctx },
@@ -623,7 +623,7 @@ func (r *connConstructorRecorder) NewConn(
ctx context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ *Transport,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
@@ -863,7 +863,7 @@ func TestServerPacketHandling(t *testing.T) {
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
handledPacket <- p
})
server.connHandler.Add(destConnID, conn)
server.tr.handlerMap.Add(destConnID, conn)
server.handlePacket(
getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, srcConnID, destConnID),
@@ -889,7 +889,7 @@ func TestServerReceiveQueue(t *testing.T) {
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ *Transport,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,

View File

@@ -38,6 +38,10 @@ func (e *errTransportClosed) Is(target error) bool {
return ok
}
type transportID uint64
var transportIDCounter atomic.Uint64
var errListenerAlreadySet = errors.New("listener already set")
// The Transport is the central point to manage incoming and outgoing QUIC connections.
@@ -133,6 +137,7 @@ type Transport struct {
initErr error
// Set in init.
transportID transportID
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
connIDLen int
// Set in init.
@@ -207,7 +212,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
}
s := newServer(
t.conn,
t.handlerMap,
t,
t.connIDGenerator,
t.statelessResetter,
t.ConnContext,
@@ -298,7 +303,7 @@ func (t *Transport) doDial(
conn := newClientConnection(
context.WithoutCancel(ctx),
sendConn,
t.handlerMap,
t,
destConnID,
srcConnID,
t.connIDGenerator,
@@ -371,6 +376,7 @@ func (t *Transport) doDial(
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.initOnce.Do(func() {
t.transportID = transportID(transportIDCounter.Add(1))
var conn rawConn
if c, ok := t.Conn.(rawConn); ok {
conn = c
@@ -420,6 +426,12 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
return t.initErr
}
func (t *Transport) connRunner() packetHandlerManager {
return t.handlerMap
}
func (t *Transport) id() transportID { return t.transportID }
// WriteTo sends a packet on the underlying connection.
func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
if err := t.init(false); err != nil {

View File

@@ -510,7 +510,7 @@ func testTransportDial(t *testing.T, early bool) {
newClientConnection = func(
_ context.Context,
_ sendConn,
_ connRunner,
_ *Transport,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
@@ -585,7 +585,7 @@ func TestTransportDialingVersionNegotiation(t *testing.T) {
newClientConnection = func(
_ context.Context,
_ sendConn,
_ connRunner,
_ *Transport,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,