forked from quic-go/quic-go
add support for connection migration (#4960)
This commit is contained in:
@@ -15,7 +15,7 @@ type connRunnerCallbacks struct {
|
||||
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
|
||||
}
|
||||
|
||||
type connRunners []connRunnerCallbacks
|
||||
type connRunners map[transportID]connRunnerCallbacks
|
||||
|
||||
func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
|
||||
for _, c := range cr {
|
||||
@@ -55,6 +55,7 @@ type connIDGenerator struct {
|
||||
}
|
||||
|
||||
func newConnIDGenerator(
|
||||
tID transportID,
|
||||
initialConnectionID protocol.ConnectionID,
|
||||
initialClientDestConnID *protocol.ConnectionID, // nil for the client
|
||||
statelessResetter *statelessResetter,
|
||||
@@ -66,7 +67,7 @@ func newConnIDGenerator(
|
||||
generator: generator,
|
||||
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
|
||||
statelessResetter: statelessResetter,
|
||||
connRunners: []connRunnerCallbacks{connRunner},
|
||||
connRunners: map[transportID]connRunnerCallbacks{tID: connRunner},
|
||||
queueControlFrame: queueControlFrame,
|
||||
}
|
||||
m.activeSrcConnIDs[0] = initialConnectionID
|
||||
@@ -162,12 +163,17 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
|
||||
m.connRunners.ReplaceWithClosed(connIDs, connClose)
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) AddConnRunner(r connRunnerCallbacks) {
|
||||
func (m *connIDGenerator) AddConnRunner(id transportID, r connRunnerCallbacks) {
|
||||
// The transport might have already been added earlier.
|
||||
// This happens if the application migrates back to and old path.
|
||||
if _, ok := m.connRunners[id]; ok {
|
||||
return
|
||||
}
|
||||
m.connRunners[id] = r
|
||||
if m.initialClientDestConnID != nil {
|
||||
r.AddConnectionID(*m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
r.AddConnectionID(connID)
|
||||
}
|
||||
m.connRunners = append(m.connRunners, r)
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
|
||||
initialClientDestConnID = &connID
|
||||
}
|
||||
g := newConnIDGenerator(
|
||||
1,
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
initialClientDestConnID,
|
||||
sr,
|
||||
@@ -116,6 +117,7 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool)
|
||||
removed []protocol.ConnectionID
|
||||
)
|
||||
g := newConnIDGenerator(
|
||||
0,
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
initialClientDestConnID,
|
||||
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
|
||||
@@ -166,6 +168,7 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
|
||||
replacedWith []byte
|
||||
)
|
||||
g := newConnIDGenerator(
|
||||
1,
|
||||
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
|
||||
initialClientDestConnID,
|
||||
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
|
||||
@@ -229,6 +232,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
var queuedFrames []wire.Frame
|
||||
|
||||
g := newConnIDGenerator(
|
||||
1,
|
||||
initialConnID,
|
||||
&clientDestConnID,
|
||||
sr,
|
||||
@@ -240,7 +244,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
|
||||
require.Len(t, tracker1.added, 2)
|
||||
|
||||
// add the second runner - it should get all existing connection IDs
|
||||
g.AddConnRunner(runner2)
|
||||
g.AddConnRunner(2, runner2)
|
||||
require.Len(t, tracker1.added, 2) // unchanged
|
||||
require.Len(t, tracker2.added, 4)
|
||||
require.Contains(t, tracker2.added, initialConnID)
|
||||
|
||||
127
connection.go
127
connection.go
@@ -112,6 +112,8 @@ func nextConnTracingID() ConnectionTracingID { return ConnectionTracingID(connTr
|
||||
|
||||
// A Connection is a QUIC connection
|
||||
type connection struct {
|
||||
tr *Transport
|
||||
|
||||
// Destination connection ID used during the handshake.
|
||||
// Used to check source connection ID on incoming packets.
|
||||
handshakeDestConnID protocol.ConnectionID
|
||||
@@ -129,8 +131,9 @@ type connection struct {
|
||||
sendQueue sender
|
||||
|
||||
// lazily initialzed: most connections never migrate
|
||||
pathManager *pathManager
|
||||
largestRcvdAppData protocol.PacketNumber
|
||||
pathManager *pathManager
|
||||
largestRcvdAppData protocol.PacketNumber
|
||||
pathManagerOutgoing atomic.Pointer[pathManagerOutgoing]
|
||||
|
||||
streamsMap streamManager
|
||||
connIDManager *connIDManager
|
||||
@@ -223,7 +226,7 @@ var newConnection = func(
|
||||
ctx context.Context,
|
||||
ctxCancel context.CancelCauseFunc,
|
||||
conn sendConn,
|
||||
runner connRunner,
|
||||
tr *Transport,
|
||||
origDestConnID protocol.ConnectionID,
|
||||
retrySrcConnID *protocol.ConnectionID,
|
||||
clientDestConnID protocol.ConnectionID,
|
||||
@@ -242,6 +245,7 @@ var newConnection = func(
|
||||
s := &connection{
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
tr: tr,
|
||||
conn: conn,
|
||||
config: conf,
|
||||
handshakeDestConnID: destConnID,
|
||||
@@ -258,6 +262,7 @@ var newConnection = func(
|
||||
} else {
|
||||
s.logID = destConnID.String()
|
||||
}
|
||||
runner := tr.connRunner()
|
||||
s.connIDManager = newConnIDManager(
|
||||
destConnID,
|
||||
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
|
||||
@@ -265,6 +270,7 @@ var newConnection = func(
|
||||
s.queueControlFrame,
|
||||
)
|
||||
s.connIDGenerator = newConnIDGenerator(
|
||||
tr.id(),
|
||||
srcConnID,
|
||||
&clientDestConnID,
|
||||
statelessResetter,
|
||||
@@ -301,7 +307,6 @@ var newConnection = func(
|
||||
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
|
||||
AckDelayExponent: protocol.AckDelayExponent,
|
||||
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
|
||||
DisableActiveMigration: true,
|
||||
StatelessResetToken: &statelessResetToken,
|
||||
OriginalDestinationConnectionID: origDestConnID,
|
||||
// For interoperability with quic-go versions before May 2023, this value must be set to a value
|
||||
@@ -344,7 +349,7 @@ var newConnection = func(
|
||||
var newClientConnection = func(
|
||||
ctx context.Context,
|
||||
conn sendConn,
|
||||
runner connRunner,
|
||||
tr *Transport,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
@@ -359,6 +364,7 @@ var newClientConnection = func(
|
||||
v protocol.Version,
|
||||
) quicConn {
|
||||
s := &connection{
|
||||
tr: tr,
|
||||
conn: conn,
|
||||
config: conf,
|
||||
origDestConnID: destConnID,
|
||||
@@ -371,6 +377,7 @@ var newClientConnection = func(
|
||||
versionNegotiated: hasNegotiatedVersion,
|
||||
version: v,
|
||||
}
|
||||
runner := tr.connRunner()
|
||||
s.connIDManager = newConnIDManager(
|
||||
destConnID,
|
||||
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
|
||||
@@ -378,6 +385,7 @@ var newClientConnection = func(
|
||||
s.queueControlFrame,
|
||||
)
|
||||
s.connIDGenerator = newConnIDGenerator(
|
||||
tr.id(),
|
||||
srcConnID,
|
||||
nil,
|
||||
statelessResetter,
|
||||
@@ -415,7 +423,6 @@ var newClientConnection = func(
|
||||
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
|
||||
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
|
||||
AckDelayExponent: protocol.AckDelayExponent,
|
||||
DisableActiveMigration: true,
|
||||
// For interoperability with quic-go versions before May 2023, this value must be set to a value
|
||||
// different from protocol.DefaultActiveConnectionIDLimit.
|
||||
// If set to the default value, it will be omitted from the transport parameters, which will make
|
||||
@@ -645,6 +652,16 @@ runLoop:
|
||||
}
|
||||
}
|
||||
|
||||
if s.perspective == protocol.PerspectiveClient {
|
||||
pm := s.pathManagerOutgoing.Load()
|
||||
if pm != nil {
|
||||
tr, ok := pm.ShouldSwitchPath()
|
||||
if ok {
|
||||
s.switchToNewPath(tr, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.sendQueue.WouldBlock() {
|
||||
// The send queue is still busy sending out packets. Wait until there's space to enqueue new packets.
|
||||
sendQueueAvailable = s.sendQueue.Available()
|
||||
@@ -759,6 +776,24 @@ func (s *connection) idleTimeoutStartTime() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
func (s *connection) switchToNewPath(tr *Transport, now time.Time) {
|
||||
initialPacketSize := protocol.ByteCount(s.config.InitialPacketSize)
|
||||
s.sentPacketHandler.MigratedPath(now, initialPacketSize)
|
||||
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
|
||||
if s.peerParams.MaxUDPPayloadSize > 0 && s.peerParams.MaxUDPPayloadSize < maxPacketSize {
|
||||
maxPacketSize = s.peerParams.MaxUDPPayloadSize
|
||||
}
|
||||
s.mtuDiscoverer.Reset(now, initialPacketSize, maxPacketSize)
|
||||
s.conn = newSendConn(tr.conn, s.conn.RemoteAddr(), packetInfo{}, utils.DefaultLogger) // TODO: find a better way
|
||||
s.sendQueue.Close()
|
||||
s.sendQueue = newSendQueue(s.conn)
|
||||
go func() {
|
||||
if err := s.sendQueue.Run(); err != nil {
|
||||
s.destroyImpl(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *connection) handleHandshakeComplete(now time.Time) error {
|
||||
defer close(s.handshakeCompleteChan)
|
||||
// Once the handshake completes, we have derived 1-RTT keys.
|
||||
@@ -1642,7 +1677,29 @@ func (s *connection) handlePathChallengeFrame(f *wire.PathChallengeFrame) {
|
||||
}
|
||||
|
||||
func (s *connection) handlePathResponseFrame(f *wire.PathResponseFrame) error {
|
||||
s.logger.Debugf("received PATH_RESPONSE frame: %v", f.Data)
|
||||
switch s.perspective {
|
||||
case protocol.PerspectiveClient:
|
||||
return s.handlePathResponseFrameClient(f)
|
||||
case protocol.PerspectiveServer:
|
||||
return s.handlePathResponseFrameServer(f)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *connection) handlePathResponseFrameClient(f *wire.PathResponseFrame) error {
|
||||
pm := s.pathManagerOutgoing.Load()
|
||||
if pm == nil {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "unexpected PATH_RESPONSE frame",
|
||||
}
|
||||
}
|
||||
pm.HandlePathResponseFrame(f)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *connection) handlePathResponseFrameServer(f *wire.PathResponseFrame) error {
|
||||
if s.pathManager == nil {
|
||||
// since we didn't send PATH_CHALLENGEs yet, we don't expect PATH_RESPONSEs
|
||||
return &qerr.TransportError{
|
||||
@@ -2020,6 +2077,25 @@ func (s *connection) triggerSending(now time.Time) error {
|
||||
}
|
||||
|
||||
func (s *connection) sendPackets(now time.Time) error {
|
||||
if s.perspective == protocol.PerspectiveClient && s.handshakeConfirmed {
|
||||
if pm := s.pathManagerOutgoing.Load(); pm != nil {
|
||||
connID, frame, tr, ok := pm.NextPathToProbe()
|
||||
if ok {
|
||||
probe, buf, err := s.packer.PackPathProbePacket(connID, frame, s.version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debugf("sending path probe packet from %s", s.LocalAddr())
|
||||
s.logShortHeaderPacket(probe.DestConnID, probe.Ack, probe.Frames, probe.StreamFrames, probe.PacketNumber, probe.PacketNumberLen, probe.KeyPhase, protocol.ECNNon, buf.Len(), false)
|
||||
s.registerPackedShortHeaderPacket(probe, protocol.ECNNon, now)
|
||||
tr.WriteTo(buf.Data, s.conn.RemoteAddr())
|
||||
// There's (likely) more data to send. Loop around again.
|
||||
s.scheduleSending()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Path MTU Discovery
|
||||
// Can't use GSO, since we need to send a single packet that's larger than our current maximum size.
|
||||
// Performance-wise, this doesn't matter, since we only send a very small (<10) number of
|
||||
@@ -2527,6 +2603,43 @@ func (s *connection) ReceiveDatagram(ctx context.Context) ([]byte, error) {
|
||||
func (s *connection) LocalAddr() net.Addr { return s.conn.LocalAddr() }
|
||||
func (s *connection) RemoteAddr() net.Addr { return s.conn.RemoteAddr() }
|
||||
|
||||
func (s *connection) getPathManager() *pathManagerOutgoing {
|
||||
s.pathManagerOutgoing.CompareAndSwap(nil,
|
||||
func() *pathManagerOutgoing { // this function is only called if a swap is performed
|
||||
return newPathManagerOutgoing(
|
||||
s.connIDManager.GetConnIDForPath,
|
||||
s.connIDManager.RetireConnIDForPath,
|
||||
s.scheduleSending,
|
||||
)
|
||||
}(),
|
||||
)
|
||||
return s.pathManagerOutgoing.Load()
|
||||
}
|
||||
|
||||
func (s *connection) AddPath(t *Transport) (*Path, error) {
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
return nil, errors.New("server cannot initiate connection migration")
|
||||
}
|
||||
if s.peerParams.DisableActiveMigration {
|
||||
return nil, errors.New("server disabled connection migration")
|
||||
}
|
||||
if err := t.init(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.getPathManager().NewPath(t, func() {
|
||||
runner := t.connRunner()
|
||||
s.connIDGenerator.AddConnRunner(
|
||||
t.id(),
|
||||
connRunnerCallbacks{
|
||||
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
|
||||
RemoveConnectionID: runner.Remove,
|
||||
RetireConnectionID: runner.Retire,
|
||||
ReplaceWithClosed: runner.ReplaceWithClosed,
|
||||
},
|
||||
)
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *connection) NextConnection(ctx context.Context) (Connection, error) {
|
||||
// The handshake might fail after the server rejected 0-RTT.
|
||||
// This could happen if the Finished message is malformed or never received.
|
||||
|
||||
@@ -81,7 +81,7 @@ func connectionOptRetrySrcConnID(rcid protocol.ConnectionID) testConnectionOpt {
|
||||
|
||||
type testConnection struct {
|
||||
conn *connection
|
||||
connRunner *MockConnRunner
|
||||
connRunner *MockPacketHandlerManager
|
||||
sendConn *MockSendConn
|
||||
packer *MockPacker
|
||||
destConnID protocol.ConnectionID
|
||||
@@ -101,7 +101,7 @@ func newServerTestConnection(
|
||||
}
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
|
||||
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
|
||||
connRunner := NewMockConnRunner(mockCtrl)
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
sendConn := NewMockSendConn(mockCtrl)
|
||||
sendConn.EXPECT().capabilities().Return(connCapabilities{GSO: gso}).AnyTimes()
|
||||
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
|
||||
@@ -119,7 +119,7 @@ func newServerTestConnection(
|
||||
ctx,
|
||||
cancel,
|
||||
sendConn,
|
||||
connRunner,
|
||||
&Transport{handlerMap: phm},
|
||||
origDestConnID,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
@@ -141,7 +141,7 @@ func newServerTestConnection(
|
||||
}
|
||||
return &testConnection{
|
||||
conn: conn,
|
||||
connRunner: connRunner,
|
||||
connRunner: phm,
|
||||
sendConn: sendConn,
|
||||
packer: packer,
|
||||
destConnID: origDestConnID,
|
||||
@@ -162,7 +162,7 @@ func newClientTestConnection(
|
||||
}
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
|
||||
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
|
||||
connRunner := NewMockConnRunner(mockCtrl)
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
sendConn := NewMockSendConn(mockCtrl)
|
||||
sendConn.EXPECT().capabilities().Return(connCapabilities{}).AnyTimes()
|
||||
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
|
||||
@@ -178,7 +178,7 @@ func newClientTestConnection(
|
||||
conn := newClientConnection(
|
||||
context.Background(),
|
||||
sendConn,
|
||||
connRunner,
|
||||
&Transport{handlerMap: phm},
|
||||
destConnID,
|
||||
srcConnID,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
@@ -198,7 +198,7 @@ func newClientTestConnection(
|
||||
}
|
||||
return &testConnection{
|
||||
conn: conn,
|
||||
connRunner: connRunner,
|
||||
connRunner: phm,
|
||||
sendConn: sendConn,
|
||||
packer: packer,
|
||||
destConnID: destConnID,
|
||||
@@ -2994,3 +2994,86 @@ func testConnectionPathValidation(t *testing.T, isNATRebinding bool) {
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionMigrationServer(t *testing.T) {
|
||||
tc := newServerTestConnection(t, nil, nil, false)
|
||||
_, err := tc.conn.AddPath(&Transport{})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "server cannot initiate connection migration")
|
||||
}
|
||||
|
||||
func TestConnectionMigration(t *testing.T) {
|
||||
t.Run("disabled", func(t *testing.T) {
|
||||
testConnectionMigration(t, false)
|
||||
})
|
||||
|
||||
t.Run("enabled", func(t *testing.T) {
|
||||
testConnectionMigration(t, true)
|
||||
})
|
||||
}
|
||||
|
||||
func testConnectionMigration(t *testing.T, enabled bool) {
|
||||
tc := newClientTestConnection(t, nil, nil, false, connectionOptHandshakeConfirmed())
|
||||
require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
|
||||
InitialSourceConnectionID: tc.destConnID,
|
||||
OriginalDestinationConnectionID: tc.destConnID,
|
||||
DisableActiveMigration: !enabled,
|
||||
}))
|
||||
|
||||
tr := &Transport{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
StatelessResetKey: &StatelessResetKey{},
|
||||
}
|
||||
defer tr.Close()
|
||||
path, err := tc.conn.AddPath(tr)
|
||||
if !enabled {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "server disabled connection migration")
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, path)
|
||||
|
||||
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
|
||||
shortHeaderPacket{}, errNothingToPack,
|
||||
).AnyTimes()
|
||||
packedProbe := make(chan struct{})
|
||||
tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ protocol.ConnectionID, f ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||
defer close(packedProbe)
|
||||
return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil
|
||||
},
|
||||
).AnyTimes()
|
||||
// add a new connection ID, so the path can be probed
|
||||
require.NoError(t, tc.conn.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: 1,
|
||||
ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
||||
}))
|
||||
errChan := make(chan error, 1)
|
||||
go func() { errChan <- tc.conn.run() }()
|
||||
|
||||
// Adding the path initialized the transport.
|
||||
// We can test this by triggering a stateless reset.
|
||||
conn := newUDPConnLocalhost(t)
|
||||
_, err = conn.WriteTo(append([]byte{0x40}, make([]byte, 100)...), tr.Conn.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, _, err = conn.ReadFrom(make([]byte, 100))
|
||||
require.NoError(t, err)
|
||||
|
||||
go func() { path.Probe(context.Background()) }()
|
||||
select {
|
||||
case <-packedProbe:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// teardown
|
||||
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
||||
tc.conn.destroy(nil)
|
||||
select {
|
||||
case <-errChan:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
144
integrationtests/self/connection_migration_test.go
Normal file
144
integrationtests/self/connection_migration_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnectionMigration(t *testing.T) {
|
||||
ln, err := quic.ListenAddr("localhost:0", tlsConfig, getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
tr1 := &quic.Transport{Conn: newUDPConnLocalhost(t)}
|
||||
defer tr1.Close()
|
||||
tr2 := &quic.Transport{Conn: newUDPConnLocalhost(t)}
|
||||
defer tr2.Close()
|
||||
|
||||
var packetsPath1, packetsPath2 atomic.Int64
|
||||
|
||||
const rtt = 5 * time.Millisecond
|
||||
proxy := quicproxy.Proxy{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ServerAddr: ln.Addr().(*net.UDPAddr),
|
||||
DelayPacket: func(dir quicproxy.Direction, from, to net.Addr, _ []byte) time.Duration {
|
||||
var port int
|
||||
switch dir {
|
||||
case quicproxy.DirectionIncoming:
|
||||
port = from.(*net.UDPAddr).Port
|
||||
case quicproxy.DirectionOutgoing:
|
||||
port = to.(*net.UDPAddr).Port
|
||||
}
|
||||
switch port {
|
||||
case tr1.Conn.LocalAddr().(*net.UDPAddr).Port:
|
||||
packetsPath1.Add(1)
|
||||
case tr2.Conn.LocalAddr().(*net.UDPAddr).Port:
|
||||
packetsPath2.Add(1)
|
||||
default:
|
||||
fmt.Println("address not found", from)
|
||||
}
|
||||
return rtt / 2
|
||||
},
|
||||
}
|
||||
require.NoError(t, proxy.Start())
|
||||
defer proxy.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := tr1.Dial(ctx, proxy.LocalAddr(), getTLSClientConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
sconn, err := ln.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sconn.CloseWithError(0, "")
|
||||
|
||||
sendAndReceiveFile := func(t *testing.T) {
|
||||
t.Helper()
|
||||
str, err := conn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
sstr, err := sconn.AcceptUniStream(ctx)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("accepting stream: %w", err)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(sstr)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("reading stream data: %w", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(data, PRData) {
|
||||
errChan <- errors.New("unexpected data")
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = str.Write(PRData)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, str.Close())
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for data")
|
||||
}
|
||||
}
|
||||
|
||||
sendAndReceiveFile(t) // stream 2
|
||||
require.NotZero(t, packetsPath1.Load())
|
||||
require.Zero(t, packetsPath2.Load())
|
||||
|
||||
// probing the path causes a few packets to be sent on path 2
|
||||
path, err := conn.AddPath(tr2)
|
||||
require.NoError(t, err)
|
||||
require.ErrorIs(t, path.Switch(), quic.ErrPathNotValidated)
|
||||
require.NoError(t, path.Probe(ctx))
|
||||
require.Less(t, int(packetsPath2.Load()), 5)
|
||||
|
||||
// make sure that no more packets are sent on path 2 before switching to the path
|
||||
c2 := packetsPath2.Load()
|
||||
sendAndReceiveFile(t) // stream 6
|
||||
require.Equal(t, packetsPath2.Load(), c2)
|
||||
|
||||
time.Sleep(3 * rtt) // wait for ACKs
|
||||
|
||||
// now switch and make sure that no packets are sent on path 1
|
||||
require.NoError(t, path.Switch())
|
||||
sendAndReceiveFile(t) // stream 10
|
||||
c1 := packetsPath1.Load()
|
||||
require.Equal(t, c1, packetsPath1.Load())
|
||||
require.Greater(t, packetsPath2.Load(), c2)
|
||||
require.Equal(t, tr2.Conn.LocalAddr(), conn.LocalAddr())
|
||||
|
||||
// switch back to the handshake path
|
||||
time.Sleep(3 * rtt) // wait for ACKs
|
||||
c1BeforeSwitch := packetsPath1.Load()
|
||||
c2BeforeSwitch := packetsPath2.Load()
|
||||
path2, err := conn.AddPath(tr1)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, path2.Probe(ctx))
|
||||
time.Sleep(3 * rtt) // wait for ACKs
|
||||
require.NoError(t, path2.Switch())
|
||||
sendAndReceiveFile(t) // stream 14
|
||||
require.Greater(t, packetsPath1.Load(), c1BeforeSwitch)
|
||||
// some path probing might have happened
|
||||
require.Less(t, int(packetsPath2.Load()-c2BeforeSwitch), 20)
|
||||
require.Equal(t, tr1.Conn.LocalAddr(), conn.LocalAddr())
|
||||
}
|
||||
@@ -205,6 +205,8 @@ type Connection interface {
|
||||
SendDatagram(payload []byte) error
|
||||
// ReceiveDatagram gets a message received in a datagram, as specified in RFC 9221.
|
||||
ReceiveDatagram(context.Context) ([]byte, error)
|
||||
|
||||
AddPath(*Transport) (*Path, error)
|
||||
}
|
||||
|
||||
// An EarlyConnection is a connection that is handshaking.
|
||||
|
||||
@@ -120,6 +120,45 @@ func (c *MockEarlyConnectionAcceptUniStreamCall) DoAndReturn(f func(context.Cont
|
||||
return c
|
||||
}
|
||||
|
||||
// AddPath mocks base method.
|
||||
func (m *MockEarlyConnection) AddPath(arg0 *quic.Transport) (*quic.Path, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddPath", arg0)
|
||||
ret0, _ := ret[0].(*quic.Path)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AddPath indicates an expected call of AddPath.
|
||||
func (mr *MockEarlyConnectionMockRecorder) AddPath(arg0 any) *MockEarlyConnectionAddPathCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPath", reflect.TypeOf((*MockEarlyConnection)(nil).AddPath), arg0)
|
||||
return &MockEarlyConnectionAddPathCall{Call: call}
|
||||
}
|
||||
|
||||
// MockEarlyConnectionAddPathCall wrap *gomock.Call
|
||||
type MockEarlyConnectionAddPathCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockEarlyConnectionAddPathCall) Return(arg0 *quic.Path, arg1 error) *MockEarlyConnectionAddPathCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockEarlyConnectionAddPathCall) Do(f func(*quic.Transport) (*quic.Path, error)) *MockEarlyConnectionAddPathCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockEarlyConnectionAddPathCall) DoAndReturn(f func(*quic.Transport) (*quic.Path, error)) *MockEarlyConnectionAddPathCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// CloseWithError mocks base method.
|
||||
func (m *MockEarlyConnection) CloseWithError(arg0 quic.ApplicationErrorCode, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -119,6 +119,45 @@ func (c *MockQUICConnAcceptUniStreamCall) DoAndReturn(f func(context.Context) (R
|
||||
return c
|
||||
}
|
||||
|
||||
// AddPath mocks base method.
|
||||
func (m *MockQUICConn) AddPath(arg0 *Transport) (*Path, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddPath", arg0)
|
||||
ret0, _ := ret[0].(*Path)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AddPath indicates an expected call of AddPath.
|
||||
func (mr *MockQUICConnMockRecorder) AddPath(arg0 any) *MockQUICConnAddPathCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPath", reflect.TypeOf((*MockQUICConn)(nil).AddPath), arg0)
|
||||
return &MockQUICConnAddPathCall{Call: call}
|
||||
}
|
||||
|
||||
// MockQUICConnAddPathCall wrap *gomock.Call
|
||||
type MockQUICConnAddPathCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockQUICConnAddPathCall) Return(arg0 *Path, arg1 error) *MockQUICConnAddPathCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockQUICConnAddPathCall) Do(f func(*Transport) (*Path, error)) *MockQUICConnAddPathCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockQUICConnAddPathCall) DoAndReturn(f func(*Transport) (*Path, error)) *MockQUICConnAddPathCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// CloseWithError mocks base method.
|
||||
func (m *MockQUICConn) CloseWithError(arg0 ApplicationErrorCode, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
205
path_manager_outgoing.go
Normal file
205
path_manager_outgoing.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/ackhandler"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// ErrPathNotValidated is returned when trying to use a path before path probing has completed.
|
||||
var ErrPathNotValidated = errors.New("path not yet validated")
|
||||
|
||||
var errPathDoesNotExist = errors.New("path does not exist")
|
||||
|
||||
// Path is a network path.
|
||||
type Path struct {
|
||||
id pathID
|
||||
pathManager *pathManagerOutgoing
|
||||
tr *Transport
|
||||
|
||||
enablePath func()
|
||||
startedProbing atomic.Bool
|
||||
}
|
||||
|
||||
func (p *Path) Probe(ctx context.Context) error {
|
||||
done := make(chan struct{})
|
||||
p.pathManager.addPath(p, p.enablePath, done)
|
||||
|
||||
p.startedProbing.Store(true)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return context.Cause(ctx)
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Switch switches the QUIC connection to this path.
|
||||
// It immediately stops sending on the old path, and sends on this new path.
|
||||
func (p *Path) Switch() error {
|
||||
if err := p.pathManager.switchToPath(p.id); err != nil {
|
||||
if errors.Is(err, errPathDoesNotExist) && !p.startedProbing.Load() {
|
||||
return ErrPathNotValidated
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type pathOutgoing struct {
|
||||
pathChallenge [8]byte
|
||||
tr *Transport
|
||||
isValidated bool
|
||||
validated chan<- struct{} // closed when the path the corresponding PATH_RESPONSE is received
|
||||
enablePath func()
|
||||
}
|
||||
|
||||
type pathManagerOutgoing struct {
|
||||
getConnID func(pathID) (_ protocol.ConnectionID, ok bool)
|
||||
retireConnID func(pathID)
|
||||
scheduleSending func()
|
||||
|
||||
mx sync.Mutex
|
||||
pathsToProbe []pathID
|
||||
paths map[pathID]*pathOutgoing
|
||||
nextPathID pathID
|
||||
pathToSwitchTo *pathOutgoing
|
||||
}
|
||||
|
||||
func newPathManagerOutgoing(
|
||||
getConnID func(pathID) (_ protocol.ConnectionID, ok bool),
|
||||
retireConnID func(pathID),
|
||||
scheduleSending func(),
|
||||
) *pathManagerOutgoing {
|
||||
return &pathManagerOutgoing{
|
||||
getConnID: getConnID,
|
||||
retireConnID: retireConnID,
|
||||
scheduleSending: scheduleSending,
|
||||
paths: make(map[pathID]*pathOutgoing, 4),
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *pathManagerOutgoing) addPath(p *Path, enablePath func(), done chan<- struct{}) {
|
||||
var b [8]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
pm.mx.Lock()
|
||||
pm.paths[p.id] = &pathOutgoing{
|
||||
pathChallenge: b,
|
||||
tr: p.tr,
|
||||
validated: done,
|
||||
enablePath: enablePath,
|
||||
}
|
||||
pm.pathsToProbe = append(pm.pathsToProbe, p.id)
|
||||
pm.mx.Unlock()
|
||||
pm.scheduleSending()
|
||||
}
|
||||
|
||||
func (pm *pathManagerOutgoing) switchToPath(id pathID) error {
|
||||
pm.mx.Lock()
|
||||
defer pm.mx.Unlock()
|
||||
|
||||
p, ok := pm.paths[id]
|
||||
if !ok {
|
||||
return errPathDoesNotExist
|
||||
}
|
||||
if !p.isValidated {
|
||||
return ErrPathNotValidated
|
||||
}
|
||||
pm.pathToSwitchTo = p
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *pathManagerOutgoing) NewPath(t *Transport, enablePath func()) *Path {
|
||||
pm.mx.Lock()
|
||||
defer pm.mx.Unlock()
|
||||
|
||||
id := pm.nextPathID
|
||||
pm.nextPathID++
|
||||
return &Path{
|
||||
pathManager: pm,
|
||||
id: id,
|
||||
tr: t,
|
||||
enablePath: enablePath,
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *pathManagerOutgoing) NextPathToProbe() (_ protocol.ConnectionID, _ ackhandler.Frame, _ *Transport, hasPath bool) {
|
||||
pm.mx.Lock()
|
||||
defer pm.mx.Unlock()
|
||||
|
||||
var p *pathOutgoing
|
||||
var id pathID
|
||||
for {
|
||||
if len(pm.pathsToProbe) == 0 {
|
||||
return protocol.ConnectionID{}, ackhandler.Frame{}, nil, false
|
||||
}
|
||||
|
||||
id = pm.pathsToProbe[0]
|
||||
pm.pathsToProbe = pm.pathsToProbe[1:]
|
||||
|
||||
var ok bool
|
||||
// if the path doesn't exist in the map, it might have been abandoned
|
||||
p, ok = pm.paths[id]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
connID, ok := pm.getConnID(id)
|
||||
if !ok {
|
||||
return protocol.ConnectionID{}, ackhandler.Frame{}, nil, false
|
||||
}
|
||||
|
||||
p.enablePath()
|
||||
frame := ackhandler.Frame{
|
||||
Frame: &wire.PathChallengeFrame{Data: p.pathChallenge},
|
||||
Handler: (*pathManagerOutgoingAckHandler)(pm),
|
||||
}
|
||||
return connID, frame, p.tr, true
|
||||
}
|
||||
|
||||
func (pm *pathManagerOutgoing) HandlePathResponseFrame(f *wire.PathResponseFrame) {
|
||||
pm.mx.Lock()
|
||||
defer pm.mx.Unlock()
|
||||
|
||||
for _, p := range pm.paths {
|
||||
if f.Data == p.pathChallenge {
|
||||
// path validated
|
||||
if !p.isValidated {
|
||||
p.isValidated = true
|
||||
close(p.validated)
|
||||
}
|
||||
// makes sure that duplicate PATH_RESPONSE frames are ignored
|
||||
p.validated = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *pathManagerOutgoing) ShouldSwitchPath() (*Transport, bool) {
|
||||
pm.mx.Lock()
|
||||
defer pm.mx.Unlock()
|
||||
|
||||
if pm.pathToSwitchTo == nil {
|
||||
return nil, false
|
||||
}
|
||||
p := pm.pathToSwitchTo
|
||||
pm.pathToSwitchTo = nil
|
||||
return p.tr, true
|
||||
}
|
||||
|
||||
type pathManagerOutgoingAckHandler pathManagerOutgoing
|
||||
|
||||
var _ ackhandler.FrameHandler = &pathManagerOutgoingAckHandler{}
|
||||
|
||||
// OnAcked is called when the PATH_CHALLENGE is acked.
|
||||
// This doesn't validate the path, only receiving the PATH_RESPONSE does.
|
||||
func (pm *pathManagerOutgoingAckHandler) OnAcked(wire.Frame) {}
|
||||
|
||||
func (pm *pathManagerOutgoingAckHandler) OnLost(wire.Frame) {}
|
||||
99
path_manager_outgoing_test.go
Normal file
99
path_manager_outgoing_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPathManagerOutgoing(t *testing.T) {
|
||||
connIDs := []protocol.ConnectionID{
|
||||
protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
||||
protocol.ParseConnectionID([]byte{2, 3, 4, 5, 6, 7, 8, 9}),
|
||||
}
|
||||
var retiredConnIDs []protocol.ConnectionID
|
||||
pm := newPathManagerOutgoing(
|
||||
func(id pathID) (protocol.ConnectionID, bool) { return connIDs[id], true },
|
||||
func(id pathID) { retiredConnIDs = append(retiredConnIDs, connIDs[id]) },
|
||||
func() {},
|
||||
)
|
||||
|
||||
_, _, _, ok := pm.NextPathToProbe()
|
||||
require.False(t, ok)
|
||||
|
||||
tr1 := &Transport{}
|
||||
var enabled bool
|
||||
p := pm.NewPath(tr1, func() { enabled = true })
|
||||
require.ErrorIs(t, p.Switch(), ErrPathNotValidated)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() { errChan <- p.Probe(context.Background()) }()
|
||||
|
||||
// wait for the path to be queued for probing
|
||||
time.Sleep(scaleDuration(5 * time.Millisecond))
|
||||
|
||||
require.False(t, enabled)
|
||||
connID, f, tr, ok := pm.NextPathToProbe()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tr1, tr)
|
||||
require.Equal(t, connIDs[0], connID)
|
||||
require.IsType(t, &wire.PathChallengeFrame{}, f.Frame)
|
||||
pc := f.Frame.(*wire.PathChallengeFrame)
|
||||
require.True(t, enabled)
|
||||
|
||||
_, _, _, ok = pm.NextPathToProbe()
|
||||
require.False(t, ok)
|
||||
|
||||
select {
|
||||
case <-errChan:
|
||||
t.Fatal("should still be probing")
|
||||
default:
|
||||
}
|
||||
|
||||
// acking the frame doesn't complete path validation...
|
||||
f.Handler.OnAcked(f.Frame)
|
||||
select {
|
||||
case <-errChan:
|
||||
t.Fatal("should still be probing")
|
||||
default:
|
||||
}
|
||||
|
||||
require.ErrorIs(t, p.Switch(), ErrPathNotValidated)
|
||||
_, ok = pm.ShouldSwitchPath()
|
||||
require.False(t, ok)
|
||||
|
||||
// ... neither does receiving a random PATH_RESPONSE...
|
||||
pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: [8]byte{'f', 'o', 'o', 'f', 'o', 'o'}})
|
||||
f.Handler.OnAcked(f.Frame) // doesn't do anything
|
||||
f.Handler.OnLost(f.Frame) // doesn't do anything
|
||||
select {
|
||||
case <-errChan:
|
||||
t.Fatal("should still be probing")
|
||||
default:
|
||||
}
|
||||
|
||||
// ... only receiving the corresponding PATH_RESPONSE does
|
||||
pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc.Data})
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// receiving it multiple times is ok
|
||||
pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc.Data})
|
||||
|
||||
// now switch to the other path
|
||||
_, ok = pm.ShouldSwitchPath()
|
||||
require.False(t, ok)
|
||||
require.NoError(t, p.Switch())
|
||||
switchToTransport, ok := pm.ShouldSwitchPath()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tr1, switchToTransport)
|
||||
}
|
||||
16
server.go
16
server.go
@@ -61,6 +61,7 @@ type rejectedPacket struct {
|
||||
|
||||
// A Listener of QUIC
|
||||
type baseServer struct {
|
||||
tr *Transport
|
||||
disableVersionNegotiation bool
|
||||
acceptEarlyConns bool
|
||||
|
||||
@@ -74,7 +75,6 @@ type baseServer struct {
|
||||
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
statelessResetter *statelessResetter
|
||||
connHandler packetHandlerManager
|
||||
onClose func()
|
||||
|
||||
receivedPackets chan receivedPacket
|
||||
@@ -89,7 +89,7 @@ type baseServer struct {
|
||||
context.Context,
|
||||
context.CancelCauseFunc,
|
||||
sendConn,
|
||||
connRunner,
|
||||
*Transport,
|
||||
protocol.ConnectionID, /* original dest connection ID */
|
||||
*protocol.ConnectionID, /* retry src connection ID */
|
||||
protocol.ConnectionID, /* client dest connection ID */
|
||||
@@ -247,7 +247,7 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear
|
||||
|
||||
func newServer(
|
||||
conn rawConn,
|
||||
connHandler packetHandlerManager,
|
||||
tr *Transport,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetter *statelessResetter,
|
||||
connContext func(context.Context) context.Context,
|
||||
@@ -264,6 +264,7 @@ func newServer(
|
||||
s := &baseServer{
|
||||
conn: conn,
|
||||
connContext: connContext,
|
||||
tr: tr,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
|
||||
@@ -271,7 +272,6 @@ func newServer(
|
||||
verifySourceAddress: verifySourceAddress,
|
||||
connIDGenerator: connIDGenerator,
|
||||
statelessResetter: statelessResetter,
|
||||
connHandler: connHandler,
|
||||
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
|
||||
errorChan: make(chan struct{}),
|
||||
stopAccepting: make(chan struct{}),
|
||||
@@ -501,7 +501,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
|
||||
}
|
||||
|
||||
// check again if we might have a connection now
|
||||
if handler, ok := s.connHandler.Get(connID); ok {
|
||||
if handler, ok := s.tr.connRunner().Get(connID); ok {
|
||||
handler.handlePacket(p)
|
||||
return true
|
||||
}
|
||||
@@ -591,7 +591,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
// The server queues packets for a while, and we might already have established a connection by now.
|
||||
// This results in a second check in the connection map.
|
||||
// That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets).
|
||||
if handler, ok := s.connHandler.Get(hdr.DestConnectionID); ok {
|
||||
if handler, ok := s.tr.connRunner().Get(hdr.DestConnectionID); ok {
|
||||
handler.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
@@ -706,7 +706,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
ctx,
|
||||
cancel,
|
||||
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
|
||||
s.connHandler,
|
||||
s.tr,
|
||||
origDestConnID,
|
||||
retrySrcConnID,
|
||||
hdr.DestConnectionID,
|
||||
@@ -727,7 +727,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
// This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
|
||||
// under normal circumstances the packet would just be routed to that connection.
|
||||
// The only time this collision will occur if we receive the two Initial packets at the same time.
|
||||
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
|
||||
if added := s.tr.connRunner().AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
conn.closeWithTransportError(qerr.ConnectionRefused)
|
||||
return nil
|
||||
|
||||
@@ -37,7 +37,7 @@ type serverOpts struct {
|
||||
context.Context,
|
||||
context.CancelCauseFunc,
|
||||
sendConn,
|
||||
connRunner,
|
||||
*Transport,
|
||||
protocol.ConnectionID, // original dest connection ID
|
||||
*protocol.ConnectionID, // retry src connection ID
|
||||
protocol.ConnectionID, // client dest connection ID
|
||||
@@ -63,7 +63,7 @@ func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer {
|
||||
config := populateConfig(serverOpts.config)
|
||||
s := newServer(
|
||||
c,
|
||||
newPacketHandlerMap(nil, utils.DefaultLogger),
|
||||
&Transport{handlerMap: newPacketHandlerMap(nil, utils.DefaultLogger)},
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
&statelessResetter{},
|
||||
func(ctx context.Context) context.Context { return ctx },
|
||||
@@ -623,7 +623,7 @@ func (r *connConstructorRecorder) NewConn(
|
||||
ctx context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ *Transport,
|
||||
origDestConnID protocol.ConnectionID,
|
||||
retrySrcConnID *protocol.ConnectionID,
|
||||
clientDestConnID protocol.ConnectionID,
|
||||
@@ -863,7 +863,7 @@ func TestServerPacketHandling(t *testing.T) {
|
||||
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
|
||||
handledPacket <- p
|
||||
})
|
||||
server.connHandler.Add(destConnID, conn)
|
||||
server.tr.handlerMap.Add(destConnID, conn)
|
||||
|
||||
server.handlePacket(
|
||||
getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, srcConnID, destConnID),
|
||||
@@ -889,7 +889,7 @@ func TestServerReceiveQueue(t *testing.T) {
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ *Transport,
|
||||
_ protocol.ConnectionID,
|
||||
_ *protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
|
||||
16
transport.go
16
transport.go
@@ -38,6 +38,10 @@ func (e *errTransportClosed) Is(target error) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
type transportID uint64
|
||||
|
||||
var transportIDCounter atomic.Uint64
|
||||
|
||||
var errListenerAlreadySet = errors.New("listener already set")
|
||||
|
||||
// The Transport is the central point to manage incoming and outgoing QUIC connections.
|
||||
@@ -133,6 +137,7 @@ type Transport struct {
|
||||
initErr error
|
||||
|
||||
// Set in init.
|
||||
transportID transportID
|
||||
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
|
||||
connIDLen int
|
||||
// Set in init.
|
||||
@@ -207,7 +212,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
||||
}
|
||||
s := newServer(
|
||||
t.conn,
|
||||
t.handlerMap,
|
||||
t,
|
||||
t.connIDGenerator,
|
||||
t.statelessResetter,
|
||||
t.ConnContext,
|
||||
@@ -298,7 +303,7 @@ func (t *Transport) doDial(
|
||||
conn := newClientConnection(
|
||||
context.WithoutCancel(ctx),
|
||||
sendConn,
|
||||
t.handlerMap,
|
||||
t,
|
||||
destConnID,
|
||||
srcConnID,
|
||||
t.connIDGenerator,
|
||||
@@ -371,6 +376,7 @@ func (t *Transport) doDial(
|
||||
|
||||
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
||||
t.initOnce.Do(func() {
|
||||
t.transportID = transportID(transportIDCounter.Add(1))
|
||||
var conn rawConn
|
||||
if c, ok := t.Conn.(rawConn); ok {
|
||||
conn = c
|
||||
@@ -420,6 +426,12 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
||||
return t.initErr
|
||||
}
|
||||
|
||||
func (t *Transport) connRunner() packetHandlerManager {
|
||||
return t.handlerMap
|
||||
}
|
||||
|
||||
func (t *Transport) id() transportID { return t.transportID }
|
||||
|
||||
// WriteTo sends a packet on the underlying connection.
|
||||
func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if err := t.init(false); err != nil {
|
||||
|
||||
@@ -510,7 +510,7 @@ func testTransportDial(t *testing.T, early bool) {
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ *Transport,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
@@ -585,7 +585,7 @@ func TestTransportDialingVersionNegotiation(t *testing.T) {
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ *Transport,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
|
||||
Reference in New Issue
Block a user