forked from quic-go/quic-go
handle stream-related frame in the streams map (#5212)
* handle stream-related frame in the streams map * remove stream manager interface and mock
This commit is contained in:
109
connection.go
109
connection.go
@@ -30,23 +30,6 @@ type unpacker interface {
|
||||
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
|
||||
}
|
||||
|
||||
type streamManager interface {
|
||||
GetOrOpenSendStream(protocol.StreamID) (*SendStream, error)
|
||||
GetOrOpenReceiveStream(protocol.StreamID) (*ReceiveStream, error)
|
||||
OpenStream() (*Stream, error)
|
||||
OpenUniStream() (*SendStream, error)
|
||||
OpenStreamSync(context.Context) (*Stream, error)
|
||||
OpenUniStreamSync(context.Context) (*SendStream, error)
|
||||
AcceptStream(context.Context) (*Stream, error)
|
||||
AcceptUniStream(context.Context) (*ReceiveStream, error)
|
||||
DeleteStream(protocol.StreamID) error
|
||||
UpdateLimits(*wire.TransportParameters)
|
||||
HandleMaxStreamsFrame(*wire.MaxStreamsFrame)
|
||||
CloseWithError(error)
|
||||
ResetFor0RTT()
|
||||
UseResetMaps()
|
||||
}
|
||||
|
||||
type cryptoStreamHandler interface {
|
||||
StartHandshake(context.Context) error
|
||||
ChangeConnectionID(protocol.ConnectionID)
|
||||
@@ -133,7 +116,7 @@ type connection struct {
|
||||
largestRcvdAppData protocol.PacketNumber
|
||||
pathManagerOutgoing atomic.Pointer[pathManagerOutgoing]
|
||||
|
||||
streamsMap streamManager
|
||||
streamsMap *streamsMap
|
||||
connIDManager *connIDManager
|
||||
connIDGenerator *connIDGenerator
|
||||
|
||||
@@ -1498,25 +1481,25 @@ func (s *connection) handleFrame(
|
||||
case *wire.CryptoFrame:
|
||||
err = s.handleCryptoFrame(frame, encLevel, rcvTime)
|
||||
case *wire.StreamFrame:
|
||||
err = s.handleStreamFrame(frame, rcvTime)
|
||||
err = s.streamsMap.HandleStreamFrame(frame, rcvTime)
|
||||
case *wire.AckFrame:
|
||||
err = s.handleAckFrame(frame, encLevel, rcvTime)
|
||||
case *wire.ConnectionCloseFrame:
|
||||
err = s.handleConnectionCloseFrame(frame)
|
||||
case *wire.ResetStreamFrame:
|
||||
err = s.handleResetStreamFrame(frame, rcvTime)
|
||||
err = s.streamsMap.HandleResetStreamFrame(frame, rcvTime)
|
||||
case *wire.MaxDataFrame:
|
||||
s.handleMaxDataFrame(frame)
|
||||
s.connFlowController.UpdateSendWindow(frame.MaximumData)
|
||||
case *wire.MaxStreamDataFrame:
|
||||
err = s.handleMaxStreamDataFrame(frame)
|
||||
err = s.streamsMap.HandleMaxStreamDataFrame(frame)
|
||||
case *wire.MaxStreamsFrame:
|
||||
s.handleMaxStreamsFrame(frame)
|
||||
s.streamsMap.HandleMaxStreamsFrame(frame)
|
||||
case *wire.DataBlockedFrame:
|
||||
case *wire.StreamDataBlockedFrame:
|
||||
err = s.handleStreamDataBlockedFrame(frame)
|
||||
err = s.streamsMap.HandleStreamDataBlockedFrame(frame)
|
||||
case *wire.StreamsBlockedFrame:
|
||||
case *wire.StopSendingFrame:
|
||||
err = s.handleStopSendingFrame(frame)
|
||||
err = s.streamsMap.HandleStopSendingFrame(frame)
|
||||
case *wire.PingFrame:
|
||||
case *wire.PathChallengeFrame:
|
||||
s.handlePathChallengeFrame(frame)
|
||||
@@ -1526,9 +1509,9 @@ func (s *connection) handleFrame(
|
||||
case *wire.NewTokenFrame:
|
||||
err = s.handleNewTokenFrame(frame)
|
||||
case *wire.NewConnectionIDFrame:
|
||||
err = s.handleNewConnectionIDFrame(frame)
|
||||
err = s.connIDManager.Add(frame)
|
||||
case *wire.RetireConnectionIDFrame:
|
||||
err = s.handleRetireConnectionIDFrame(rcvTime, frame, destConnID)
|
||||
err = s.connIDGenerator.Retire(frame.SequenceNumber, destConnID, rcvTime.Add(3*s.rttStats.PTO(false)))
|
||||
case *wire.HandshakeDoneFrame:
|
||||
err = s.handleHandshakeDoneFrame(rcvTime)
|
||||
case *wire.DatagramFrame:
|
||||
@@ -1625,70 +1608,6 @@ func (s *connection) handleHandshakeEvents(now time.Time) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *connection) handleStreamFrame(frame *wire.StreamFrame, rcvTime time.Time) error {
|
||||
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil { // stream was already closed and garbage collected
|
||||
return nil
|
||||
}
|
||||
return str.handleStreamFrame(frame, rcvTime)
|
||||
}
|
||||
|
||||
func (s *connection) handleMaxDataFrame(frame *wire.MaxDataFrame) {
|
||||
s.connFlowController.UpdateSendWindow(frame.MaximumData)
|
||||
}
|
||||
|
||||
func (s *connection) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error {
|
||||
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil {
|
||||
// stream is closed and already garbage collected
|
||||
return nil
|
||||
}
|
||||
str.updateSendWindow(frame.MaximumStreamData)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *connection) handleStreamDataBlockedFrame(frame *wire.StreamDataBlockedFrame) error {
|
||||
// We don't need to do anything in response to a STREAM_DATA_BLOCKED frame,
|
||||
// but we need to make sure that the stream ID is valid.
|
||||
_, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *connection) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) {
|
||||
s.streamsMap.HandleMaxStreamsFrame(frame)
|
||||
}
|
||||
|
||||
func (s *connection) handleResetStreamFrame(frame *wire.ResetStreamFrame, rcvTime time.Time) error {
|
||||
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil {
|
||||
// stream is closed and already garbage collected
|
||||
return nil
|
||||
}
|
||||
return str.handleResetStreamFrame(frame, rcvTime)
|
||||
}
|
||||
|
||||
func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error {
|
||||
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil {
|
||||
// stream is closed and already garbage collected
|
||||
return nil
|
||||
}
|
||||
str.handleStopSendingFrame(frame)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *connection) handlePathChallengeFrame(f *wire.PathChallengeFrame) {
|
||||
if s.perspective == protocol.PerspectiveClient {
|
||||
s.queueControlFrame(&wire.PathResponseFrame{Data: f.Data})
|
||||
@@ -1743,14 +1662,6 @@ func (s *connection) handleNewTokenFrame(frame *wire.NewTokenFrame) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *connection) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) error {
|
||||
return s.connIDManager.Add(f)
|
||||
}
|
||||
|
||||
func (s *connection) handleRetireConnectionIDFrame(now time.Time, f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error {
|
||||
return s.connIDGenerator.Retire(f.SequenceNumber, destConnID, now.Add(3*s.rttStats.PTO(false)))
|
||||
}
|
||||
|
||||
func (s *connection) handleHandshakeDoneFrame(rcvTime time.Time) error {
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
return &qerr.TransportError{
|
||||
|
||||
@@ -35,10 +35,6 @@ func connectionOptCryptoSetup(cs *mocks.MockCryptoSetup) testConnectionOpt {
|
||||
return func(conn *connection) { conn.cryptoStreamHandler = cs }
|
||||
}
|
||||
|
||||
func connectionOptStreamManager(sm *MockStreamManager) testConnectionOpt {
|
||||
return func(conn *connection) { conn.streamsMap = sm }
|
||||
}
|
||||
|
||||
func connectionOptConnFlowController(cfc flowcontrol.ConnectionFlowController) testConnectionOpt {
|
||||
return func(conn *connection) { conn.connFlowController = cfc }
|
||||
}
|
||||
@@ -208,145 +204,28 @@ func newClientTestConnection(
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
|
||||
const streamID protocol.StreamID = 5
|
||||
now := time.Now()
|
||||
func TestConnectionHandleStreamRelatedFrames(t *testing.T) {
|
||||
const id protocol.StreamID = 5
|
||||
connID := protocol.ConnectionID{}
|
||||
f := &wire.StreamFrame{StreamID: streamID, Data: []byte("foobar")}
|
||||
rsf := &wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 42, FinalSize: 1337}
|
||||
sdbf := &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: 1337}
|
||||
|
||||
t.Run("for existing and new streams", func(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
streamsMap := NewMockStreamManager(mockCtrl)
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
||||
mockSender := NewMockStreamSender(mockCtrl)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, gomock.Any()).AnyTimes()
|
||||
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
|
||||
str := newReceiveStream(streamID, mockSender, mockFC)
|
||||
// STREAM frame
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(len(f.Data)), false, gomock.Any())
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
||||
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// RESET_STREAM frame
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1337), true, gomock.Any())
|
||||
mockFC.EXPECT().Abandon()
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
||||
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
||||
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
tests := []struct {
|
||||
name string
|
||||
frame wire.Frame
|
||||
}{
|
||||
{name: "STREAM", frame: &wire.StreamFrame{StreamID: id, Data: []byte("foobar")}},
|
||||
{name: "RESET_STREAM", frame: &wire.ResetStreamFrame{StreamID: id, ErrorCode: 42, FinalSize: 1337}},
|
||||
{name: "STOP_SENDING", frame: &wire.StopSendingFrame{StreamID: id, ErrorCode: 42}},
|
||||
{name: "MAX_STREAM_DATA", frame: &wire.MaxStreamDataFrame{StreamID: id, MaximumStreamData: 1337}},
|
||||
{name: "STREAM_DATA_BLOCKED", frame: &wire.StreamDataBlockedFrame{StreamID: id, MaximumStreamData: 42}},
|
||||
}
|
||||
|
||||
t.Run("for closed streams", func(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
streamsMap := NewMockStreamManager(mockCtrl)
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
||||
// STREAM frame
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
||||
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// RESET_STREAM frame
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
||||
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
||||
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("for invalid streams", func(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
streamsMap := NewMockStreamManager(mockCtrl)
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
||||
// STREAM frame
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, assert.AnError)
|
||||
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
// RESET_STREAM frame
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, assert.AnError)
|
||||
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
||||
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, assert.AnError)
|
||||
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectionHandleSendStreamFrames(t *testing.T) {
|
||||
const streamID protocol.StreamID = 3
|
||||
now := time.Now()
|
||||
connID := protocol.ConnectionID{}
|
||||
ss := &wire.StopSendingFrame{StreamID: streamID, ErrorCode: 42}
|
||||
msd := &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337}
|
||||
|
||||
t.Run("for existing and new streams", func(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
streamsMap := NewMockStreamManager(mockCtrl)
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
||||
mockSender := NewMockStreamSender(mockCtrl)
|
||||
mockSender.EXPECT().onHasStreamControlFrame(streamID, gomock.Any()).AnyTimes()
|
||||
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
|
||||
str := newSendStream(context.Background(), streamID, mockSender, mockFC)
|
||||
// STOP_SENDING frame
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
|
||||
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// MAX_STREAM_DATA frame
|
||||
mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(1337))
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
|
||||
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("for closed streams", func(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
streamsMap := NewMockStreamManager(mockCtrl)
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
||||
// STOP_SENDING frame
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
|
||||
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// MAX_STREAM_DATA frame
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
|
||||
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("for invalid streams", func(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
streamsMap := NewMockStreamManager(mockCtrl)
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
||||
// STOP_SENDING frame
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, assert.AnError)
|
||||
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
// MAX_STREAM_DATA frame
|
||||
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, assert.AnError)
|
||||
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectionHandleStreamNumFrames(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
streamsMap := NewMockStreamManager(mockCtrl)
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
||||
now := time.Now()
|
||||
connID := protocol.ConnectionID{}
|
||||
// MAX_STREAMS frame
|
||||
msf := &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}
|
||||
streamsMap.EXPECT().HandleMaxStreamsFrame(msf)
|
||||
_, err := tc.conn.handleFrame(msf, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
// STREAMS_BLOCKED frame
|
||||
_, err = tc.conn.handleFrame(&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 1}, protocol.Encryption1RTT, connID, now)
|
||||
require.NoError(t, err)
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
tc := newServerTestConnection(t, gomock.NewController(t), nil, false)
|
||||
_, err := tc.conn.handleFrame(test.frame, protocol.Encryption1RTT, connID, time.Now())
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
|
||||
@@ -365,60 +244,6 @@ func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestConnectionOpenStreams(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
streamsMap := NewMockStreamManager(mockCtrl)
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
||||
|
||||
// using OpenStream
|
||||
str1 := &Stream{}
|
||||
streamsMap.EXPECT().OpenStream().Return(str1, nil)
|
||||
str, err := tc.conn.OpenStream()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, str1, str)
|
||||
|
||||
// using OpenStreamSync
|
||||
streamsMap.EXPECT().OpenStreamSync(context.Background()).Return(str1, nil)
|
||||
str, err = tc.conn.OpenStreamSync(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, str1, str)
|
||||
|
||||
// using OpenUniStream
|
||||
str2 := &SendStream{}
|
||||
streamsMap.EXPECT().OpenUniStream().Return(str2, nil)
|
||||
ustr, err := tc.conn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, str2, ustr)
|
||||
|
||||
// using OpenUniStreamSync
|
||||
streamsMap.EXPECT().OpenUniStreamSync(context.Background()).Return(str2, nil)
|
||||
ustr, err = tc.conn.OpenUniStreamSync(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, str2, ustr)
|
||||
}
|
||||
|
||||
func TestConnectionAcceptStreams(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
streamsMap := NewMockStreamManager(mockCtrl)
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
||||
|
||||
// bidirectional streams
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
str1 := &Stream{}
|
||||
streamsMap.EXPECT().AcceptStream(ctx).Return(str1, nil)
|
||||
str, err := tc.conn.AcceptStream(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, str1, str)
|
||||
|
||||
// unidirectional streams
|
||||
str2 := &ReceiveStream{}
|
||||
streamsMap.EXPECT().AcceptUniStream(ctx).Return(str2, nil)
|
||||
ustr, err := tc.conn.AcceptUniStream(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, str2, ustr)
|
||||
}
|
||||
|
||||
func TestConnectionServerInvalidFrames(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false)
|
||||
@@ -943,14 +768,12 @@ func TestConnectionMaxUnprocessedPackets(t *testing.T) {
|
||||
|
||||
func TestConnectionRemoteClose(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
mockStreamManager := NewMockStreamManager(mockCtrl)
|
||||
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
||||
unpacker := NewMockUnpacker(mockCtrl)
|
||||
tc := newServerTestConnection(t,
|
||||
mockCtrl,
|
||||
nil,
|
||||
false,
|
||||
connectionOptStreamManager(mockStreamManager),
|
||||
connectionOptTracer(tr),
|
||||
connectionOptUnpacker(unpacker),
|
||||
)
|
||||
@@ -964,8 +787,6 @@ func TestConnectionRemoteClose(t *testing.T) {
|
||||
|
||||
expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, Remote: true}
|
||||
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
streamErrChan := make(chan error, 1)
|
||||
mockStreamManager.EXPECT().CloseWithError(gomock.Any()).Do(func(e error) { streamErrChan <- e })
|
||||
tracerErrChan := make(chan error, 1)
|
||||
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { tracerErrChan <- e })
|
||||
tracer.EXPECT().Close()
|
||||
@@ -988,12 +809,6 @@ func TestConnectionRemoteClose(t *testing.T) {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
select {
|
||||
case err := <-streamErrChan:
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionIdleTimeoutDuringHandshake(t *testing.T) {
|
||||
@@ -1061,6 +876,8 @@ func TestConnectionTransportParameters(t *testing.T) {
|
||||
)
|
||||
_, err := tc.conn.OpenStream()
|
||||
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
||||
_, err = tc.conn.OpenUniStream()
|
||||
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
||||
tracer.EXPECT().ReceivedTransportParameters(gomock.Any())
|
||||
params := &wire.TransportParameters{
|
||||
MaxIdleTimeout: 90 * time.Second,
|
||||
@@ -1071,11 +888,78 @@ func TestConnectionTransportParameters(t *testing.T) {
|
||||
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
|
||||
OriginalDestinationConnectionID: tc.destConnID,
|
||||
MaxBidiStreamNum: 1,
|
||||
MaxUniStreamNum: 1,
|
||||
}
|
||||
require.NoError(t, tc.conn.handleTransportParameters(params))
|
||||
require.Equal(t, protocol.ByteCount(1337), connFC.SendWindowSize())
|
||||
_, err = tc.conn.OpenStream()
|
||||
require.NoError(t, err)
|
||||
_, err = tc.conn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestConnectionHandleMaxStreamsFrame(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
connFC := flowcontrol.NewConnectionFlowController(0, 0, nil, &utils.RTTStats{}, utils.DefaultLogger)
|
||||
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC))
|
||||
tc.conn.handleTransportParameters(&wire.TransportParameters{})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
uniStreamChan := make(chan error)
|
||||
go func() {
|
||||
_, err := tc.conn.OpenUniStreamSync(ctx)
|
||||
uniStreamChan <- err
|
||||
}()
|
||||
bidiStreamChan := make(chan error)
|
||||
go func() {
|
||||
_, err := tc.conn.OpenStreamSync(ctx)
|
||||
bidiStreamChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-uniStreamChan:
|
||||
t.Fatal("uni stream should be blocked")
|
||||
case <-bidiStreamChan:
|
||||
t.Fatal("bidi stream should be blocked")
|
||||
case <-time.After(scaleDuration(10 * time.Millisecond)):
|
||||
}
|
||||
|
||||
// MAX_STREAMS frame for bidirectional stream
|
||||
_, err := tc.conn.handleFrame(
|
||||
&wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10},
|
||||
protocol.Encryption1RTT,
|
||||
protocol.ConnectionID{},
|
||||
time.Now(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-uniStreamChan:
|
||||
t.Fatal("uni stream should be blocked")
|
||||
case <-time.After(scaleDuration(10 * time.Millisecond)):
|
||||
}
|
||||
select {
|
||||
case err := <-bidiStreamChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// MAX_STREAMS frame for bidirectional stream
|
||||
_, err = tc.conn.handleFrame(
|
||||
&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 10},
|
||||
protocol.Encryption1RTT,
|
||||
protocol.ConnectionID{},
|
||||
time.Now(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
select {
|
||||
case err := <-uniStreamChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionTransportParameterValidationFailureServer(t *testing.T) {
|
||||
@@ -3075,10 +2959,11 @@ func testConnectionMigration(t *testing.T, enabled bool) {
|
||||
).AnyTimes()
|
||||
tc.connRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any())
|
||||
// add a new connection ID, so the path can be probed
|
||||
require.NoError(t, tc.conn.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{
|
||||
_, err = tc.conn.handleFrame(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: 1,
|
||||
ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
||||
}))
|
||||
}, protocol.EncryptionInitial, tc.destConnID, time.Now())
|
||||
require.NoError(t, err)
|
||||
errChan := make(chan error, 1)
|
||||
go func() { errChan <- tc.conn.run() }()
|
||||
|
||||
|
||||
@@ -1,573 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/quic-go/quic-go (interfaces: StreamManager)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager
|
||||
//
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
protocol "github.com/quic-go/quic-go/internal/protocol"
|
||||
wire "github.com/quic-go/quic-go/internal/wire"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockStreamManager is a mock of StreamManager interface.
|
||||
type MockStreamManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStreamManagerMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockStreamManagerMockRecorder is the mock recorder for MockStreamManager.
|
||||
type MockStreamManagerMockRecorder struct {
|
||||
mock *MockStreamManager
|
||||
}
|
||||
|
||||
// NewMockStreamManager creates a new mock instance.
|
||||
func NewMockStreamManager(ctrl *gomock.Controller) *MockStreamManager {
|
||||
mock := &MockStreamManager{ctrl: ctrl}
|
||||
mock.recorder = &MockStreamManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AcceptStream mocks base method.
|
||||
func (m *MockStreamManager) AcceptStream(arg0 context.Context) (*Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptStream", arg0)
|
||||
ret0, _ := ret[0].(*Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptStream indicates an expected call of AcceptStream.
|
||||
func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 any) *MockStreamManagerAcceptStreamCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0)
|
||||
return &MockStreamManagerAcceptStreamCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerAcceptStreamCall wrap *gomock.Call
|
||||
type MockStreamManagerAcceptStreamCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerAcceptStreamCall) Return(arg0 *Stream, arg1 error) *MockStreamManagerAcceptStreamCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerAcceptStreamCall) Do(f func(context.Context) (*Stream, error)) *MockStreamManagerAcceptStreamCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerAcceptStreamCall) DoAndReturn(f func(context.Context) (*Stream, error)) *MockStreamManagerAcceptStreamCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// AcceptUniStream mocks base method.
|
||||
func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (*ReceiveStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
|
||||
ret0, _ := ret[0].(*ReceiveStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptUniStream indicates an expected call of AcceptUniStream.
|
||||
func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 any) *MockStreamManagerAcceptUniStreamCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0)
|
||||
return &MockStreamManagerAcceptUniStreamCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerAcceptUniStreamCall wrap *gomock.Call
|
||||
type MockStreamManagerAcceptUniStreamCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerAcceptUniStreamCall) Return(arg0 *ReceiveStream, arg1 error) *MockStreamManagerAcceptUniStreamCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerAcceptUniStreamCall) Do(f func(context.Context) (*ReceiveStream, error)) *MockStreamManagerAcceptUniStreamCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerAcceptUniStreamCall) DoAndReturn(f func(context.Context) (*ReceiveStream, error)) *MockStreamManagerAcceptUniStreamCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// CloseWithError mocks base method.
|
||||
func (m *MockStreamManager) CloseWithError(arg0 error) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "CloseWithError", arg0)
|
||||
}
|
||||
|
||||
// CloseWithError indicates an expected call of CloseWithError.
|
||||
func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 any) *MockStreamManagerCloseWithErrorCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0)
|
||||
return &MockStreamManagerCloseWithErrorCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerCloseWithErrorCall wrap *gomock.Call
|
||||
type MockStreamManagerCloseWithErrorCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerCloseWithErrorCall) Return() *MockStreamManagerCloseWithErrorCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerCloseWithErrorCall) Do(f func(error)) *MockStreamManagerCloseWithErrorCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerCloseWithErrorCall) DoAndReturn(f func(error)) *MockStreamManagerCloseWithErrorCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DeleteStream mocks base method.
|
||||
func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteStream", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteStream indicates an expected call of DeleteStream.
|
||||
func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 any) *MockStreamManagerDeleteStreamCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0)
|
||||
return &MockStreamManagerDeleteStreamCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerDeleteStreamCall wrap *gomock.Call
|
||||
type MockStreamManagerDeleteStreamCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerDeleteStreamCall) Return(arg0 error) *MockStreamManagerDeleteStreamCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerDeleteStreamCall) Do(f func(protocol.StreamID) error) *MockStreamManagerDeleteStreamCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerDeleteStreamCall) DoAndReturn(f func(protocol.StreamID) error) *MockStreamManagerDeleteStreamCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// GetOrOpenReceiveStream mocks base method.
|
||||
func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (*ReceiveStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0)
|
||||
ret0, _ := ret[0].(*ReceiveStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream.
|
||||
func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 any) *MockStreamManagerGetOrOpenReceiveStreamCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0)
|
||||
return &MockStreamManagerGetOrOpenReceiveStreamCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerGetOrOpenReceiveStreamCall wrap *gomock.Call
|
||||
type MockStreamManagerGetOrOpenReceiveStreamCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerGetOrOpenReceiveStreamCall) Return(arg0 *ReceiveStream, arg1 error) *MockStreamManagerGetOrOpenReceiveStreamCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerGetOrOpenReceiveStreamCall) Do(f func(protocol.StreamID) (*ReceiveStream, error)) *MockStreamManagerGetOrOpenReceiveStreamCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerGetOrOpenReceiveStreamCall) DoAndReturn(f func(protocol.StreamID) (*ReceiveStream, error)) *MockStreamManagerGetOrOpenReceiveStreamCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// GetOrOpenSendStream mocks base method.
|
||||
func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (*SendStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
|
||||
ret0, _ := ret[0].(*SendStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream.
|
||||
func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 any) *MockStreamManagerGetOrOpenSendStreamCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0)
|
||||
return &MockStreamManagerGetOrOpenSendStreamCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerGetOrOpenSendStreamCall wrap *gomock.Call
|
||||
type MockStreamManagerGetOrOpenSendStreamCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerGetOrOpenSendStreamCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerGetOrOpenSendStreamCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerGetOrOpenSendStreamCall) Do(f func(protocol.StreamID) (*SendStream, error)) *MockStreamManagerGetOrOpenSendStreamCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerGetOrOpenSendStreamCall) DoAndReturn(f func(protocol.StreamID) (*SendStream, error)) *MockStreamManagerGetOrOpenSendStreamCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// HandleMaxStreamsFrame mocks base method.
|
||||
func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "HandleMaxStreamsFrame", arg0)
|
||||
}
|
||||
|
||||
// HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame.
|
||||
func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 any) *MockStreamManagerHandleMaxStreamsFrameCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0)
|
||||
return &MockStreamManagerHandleMaxStreamsFrameCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerHandleMaxStreamsFrameCall wrap *gomock.Call
|
||||
type MockStreamManagerHandleMaxStreamsFrameCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerHandleMaxStreamsFrameCall) Return() *MockStreamManagerHandleMaxStreamsFrameCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerHandleMaxStreamsFrameCall) Do(f func(*wire.MaxStreamsFrame)) *MockStreamManagerHandleMaxStreamsFrameCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerHandleMaxStreamsFrameCall) DoAndReturn(f func(*wire.MaxStreamsFrame)) *MockStreamManagerHandleMaxStreamsFrameCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// OpenStream mocks base method.
|
||||
func (m *MockStreamManager) OpenStream() (*Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OpenStream")
|
||||
ret0, _ := ret[0].(*Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenStream indicates an expected call of OpenStream.
|
||||
func (mr *MockStreamManagerMockRecorder) OpenStream() *MockStreamManagerOpenStreamCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamManager)(nil).OpenStream))
|
||||
return &MockStreamManagerOpenStreamCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerOpenStreamCall wrap *gomock.Call
|
||||
type MockStreamManagerOpenStreamCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerOpenStreamCall) Return(arg0 *Stream, arg1 error) *MockStreamManagerOpenStreamCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerOpenStreamCall) Do(f func() (*Stream, error)) *MockStreamManagerOpenStreamCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerOpenStreamCall) DoAndReturn(f func() (*Stream, error)) *MockStreamManagerOpenStreamCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// OpenStreamSync mocks base method.
|
||||
func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (*Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OpenStreamSync", arg0)
|
||||
ret0, _ := ret[0].(*Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenStreamSync indicates an expected call of OpenStreamSync.
|
||||
func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 any) *MockStreamManagerOpenStreamSyncCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0)
|
||||
return &MockStreamManagerOpenStreamSyncCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerOpenStreamSyncCall wrap *gomock.Call
|
||||
type MockStreamManagerOpenStreamSyncCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerOpenStreamSyncCall) Return(arg0 *Stream, arg1 error) *MockStreamManagerOpenStreamSyncCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerOpenStreamSyncCall) Do(f func(context.Context) (*Stream, error)) *MockStreamManagerOpenStreamSyncCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerOpenStreamSyncCall) DoAndReturn(f func(context.Context) (*Stream, error)) *MockStreamManagerOpenStreamSyncCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// OpenUniStream mocks base method.
|
||||
func (m *MockStreamManager) OpenUniStream() (*SendStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OpenUniStream")
|
||||
ret0, _ := ret[0].(*SendStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenUniStream indicates an expected call of OpenUniStream.
|
||||
func (mr *MockStreamManagerMockRecorder) OpenUniStream() *MockStreamManagerOpenUniStreamCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStream))
|
||||
return &MockStreamManagerOpenUniStreamCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerOpenUniStreamCall wrap *gomock.Call
|
||||
type MockStreamManagerOpenUniStreamCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerOpenUniStreamCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerOpenUniStreamCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerOpenUniStreamCall) Do(f func() (*SendStream, error)) *MockStreamManagerOpenUniStreamCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerOpenUniStreamCall) DoAndReturn(f func() (*SendStream, error)) *MockStreamManagerOpenUniStreamCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// OpenUniStreamSync mocks base method.
|
||||
func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (*SendStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
|
||||
ret0, _ := ret[0].(*SendStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync.
|
||||
func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 any) *MockStreamManagerOpenUniStreamSyncCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0)
|
||||
return &MockStreamManagerOpenUniStreamSyncCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerOpenUniStreamSyncCall wrap *gomock.Call
|
||||
type MockStreamManagerOpenUniStreamSyncCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerOpenUniStreamSyncCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerOpenUniStreamSyncCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerOpenUniStreamSyncCall) Do(f func(context.Context) (*SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerOpenUniStreamSyncCall) DoAndReturn(f func(context.Context) (*SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// ResetFor0RTT mocks base method.
|
||||
func (m *MockStreamManager) ResetFor0RTT() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ResetFor0RTT")
|
||||
}
|
||||
|
||||
// ResetFor0RTT indicates an expected call of ResetFor0RTT.
|
||||
func (mr *MockStreamManagerMockRecorder) ResetFor0RTT() *MockStreamManagerResetFor0RTTCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetFor0RTT", reflect.TypeOf((*MockStreamManager)(nil).ResetFor0RTT))
|
||||
return &MockStreamManagerResetFor0RTTCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerResetFor0RTTCall wrap *gomock.Call
|
||||
type MockStreamManagerResetFor0RTTCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerResetFor0RTTCall) Return() *MockStreamManagerResetFor0RTTCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerResetFor0RTTCall) Do(f func()) *MockStreamManagerResetFor0RTTCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerResetFor0RTTCall) DoAndReturn(f func()) *MockStreamManagerResetFor0RTTCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// UpdateLimits mocks base method.
|
||||
func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "UpdateLimits", arg0)
|
||||
}
|
||||
|
||||
// UpdateLimits indicates an expected call of UpdateLimits.
|
||||
func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 any) *MockStreamManagerUpdateLimitsCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0)
|
||||
return &MockStreamManagerUpdateLimitsCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerUpdateLimitsCall wrap *gomock.Call
|
||||
type MockStreamManagerUpdateLimitsCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerUpdateLimitsCall) Return() *MockStreamManagerUpdateLimitsCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerUpdateLimitsCall) Do(f func(*wire.TransportParameters)) *MockStreamManagerUpdateLimitsCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerUpdateLimitsCall) DoAndReturn(f func(*wire.TransportParameters)) *MockStreamManagerUpdateLimitsCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// UseResetMaps mocks base method.
|
||||
func (m *MockStreamManager) UseResetMaps() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "UseResetMaps")
|
||||
}
|
||||
|
||||
// UseResetMaps indicates an expected call of UseResetMaps.
|
||||
func (mr *MockStreamManagerMockRecorder) UseResetMaps() *MockStreamManagerUseResetMapsCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UseResetMaps", reflect.TypeOf((*MockStreamManager)(nil).UseResetMaps))
|
||||
return &MockStreamManagerUseResetMapsCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamManagerUseResetMapsCall wrap *gomock.Call
|
||||
type MockStreamManagerUseResetMapsCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamManagerUseResetMapsCall) Return() *MockStreamManagerUseResetMapsCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamManagerUseResetMapsCall) Do(f func()) *MockStreamManagerUseResetMapsCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamManagerUseResetMapsCall) DoAndReturn(f func()) *MockStreamManagerUseResetMapsCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
@@ -26,9 +26,6 @@ type FrameSource = frameSource
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource"
|
||||
type AckFrameSource = ackFrameSource
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager"
|
||||
type StreamManager = streamManager
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager"
|
||||
type SealingManager = sealingManager
|
||||
|
||||
|
||||
215
streams_map.go
215
streams_map.go
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/flowcontrol"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
@@ -37,8 +38,6 @@ type streamsMap struct {
|
||||
reset bool
|
||||
}
|
||||
|
||||
var _ streamManager = &streamsMap{}
|
||||
|
||||
func newStreamsMap(
|
||||
ctx context.Context,
|
||||
sender streamSender,
|
||||
@@ -180,90 +179,6 @@ func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
|
||||
panic("")
|
||||
}
|
||||
|
||||
func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStream, error) {
|
||||
str, err := m.getOrOpenReceiveStream(id)
|
||||
if err != nil {
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStream, error) {
|
||||
switch id.Type() {
|
||||
case protocol.StreamTypeUni:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
// an outgoing unidirectional stream is a send stream, not a receive stream
|
||||
return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
|
||||
}
|
||||
return m.incomingUniStreams.GetOrOpenStream(id)
|
||||
case protocol.StreamTypeBidi:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
str, err := m.outgoingBidiStreams.GetStream(id)
|
||||
if str == nil && err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return str.ReceiveStream, err
|
||||
} else {
|
||||
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
|
||||
if str == nil && err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return str.ReceiveStream, err
|
||||
}
|
||||
}
|
||||
panic("")
|
||||
}
|
||||
|
||||
func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (*SendStream, error) {
|
||||
str, err := m.getOrOpenSendStream(id)
|
||||
if err != nil {
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (*SendStream, error) {
|
||||
switch id.Type() {
|
||||
case protocol.StreamTypeUni:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
str, err := m.outgoingUniStreams.GetStream(id)
|
||||
if str == nil && err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return str, err
|
||||
}
|
||||
// an incoming unidirectional stream is a receive stream, not a send stream
|
||||
return nil, fmt.Errorf("peer attempted to open send stream %d", id)
|
||||
case protocol.StreamTypeBidi:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
str, err := m.outgoingBidiStreams.GetStream(id)
|
||||
if str == nil && err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str.SendStream, nil
|
||||
} else {
|
||||
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
|
||||
if str == nil && err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str.SendStream, nil
|
||||
}
|
||||
}
|
||||
panic("")
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
|
||||
switch f.Type {
|
||||
case protocol.StreamTypeUni:
|
||||
@@ -273,6 +188,134 @@ func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
|
||||
}
|
||||
}
|
||||
|
||||
type sendStreamFrameHandler interface {
|
||||
updateSendWindow(protocol.ByteCount)
|
||||
handleStopSendingFrame(*wire.StopSendingFrame)
|
||||
}
|
||||
|
||||
func (m *streamsMap) getSendStream(id protocol.StreamID) (sendStreamFrameHandler, error) {
|
||||
switch id.Type() {
|
||||
case protocol.StreamTypeUni:
|
||||
if id.InitiatedBy() != m.perspective {
|
||||
// an outgoing unidirectional stream is a send stream, not a receive stream
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: fmt.Sprintf("invalid frame for send stream %d", id),
|
||||
}
|
||||
}
|
||||
str, err := m.outgoingUniStreams.GetStream(id)
|
||||
if str == nil || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str, nil
|
||||
case protocol.StreamTypeBidi:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
str, err := m.outgoingBidiStreams.GetStream(id)
|
||||
if str == nil || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
|
||||
if str == nil || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleMaxStreamDataFrame(f *wire.MaxStreamDataFrame) error {
|
||||
str, err := m.getSendStream(f.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil { // stream already deleted
|
||||
return nil
|
||||
}
|
||||
str.updateSendWindow(f.MaximumStreamData)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleStopSendingFrame(f *wire.StopSendingFrame) error {
|
||||
str, err := m.getSendStream(f.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil { // stream already deleted
|
||||
return nil
|
||||
}
|
||||
str.handleStopSendingFrame(f)
|
||||
return nil
|
||||
}
|
||||
|
||||
type receiveStreamFrameHandler interface {
|
||||
handleResetStreamFrame(*wire.ResetStreamFrame, time.Time) error
|
||||
handleStreamFrame(*wire.StreamFrame, time.Time) error
|
||||
}
|
||||
|
||||
func (m *streamsMap) getReceiveStream(id protocol.StreamID) (receiveStreamFrameHandler, error) {
|
||||
switch id.Type() {
|
||||
case protocol.StreamTypeUni:
|
||||
// an outgoing unidirectional stream is a send stream, not a receive stream
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: fmt.Sprintf("invalid frame for receive stream %d", id),
|
||||
}
|
||||
}
|
||||
str, err := m.incomingUniStreams.GetOrOpenStream(id)
|
||||
if err != nil || str == nil {
|
||||
return nil, err
|
||||
}
|
||||
return str, nil
|
||||
case protocol.StreamTypeBidi:
|
||||
var str *Stream
|
||||
var err error
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
str, err = m.outgoingBidiStreams.GetStream(id)
|
||||
} else {
|
||||
str, err = m.incomingBidiStreams.GetOrOpenStream(id)
|
||||
}
|
||||
if str == nil || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleStreamDataBlockedFrame(f *wire.StreamDataBlockedFrame) error {
|
||||
if _, err := m.getReceiveStream(f.StreamID); err != nil {
|
||||
return err
|
||||
}
|
||||
// We don't need to do anything in response to a STREAM_DATA_BLOCKED frame,
|
||||
// but we need to make sure that the stream ID is valid.
|
||||
return nil // we don't need to do anything in response to a STREAM_DATA_BLOCKED frame
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleResetStreamFrame(f *wire.ResetStreamFrame, rcvTime time.Time) error {
|
||||
str, err := m.getReceiveStream(f.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil { // stream already deleted
|
||||
return nil
|
||||
}
|
||||
return str.handleResetStreamFrame(f, rcvTime)
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleStreamFrame(f *wire.StreamFrame, rcvTime time.Time) error {
|
||||
str, err := m.getReceiveStream(f.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil { // stream already deleted
|
||||
return nil
|
||||
}
|
||||
return str.handleStreamFrame(f, rcvTime)
|
||||
}
|
||||
|
||||
func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
|
||||
m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
|
||||
m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective))
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
@@ -115,7 +116,10 @@ func (m *incomingStreamsMap[T]) GetOrOpenStream(id protocol.StreamID) (T, error)
|
||||
m.mutex.RLock()
|
||||
if id > m.maxStream {
|
||||
m.mutex.RUnlock()
|
||||
return *new(T), fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream)
|
||||
return *new(T), &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamLimitError,
|
||||
ErrorMessage: fmt.Sprintf("peer tried to open stream %d (current limit: %d)", id, m.maxStream),
|
||||
}
|
||||
}
|
||||
// if the num is smaller than the highest we accepted
|
||||
// * this stream exists in the map, and we can return it, or
|
||||
@@ -152,7 +156,13 @@ func (m *incomingStreamsMap[T]) DeleteStream(id protocol.StreamID) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.deleteStream(id)
|
||||
if err := m.deleteStream(id); err != nil {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *incomingStreamsMap[T]) deleteStream(id protocol.StreamID) error {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -80,6 +81,7 @@ func testStreamsMapIncomingGettingStreams(t *testing.T, perspective protocol.Per
|
||||
_, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams - 4)
|
||||
require.NoError(t, err)
|
||||
_, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamLimitError})
|
||||
require.ErrorContains(t, err, "peer tried to open stream")
|
||||
require.Equal(t, maxNumStreams, newStreamCounter)
|
||||
}
|
||||
@@ -151,6 +153,7 @@ func testStreamsMapIncomingDeletingStreams(t *testing.T, perspective protocol.Pe
|
||||
perspective,
|
||||
)
|
||||
err := m.DeleteStream(firstStream + 1337*4)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, "tried to delete unknown incoming stream")
|
||||
|
||||
s, err := m.GetOrOpenStream(firstStream + 4)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
@@ -153,7 +154,10 @@ func (m *outgoingStreamsMap[T]) GetStream(id protocol.StreamID) (T, error) {
|
||||
m.mutex.RLock()
|
||||
if id >= m.nextStream {
|
||||
m.mutex.RUnlock()
|
||||
return *new(T), fmt.Errorf("peer attempted to open stream %d", id)
|
||||
return *new(T), &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
|
||||
}
|
||||
}
|
||||
s := m.streams[id]
|
||||
m.mutex.RUnlock()
|
||||
@@ -165,7 +169,10 @@ func (m *outgoingStreamsMap[T]) DeleteStream(id protocol.StreamID) error {
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if _, ok := m.streams[id]; !ok {
|
||||
return fmt.Errorf("tried to delete unknown outgoing stream %d", id)
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: fmt.Sprintf("tried to delete unknown outgoing stream %d", id),
|
||||
}
|
||||
}
|
||||
delete(m.streams, id)
|
||||
return nil
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -34,6 +35,7 @@ func testStreamsMapOutgoingOpenAndDelete(t *testing.T, perspective protocol.Pers
|
||||
m.SetMaxStream(protocol.MaxStreamID)
|
||||
|
||||
_, err := m.GetStream(firstStream)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstStream))
|
||||
|
||||
str1, err := m.OpenStream()
|
||||
@@ -54,14 +56,14 @@ func testStreamsMapOutgoingOpenAndDelete(t *testing.T, perspective protocol.Pers
|
||||
|
||||
err = m.DeleteStream(firstStream + 1337*4)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, "tried to delete unknown outgoing stream")
|
||||
|
||||
require.NoError(t, m.DeleteStream(firstStream))
|
||||
// deleting the same stream twice will fail
|
||||
require.ErrorContains(t,
|
||||
m.DeleteStream(firstStream),
|
||||
"tried to delete unknown outgoing stream",
|
||||
)
|
||||
err = m.DeleteStream(firstStream)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, "tried to delete unknown outgoing stream")
|
||||
// after deleting the stream it's not available anymore
|
||||
str, err := m.GetStream(firstStream)
|
||||
require.NoError(t, err)
|
||||
@@ -98,7 +100,6 @@ func testStreamsMapOutgoingLimits(t *testing.T, perspective protocol.Perspective
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err = m.OpenStreamSync(ctx)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
|
||||
// OpenStreamSync blocks until the context is canceled...
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
|
||||
func TestStreamsMapCreatingStreams(t *testing.T) {
|
||||
t.Run("client", func(t *testing.T) {
|
||||
testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveClient,
|
||||
testStreamsMapCreatingStreams(t, protocol.PerspectiveClient,
|
||||
protocol.FirstIncomingBidiStreamClient,
|
||||
protocol.FirstOutgoingBidiStreamClient,
|
||||
protocol.FirstIncomingUniStreamClient,
|
||||
@@ -27,7 +27,7 @@ func TestStreamsMapCreatingStreams(t *testing.T) {
|
||||
)
|
||||
})
|
||||
t.Run("server", func(t *testing.T) {
|
||||
testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveServer,
|
||||
testStreamsMapCreatingStreams(t, protocol.PerspectiveServer,
|
||||
protocol.FirstIncomingBidiStreamServer,
|
||||
protocol.FirstOutgoingBidiStreamServer,
|
||||
protocol.FirstIncomingUniStreamServer,
|
||||
@@ -36,7 +36,7 @@ func TestStreamsMapCreatingStreams(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func testStreamsMapCreatingAndDeletingStreams(t *testing.T,
|
||||
func testStreamsMapCreatingStreams(t *testing.T,
|
||||
perspective protocol.Perspective,
|
||||
firstIncomingBidiStream protocol.StreamID,
|
||||
firstOutgoingBidiStream protocol.StreamID,
|
||||
@@ -50,7 +50,9 @@ func testStreamsMapCreatingAndDeletingStreams(t *testing.T,
|
||||
mockSender,
|
||||
func(wire.Frame) {},
|
||||
func(protocol.StreamID) flowcontrol.StreamFlowController {
|
||||
return mocks.NewMockStreamFlowController(mockCtrl)
|
||||
fc := mocks.NewMockStreamFlowController(mockCtrl)
|
||||
fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
return fc
|
||||
},
|
||||
1,
|
||||
1,
|
||||
@@ -76,25 +78,9 @@ func testStreamsMapCreatingAndDeletingStreams(t *testing.T,
|
||||
assert.Equal(t, ustr1.StreamID(), firstOutgoingUniStream)
|
||||
assert.Equal(t, ustr2.StreamID(), firstOutgoingUniStream+4)
|
||||
|
||||
// accepting streams:
|
||||
// This function is called when a frame referencing this stream is received.
|
||||
// The peer may open a peer-initiated stream...
|
||||
_, err = m.GetOrOpenReceiveStream(firstIncomingBidiStream)
|
||||
require.NoError(t, err)
|
||||
_, err = m.GetOrOpenReceiveStream(firstIncomingUniStream)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ... but not a stream that is initiated by us.
|
||||
_, err = m.GetOrOpenSendStream(firstOutgoingBidiStream + 8)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream+8),
|
||||
})
|
||||
_, err = m.GetOrOpenSendStream(firstOutgoingUniStream + 8)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", firstOutgoingUniStream+8),
|
||||
})
|
||||
// accepting streams is triggered by receiving a frame referencing this stream
|
||||
require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingBidiStream}, time.Now()))
|
||||
require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingUniStream}, time.Now()))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
@@ -141,7 +127,9 @@ func testStreamsMapDeletingStreams(t *testing.T,
|
||||
mockSender,
|
||||
func(frame wire.Frame) { frameQueue = append(frameQueue, frame) },
|
||||
func(protocol.StreamID) flowcontrol.StreamFlowController {
|
||||
return mocks.NewMockStreamFlowController(mockCtrl)
|
||||
fc := mocks.NewMockStreamFlowController(mockCtrl)
|
||||
fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
return fc
|
||||
},
|
||||
100,
|
||||
100,
|
||||
@@ -155,37 +143,25 @@ func testStreamsMapDeletingStreams(t *testing.T,
|
||||
_, err := m.OpenStream()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, m.DeleteStream(firstOutgoingBidiStream))
|
||||
sstr, err := m.GetOrOpenSendStream(firstOutgoingBidiStream)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, sstr)
|
||||
require.ErrorContains(t,
|
||||
m.DeleteStream(firstOutgoingBidiStream+400),
|
||||
fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingBidiStream+400),
|
||||
)
|
||||
err = m.DeleteStream(firstOutgoingBidiStream + 400)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingBidiStream+400))
|
||||
|
||||
_, err = m.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, m.DeleteStream(firstOutgoingUniStream))
|
||||
sstr, err = m.GetOrOpenSendStream(firstOutgoingUniStream)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, sstr)
|
||||
require.ErrorContains(t,
|
||||
m.DeleteStream(firstOutgoingUniStream+400),
|
||||
fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingUniStream+400),
|
||||
)
|
||||
err = m.DeleteStream(firstOutgoingUniStream + 400)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingUniStream+400))
|
||||
|
||||
require.Empty(t, frameQueue)
|
||||
// deleting incoming bidirectional streams
|
||||
_, err = m.GetOrOpenReceiveStream(firstIncomingBidiStream)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingBidiStream}, time.Now()))
|
||||
require.NoError(t, m.DeleteStream(firstIncomingBidiStream))
|
||||
sstr, err = m.GetOrOpenSendStream(firstIncomingBidiStream)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, sstr)
|
||||
require.ErrorContains(t,
|
||||
m.DeleteStream(firstIncomingBidiStream+400),
|
||||
fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingBidiStream+400),
|
||||
)
|
||||
err = m.DeleteStream(firstIncomingBidiStream + 400)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingBidiStream+400))
|
||||
|
||||
// the MAX_STREAMS frame is only queued once the stream is accepted
|
||||
require.Empty(t, frameQueue)
|
||||
_, err = m.AcceptStream(context.Background())
|
||||
@@ -200,16 +176,12 @@ func testStreamsMapDeletingStreams(t *testing.T,
|
||||
frameQueue = frameQueue[:0]
|
||||
|
||||
// deleting incoming unidirectional streams
|
||||
_, err = m.GetOrOpenReceiveStream(firstIncomingUniStream)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingUniStream}, time.Now()))
|
||||
require.NoError(t, m.DeleteStream(firstIncomingUniStream))
|
||||
rstr, err := m.GetOrOpenReceiveStream(firstIncomingUniStream)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, rstr)
|
||||
require.ErrorContains(t,
|
||||
m.DeleteStream(firstIncomingUniStream+400),
|
||||
fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingUniStream+400),
|
||||
)
|
||||
err = m.DeleteStream(firstIncomingUniStream + 400)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingUniStream+400))
|
||||
|
||||
// the MAX_STREAMS frame is only queued once the stream is accepted
|
||||
require.Empty(t, frameQueue)
|
||||
_, err = m.AcceptUniStream(context.Background())
|
||||
@@ -254,6 +226,7 @@ func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective)
|
||||
// increase via transport parameters
|
||||
_, err := m.OpenStream()
|
||||
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
||||
require.ErrorContains(t, err, "too many open streams")
|
||||
m.UpdateLimits(&wire.TransportParameters{MaxBidiStreamNum: 1})
|
||||
_, err = m.OpenStream()
|
||||
require.NoError(t, err)
|
||||
@@ -293,6 +266,239 @@ func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective)
|
||||
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
||||
}
|
||||
|
||||
func TestStreamsMapHandleReceiveStreamFrames(t *testing.T) {
|
||||
for _, pers := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} {
|
||||
t.Run(pers.String(), func(t *testing.T) {
|
||||
t.Run("STREAM frame", func(t *testing.T) {
|
||||
testStreamsMapHandleReceiveStreamFrames(t,
|
||||
pers,
|
||||
func(m *streamsMap, id protocol.StreamID) error {
|
||||
return m.HandleStreamFrame(&wire.StreamFrame{StreamID: id}, time.Now())
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("STREAM_DATA_BLOCKED frame", func(t *testing.T) {
|
||||
testStreamsMapHandleReceiveStreamFrames(t,
|
||||
pers,
|
||||
func(m *streamsMap, id protocol.StreamID) error {
|
||||
return m.HandleStreamDataBlockedFrame(&wire.StreamDataBlockedFrame{StreamID: id})
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("RESET_STREAM frame", func(t *testing.T) {
|
||||
testStreamsMapHandleReceiveStreamFrames(t,
|
||||
pers,
|
||||
func(m *streamsMap, id protocol.StreamID) error {
|
||||
return m.HandleResetStreamFrame(&wire.ResetStreamFrame{StreamID: id}, time.Now())
|
||||
},
|
||||
)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testStreamsMapHandleReceiveStreamFrames(t *testing.T, pers protocol.Perspective, handleFrame func(*streamsMap, protocol.StreamID) error) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
mockSender := NewMockStreamSender(mockCtrl)
|
||||
var streamsCreated []protocol.StreamID
|
||||
m := newStreamsMap(
|
||||
context.Background(),
|
||||
mockSender,
|
||||
func(frame wire.Frame) {},
|
||||
func(id protocol.StreamID) flowcontrol.StreamFlowController {
|
||||
streamsCreated = append(streamsCreated, id)
|
||||
fc := mocks.NewMockStreamFlowController(mockCtrl)
|
||||
fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
fc.EXPECT().Abandon().AnyTimes()
|
||||
return fc
|
||||
},
|
||||
100,
|
||||
100,
|
||||
pers,
|
||||
)
|
||||
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: protocol.MaxStreamCount})
|
||||
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount})
|
||||
|
||||
var firstOutgoingUniStream, firstOutgoingBidiStream, firstIncomingUniStream, firstIncomingBidiStream protocol.StreamID
|
||||
if pers == protocol.PerspectiveClient {
|
||||
firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamClient
|
||||
firstOutgoingUniStream = protocol.FirstOutgoingUniStreamClient
|
||||
firstIncomingUniStream = protocol.FirstIncomingUniStreamClient
|
||||
firstIncomingBidiStream = protocol.FirstIncomingBidiStreamClient
|
||||
} else {
|
||||
firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamServer
|
||||
firstOutgoingUniStream = protocol.FirstOutgoingUniStreamServer
|
||||
firstIncomingUniStream = protocol.FirstIncomingUniStreamServer
|
||||
firstIncomingBidiStream = protocol.FirstIncomingBidiStreamServer
|
||||
}
|
||||
|
||||
// 1. The peer can't open a unidirectional send stream...
|
||||
err := handleFrame(m, firstOutgoingUniStream)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, fmt.Sprintf("invalid frame for receive stream %d", firstOutgoingUniStream))
|
||||
require.Empty(t, streamsCreated)
|
||||
// ... and a STREAM frame for a unidirectional send stream is invalid even if the stream is open.
|
||||
_, err = m.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
err = handleFrame(m, firstOutgoingUniStream)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, fmt.Sprintf("invalid frame for receive stream %d", firstOutgoingUniStream))
|
||||
streamsCreated = streamsCreated[:0]
|
||||
|
||||
// 2. The peer can't open a bidirectional stream initiated by us...
|
||||
err = handleFrame(m, firstOutgoingBidiStream)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream))
|
||||
require.Empty(t, streamsCreated)
|
||||
// ... but it's valid once we have opened the stream.
|
||||
_, err = m.OpenStream()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, handleFrame(m, firstOutgoingBidiStream))
|
||||
streamsCreated = streamsCreated[:0]
|
||||
// Delayed frames for deleted streams are absorbed.
|
||||
require.NoError(t, m.DeleteStream(firstOutgoingBidiStream))
|
||||
require.NoError(t, handleFrame(m, firstOutgoingBidiStream))
|
||||
require.Empty(t, streamsCreated)
|
||||
|
||||
// 3. The peer can send STREAM frames for unidirectional receive streams,
|
||||
// as long as they're below the stream limit.
|
||||
require.ErrorIs(t,
|
||||
handleFrame(m, firstIncomingUniStream+400),
|
||||
&qerr.TransportError{ErrorCode: qerr.StreamLimitError},
|
||||
)
|
||||
require.Empty(t, streamsCreated)
|
||||
require.NoError(t, handleFrame(m, firstIncomingUniStream))
|
||||
require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingUniStream})
|
||||
streamsCreated = streamsCreated[:0]
|
||||
// Delayed frames for deleted streams are absorbed.
|
||||
require.NoError(t, m.DeleteStream(firstIncomingUniStream))
|
||||
require.NoError(t, handleFrame(m, firstIncomingUniStream))
|
||||
require.Empty(t, streamsCreated)
|
||||
|
||||
// 4. The peer can send STREAM frames for bidirectional receive streams,
|
||||
// as long as they're below the stream limit.
|
||||
require.ErrorIs(t,
|
||||
handleFrame(m, firstIncomingBidiStream+400),
|
||||
&qerr.TransportError{ErrorCode: qerr.StreamLimitError},
|
||||
)
|
||||
require.Empty(t, streamsCreated)
|
||||
require.NoError(t, handleFrame(m, firstIncomingBidiStream))
|
||||
require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingBidiStream})
|
||||
}
|
||||
|
||||
func TestStreamsMapHandleSendStreamFrames(t *testing.T) {
|
||||
for _, pers := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} {
|
||||
t.Run(pers.String(), func(t *testing.T) {
|
||||
t.Run("STOP_SENDING frame", func(t *testing.T) {
|
||||
testStreamsMapHandleSendStreamFrames(t,
|
||||
pers,
|
||||
func(m *streamsMap, id protocol.StreamID) error {
|
||||
return m.HandleStopSendingFrame(&wire.StopSendingFrame{StreamID: id})
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("MAX_STREAM_DATA frame", func(t *testing.T) {
|
||||
testStreamsMapHandleSendStreamFrames(t,
|
||||
pers,
|
||||
func(m *streamsMap, id protocol.StreamID) error {
|
||||
return m.HandleMaxStreamDataFrame(&wire.MaxStreamDataFrame{StreamID: id, MaximumStreamData: 1000})
|
||||
},
|
||||
)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testStreamsMapHandleSendStreamFrames(t *testing.T, pers protocol.Perspective, handleFrame func(m *streamsMap, id protocol.StreamID) error) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
mockSender := NewMockStreamSender(mockCtrl)
|
||||
mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
var streamsCreated []protocol.StreamID
|
||||
m := newStreamsMap(
|
||||
context.Background(),
|
||||
mockSender,
|
||||
func(frame wire.Frame) {},
|
||||
func(id protocol.StreamID) flowcontrol.StreamFlowController {
|
||||
streamsCreated = append(streamsCreated, id)
|
||||
fc := mocks.NewMockStreamFlowController(mockCtrl)
|
||||
fc.EXPECT().UpdateSendWindow(gomock.Any()).AnyTimes()
|
||||
return fc
|
||||
},
|
||||
100,
|
||||
100,
|
||||
pers,
|
||||
)
|
||||
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: protocol.MaxStreamCount})
|
||||
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount})
|
||||
|
||||
var firstOutgoingUniStream, firstOutgoingBidiStream, firstIncomingUniStream, firstIncomingBidiStream protocol.StreamID
|
||||
if pers == protocol.PerspectiveClient {
|
||||
firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamClient
|
||||
firstOutgoingUniStream = protocol.FirstOutgoingUniStreamClient
|
||||
firstIncomingUniStream = protocol.FirstIncomingUniStreamClient
|
||||
firstIncomingBidiStream = protocol.FirstIncomingBidiStreamClient
|
||||
} else {
|
||||
firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamServer
|
||||
firstOutgoingUniStream = protocol.FirstOutgoingUniStreamServer
|
||||
firstIncomingUniStream = protocol.FirstIncomingUniStreamServer
|
||||
firstIncomingBidiStream = protocol.FirstIncomingBidiStreamServer
|
||||
}
|
||||
|
||||
// 1. The peer can't open a unidirectional send stream...
|
||||
err := handleFrame(m, firstOutgoingUniStream)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingUniStream))
|
||||
require.Empty(t, streamsCreated)
|
||||
// ... but once we have opened the stream, it's valid.
|
||||
_, err = m.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, handleFrame(m, firstOutgoingUniStream))
|
||||
streamsCreated = streamsCreated[:0]
|
||||
// Delayed frames for deleted streams are absorbed.
|
||||
require.NoError(t, m.DeleteStream(firstOutgoingUniStream))
|
||||
require.NoError(t, handleFrame(m, firstOutgoingUniStream))
|
||||
require.Empty(t, streamsCreated)
|
||||
|
||||
// 2. The peer can't open a bidirectional stream initiated by us...
|
||||
err = handleFrame(m, firstOutgoingBidiStream)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream))
|
||||
require.Empty(t, streamsCreated)
|
||||
// ... but once we have opened the stream, it's valid.
|
||||
_, err = m.OpenStream()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, handleFrame(m, firstOutgoingBidiStream))
|
||||
streamsCreated = streamsCreated[:0]
|
||||
// Delayed frames for deleted streams are absorbed.
|
||||
require.NoError(t, m.DeleteStream(firstOutgoingBidiStream))
|
||||
require.NoError(t, handleFrame(m, firstOutgoingBidiStream))
|
||||
require.Empty(t, streamsCreated)
|
||||
|
||||
// 3. The peer can't send STOP_SENDING frames for unidirectional send streams
|
||||
err = handleFrame(m, firstIncomingUniStream)
|
||||
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
||||
require.ErrorContains(t, err, fmt.Sprintf("invalid frame for send stream %d", firstIncomingUniStream))
|
||||
require.Empty(t, streamsCreated)
|
||||
|
||||
// 4. The peer can send STOP_SENDING frames for bidirectional receive streams iniated by itself,
|
||||
// as long as they're below the stream limit.
|
||||
require.ErrorIs(t,
|
||||
handleFrame(m, firstIncomingBidiStream+400),
|
||||
&qerr.TransportError{ErrorCode: qerr.StreamLimitError},
|
||||
)
|
||||
require.Empty(t, streamsCreated)
|
||||
require.NoError(t, handleFrame(m, firstIncomingBidiStream))
|
||||
require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingBidiStream})
|
||||
streamsCreated = streamsCreated[:0]
|
||||
// Delayed frames for deleted streams are absorbed.
|
||||
require.NoError(t, m.DeleteStream(firstIncomingBidiStream))
|
||||
require.NoError(t, handleFrame(m, firstIncomingBidiStream))
|
||||
require.Empty(t, streamsCreated)
|
||||
}
|
||||
|
||||
func TestStreamsMapClosing(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
mockSender := NewMockStreamSender(mockCtrl)
|
||||
@@ -366,7 +572,9 @@ func TestStreamsMap0RTTRejection(t *testing.T) {
|
||||
mockSender,
|
||||
func(wire.Frame) {},
|
||||
func(protocol.StreamID) flowcontrol.StreamFlowController {
|
||||
return mocks.NewMockStreamFlowController(mockCtrl)
|
||||
fc := mocks.NewMockStreamFlowController(mockCtrl)
|
||||
fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
return fc
|
||||
},
|
||||
1,
|
||||
1,
|
||||
@@ -380,10 +588,11 @@ func TestStreamsMap0RTTRejection(t *testing.T) {
|
||||
require.ErrorIs(t, err, Err0RTTRejected)
|
||||
_, err = m.AcceptStream(context.Background())
|
||||
require.ErrorIs(t, err, Err0RTTRejected)
|
||||
_, err = m.AcceptUniStream(context.Background())
|
||||
require.ErrorIs(t, err, Err0RTTRejected)
|
||||
|
||||
// make sure that we can still get new streams, as the server might be sending us data
|
||||
str, err := m.GetOrOpenReceiveStream(3)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, str)
|
||||
require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: 3}, time.Now()))
|
||||
|
||||
// now switch to using the new streams map
|
||||
m.UseResetMaps()
|
||||
|
||||
Reference in New Issue
Block a user