forked from quic-go/quic-go
send PATH_RESPONSEs on the same path (#4991)
* make it possible to pack path probes with multiple frames * simplify function signature of pathManager.HandlePacket * simplify connection short header packet handling logic No functional change expected. * make server send PATH_RESPONSEs on the same path This makes sure that we’re actually testing for return routability.
This commit is contained in:
116
connection.go
116
connection.go
@@ -1043,7 +1043,7 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) (wasProcessed boo
|
||||
)
|
||||
}
|
||||
}
|
||||
isNonProbing, err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log)
|
||||
isNonProbing, pathChallenge, err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -1052,47 +1052,46 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) (wasProcessed boo
|
||||
if s.perspective == protocol.PerspectiveClient {
|
||||
return true, nil
|
||||
}
|
||||
if addrsEqual(p.remoteAddr, s.RemoteAddr()) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var shouldSwitchPath bool
|
||||
if !addrsEqual(p.remoteAddr, s.RemoteAddr()) {
|
||||
if s.pathManager == nil {
|
||||
s.pathManager = newPathManager(
|
||||
s.connIDManager.GetConnIDForPath,
|
||||
s.connIDManager.RetireConnIDForPath,
|
||||
s.logger,
|
||||
)
|
||||
}
|
||||
var destConnID protocol.ConnectionID
|
||||
var pathChallenge ackhandler.Frame
|
||||
destConnID, pathChallenge, shouldSwitchPath = s.pathManager.HandlePacket(p, isNonProbing)
|
||||
if pathChallenge.Frame != nil {
|
||||
probe, buf, err := s.packer.PackPathProbePacket(destConnID, pathChallenge, s.version)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
s.logger.Debugf("sending path probe packet to %s", p.remoteAddr)
|
||||
s.logShortHeaderPacket(probe.DestConnID, probe.Ack, probe.Frames, probe.StreamFrames, probe.PacketNumber, probe.PacketNumberLen, probe.KeyPhase, protocol.ECNNon, buf.Len(), false)
|
||||
s.registerPackedShortHeaderPacket(probe, protocol.ECNNon, p.rcvTime)
|
||||
s.sendQueue.SendProbe(buf, p.remoteAddr)
|
||||
}
|
||||
// We only switch paths in response to the highest-numbered non-probing packet,
|
||||
// see section 9.3 of RFC 9000.
|
||||
if !shouldSwitchPath || pn != s.largestRcvdAppData {
|
||||
return true, nil
|
||||
}
|
||||
s.pathManager.SwitchToPath(p.remoteAddr)
|
||||
s.sentPacketHandler.MigratedPath(p.rcvTime, protocol.ByteCount(s.config.InitialPacketSize))
|
||||
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
|
||||
if s.peerParams.MaxUDPPayloadSize > 0 && s.peerParams.MaxUDPPayloadSize < maxPacketSize {
|
||||
maxPacketSize = s.peerParams.MaxUDPPayloadSize
|
||||
}
|
||||
s.mtuDiscoverer.Reset(
|
||||
p.rcvTime,
|
||||
protocol.ByteCount(s.config.InitialPacketSize),
|
||||
maxPacketSize,
|
||||
if s.pathManager == nil {
|
||||
s.pathManager = newPathManager(
|
||||
s.connIDManager.GetConnIDForPath,
|
||||
s.connIDManager.RetireConnIDForPath,
|
||||
s.logger,
|
||||
)
|
||||
s.conn.ChangeRemoteAddr(p.remoteAddr, p.info)
|
||||
}
|
||||
destConnID, frames, shouldSwitchPath := s.pathManager.HandlePacket(p.remoteAddr, pathChallenge, isNonProbing)
|
||||
if len(frames) > 0 {
|
||||
probe, buf, err := s.packer.PackPathProbePacket(destConnID, frames, s.version)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
s.logger.Debugf("sending path probe packet to %s", p.remoteAddr)
|
||||
s.logShortHeaderPacket(probe.DestConnID, probe.Ack, probe.Frames, probe.StreamFrames, probe.PacketNumber, probe.PacketNumberLen, probe.KeyPhase, protocol.ECNNon, buf.Len(), false)
|
||||
s.registerPackedShortHeaderPacket(probe, protocol.ECNNon, p.rcvTime)
|
||||
s.sendQueue.SendProbe(buf, p.remoteAddr)
|
||||
}
|
||||
// We only switch paths in response to the highest-numbered non-probing packet,
|
||||
// see section 9.3 of RFC 9000.
|
||||
if !shouldSwitchPath || pn != s.largestRcvdAppData {
|
||||
return true, nil
|
||||
}
|
||||
s.pathManager.SwitchToPath(p.remoteAddr)
|
||||
s.sentPacketHandler.MigratedPath(p.rcvTime, protocol.ByteCount(s.config.InitialPacketSize))
|
||||
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
|
||||
if s.peerParams.MaxUDPPayloadSize > 0 && s.peerParams.MaxUDPPayloadSize < maxPacketSize {
|
||||
maxPacketSize = s.peerParams.MaxUDPPayloadSize
|
||||
}
|
||||
s.mtuDiscoverer.Reset(
|
||||
p.rcvTime,
|
||||
protocol.ByteCount(s.config.InitialPacketSize),
|
||||
maxPacketSize,
|
||||
)
|
||||
s.conn.ChangeRemoteAddr(p.remoteAddr, p.info)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -1377,7 +1376,7 @@ func (s *connection) handleUnpackedLongHeaderPacket(
|
||||
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames)
|
||||
}
|
||||
}
|
||||
isAckEliciting, _, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log, rcvTime)
|
||||
isAckEliciting, _, _, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log, rcvTime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1391,28 +1390,30 @@ func (s *connection) handleUnpackedShortHeaderPacket(
|
||||
ecn protocol.ECN,
|
||||
rcvTime time.Time,
|
||||
log func([]logging.Frame),
|
||||
) (isNonProbing bool, _ error) {
|
||||
) (isNonProbing bool, pathChallenge *wire.PathChallengeFrame, _ error) {
|
||||
s.lastPacketReceivedTime = rcvTime
|
||||
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
|
||||
s.keepAlivePingSent = false
|
||||
|
||||
isAckEliciting, isNonProbing, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log, rcvTime)
|
||||
isAckEliciting, isNonProbing, pathChallenge, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log, rcvTime)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
if err := s.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting); err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
return isNonProbing, nil
|
||||
return isNonProbing, pathChallenge, nil
|
||||
}
|
||||
|
||||
// handleFrames parses the frames, one after the other, and handles them.
|
||||
// It returns the last PATH_CHALLENGE frame contained in the packet, if any.
|
||||
func (s *connection) handleFrames(
|
||||
data []byte,
|
||||
destConnID protocol.ConnectionID,
|
||||
encLevel protocol.EncryptionLevel,
|
||||
log func([]logging.Frame),
|
||||
rcvTime time.Time,
|
||||
) (isAckEliciting, isNonProbing bool, _ error) {
|
||||
) (isAckEliciting, isNonProbing bool, pathChallenge *wire.PathChallengeFrame, _ error) {
|
||||
// Only used for tracing.
|
||||
// If we're not tracing, this slice will always remain empty.
|
||||
var frames []logging.Frame
|
||||
@@ -1424,7 +1425,7 @@ func (s *connection) handleFrames(
|
||||
for len(data) > 0 {
|
||||
l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
return false, false, nil, err
|
||||
}
|
||||
data = data[l:]
|
||||
if frame == nil {
|
||||
@@ -1444,19 +1445,23 @@ func (s *connection) handleFrames(
|
||||
if handleErr != nil {
|
||||
continue
|
||||
}
|
||||
if err := s.handleFrame(frame, encLevel, destConnID, rcvTime); err != nil {
|
||||
pc, err := s.handleFrame(frame, encLevel, destConnID, rcvTime)
|
||||
if err != nil {
|
||||
if log == nil {
|
||||
return false, false, err
|
||||
return false, false, nil, err
|
||||
}
|
||||
// If we're logging, we need to keep parsing (but not handling) all frames.
|
||||
handleErr = err
|
||||
}
|
||||
if pc != nil {
|
||||
pathChallenge = pc
|
||||
}
|
||||
}
|
||||
|
||||
if log != nil {
|
||||
log(frames)
|
||||
if handleErr != nil {
|
||||
return false, false, handleErr
|
||||
return false, false, nil, handleErr
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1466,7 +1471,7 @@ func (s *connection) handleFrames(
|
||||
// and an ACK serialized after that CRYPTO frame. In this case, we still want to process the ACK frame.
|
||||
if !handshakeWasComplete && s.handshakeComplete {
|
||||
if err := s.handleHandshakeComplete(rcvTime); err != nil {
|
||||
return false, false, err
|
||||
return false, false, nil, err
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -1477,7 +1482,7 @@ func (s *connection) handleFrame(
|
||||
encLevel protocol.EncryptionLevel,
|
||||
destConnID protocol.ConnectionID,
|
||||
rcvTime time.Time,
|
||||
) error {
|
||||
) (pathChallenge *wire.PathChallengeFrame, _ error) {
|
||||
var err error
|
||||
wire.LogFrame(s.logger, f, false)
|
||||
switch frame := f.(type) {
|
||||
@@ -1506,6 +1511,7 @@ func (s *connection) handleFrame(
|
||||
case *wire.PingFrame:
|
||||
case *wire.PathChallengeFrame:
|
||||
s.handlePathChallengeFrame(frame)
|
||||
pathChallenge = frame
|
||||
case *wire.PathResponseFrame:
|
||||
err = s.handlePathResponseFrame(frame)
|
||||
case *wire.NewTokenFrame:
|
||||
@@ -1521,7 +1527,7 @@ func (s *connection) handleFrame(
|
||||
default:
|
||||
err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name())
|
||||
}
|
||||
return err
|
||||
return pathChallenge, err
|
||||
}
|
||||
|
||||
// handlePacket is called by the server with a new packet
|
||||
@@ -1675,7 +1681,9 @@ func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error
|
||||
}
|
||||
|
||||
func (s *connection) handlePathChallengeFrame(f *wire.PathChallengeFrame) {
|
||||
s.queueControlFrame(&wire.PathResponseFrame{Data: f.Data})
|
||||
if s.perspective == protocol.PerspectiveClient {
|
||||
s.queueControlFrame(&wire.PathResponseFrame{Data: f.Data})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *connection) handlePathResponseFrame(f *wire.PathResponseFrame) error {
|
||||
@@ -2083,7 +2091,7 @@ func (s *connection) sendPackets(now time.Time) error {
|
||||
if pm := s.pathManagerOutgoing.Load(); pm != nil {
|
||||
connID, frame, tr, ok := pm.NextPathToProbe()
|
||||
if ok {
|
||||
probe, buf, err := s.packer.PackPathProbePacket(connID, frame, s.version)
|
||||
probe, buf, err := s.packer.PackPathProbePacket(connID, []ackhandler.Frame{frame}, s.version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -222,14 +222,17 @@ func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
|
||||
// STREAM frame
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
||||
str.EXPECT().handleStreamFrame(f, now)
|
||||
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
|
||||
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// RESET_STREAM frame
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
||||
str.EXPECT().handleResetStreamFrame(rsf, now)
|
||||
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
|
||||
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
||||
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
|
||||
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("for closed streams", func(t *testing.T) {
|
||||
@@ -238,13 +241,16 @@ func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
||||
// STREAM frame
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
||||
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
|
||||
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// RESET_STREAM frame
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
||||
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
|
||||
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
||||
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
|
||||
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("for invalid streams", func(t *testing.T) {
|
||||
@@ -254,13 +260,16 @@ func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
|
||||
testErr := errors.New("test err")
|
||||
// STREAM frame
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
|
||||
require.ErrorIs(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now), testErr)
|
||||
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
|
||||
require.ErrorIs(t, err, testErr)
|
||||
// RESET_STREAM frame
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
|
||||
require.ErrorIs(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now), testErr)
|
||||
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
|
||||
require.ErrorIs(t, err, testErr)
|
||||
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
|
||||
require.ErrorIs(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now), testErr)
|
||||
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
|
||||
require.ErrorIs(t, err, testErr)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -279,11 +288,13 @@ func TestConnectionHandleSendStreamFrames(t *testing.T) {
|
||||
// STOP_SENDING frame
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
|
||||
str.EXPECT().handleStopSendingFrame(ss)
|
||||
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
|
||||
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// MAX_STREAM_DATA frame
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
|
||||
str.EXPECT().updateSendWindow(msd.MaximumStreamData)
|
||||
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
|
||||
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("for closed streams", func(t *testing.T) {
|
||||
@@ -292,10 +303,12 @@ func TestConnectionHandleSendStreamFrames(t *testing.T) {
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
||||
// STOP_SENDING frame
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
|
||||
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
|
||||
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// MAX_STREAM_DATA frame
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
|
||||
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
|
||||
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("for invalid streams", func(t *testing.T) {
|
||||
@@ -305,10 +318,12 @@ func TestConnectionHandleSendStreamFrames(t *testing.T) {
|
||||
testErr := errors.New("test err")
|
||||
// STOP_SENDING frame
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
|
||||
require.ErrorIs(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now), testErr)
|
||||
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
|
||||
require.ErrorIs(t, err, testErr)
|
||||
// MAX_STREAM_DATA frame
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
|
||||
require.ErrorIs(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now), testErr)
|
||||
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
|
||||
require.ErrorIs(t, err, testErr)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -321,9 +336,11 @@ func TestConnectionHandleStreamNumFrames(t *testing.T) {
|
||||
// MAX_STREAMS frame
|
||||
msf := &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}
|
||||
streamsMap.EXPECT().HandleMaxStreamsFrame(msf)
|
||||
require.NoError(t, tc.conn.handleFrame(msf, protocol.Encryption1RTT, connID, now))
|
||||
_, err := tc.conn.handleFrame(msf, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// STREAMS_BLOCKED frame
|
||||
tc.conn.handleFrame(&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 1}, protocol.Encryption1RTT, connID, now)
|
||||
_, err = tc.conn.handleFrame(&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 1}, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
|
||||
@@ -334,9 +351,11 @@ func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
|
||||
connID := protocol.ConnectionID{}
|
||||
// MAX_DATA frame
|
||||
connFC.EXPECT().UpdateSendWindow(protocol.ByteCount(1337))
|
||||
require.NoError(t, tc.conn.handleFrame(&wire.MaxDataFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
|
||||
_, err := tc.conn.handleFrame(&wire.MaxDataFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// DATA_BLOCKED frame
|
||||
require.NoError(t, tc.conn.handleFrame(&wire.DataBlockedFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
|
||||
_, err = tc.conn.handleFrame(&wire.DataBlockedFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestConnectionOpenStreams(t *testing.T) {
|
||||
@@ -404,10 +423,8 @@ func TestConnectionServerInvalidFrames(t *testing.T) {
|
||||
{Name: "PATH_RESPONSE", Frame: &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}},
|
||||
} {
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
require.ErrorIs(t,
|
||||
tc.conn.handleFrame(test.Frame, protocol.Encryption1RTT, protocol.ConnectionID{}, time.Now()),
|
||||
&qerr.TransportError{ErrorCode: qerr.ProtocolViolation},
|
||||
)
|
||||
_, err := tc.conn.handleFrame(test.Frame, protocol.Encryption1RTT, protocol.ConnectionID{}, time.Now())
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2887,8 +2904,8 @@ func testConnectionPathValidation(t *testing.T, isNATRebinding bool) {
|
||||
protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil,
|
||||
),
|
||||
tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ protocol.ConnectionID, f ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||
pathChallenge = f.Frame.(*wire.PathChallengeFrame)
|
||||
func(_ protocol.ConnectionID, frames []ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||
pathChallenge = frames[0].Frame.(*wire.PathChallengeFrame)
|
||||
return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil
|
||||
},
|
||||
),
|
||||
@@ -2960,6 +2977,8 @@ func testConnectionPathValidation(t *testing.T, isNATRebinding bool) {
|
||||
}
|
||||
|
||||
payload := []byte{1} // PING frame
|
||||
payload, err = (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(payload, protocol.Version1)
|
||||
require.NoError(t, err)
|
||||
gomock.InOrder(
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
|
||||
protocol.PacketNumber(12), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil,
|
||||
@@ -3039,7 +3058,7 @@ func testConnectionMigration(t *testing.T, enabled bool) {
|
||||
).AnyTimes()
|
||||
packedProbe := make(chan struct{})
|
||||
tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ protocol.ConnectionID, f ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||
func(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||
defer close(packedProbe)
|
||||
return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil
|
||||
},
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
)
|
||||
|
||||
func TestConnectionMigration(t *testing.T) {
|
||||
ln, err := quic.ListenAddr("localhost:0", tlsConfig, getQuicConfig(nil))
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
|
||||
@@ -319,7 +319,7 @@ func (c *MockPackerPackMTUProbePacketCall) DoAndReturn(f func(ackhandler.Frame,
|
||||
}
|
||||
|
||||
// PackPathProbePacket mocks base method.
|
||||
func (m *MockPacker) PackPathProbePacket(arg0 protocol.ConnectionID, arg1 ackhandler.Frame, arg2 protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||
func (m *MockPacker) PackPathProbePacket(arg0 protocol.ConnectionID, arg1 []ackhandler.Frame, arg2 protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PackPathProbePacket", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(shortHeaderPacket)
|
||||
@@ -347,13 +347,13 @@ func (c *MockPackerPackPathProbePacketCall) Return(arg0 shortHeaderPacket, arg1
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPackerPackPathProbePacketCall) Do(f func(protocol.ConnectionID, ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall {
|
||||
func (c *MockPackerPackPathProbePacketCall) Do(f func(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPackerPackPathProbePacketCall) DoAndReturn(f func(protocol.ConnectionID, ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall {
|
||||
func (c *MockPackerPackPathProbePacketCall) DoAndReturn(f func(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ type packer interface {
|
||||
MaybePackPTOProbePacket(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error)
|
||||
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
|
||||
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
|
||||
PackPathProbePacket(protocol.ConnectionID, ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)
|
||||
PackPathProbePacket(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)
|
||||
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
|
||||
|
||||
SetToken([]byte)
|
||||
@@ -794,16 +794,20 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
|
||||
return packet, buffer, err
|
||||
}
|
||||
|
||||
func (p *packetPacker) PackPathProbePacket(connID protocol.ConnectionID, f ackhandler.Frame, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||
func (p *packetPacker) PackPathProbePacket(connID protocol.ConnectionID, frames []ackhandler.Frame, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
||||
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
|
||||
buf := getPacketBuffer()
|
||||
s, err := p.cryptoSetup.Get1RTTSealer()
|
||||
if err != nil {
|
||||
return shortHeaderPacket{}, nil, err
|
||||
}
|
||||
var l protocol.ByteCount
|
||||
for _, f := range frames {
|
||||
l += f.Frame.Length(v)
|
||||
}
|
||||
payload := payload{
|
||||
frames: []ackhandler.Frame{f},
|
||||
length: f.Frame.Length(v),
|
||||
frames: frames,
|
||||
length: l,
|
||||
}
|
||||
padding := protocol.MinInitialPacketSize - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead())
|
||||
packet, err := p.appendShortHeaderPacket(buf, connID, pn, pnLen, s.KeyPhase(), payload, padding, protocol.MinInitialPacketSize, s, false, v)
|
||||
|
||||
@@ -917,14 +917,21 @@ func TestPackPathProbePacket(t *testing.T) {
|
||||
|
||||
p, buf, err := tp.packer.PackPathProbePacket(
|
||||
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
||||
ackhandler.Frame{Frame: &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}},
|
||||
[]ackhandler.Frame{
|
||||
{Frame: &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}},
|
||||
{Frame: &wire.PathResponseFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}}},
|
||||
},
|
||||
protocol.Version1,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, protocol.PacketNumber(0x43), p.PacketNumber)
|
||||
require.Nil(t, p.Ack)
|
||||
require.Empty(t, p.StreamFrames)
|
||||
require.Equal(t, &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, p.Frames[0].Frame)
|
||||
require.Len(t, p.Frames, 2)
|
||||
// the frame order is randomized
|
||||
frames := []wire.Frame{p.Frames[0].Frame, p.Frames[1].Frame}
|
||||
require.Contains(t, frames, &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}})
|
||||
require.Contains(t, frames, &wire.PathResponseFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}})
|
||||
require.Len(t, buf.Data, protocol.MinInitialPacketSize)
|
||||
require.True(t, p.IsPathProbePacket)
|
||||
require.False(t, p.IsPathMTUProbePacket)
|
||||
|
||||
@@ -46,47 +46,69 @@ func newPathManager(
|
||||
|
||||
// Returns a path challenge frame if one should be sent.
|
||||
// May return nil.
|
||||
func (pm *pathManager) HandlePacket(p receivedPacket, isNonProbing bool) (_ protocol.ConnectionID, _ ackhandler.Frame, shouldSwitch bool) {
|
||||
for _, path := range pm.paths {
|
||||
if addrsEqual(path.addr, p.remoteAddr) {
|
||||
func (pm *pathManager) HandlePacket(
|
||||
remoteAddr net.Addr,
|
||||
pathChallenge *wire.PathChallengeFrame, // may be nil if the packet didn't contain a PATH_CHALLENGE
|
||||
isNonProbing bool,
|
||||
) (_ protocol.ConnectionID, _ []ackhandler.Frame, shouldSwitch bool) {
|
||||
var p *path
|
||||
pathID := pm.nextPathID
|
||||
for id, path := range pm.paths {
|
||||
if addrsEqual(path.addr, remoteAddr) {
|
||||
p = path
|
||||
pathID = id
|
||||
// already sent a PATH_CHALLENGE for this path
|
||||
if isNonProbing {
|
||||
path.rcvdNonProbing = true
|
||||
}
|
||||
if pm.logger.Debug() {
|
||||
pm.logger.Debugf("received packet for path %s that was already probed, validated: %t", p.remoteAddr, path.validated)
|
||||
pm.logger.Debugf("received packet for path %s that was already probed, validated: %t", remoteAddr, path.validated)
|
||||
}
|
||||
shouldSwitch = path.validated && path.rcvdNonProbing
|
||||
if pathChallenge == nil {
|
||||
return protocol.ConnectionID{}, nil, shouldSwitch
|
||||
}
|
||||
return protocol.ConnectionID{}, ackhandler.Frame{}, path.validated && path.rcvdNonProbing
|
||||
}
|
||||
}
|
||||
|
||||
if len(pm.paths) >= maxPaths {
|
||||
if pm.logger.Debug() {
|
||||
pm.logger.Debugf("received packet for previously unseen path %s, but already have %d paths", p.remoteAddr, len(pm.paths))
|
||||
pm.logger.Debugf("received packet for previously unseen path %s, but already have %d paths", remoteAddr, len(pm.paths))
|
||||
}
|
||||
return protocol.ConnectionID{}, ackhandler.Frame{}, false
|
||||
return protocol.ConnectionID{}, nil, shouldSwitch
|
||||
}
|
||||
|
||||
// previously unseen path, initiate path validation by sending a PATH_CHALLENGE
|
||||
connID, ok := pm.getConnID(pm.nextPathID)
|
||||
connID, ok := pm.getConnID(pathID)
|
||||
if !ok {
|
||||
pm.logger.Debugf("skipping validation of new path %s since no connection ID is available", p.remoteAddr)
|
||||
return protocol.ConnectionID{}, ackhandler.Frame{}, false
|
||||
pm.logger.Debugf("skipping validation of new path %s since no connection ID is available", remoteAddr)
|
||||
return protocol.ConnectionID{}, nil, shouldSwitch
|
||||
}
|
||||
var b [8]byte
|
||||
rand.Read(b[:])
|
||||
pm.paths[pm.nextPathID] = &path{
|
||||
addr: p.remoteAddr,
|
||||
pathChallenge: b,
|
||||
rcvdNonProbing: isNonProbing,
|
||||
|
||||
frames := make([]ackhandler.Frame, 0, 2)
|
||||
if p == nil {
|
||||
var pathChallengeData [8]byte
|
||||
rand.Read(pathChallengeData[:])
|
||||
p = &path{
|
||||
addr: remoteAddr,
|
||||
rcvdNonProbing: isNonProbing,
|
||||
pathChallenge: pathChallengeData,
|
||||
}
|
||||
frames = append(frames, ackhandler.Frame{
|
||||
Frame: &wire.PathChallengeFrame{Data: p.pathChallenge},
|
||||
Handler: (*pathManagerAckHandler)(pm),
|
||||
})
|
||||
pm.paths[pm.nextPathID] = p
|
||||
pm.nextPathID++
|
||||
pm.logger.Debugf("enqueueing PATH_CHALLENGE for new path %s", remoteAddr)
|
||||
}
|
||||
pm.nextPathID++
|
||||
frame := ackhandler.Frame{
|
||||
Frame: &wire.PathChallengeFrame{Data: b},
|
||||
Handler: (*pathManagerAckHandler)(pm),
|
||||
if pathChallenge != nil {
|
||||
frames = append(frames, ackhandler.Frame{
|
||||
Frame: &wire.PathResponseFrame{Data: pathChallenge.Data},
|
||||
Handler: (*pathManagerAckHandler)(pm),
|
||||
})
|
||||
}
|
||||
pm.logger.Debugf("enqueueing PATH_CHALLENGE for new path %s", p.remoteAddr)
|
||||
return connID, frame, false
|
||||
return connID, frames, shouldSwitch
|
||||
}
|
||||
|
||||
func (pm *pathManager) HandlePathResponseFrame(f *wire.PathResponseFrame) {
|
||||
|
||||
@@ -28,67 +28,135 @@ func TestPathManagerIntentionalMigration(t *testing.T) {
|
||||
func(id pathID) { retiredConnIDs = append(retiredConnIDs, connIDs[id]) },
|
||||
utils.DefaultLogger,
|
||||
)
|
||||
connID, f1, shouldSwitch := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
|
||||
connID, frames, shouldSwitch := pm.HandlePacket(
|
||||
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
|
||||
&wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}},
|
||||
false,
|
||||
)
|
||||
require.Equal(t, connIDs[0], connID)
|
||||
require.NotNil(t, f1.Frame)
|
||||
pc1 := f1.Frame.(*wire.PathChallengeFrame)
|
||||
require.Len(t, frames, 2)
|
||||
require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame)
|
||||
pc1 := frames[0].Frame.(*wire.PathChallengeFrame)
|
||||
require.NotZero(t, pc1.Data)
|
||||
require.NotEqual(t, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, pc1.Data)
|
||||
require.IsType(t, &wire.PathResponseFrame{}, frames[1].Frame)
|
||||
require.Equal(t, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, frames[1].Frame.(*wire.PathResponseFrame).Data)
|
||||
require.False(t, shouldSwitch)
|
||||
|
||||
// receiving another packet for the same path doesn't trigger another PATH_CHALLENGE
|
||||
connID, f, shouldSwitch := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
|
||||
connID, frames, shouldSwitch = pm.HandlePacket(
|
||||
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
require.Zero(t, connID)
|
||||
require.Nil(t, f.Frame)
|
||||
require.Len(t, frames, 0)
|
||||
require.False(t, shouldSwitch)
|
||||
|
||||
// receiving a packet for a different path triggers another PATH_CHALLENGE
|
||||
addr2 := &net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1000}
|
||||
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: addr2}, false)
|
||||
connID, frames, shouldSwitch = pm.HandlePacket(addr2, nil, false)
|
||||
require.Equal(t, connIDs[1], connID)
|
||||
require.NotNil(t, f.Frame)
|
||||
pc2 := f.Frame.(*wire.PathChallengeFrame)
|
||||
require.Len(t, frames, 1)
|
||||
require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame)
|
||||
pc2 := frames[0].Frame.(*wire.PathChallengeFrame)
|
||||
require.NotEqual(t, pc1.Data, pc2.Data)
|
||||
require.False(t, shouldSwitch)
|
||||
|
||||
// acknowledging the PATH_CHALLENGE doesn't confirm the path
|
||||
f1.Handler.OnAcked(f1.Frame)
|
||||
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
|
||||
frames[0].Handler.OnAcked(frames[0].Frame)
|
||||
connID, frames, shouldSwitch = pm.HandlePacket(
|
||||
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
require.Zero(t, connID)
|
||||
require.Nil(t, f.Frame)
|
||||
require.Empty(t, frames)
|
||||
require.False(t, shouldSwitch)
|
||||
|
||||
// receiving a PATH_RESPONSE for the second path confirms the path
|
||||
pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc2.Data})
|
||||
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: addr2}, false)
|
||||
connID, frames, shouldSwitch = pm.HandlePacket(addr2, nil, false)
|
||||
require.Zero(t, connID)
|
||||
require.Nil(t, f.Frame)
|
||||
require.Empty(t, frames)
|
||||
require.False(t, shouldSwitch) // no non-probing packet received yet
|
||||
require.Empty(t, retiredConnIDs)
|
||||
|
||||
// confirming the path doesn't remove other paths
|
||||
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
|
||||
connID, frames, shouldSwitch = pm.HandlePacket(
|
||||
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
require.Zero(t, connID)
|
||||
require.Nil(t, f.Frame)
|
||||
require.Empty(t, frames)
|
||||
require.False(t, shouldSwitch)
|
||||
|
||||
// now receive a non-probing packet for the new path
|
||||
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1000}}, true)
|
||||
connID, frames, shouldSwitch = pm.HandlePacket(
|
||||
&net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1000},
|
||||
nil,
|
||||
true,
|
||||
)
|
||||
require.Zero(t, connID)
|
||||
require.Nil(t, f.Frame)
|
||||
require.Empty(t, frames)
|
||||
require.True(t, shouldSwitch)
|
||||
|
||||
// now switch to the new path
|
||||
pm.SwitchToPath(&net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1000})
|
||||
|
||||
// switching to the path removes other paths
|
||||
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
|
||||
connID, frames, shouldSwitch = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, nil, false)
|
||||
require.Equal(t, connIDs[2], connID)
|
||||
require.NotNil(t, f.Frame)
|
||||
require.NotEqual(t, f.Frame.(*wire.PathChallengeFrame).Data, pc1.Data)
|
||||
require.NotEmpty(t, frames)
|
||||
require.NotEqual(t, frames[0].Frame.(*wire.PathChallengeFrame).Data, pc1.Data)
|
||||
require.False(t, shouldSwitch)
|
||||
require.Equal(t, []protocol.ConnectionID{connIDs[0]}, retiredConnIDs)
|
||||
}
|
||||
|
||||
func TestPathManagerMultipleProbes(t *testing.T) {
|
||||
connIDs := []protocol.ConnectionID{
|
||||
protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
||||
}
|
||||
pm := newPathManager(
|
||||
func(id pathID) (protocol.ConnectionID, bool) { return connIDs[id], true },
|
||||
func(id pathID) {},
|
||||
utils.DefaultLogger,
|
||||
)
|
||||
// first receive a packet without a PATH_CHALLENGE
|
||||
connID, frames, shouldSwitch := pm.HandlePacket(
|
||||
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
require.Equal(t, connIDs[0], connID)
|
||||
require.Len(t, frames, 1)
|
||||
require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame)
|
||||
require.False(t, shouldSwitch)
|
||||
|
||||
// now receive a packet on the same path with a PATH_CHALLENGE
|
||||
connID, frames, shouldSwitch = pm.HandlePacket(
|
||||
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
|
||||
&wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}},
|
||||
false,
|
||||
)
|
||||
require.Equal(t, connIDs[0], connID)
|
||||
require.Len(t, frames, 1)
|
||||
require.Equal(t, &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, frames[0].Frame)
|
||||
require.False(t, shouldSwitch)
|
||||
|
||||
// now receive an other packet on the same path with a PATH_RESPONSE
|
||||
connID, frames, shouldSwitch = pm.HandlePacket(
|
||||
&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000},
|
||||
&wire.PathChallengeFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}},
|
||||
false,
|
||||
)
|
||||
require.Equal(t, connIDs[0], connID)
|
||||
require.Len(t, frames, 1)
|
||||
require.Equal(t, &wire.PathResponseFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}}, frames[0].Frame)
|
||||
require.False(t, shouldSwitch)
|
||||
}
|
||||
|
||||
// The first packet received on the new path is already a non-probing packet.
|
||||
// We still need to validate the new path, but we can then switch over immediately.
|
||||
// This is the typical scenario when a NAT rebinding happens.
|
||||
@@ -103,19 +171,20 @@ func TestPathManagerNATRebinding(t *testing.T) {
|
||||
utils.DefaultLogger,
|
||||
)
|
||||
|
||||
connID, f, shouldSwitch := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, true)
|
||||
connID, frames, shouldSwitch := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, nil, true)
|
||||
require.Equal(t, connIDs[0], connID)
|
||||
require.NotNil(t, f.Frame)
|
||||
pc1 := f.Frame.(*wire.PathChallengeFrame)
|
||||
require.Len(t, frames, 1)
|
||||
require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame)
|
||||
pc1 := frames[0].Frame.(*wire.PathChallengeFrame)
|
||||
require.NotZero(t, pc1.Data)
|
||||
require.False(t, shouldSwitch)
|
||||
|
||||
// receiving a PATH_RESPONSE for the second path confirms the path
|
||||
pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc1.Data})
|
||||
// we now switch to the new path, as soon as the next packet on that path is received
|
||||
connID, f, shouldSwitch = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}}, false)
|
||||
connID, frames, shouldSwitch = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, nil, false)
|
||||
require.Zero(t, connID)
|
||||
require.Nil(t, f.Frame)
|
||||
require.Empty(t, frames)
|
||||
require.True(t, shouldSwitch)
|
||||
}
|
||||
|
||||
@@ -134,41 +203,41 @@ func TestPathManagerLimits(t *testing.T) {
|
||||
)
|
||||
|
||||
for i := range maxPaths {
|
||||
connID, f, _ := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000 + i}}, true)
|
||||
require.NotNil(t, f.Frame)
|
||||
connID, frames, _ := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000 + i}, nil, true)
|
||||
require.NotEmpty(t, frames)
|
||||
require.Equal(t, connIDs[i], connID)
|
||||
}
|
||||
// the maximum number of paths is already being probed
|
||||
connID, f, _ := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 2000}}, true)
|
||||
connID, frames, _ := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 2000}, nil, true)
|
||||
require.Zero(t, connID)
|
||||
require.Nil(t, f.Frame)
|
||||
require.Empty(t, frames)
|
||||
|
||||
// switching to a new path frees is up all paths
|
||||
var f1 ackhandler.Frame
|
||||
var f1 []ackhandler.Frame
|
||||
pm.SwitchToPath(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000})
|
||||
for i := range maxPaths {
|
||||
connID, f, _ := pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 3000 + i}}, true)
|
||||
connID, frames, _ := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 3000 + i}, nil, true)
|
||||
if i == 0 {
|
||||
f1 = f
|
||||
f1 = frames
|
||||
}
|
||||
require.NotNil(t, f.Frame)
|
||||
require.NotEmpty(t, frames)
|
||||
require.Equal(t, connIDs[maxPaths+i], connID)
|
||||
}
|
||||
// again, the maximum number of paths is already being probed
|
||||
connID, f, _ = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 2000}}, true)
|
||||
connID, frames, _ = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 2000}, nil, true)
|
||||
require.Zero(t, connID)
|
||||
require.Nil(t, f.Frame)
|
||||
require.Empty(t, frames)
|
||||
|
||||
// losing the frame removes this path
|
||||
f1.Handler.OnLost(f1.Frame)
|
||||
f1[0].Handler.OnLost(f1[0].Frame)
|
||||
|
||||
// we can open exactly one more path
|
||||
connID, f, _ = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4000}}, true)
|
||||
require.NotNil(t, f.Frame)
|
||||
connID, frames, _ = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4000}, nil, true)
|
||||
require.NotEmpty(t, frames)
|
||||
require.Equal(t, connIDs[2*maxPaths], connID)
|
||||
connID, f, _ = pm.HandlePacket(receivedPacket{remoteAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4001}}, true)
|
||||
connID, frames, _ = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4001}, nil, true)
|
||||
require.Zero(t, connID)
|
||||
require.Nil(t, f.Frame)
|
||||
require.Empty(t, frames)
|
||||
}
|
||||
|
||||
type mockAddr struct {
|
||||
|
||||
Reference in New Issue
Block a user