send PATH_RESPONSEs on the same path (#4991)

* make it possible to pack path probes with multiple frames

* simplify function signature of pathManager.HandlePacket

* simplify connection short header packet handling logic

No functional change expected.

* make server send PATH_RESPONSEs on the same path

This makes sure that we’re actually testing for return routability.
This commit is contained in:
Marten Seemann
2025-03-16 10:28:53 +07:00
committed by GitHub
parent 7e3d668981
commit 6fe46d6253
8 changed files with 281 additions and 152 deletions

View File

@@ -1043,7 +1043,7 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) (wasProcessed boo
)
}
}
isNonProbing, err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log)
isNonProbing, pathChallenge, err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log)
if err != nil {
return false, err
}
@@ -1052,47 +1052,46 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) (wasProcessed boo
if s.perspective == protocol.PerspectiveClient {
return true, nil
}
if addrsEqual(p.remoteAddr, s.RemoteAddr()) {
return true, nil
}
var shouldSwitchPath bool
if !addrsEqual(p.remoteAddr, s.RemoteAddr()) {
if s.pathManager == nil {
s.pathManager = newPathManager(
s.connIDManager.GetConnIDForPath,
s.connIDManager.RetireConnIDForPath,
s.logger,
)
}
var destConnID protocol.ConnectionID
var pathChallenge ackhandler.Frame
destConnID, pathChallenge, shouldSwitchPath = s.pathManager.HandlePacket(p, isNonProbing)
if pathChallenge.Frame != nil {
probe, buf, err := s.packer.PackPathProbePacket(destConnID, pathChallenge, s.version)
if err != nil {
return false, err
}
s.logger.Debugf("sending path probe packet to %s", p.remoteAddr)
s.logShortHeaderPacket(probe.DestConnID, probe.Ack, probe.Frames, probe.StreamFrames, probe.PacketNumber, probe.PacketNumberLen, probe.KeyPhase, protocol.ECNNon, buf.Len(), false)
s.registerPackedShortHeaderPacket(probe, protocol.ECNNon, p.rcvTime)
s.sendQueue.SendProbe(buf, p.remoteAddr)
}
// We only switch paths in response to the highest-numbered non-probing packet,
// see section 9.3 of RFC 9000.
if !shouldSwitchPath || pn != s.largestRcvdAppData {
return true, nil
}
s.pathManager.SwitchToPath(p.remoteAddr)
s.sentPacketHandler.MigratedPath(p.rcvTime, protocol.ByteCount(s.config.InitialPacketSize))
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if s.peerParams.MaxUDPPayloadSize > 0 && s.peerParams.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = s.peerParams.MaxUDPPayloadSize
}
s.mtuDiscoverer.Reset(
p.rcvTime,
protocol.ByteCount(s.config.InitialPacketSize),
maxPacketSize,
if s.pathManager == nil {
s.pathManager = newPathManager(
s.connIDManager.GetConnIDForPath,
s.connIDManager.RetireConnIDForPath,
s.logger,
)
s.conn.ChangeRemoteAddr(p.remoteAddr, p.info)
}
destConnID, frames, shouldSwitchPath := s.pathManager.HandlePacket(p.remoteAddr, pathChallenge, isNonProbing)
if len(frames) > 0 {
probe, buf, err := s.packer.PackPathProbePacket(destConnID, frames, s.version)
if err != nil {
return true, err
}
s.logger.Debugf("sending path probe packet to %s", p.remoteAddr)
s.logShortHeaderPacket(probe.DestConnID, probe.Ack, probe.Frames, probe.StreamFrames, probe.PacketNumber, probe.PacketNumberLen, probe.KeyPhase, protocol.ECNNon, buf.Len(), false)
s.registerPackedShortHeaderPacket(probe, protocol.ECNNon, p.rcvTime)
s.sendQueue.SendProbe(buf, p.remoteAddr)
}
// We only switch paths in response to the highest-numbered non-probing packet,
// see section 9.3 of RFC 9000.
if !shouldSwitchPath || pn != s.largestRcvdAppData {
return true, nil
}
s.pathManager.SwitchToPath(p.remoteAddr)
s.sentPacketHandler.MigratedPath(p.rcvTime, protocol.ByteCount(s.config.InitialPacketSize))
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if s.peerParams.MaxUDPPayloadSize > 0 && s.peerParams.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = s.peerParams.MaxUDPPayloadSize
}
s.mtuDiscoverer.Reset(
p.rcvTime,
protocol.ByteCount(s.config.InitialPacketSize),
maxPacketSize,
)
s.conn.ChangeRemoteAddr(p.remoteAddr, p.info)
return true, nil
}
@@ -1377,7 +1376,7 @@ func (s *connection) handleUnpackedLongHeaderPacket(
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames)
}
}
isAckEliciting, _, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log, rcvTime)
isAckEliciting, _, _, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log, rcvTime)
if err != nil {
return err
}
@@ -1391,28 +1390,30 @@ func (s *connection) handleUnpackedShortHeaderPacket(
ecn protocol.ECN,
rcvTime time.Time,
log func([]logging.Frame),
) (isNonProbing bool, _ error) {
) (isNonProbing bool, pathChallenge *wire.PathChallengeFrame, _ error) {
s.lastPacketReceivedTime = rcvTime
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false
isAckEliciting, isNonProbing, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log, rcvTime)
isAckEliciting, isNonProbing, pathChallenge, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log, rcvTime)
if err != nil {
return false, err
return false, nil, err
}
if err := s.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting); err != nil {
return false, err
return false, nil, err
}
return isNonProbing, nil
return isNonProbing, pathChallenge, nil
}
// handleFrames parses the frames, one after the other, and handles them.
// It returns the last PATH_CHALLENGE frame contained in the packet, if any.
func (s *connection) handleFrames(
data []byte,
destConnID protocol.ConnectionID,
encLevel protocol.EncryptionLevel,
log func([]logging.Frame),
rcvTime time.Time,
) (isAckEliciting, isNonProbing bool, _ error) {
) (isAckEliciting, isNonProbing bool, pathChallenge *wire.PathChallengeFrame, _ error) {
// Only used for tracing.
// If we're not tracing, this slice will always remain empty.
var frames []logging.Frame
@@ -1424,7 +1425,7 @@ func (s *connection) handleFrames(
for len(data) > 0 {
l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version)
if err != nil {
return false, false, err
return false, false, nil, err
}
data = data[l:]
if frame == nil {
@@ -1444,19 +1445,23 @@ func (s *connection) handleFrames(
if handleErr != nil {
continue
}
if err := s.handleFrame(frame, encLevel, destConnID, rcvTime); err != nil {
pc, err := s.handleFrame(frame, encLevel, destConnID, rcvTime)
if err != nil {
if log == nil {
return false, false, err
return false, false, nil, err
}
// If we're logging, we need to keep parsing (but not handling) all frames.
handleErr = err
}
if pc != nil {
pathChallenge = pc
}
}
if log != nil {
log(frames)
if handleErr != nil {
return false, false, handleErr
return false, false, nil, handleErr
}
}
@@ -1466,7 +1471,7 @@ func (s *connection) handleFrames(
// and an ACK serialized after that CRYPTO frame. In this case, we still want to process the ACK frame.
if !handshakeWasComplete && s.handshakeComplete {
if err := s.handleHandshakeComplete(rcvTime); err != nil {
return false, false, err
return false, false, nil, err
}
}
return
@@ -1477,7 +1482,7 @@ func (s *connection) handleFrame(
encLevel protocol.EncryptionLevel,
destConnID protocol.ConnectionID,
rcvTime time.Time,
) error {
) (pathChallenge *wire.PathChallengeFrame, _ error) {
var err error
wire.LogFrame(s.logger, f, false)
switch frame := f.(type) {
@@ -1506,6 +1511,7 @@ func (s *connection) handleFrame(
case *wire.PingFrame:
case *wire.PathChallengeFrame:
s.handlePathChallengeFrame(frame)
pathChallenge = frame
case *wire.PathResponseFrame:
err = s.handlePathResponseFrame(frame)
case *wire.NewTokenFrame:
@@ -1521,7 +1527,7 @@ func (s *connection) handleFrame(
default:
err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name())
}
return err
return pathChallenge, err
}
// handlePacket is called by the server with a new packet
@@ -1675,7 +1681,9 @@ func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error
}
func (s *connection) handlePathChallengeFrame(f *wire.PathChallengeFrame) {
s.queueControlFrame(&wire.PathResponseFrame{Data: f.Data})
if s.perspective == protocol.PerspectiveClient {
s.queueControlFrame(&wire.PathResponseFrame{Data: f.Data})
}
}
func (s *connection) handlePathResponseFrame(f *wire.PathResponseFrame) error {
@@ -2083,7 +2091,7 @@ func (s *connection) sendPackets(now time.Time) error {
if pm := s.pathManagerOutgoing.Load(); pm != nil {
connID, frame, tr, ok := pm.NextPathToProbe()
if ok {
probe, buf, err := s.packer.PackPathProbePacket(connID, frame, s.version)
probe, buf, err := s.packer.PackPathProbePacket(connID, []ackhandler.Frame{frame}, s.version)
if err != nil {
return err
}

View File

@@ -222,14 +222,17 @@ func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
// STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
str.EXPECT().handleStreamFrame(f, now)
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// RESET_STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
str.EXPECT().handleResetStreamFrame(rsf, now)
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
})
t.Run("for closed streams", func(t *testing.T) {
@@ -238,13 +241,16 @@ func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// RESET_STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
})
t.Run("for invalid streams", func(t *testing.T) {
@@ -254,13 +260,16 @@ func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
testErr := errors.New("test err")
// STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now), testErr)
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
require.ErrorIs(t, err, testErr)
// RESET_STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now), testErr)
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
require.ErrorIs(t, err, testErr)
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now), testErr)
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
require.ErrorIs(t, err, testErr)
})
}
@@ -279,11 +288,13 @@ func TestConnectionHandleSendStreamFrames(t *testing.T) {
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
str.EXPECT().handleStopSendingFrame(ss)
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// MAX_STREAM_DATA frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
str.EXPECT().updateSendWindow(msd.MaximumStreamData)
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
})
t.Run("for closed streams", func(t *testing.T) {
@@ -292,10 +303,12 @@ func TestConnectionHandleSendStreamFrames(t *testing.T) {
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// MAX_STREAM_DATA frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
})
t.Run("for invalid streams", func(t *testing.T) {
@@ -305,10 +318,12 @@ func TestConnectionHandleSendStreamFrames(t *testing.T) {
testErr := errors.New("test err")
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now), testErr)
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
require.ErrorIs(t, err, testErr)
// MAX_STREAM_DATA frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now), testErr)
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
require.ErrorIs(t, err, testErr)
})
}
@@ -321,9 +336,11 @@ func TestConnectionHandleStreamNumFrames(t *testing.T) {
// MAX_STREAMS frame
msf := &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}
streamsMap.EXPECT().HandleMaxStreamsFrame(msf)
require.NoError(t, tc.conn.handleFrame(msf, protocol.Encryption1RTT, connID, now))
_, err := tc.conn.handleFrame(msf, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// STREAMS_BLOCKED frame
tc.conn.handleFrame(&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 1}, protocol.Encryption1RTT, connID, now)
_, err = tc.conn.handleFrame(&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 1}, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
}
func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
@@ -334,9 +351,11 @@ func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
connID := protocol.ConnectionID{}
// MAX_DATA frame
connFC.EXPECT().UpdateSendWindow(protocol.ByteCount(1337))
require.NoError(t, tc.conn.handleFrame(&wire.MaxDataFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
_, err := tc.conn.handleFrame(&wire.MaxDataFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// DATA_BLOCKED frame
require.NoError(t, tc.conn.handleFrame(&wire.DataBlockedFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
_, err = tc.conn.handleFrame(&wire.DataBlockedFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
}
func TestConnectionOpenStreams(t *testing.T) {
@@ -404,10 +423,8 @@ func TestConnectionServerInvalidFrames(t *testing.T) {
{Name: "PATH_RESPONSE", Frame: &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}},
} {
t.Run(test.Name, func(t *testing.T) {
require.ErrorIs(t,
tc.conn.handleFrame(test.Frame, protocol.Encryption1RTT, protocol.ConnectionID{}, time.Now()),
&qerr.TransportError{ErrorCode: qerr.ProtocolViolation},
)
_, err := tc.conn.handleFrame(test.Frame, protocol.Encryption1RTT, protocol.ConnectionID{}, time.Now())
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
})
}
}
@@ -2887,8 +2904,8 @@ func testConnectionPathValidation(t *testing.T, isNATRebinding bool) {
protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil,
),
tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(_ protocol.ConnectionID, f ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
pathChallenge = f.Frame.(*wire.PathChallengeFrame)
func(_ protocol.ConnectionID, frames []ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
pathChallenge = frames[0].Frame.(*wire.PathChallengeFrame)
return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil
},
),
@@ -2960,6 +2977,8 @@ func testConnectionPathValidation(t *testing.T, isNATRebinding bool) {
}
payload := []byte{1} // PING frame
payload, err = (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(payload, protocol.Version1)
require.NoError(t, err)
gomock.InOrder(
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
protocol.PacketNumber(12), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil,
@@ -3039,7 +3058,7 @@ func testConnectionMigration(t *testing.T, enabled bool) {
).AnyTimes()
packedProbe := make(chan struct{})
tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(_ protocol.ConnectionID, f ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
func(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
defer close(packedProbe)
return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil
},

View File

@@ -18,7 +18,7 @@ import (
)
func TestConnectionMigration(t *testing.T) {
ln, err := quic.ListenAddr("localhost:0", tlsConfig, getQuicConfig(nil))
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()

View File

@@ -319,7 +319,7 @@ func (c *MockPackerPackMTUProbePacketCall) DoAndReturn(f func(ackhandler.Frame,
}
// PackPathProbePacket mocks base method.
func (m *MockPacker) PackPathProbePacket(arg0 protocol.ConnectionID, arg1 ackhandler.Frame, arg2 protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
func (m *MockPacker) PackPathProbePacket(arg0 protocol.ConnectionID, arg1 []ackhandler.Frame, arg2 protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PackPathProbePacket", arg0, arg1, arg2)
ret0, _ := ret[0].(shortHeaderPacket)
@@ -347,13 +347,13 @@ func (c *MockPackerPackPathProbePacketCall) Return(arg0 shortHeaderPacket, arg1
}
// Do rewrite *gomock.Call.Do
func (c *MockPackerPackPathProbePacketCall) Do(f func(protocol.ConnectionID, ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall {
func (c *MockPackerPackPathProbePacketCall) Do(f func(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPackerPackPathProbePacketCall) DoAndReturn(f func(protocol.ConnectionID, ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall {
func (c *MockPackerPackPathProbePacketCall) DoAndReturn(f func(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

View File

@@ -25,7 +25,7 @@ type packer interface {
MaybePackPTOProbePacket(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error)
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackPathProbePacket(protocol.ConnectionID, ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)
PackPathProbePacket(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
SetToken([]byte)
@@ -794,16 +794,20 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
return packet, buffer, err
}
func (p *packetPacker) PackPathProbePacket(connID protocol.ConnectionID, f ackhandler.Frame, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
func (p *packetPacker) PackPathProbePacket(connID protocol.ConnectionID, frames []ackhandler.Frame, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
buf := getPacketBuffer()
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return shortHeaderPacket{}, nil, err
}
var l protocol.ByteCount
for _, f := range frames {
l += f.Frame.Length(v)
}
payload := payload{
frames: []ackhandler.Frame{f},
length: f.Frame.Length(v),
frames: frames,
length: l,
}
padding := protocol.MinInitialPacketSize - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead())
packet, err := p.appendShortHeaderPacket(buf, connID, pn, pnLen, s.KeyPhase(), payload, padding, protocol.MinInitialPacketSize, s, false, v)

View File

@@ -917,14 +917,21 @@ func TestPackPathProbePacket(t *testing.T) {
p, buf, err := tp.packer.PackPathProbePacket(
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
ackhandler.Frame{Frame: &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}},
[]ackhandler.Frame{
{Frame: &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}},
{Frame: &wire.PathResponseFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}}},
},
protocol.Version1,
)
require.NoError(t, err)
require.Equal(t, protocol.PacketNumber(0x43), p.PacketNumber)
require.Nil(t, p.Ack)
require.Empty(t, p.StreamFrames)
require.Equal(t, &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, p.Frames[0].Frame)
require.Len(t, p.Frames, 2)
// the frame order is randomized
frames := []wire.Frame{p.Frames[0].Frame, p.Frames[1].Frame}
require.Contains(t, frames, &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}})
require.Contains(t, frames, &wire.PathResponseFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}})
require.Len(t, buf.Data, protocol.MinInitialPacketSize)
require.True(t, p.IsPathProbePacket)
require.False(t, p.IsPathMTUProbePacket)

View File

@@ -46,47 +46,69 @@ func newPathManager(
// Returns a path challenge frame if one should be sent.
// May return nil.
func (pm *pathManager) HandlePacket(p receivedPacket, isNonProbing bool) (_ protocol.ConnectionID, _ ackhandler.Frame, shouldSwitch bool) {
for _, path := range pm.paths {
if addrsEqual(path.addr, p.remoteAddr) {
func (pm *pathManager) HandlePacket(
remoteAddr net.Addr,
pathChallenge *wire.PathChallengeFrame, // may be nil if the packet didn't contain a PATH_CHALLENGE
isNonProbing bool,
) (_ protocol.ConnectionID, _ []ackhandler.Frame, shouldSwitch bool) {
var p *path
pathID := pm.nextPathID
for id, path := range pm.paths {
if addrsEqual(path.addr, remoteAddr) {
p = path
pathID = id
// already sent a PATH_CHALLENGE for this path
if isNonProbing {
path.rcvdNonProbing = true
}
if pm.logger.Debug() {
pm.logger.Debugf("received packet for path %s that was already probed, validated: %t", p.remoteAddr, path.validated)
pm.logger.Debugf("received packet for path %s that was already probed, validated: %t", remoteAddr, path.validated)
}
shouldSwitch = path.validated && path.rcvdNonProbing
if pathChallenge == nil {
return protocol.ConnectionID{}, nil, shouldSwitch
}
return protocol.ConnectionID{}, ackhandler.Frame{}, path.validated && path.rcvdNonProbing
}
}
if len(pm.paths) >= maxPaths {
if pm.logger.Debug() {
pm.logger.Debugf("received packet for previously unseen path %s, but already have %d paths", p.remoteAddr, len(pm.paths))
pm.logger.Debugf("received packet for previously unseen path %s, but already have %d paths", remoteAddr, len(pm.paths))
}
return protocol.ConnectionID{}, ackhandler.Frame{}, false
return protocol.ConnectionID{}, nil, shouldSwitch
}
// previously unseen path, initiate path validation by sending a PATH_CHALLENGE
connID, ok := pm.getConnID(pm.nextPathID)
connID, ok := pm.getConnID(pathID)
if !ok {
pm.logger.Debugf("skipping validation of new path %s since no connection ID is available", p.remoteAddr)
return protocol.ConnectionID{}, ackhandler.Frame{}, false
pm.logger.Debugf("skipping validation of new path %s since no connection ID is available", remoteAddr)
return protocol.ConnectionID{}, nil, shouldSwitch
}
var b [8]byte
rand.Read(b[:])
pm.paths[pm.nextPathID] = &path{
addr: p.remoteAddr,
pathChallenge: b,
rcvdNonProbing: isNonProbing,
frames := make([]ackhandler.Frame, 0, 2)
if p == nil {
var pathChallengeData [8]byte
rand.Read(pathChallengeData[:])
p = &path{
addr: remoteAddr,
rcvdNonProbing: isNonProbing,
pathChallenge: pathChallengeData,
}
frames = append(frames, ackhandler.Frame{
Frame: &wire.PathChallengeFrame{Data: p.pathChallenge},
Handler: (*pathManagerAckHandler)(pm),
})
pm.paths[pm.nextPathID] = p
pm.nextPathID++
pm.logger.Debugf("enqueueing PATH_CHALLENGE for new path %s", remoteAddr)
}
pm.nextPathID++
frame := ackhandler.Frame{
Frame: &wire.PathChallengeFrame{Data: b},
Handler: (*pathManagerAckHandler)(pm),
if pathChallenge != nil {
frames = append(frames, ackhandler.Frame{
Frame: &wire.PathResponseFrame{Data: pathChallenge.Data},
Handler: (*pathManagerAckHandler)(pm),
})
}
pm.logger.Debugf("enqueueing PATH_CHALLENGE for new path %s", p.remoteAddr)
return connID, frame, false
return connID, frames, shouldSwitch
}
func (pm *pathManager) HandlePathResponseFrame(f *wire.PathResponseFrame) {

View File

@@ -28,67 +28,135 @@ func TestPathManagerIntentionalMigration(t *testing.T) {
func(id pathID) { retiredConnIDs = append(retiredConnIDs, connIDs[id]) },
utils.DefaultLogger,
)
connID, f1, shouldSwitch := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
connID, frames, shouldSwitch := pm.HandlePacket(
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
&wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}},
false,
)
require.Equal(t, connIDs[0], connID)
require.NotNil(t, f1.Frame)
pc1 := f1.Frame.(*wire.PathChallengeFrame)
require.Len(t, frames, 2)
require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame)
pc1 := frames[0].Frame.(*wire.PathChallengeFrame)
require.NotZero(t, pc1.Data)
require.NotEqual(t, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, pc1.Data)
require.IsType(t, &wire.PathResponseFrame{}, frames[1].Frame)
require.Equal(t, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, frames[1].Frame.(*wire.PathResponseFrame).Data)
require.False(t, shouldSwitch)
// receiving another packet for the same path doesn't trigger another PATH_CHALLENGE
connID, f, shouldSwitch := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
connID, frames, shouldSwitch = pm.HandlePacket(
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
nil,
false,
)
require.Zero(t, connID)
require.Nil(t, f.Frame)
require.Len(t, frames, 0)
require.False(t, shouldSwitch)
// receiving a packet for a different path triggers another PATH_CHALLENGE
addr2 := &net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1000}
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: addr2}, false)
connID, frames, shouldSwitch = pm.HandlePacket(addr2, nil, false)
require.Equal(t, connIDs[1], connID)
require.NotNil(t, f.Frame)
pc2 := f.Frame.(*wire.PathChallengeFrame)
require.Len(t, frames, 1)
require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame)
pc2 := frames[0].Frame.(*wire.PathChallengeFrame)
require.NotEqual(t, pc1.Data, pc2.Data)
require.False(t, shouldSwitch)
// acknowledging the PATH_CHALLENGE doesn't confirm the path
f1.Handler.OnAcked(f1.Frame)
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
frames[0].Handler.OnAcked(frames[0].Frame)
connID, frames, shouldSwitch = pm.HandlePacket(
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
nil,
false,
)
require.Zero(t, connID)
require.Nil(t, f.Frame)
require.Empty(t, frames)
require.False(t, shouldSwitch)
// receiving a PATH_RESPONSE for the second path confirms the path
pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc2.Data})
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: addr2}, false)
connID, frames, shouldSwitch = pm.HandlePacket(addr2, nil, false)
require.Zero(t, connID)
require.Nil(t, f.Frame)
require.Empty(t, frames)
require.False(t, shouldSwitch) // no non-probing packet received yet
require.Empty(t, retiredConnIDs)
// confirming the path doesn't remove other paths
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
connID, frames, shouldSwitch = pm.HandlePacket(
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
nil,
false,
)
require.Zero(t, connID)
require.Nil(t, f.Frame)
require.Empty(t, frames)
require.False(t, shouldSwitch)
// now receive a non-probing packet for the new path
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1000}}, true)
connID, frames, shouldSwitch = pm.HandlePacket(
&net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1000},
nil,
true,
)
require.Zero(t, connID)
require.Nil(t, f.Frame)
require.Empty(t, frames)
require.True(t, shouldSwitch)
// now switch to the new path
pm.SwitchToPath(&net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1000})
// switching to the path removes other paths
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
connID, frames, shouldSwitch = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, nil, false)
require.Equal(t, connIDs[2], connID)
require.NotNil(t, f.Frame)
require.NotEqual(t, f.Frame.(*wire.PathChallengeFrame).Data, pc1.Data)
require.NotEmpty(t, frames)
require.NotEqual(t, frames[0].Frame.(*wire.PathChallengeFrame).Data, pc1.Data)
require.False(t, shouldSwitch)
require.Equal(t, []protocol.ConnectionID{connIDs[0]}, retiredConnIDs)
}
func TestPathManagerMultipleProbes(t *testing.T) {
connIDs := []protocol.ConnectionID{
protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
}
pm := newPathManager(
func(id pathID) (protocol.ConnectionID, bool) { return connIDs[id], true },
func(id pathID) {},
utils.DefaultLogger,
)
// first receive a packet without a PATH_CHALLENGE
connID, frames, shouldSwitch := pm.HandlePacket(
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
nil,
false,
)
require.Equal(t, connIDs[0], connID)
require.Len(t, frames, 1)
require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame)
require.False(t, shouldSwitch)
// now receive a packet on the same path with a PATH_CHALLENGE
connID, frames, shouldSwitch = pm.HandlePacket(
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
&wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}},
false,
)
require.Equal(t, connIDs[0], connID)
require.Len(t, frames, 1)
require.Equal(t, &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, frames[0].Frame)
require.False(t, shouldSwitch)
// now receive an other packet on the same path with a PATH_RESPONSE
connID, frames, shouldSwitch = pm.HandlePacket(
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
&wire.PathChallengeFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}},
false,
)
require.Equal(t, connIDs[0], connID)
require.Len(t, frames, 1)
require.Equal(t, &wire.PathResponseFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}}, frames[0].Frame)
require.False(t, shouldSwitch)
}
// The first packet received on the new path is already a non-probing packet.
// We still need to validate the new path, but we can then switch over immediately.
// This is the typical scenario when a NAT rebinding happens.
@@ -103,19 +171,20 @@ func TestPathManagerNATRebinding(t *testing.T) {
utils.DefaultLogger,
)
connID, f, shouldSwitch := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, true)
connID, frames, shouldSwitch := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, nil, true)
require.Equal(t, connIDs[0], connID)
require.NotNil(t, f.Frame)
pc1 := f.Frame.(*wire.PathChallengeFrame)
require.Len(t, frames, 1)
require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame)
pc1 := frames[0].Frame.(*wire.PathChallengeFrame)
require.NotZero(t, pc1.Data)
require.False(t, shouldSwitch)
// receiving a PATH_RESPONSE for the second path confirms the path
pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc1.Data})
// we now switch to the new path, as soon as the next packet on that path is received
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
connID, frames, shouldSwitch = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, nil, false)
require.Zero(t, connID)
require.Nil(t, f.Frame)
require.Empty(t, frames)
require.True(t, shouldSwitch)
}
@@ -134,41 +203,41 @@ func TestPathManagerLimits(t *testing.T) {
)
for i := range maxPaths {
connID, f, _ := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000 + i}}, true)
require.NotNil(t, f.Frame)
connID, frames, _ := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000 + i}, nil, true)
require.NotEmpty(t, frames)
require.Equal(t, connIDs[i], connID)
}
// the maximum number of paths is already being probed
connID, f, _ := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 2000}}, true)
connID, frames, _ := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 2000}, nil, true)
require.Zero(t, connID)
require.Nil(t, f.Frame)
require.Empty(t, frames)
// switching to a new path frees is up all paths
var f1 ackhandler.Frame
var f1 []ackhandler.Frame
pm.SwitchToPath(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000})
for i := range maxPaths {
connID, f, _ := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 3000 + i}}, true)
connID, frames, _ := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 3000 + i}, nil, true)
if i == 0 {
f1 = f
f1 = frames
}
require.NotNil(t, f.Frame)
require.NotEmpty(t, frames)
require.Equal(t, connIDs[maxPaths+i], connID)
}
// again, the maximum number of paths is already being probed
connID, f, _ = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 2000}}, true)
connID, frames, _ = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 2000}, nil, true)
require.Zero(t, connID)
require.Nil(t, f.Frame)
require.Empty(t, frames)
// losing the frame removes this path
f1.Handler.OnLost(f1.Frame)
f1[0].Handler.OnLost(f1[0].Frame)
// we can open exactly one more path
connID, f, _ = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4000}}, true)
require.NotNil(t, f.Frame)
connID, frames, _ = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4000}, nil, true)
require.NotEmpty(t, frames)
require.Equal(t, connIDs[2*maxPaths], connID)
connID, f, _ = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4001}}, true)
connID, frames, _ = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4001}, nil, true)
require.Zero(t, connID)
require.Nil(t, f.Frame)
require.Empty(t, frames)
}
type mockAddr struct {