diff --git a/connection.go b/connection.go index 0ed4e1f0..62f02ddc 100644 --- a/connection.go +++ b/connection.go @@ -30,23 +30,6 @@ type unpacker interface { UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) } -type streamManager interface { - GetOrOpenSendStream(protocol.StreamID) (*SendStream, error) - GetOrOpenReceiveStream(protocol.StreamID) (*ReceiveStream, error) - OpenStream() (*Stream, error) - OpenUniStream() (*SendStream, error) - OpenStreamSync(context.Context) (*Stream, error) - OpenUniStreamSync(context.Context) (*SendStream, error) - AcceptStream(context.Context) (*Stream, error) - AcceptUniStream(context.Context) (*ReceiveStream, error) - DeleteStream(protocol.StreamID) error - UpdateLimits(*wire.TransportParameters) - HandleMaxStreamsFrame(*wire.MaxStreamsFrame) - CloseWithError(error) - ResetFor0RTT() - UseResetMaps() -} - type cryptoStreamHandler interface { StartHandshake(context.Context) error ChangeConnectionID(protocol.ConnectionID) @@ -133,7 +116,7 @@ type connection struct { largestRcvdAppData protocol.PacketNumber pathManagerOutgoing atomic.Pointer[pathManagerOutgoing] - streamsMap streamManager + streamsMap *streamsMap connIDManager *connIDManager connIDGenerator *connIDGenerator @@ -1498,25 +1481,25 @@ func (s *connection) handleFrame( case *wire.CryptoFrame: err = s.handleCryptoFrame(frame, encLevel, rcvTime) case *wire.StreamFrame: - err = s.handleStreamFrame(frame, rcvTime) + err = s.streamsMap.HandleStreamFrame(frame, rcvTime) case *wire.AckFrame: err = s.handleAckFrame(frame, encLevel, rcvTime) case *wire.ConnectionCloseFrame: err = s.handleConnectionCloseFrame(frame) case *wire.ResetStreamFrame: - err = s.handleResetStreamFrame(frame, rcvTime) + err = s.streamsMap.HandleResetStreamFrame(frame, rcvTime) case *wire.MaxDataFrame: - s.handleMaxDataFrame(frame) + s.connFlowController.UpdateSendWindow(frame.MaximumData) case *wire.MaxStreamDataFrame: - err = s.handleMaxStreamDataFrame(frame) + err = s.streamsMap.HandleMaxStreamDataFrame(frame) case *wire.MaxStreamsFrame: - s.handleMaxStreamsFrame(frame) + s.streamsMap.HandleMaxStreamsFrame(frame) case *wire.DataBlockedFrame: case *wire.StreamDataBlockedFrame: - err = s.handleStreamDataBlockedFrame(frame) + err = s.streamsMap.HandleStreamDataBlockedFrame(frame) case *wire.StreamsBlockedFrame: case *wire.StopSendingFrame: - err = s.handleStopSendingFrame(frame) + err = s.streamsMap.HandleStopSendingFrame(frame) case *wire.PingFrame: case *wire.PathChallengeFrame: s.handlePathChallengeFrame(frame) @@ -1526,9 +1509,9 @@ func (s *connection) handleFrame( case *wire.NewTokenFrame: err = s.handleNewTokenFrame(frame) case *wire.NewConnectionIDFrame: - err = s.handleNewConnectionIDFrame(frame) + err = s.connIDManager.Add(frame) case *wire.RetireConnectionIDFrame: - err = s.handleRetireConnectionIDFrame(rcvTime, frame, destConnID) + err = s.connIDGenerator.Retire(frame.SequenceNumber, destConnID, rcvTime.Add(3*s.rttStats.PTO(false))) case *wire.HandshakeDoneFrame: err = s.handleHandshakeDoneFrame(rcvTime) case *wire.DatagramFrame: @@ -1625,70 +1608,6 @@ func (s *connection) handleHandshakeEvents(now time.Time) error { } } -func (s *connection) handleStreamFrame(frame *wire.StreamFrame, rcvTime time.Time) error { - str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) - if err != nil { - return err - } - if str == nil { // stream was already closed and garbage collected - return nil - } - return str.handleStreamFrame(frame, rcvTime) -} - -func (s *connection) handleMaxDataFrame(frame *wire.MaxDataFrame) { - s.connFlowController.UpdateSendWindow(frame.MaximumData) -} - -func (s *connection) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { - str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) - if err != nil { - return err - } - if str == nil { - // stream is closed and already garbage collected - return nil - } - str.updateSendWindow(frame.MaximumStreamData) - return nil -} - -func (s *connection) handleStreamDataBlockedFrame(frame *wire.StreamDataBlockedFrame) error { - // We don't need to do anything in response to a STREAM_DATA_BLOCKED frame, - // but we need to make sure that the stream ID is valid. - _, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) - return err -} - -func (s *connection) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) { - s.streamsMap.HandleMaxStreamsFrame(frame) -} - -func (s *connection) handleResetStreamFrame(frame *wire.ResetStreamFrame, rcvTime time.Time) error { - str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) - if err != nil { - return err - } - if str == nil { - // stream is closed and already garbage collected - return nil - } - return str.handleResetStreamFrame(frame, rcvTime) -} - -func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error { - str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) - if err != nil { - return err - } - if str == nil { - // stream is closed and already garbage collected - return nil - } - str.handleStopSendingFrame(frame) - return nil -} - func (s *connection) handlePathChallengeFrame(f *wire.PathChallengeFrame) { if s.perspective == protocol.PerspectiveClient { s.queueControlFrame(&wire.PathResponseFrame{Data: f.Data}) @@ -1743,14 +1662,6 @@ func (s *connection) handleNewTokenFrame(frame *wire.NewTokenFrame) error { return nil } -func (s *connection) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) error { - return s.connIDManager.Add(f) -} - -func (s *connection) handleRetireConnectionIDFrame(now time.Time, f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error { - return s.connIDGenerator.Retire(f.SequenceNumber, destConnID, now.Add(3*s.rttStats.PTO(false))) -} - func (s *connection) handleHandshakeDoneFrame(rcvTime time.Time) error { if s.perspective == protocol.PerspectiveServer { return &qerr.TransportError{ diff --git a/connection_test.go b/connection_test.go index 063ec492..936671ba 100644 --- a/connection_test.go +++ b/connection_test.go @@ -35,10 +35,6 @@ func connectionOptCryptoSetup(cs *mocks.MockCryptoSetup) testConnectionOpt { return func(conn *connection) { conn.cryptoStreamHandler = cs } } -func connectionOptStreamManager(sm *MockStreamManager) testConnectionOpt { - return func(conn *connection) { conn.streamsMap = sm } -} - func connectionOptConnFlowController(cfc flowcontrol.ConnectionFlowController) testConnectionOpt { return func(conn *connection) { conn.connFlowController = cfc } } @@ -208,145 +204,28 @@ func newClientTestConnection( } } -func TestConnectionHandleReceiveStreamFrames(t *testing.T) { - const streamID protocol.StreamID = 5 - now := time.Now() +func TestConnectionHandleStreamRelatedFrames(t *testing.T) { + const id protocol.StreamID = 5 connID := protocol.ConnectionID{} - f := &wire.StreamFrame{StreamID: streamID, Data: []byte("foobar")} - rsf := &wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 42, FinalSize: 1337} - sdbf := &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: 1337} - t.Run("for existing and new streams", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - streamsMap := NewMockStreamManager(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap)) - mockSender := NewMockStreamSender(mockCtrl) - mockSender.EXPECT().onHasStreamData(streamID, gomock.Any()).AnyTimes() - mockFC := mocks.NewMockStreamFlowController(mockCtrl) - str := newReceiveStream(streamID, mockSender, mockFC) - // STREAM frame - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(len(f.Data)), false, gomock.Any()) - streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil) - _, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) - // RESET_STREAM frame - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1337), true, gomock.Any()) - mockFC.EXPECT().Abandon() - streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil) - _, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) - // STREAM_DATA_BLOCKED frames are not passed to the stream - streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil) - _, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) - }) + tests := []struct { + name string + frame wire.Frame + }{ + {name: "STREAM", frame: &wire.StreamFrame{StreamID: id, Data: []byte("foobar")}}, + {name: "RESET_STREAM", frame: &wire.ResetStreamFrame{StreamID: id, ErrorCode: 42, FinalSize: 1337}}, + {name: "STOP_SENDING", frame: &wire.StopSendingFrame{StreamID: id, ErrorCode: 42}}, + {name: "MAX_STREAM_DATA", frame: &wire.MaxStreamDataFrame{StreamID: id, MaximumStreamData: 1337}}, + {name: "STREAM_DATA_BLOCKED", frame: &wire.StreamDataBlockedFrame{StreamID: id, MaximumStreamData: 42}}, + } - t.Run("for closed streams", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - streamsMap := NewMockStreamManager(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap)) - // STREAM frame - streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil) - _, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) - // RESET_STREAM frame - streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil) - _, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) - // STREAM_DATA_BLOCKED frames are not passed to the stream - streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil) - _, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) - }) - - t.Run("for invalid streams", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - streamsMap := NewMockStreamManager(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap)) - // STREAM frame - streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, assert.AnError) - _, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now) - require.ErrorIs(t, err, assert.AnError) - // RESET_STREAM frame - streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, assert.AnError) - _, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now) - require.ErrorIs(t, err, assert.AnError) - // STREAM_DATA_BLOCKED frames are not passed to the stream - streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, assert.AnError) - _, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now) - require.ErrorIs(t, err, assert.AnError) - }) -} - -func TestConnectionHandleSendStreamFrames(t *testing.T) { - const streamID protocol.StreamID = 3 - now := time.Now() - connID := protocol.ConnectionID{} - ss := &wire.StopSendingFrame{StreamID: streamID, ErrorCode: 42} - msd := &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337} - - t.Run("for existing and new streams", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - streamsMap := NewMockStreamManager(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap)) - mockSender := NewMockStreamSender(mockCtrl) - mockSender.EXPECT().onHasStreamControlFrame(streamID, gomock.Any()).AnyTimes() - mockFC := mocks.NewMockStreamFlowController(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) - // STOP_SENDING frame - streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil) - _, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) - // MAX_STREAM_DATA frame - mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(1337)) - streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil) - _, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) - }) - - t.Run("for closed streams", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - streamsMap := NewMockStreamManager(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap)) - // STOP_SENDING frame - streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil) - _, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) - // MAX_STREAM_DATA frame - streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil) - _, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) - }) - - t.Run("for invalid streams", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - streamsMap := NewMockStreamManager(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap)) - // STOP_SENDING frame - streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, assert.AnError) - _, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now) - require.ErrorIs(t, err, assert.AnError) - // MAX_STREAM_DATA frame - streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, assert.AnError) - _, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now) - require.ErrorIs(t, err, assert.AnError) - }) -} - -func TestConnectionHandleStreamNumFrames(t *testing.T) { - mockCtrl := gomock.NewController(t) - streamsMap := NewMockStreamManager(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap)) - now := time.Now() - connID := protocol.ConnectionID{} - // MAX_STREAMS frame - msf := &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10} - streamsMap.EXPECT().HandleMaxStreamsFrame(msf) - _, err := tc.conn.handleFrame(msf, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) - // STREAMS_BLOCKED frame - _, err = tc.conn.handleFrame(&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 1}, protocol.Encryption1RTT, connID, now) - require.NoError(t, err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tc := newServerTestConnection(t, gomock.NewController(t), nil, false) + _, err := tc.conn.handleFrame(test.frame, protocol.Encryption1RTT, connID, time.Now()) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + }) + } } func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) { @@ -365,60 +244,6 @@ func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) { require.NoError(t, err) } -func TestConnectionOpenStreams(t *testing.T) { - mockCtrl := gomock.NewController(t) - streamsMap := NewMockStreamManager(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap)) - - // using OpenStream - str1 := &Stream{} - streamsMap.EXPECT().OpenStream().Return(str1, nil) - str, err := tc.conn.OpenStream() - require.NoError(t, err) - require.Equal(t, str1, str) - - // using OpenStreamSync - streamsMap.EXPECT().OpenStreamSync(context.Background()).Return(str1, nil) - str, err = tc.conn.OpenStreamSync(context.Background()) - require.NoError(t, err) - require.Equal(t, str1, str) - - // using OpenUniStream - str2 := &SendStream{} - streamsMap.EXPECT().OpenUniStream().Return(str2, nil) - ustr, err := tc.conn.OpenUniStream() - require.NoError(t, err) - require.Equal(t, str2, ustr) - - // using OpenUniStreamSync - streamsMap.EXPECT().OpenUniStreamSync(context.Background()).Return(str2, nil) - ustr, err = tc.conn.OpenUniStreamSync(context.Background()) - require.NoError(t, err) - require.Equal(t, str2, ustr) -} - -func TestConnectionAcceptStreams(t *testing.T) { - mockCtrl := gomock.NewController(t) - streamsMap := NewMockStreamManager(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap)) - - // bidirectional streams - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - str1 := &Stream{} - streamsMap.EXPECT().AcceptStream(ctx).Return(str1, nil) - str, err := tc.conn.AcceptStream(ctx) - require.NoError(t, err) - require.Equal(t, str1, str) - - // unidirectional streams - str2 := &ReceiveStream{} - streamsMap.EXPECT().AcceptUniStream(ctx).Return(str2, nil) - ustr, err := tc.conn.AcceptUniStream(ctx) - require.NoError(t, err) - require.Equal(t, str2, ustr) -} - func TestConnectionServerInvalidFrames(t *testing.T) { mockCtrl := gomock.NewController(t) tc := newServerTestConnection(t, mockCtrl, nil, false) @@ -943,14 +768,12 @@ func TestConnectionMaxUnprocessedPackets(t *testing.T) { func TestConnectionRemoteClose(t *testing.T) { mockCtrl := gomock.NewController(t) - mockStreamManager := NewMockStreamManager(mockCtrl) tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) unpacker := NewMockUnpacker(mockCtrl) tc := newServerTestConnection(t, mockCtrl, nil, false, - connectionOptStreamManager(mockStreamManager), connectionOptTracer(tr), connectionOptUnpacker(unpacker), ) @@ -964,8 +787,6 @@ func TestConnectionRemoteClose(t *testing.T) { expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, Remote: true} tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) - streamErrChan := make(chan error, 1) - mockStreamManager.EXPECT().CloseWithError(gomock.Any()).Do(func(e error) { streamErrChan <- e }) tracerErrChan := make(chan error, 1) tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { tracerErrChan <- e }) tracer.EXPECT().Close() @@ -988,12 +809,6 @@ func TestConnectionRemoteClose(t *testing.T) { case <-time.After(time.Second): t.Fatal("timeout") } - select { - case err := <-streamErrChan: - require.ErrorIs(t, err, expectedErr) - case <-time.After(time.Second): - t.Fatal("timeout") - } } func TestConnectionIdleTimeoutDuringHandshake(t *testing.T) { @@ -1061,6 +876,8 @@ func TestConnectionTransportParameters(t *testing.T) { ) _, err := tc.conn.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) + _, err = tc.conn.OpenUniStream() + require.ErrorIs(t, err, &StreamLimitReachedError{}) tracer.EXPECT().ReceivedTransportParameters(gomock.Any()) params := &wire.TransportParameters{ MaxIdleTimeout: 90 * time.Second, @@ -1071,11 +888,78 @@ func TestConnectionTransportParameters(t *testing.T) { MaxUDPPayloadSize: protocol.MaxPacketBufferSize, OriginalDestinationConnectionID: tc.destConnID, MaxBidiStreamNum: 1, + MaxUniStreamNum: 1, } require.NoError(t, tc.conn.handleTransportParameters(params)) require.Equal(t, protocol.ByteCount(1337), connFC.SendWindowSize()) _, err = tc.conn.OpenStream() require.NoError(t, err) + _, err = tc.conn.OpenUniStream() + require.NoError(t, err) +} + +func TestConnectionHandleMaxStreamsFrame(t *testing.T) { + mockCtrl := gomock.NewController(t) + connFC := flowcontrol.NewConnectionFlowController(0, 0, nil, &utils.RTTStats{}, utils.DefaultLogger) + tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC)) + tc.conn.handleTransportParameters(&wire.TransportParameters{}) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + uniStreamChan := make(chan error) + go func() { + _, err := tc.conn.OpenUniStreamSync(ctx) + uniStreamChan <- err + }() + bidiStreamChan := make(chan error) + go func() { + _, err := tc.conn.OpenStreamSync(ctx) + bidiStreamChan <- err + }() + + select { + case <-uniStreamChan: + t.Fatal("uni stream should be blocked") + case <-bidiStreamChan: + t.Fatal("bidi stream should be blocked") + case <-time.After(scaleDuration(10 * time.Millisecond)): + } + + // MAX_STREAMS frame for bidirectional stream + _, err := tc.conn.handleFrame( + &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}, + protocol.Encryption1RTT, + protocol.ConnectionID{}, + time.Now(), + ) + require.NoError(t, err) + + select { + case <-uniStreamChan: + t.Fatal("uni stream should be blocked") + case <-time.After(scaleDuration(10 * time.Millisecond)): + } + select { + case err := <-bidiStreamChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + // MAX_STREAMS frame for bidirectional stream + _, err = tc.conn.handleFrame( + &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 10}, + protocol.Encryption1RTT, + protocol.ConnectionID{}, + time.Now(), + ) + require.NoError(t, err) + select { + case err := <-uniStreamChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } } func TestConnectionTransportParameterValidationFailureServer(t *testing.T) { @@ -3075,10 +2959,11 @@ func testConnectionMigration(t *testing.T, enabled bool) { ).AnyTimes() tc.connRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any()) // add a new connection ID, so the path can be probed - require.NoError(t, tc.conn.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{ + _, err = tc.conn.handleFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), - })) + }, protocol.EncryptionInitial, tc.destConnID, time.Now()) + require.NoError(t, err) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go deleted file mode 100644 index 6b7f744e..00000000 --- a/mock_stream_manager_test.go +++ /dev/null @@ -1,573 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go (interfaces: StreamManager) -// -// Generated by this command: -// -// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager -// - -// Package quic is a generated GoMock package. -package quic - -import ( - context "context" - reflect "reflect" - - protocol "github.com/quic-go/quic-go/internal/protocol" - wire "github.com/quic-go/quic-go/internal/wire" - gomock "go.uber.org/mock/gomock" -) - -// MockStreamManager is a mock of StreamManager interface. -type MockStreamManager struct { - ctrl *gomock.Controller - recorder *MockStreamManagerMockRecorder - isgomock struct{} -} - -// MockStreamManagerMockRecorder is the mock recorder for MockStreamManager. -type MockStreamManagerMockRecorder struct { - mock *MockStreamManager -} - -// NewMockStreamManager creates a new mock instance. -func NewMockStreamManager(ctrl *gomock.Controller) *MockStreamManager { - mock := &MockStreamManager{ctrl: ctrl} - mock.recorder = &MockStreamManagerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder { - return m.recorder -} - -// AcceptStream mocks base method. -func (m *MockStreamManager) AcceptStream(arg0 context.Context) (*Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptStream", arg0) - ret0, _ := ret[0].(*Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AcceptStream indicates an expected call of AcceptStream. -func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 any) *MockStreamManagerAcceptStreamCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0) - return &MockStreamManagerAcceptStreamCall{Call: call} -} - -// MockStreamManagerAcceptStreamCall wrap *gomock.Call -type MockStreamManagerAcceptStreamCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerAcceptStreamCall) Return(arg0 *Stream, arg1 error) *MockStreamManagerAcceptStreamCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerAcceptStreamCall) Do(f func(context.Context) (*Stream, error)) *MockStreamManagerAcceptStreamCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerAcceptStreamCall) DoAndReturn(f func(context.Context) (*Stream, error)) *MockStreamManagerAcceptStreamCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// AcceptUniStream mocks base method. -func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (*ReceiveStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptUniStream", arg0) - ret0, _ := ret[0].(*ReceiveStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 any) *MockStreamManagerAcceptUniStreamCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0) - return &MockStreamManagerAcceptUniStreamCall{Call: call} -} - -// MockStreamManagerAcceptUniStreamCall wrap *gomock.Call -type MockStreamManagerAcceptUniStreamCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerAcceptUniStreamCall) Return(arg0 *ReceiveStream, arg1 error) *MockStreamManagerAcceptUniStreamCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerAcceptUniStreamCall) Do(f func(context.Context) (*ReceiveStream, error)) *MockStreamManagerAcceptUniStreamCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerAcceptUniStreamCall) DoAndReturn(f func(context.Context) (*ReceiveStream, error)) *MockStreamManagerAcceptUniStreamCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// CloseWithError mocks base method. -func (m *MockStreamManager) CloseWithError(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CloseWithError", arg0) -} - -// CloseWithError indicates an expected call of CloseWithError. -func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 any) *MockStreamManagerCloseWithErrorCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0) - return &MockStreamManagerCloseWithErrorCall{Call: call} -} - -// MockStreamManagerCloseWithErrorCall wrap *gomock.Call -type MockStreamManagerCloseWithErrorCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerCloseWithErrorCall) Return() *MockStreamManagerCloseWithErrorCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerCloseWithErrorCall) Do(f func(error)) *MockStreamManagerCloseWithErrorCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerCloseWithErrorCall) DoAndReturn(f func(error)) *MockStreamManagerCloseWithErrorCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// DeleteStream mocks base method. -func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteStream", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteStream indicates an expected call of DeleteStream. -func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 any) *MockStreamManagerDeleteStreamCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0) - return &MockStreamManagerDeleteStreamCall{Call: call} -} - -// MockStreamManagerDeleteStreamCall wrap *gomock.Call -type MockStreamManagerDeleteStreamCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerDeleteStreamCall) Return(arg0 error) *MockStreamManagerDeleteStreamCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerDeleteStreamCall) Do(f func(protocol.StreamID) error) *MockStreamManagerDeleteStreamCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerDeleteStreamCall) DoAndReturn(f func(protocol.StreamID) error) *MockStreamManagerDeleteStreamCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// GetOrOpenReceiveStream mocks base method. -func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (*ReceiveStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0) - ret0, _ := ret[0].(*ReceiveStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream. -func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 any) *MockStreamManagerGetOrOpenReceiveStreamCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0) - return &MockStreamManagerGetOrOpenReceiveStreamCall{Call: call} -} - -// MockStreamManagerGetOrOpenReceiveStreamCall wrap *gomock.Call -type MockStreamManagerGetOrOpenReceiveStreamCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerGetOrOpenReceiveStreamCall) Return(arg0 *ReceiveStream, arg1 error) *MockStreamManagerGetOrOpenReceiveStreamCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerGetOrOpenReceiveStreamCall) Do(f func(protocol.StreamID) (*ReceiveStream, error)) *MockStreamManagerGetOrOpenReceiveStreamCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerGetOrOpenReceiveStreamCall) DoAndReturn(f func(protocol.StreamID) (*ReceiveStream, error)) *MockStreamManagerGetOrOpenReceiveStreamCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// GetOrOpenSendStream mocks base method. -func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (*SendStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0) - ret0, _ := ret[0].(*SendStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream. -func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 any) *MockStreamManagerGetOrOpenSendStreamCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0) - return &MockStreamManagerGetOrOpenSendStreamCall{Call: call} -} - -// MockStreamManagerGetOrOpenSendStreamCall wrap *gomock.Call -type MockStreamManagerGetOrOpenSendStreamCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerGetOrOpenSendStreamCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerGetOrOpenSendStreamCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerGetOrOpenSendStreamCall) Do(f func(protocol.StreamID) (*SendStream, error)) *MockStreamManagerGetOrOpenSendStreamCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerGetOrOpenSendStreamCall) DoAndReturn(f func(protocol.StreamID) (*SendStream, error)) *MockStreamManagerGetOrOpenSendStreamCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// HandleMaxStreamsFrame mocks base method. -func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "HandleMaxStreamsFrame", arg0) -} - -// HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame. -func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 any) *MockStreamManagerHandleMaxStreamsFrameCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0) - return &MockStreamManagerHandleMaxStreamsFrameCall{Call: call} -} - -// MockStreamManagerHandleMaxStreamsFrameCall wrap *gomock.Call -type MockStreamManagerHandleMaxStreamsFrameCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerHandleMaxStreamsFrameCall) Return() *MockStreamManagerHandleMaxStreamsFrameCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerHandleMaxStreamsFrameCall) Do(f func(*wire.MaxStreamsFrame)) *MockStreamManagerHandleMaxStreamsFrameCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerHandleMaxStreamsFrameCall) DoAndReturn(f func(*wire.MaxStreamsFrame)) *MockStreamManagerHandleMaxStreamsFrameCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// OpenStream mocks base method. -func (m *MockStreamManager) OpenStream() (*Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStream") - ret0, _ := ret[0].(*Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenStream indicates an expected call of OpenStream. -func (mr *MockStreamManagerMockRecorder) OpenStream() *MockStreamManagerOpenStreamCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamManager)(nil).OpenStream)) - return &MockStreamManagerOpenStreamCall{Call: call} -} - -// MockStreamManagerOpenStreamCall wrap *gomock.Call -type MockStreamManagerOpenStreamCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerOpenStreamCall) Return(arg0 *Stream, arg1 error) *MockStreamManagerOpenStreamCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerOpenStreamCall) Do(f func() (*Stream, error)) *MockStreamManagerOpenStreamCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerOpenStreamCall) DoAndReturn(f func() (*Stream, error)) *MockStreamManagerOpenStreamCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// OpenStreamSync mocks base method. -func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (*Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStreamSync", arg0) - ret0, _ := ret[0].(*Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 any) *MockStreamManagerOpenStreamSyncCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0) - return &MockStreamManagerOpenStreamSyncCall{Call: call} -} - -// MockStreamManagerOpenStreamSyncCall wrap *gomock.Call -type MockStreamManagerOpenStreamSyncCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerOpenStreamSyncCall) Return(arg0 *Stream, arg1 error) *MockStreamManagerOpenStreamSyncCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerOpenStreamSyncCall) Do(f func(context.Context) (*Stream, error)) *MockStreamManagerOpenStreamSyncCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerOpenStreamSyncCall) DoAndReturn(f func(context.Context) (*Stream, error)) *MockStreamManagerOpenStreamSyncCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// OpenUniStream mocks base method. -func (m *MockStreamManager) OpenUniStream() (*SendStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStream") - ret0, _ := ret[0].(*SendStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenUniStream indicates an expected call of OpenUniStream. -func (mr *MockStreamManagerMockRecorder) OpenUniStream() *MockStreamManagerOpenUniStreamCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStream)) - return &MockStreamManagerOpenUniStreamCall{Call: call} -} - -// MockStreamManagerOpenUniStreamCall wrap *gomock.Call -type MockStreamManagerOpenUniStreamCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerOpenUniStreamCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerOpenUniStreamCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerOpenUniStreamCall) Do(f func() (*SendStream, error)) *MockStreamManagerOpenUniStreamCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerOpenUniStreamCall) DoAndReturn(f func() (*SendStream, error)) *MockStreamManagerOpenUniStreamCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// OpenUniStreamSync mocks base method. -func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (*SendStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) - ret0, _ := ret[0].(*SendStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 any) *MockStreamManagerOpenUniStreamSyncCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0) - return &MockStreamManagerOpenUniStreamSyncCall{Call: call} -} - -// MockStreamManagerOpenUniStreamSyncCall wrap *gomock.Call -type MockStreamManagerOpenUniStreamSyncCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerOpenUniStreamSyncCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerOpenUniStreamSyncCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerOpenUniStreamSyncCall) Do(f func(context.Context) (*SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerOpenUniStreamSyncCall) DoAndReturn(f func(context.Context) (*SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// ResetFor0RTT mocks base method. -func (m *MockStreamManager) ResetFor0RTT() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ResetFor0RTT") -} - -// ResetFor0RTT indicates an expected call of ResetFor0RTT. -func (mr *MockStreamManagerMockRecorder) ResetFor0RTT() *MockStreamManagerResetFor0RTTCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetFor0RTT", reflect.TypeOf((*MockStreamManager)(nil).ResetFor0RTT)) - return &MockStreamManagerResetFor0RTTCall{Call: call} -} - -// MockStreamManagerResetFor0RTTCall wrap *gomock.Call -type MockStreamManagerResetFor0RTTCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerResetFor0RTTCall) Return() *MockStreamManagerResetFor0RTTCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerResetFor0RTTCall) Do(f func()) *MockStreamManagerResetFor0RTTCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerResetFor0RTTCall) DoAndReturn(f func()) *MockStreamManagerResetFor0RTTCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// UpdateLimits mocks base method. -func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateLimits", arg0) -} - -// UpdateLimits indicates an expected call of UpdateLimits. -func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 any) *MockStreamManagerUpdateLimitsCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0) - return &MockStreamManagerUpdateLimitsCall{Call: call} -} - -// MockStreamManagerUpdateLimitsCall wrap *gomock.Call -type MockStreamManagerUpdateLimitsCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerUpdateLimitsCall) Return() *MockStreamManagerUpdateLimitsCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerUpdateLimitsCall) Do(f func(*wire.TransportParameters)) *MockStreamManagerUpdateLimitsCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerUpdateLimitsCall) DoAndReturn(f func(*wire.TransportParameters)) *MockStreamManagerUpdateLimitsCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// UseResetMaps mocks base method. -func (m *MockStreamManager) UseResetMaps() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UseResetMaps") -} - -// UseResetMaps indicates an expected call of UseResetMaps. -func (mr *MockStreamManagerMockRecorder) UseResetMaps() *MockStreamManagerUseResetMapsCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UseResetMaps", reflect.TypeOf((*MockStreamManager)(nil).UseResetMaps)) - return &MockStreamManagerUseResetMapsCall{Call: call} -} - -// MockStreamManagerUseResetMapsCall wrap *gomock.Call -type MockStreamManagerUseResetMapsCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamManagerUseResetMapsCall) Return() *MockStreamManagerUseResetMapsCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamManagerUseResetMapsCall) Do(f func()) *MockStreamManagerUseResetMapsCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamManagerUseResetMapsCall) DoAndReturn(f func()) *MockStreamManagerUseResetMapsCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/mockgen.go b/mockgen.go index 4384f5d2..22256d5b 100644 --- a/mockgen.go +++ b/mockgen.go @@ -26,9 +26,6 @@ type FrameSource = frameSource //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource" type AckFrameSource = ackFrameSource -//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager" -type StreamManager = streamManager - //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager" type SealingManager = sealingManager diff --git a/streams_map.go b/streams_map.go index 16340d1a..22c7a97b 100644 --- a/streams_map.go +++ b/streams_map.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "time" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/protocol" @@ -37,8 +38,6 @@ type streamsMap struct { reset bool } -var _ streamManager = &streamsMap{} - func newStreamsMap( ctx context.Context, sender streamSender, @@ -180,90 +179,6 @@ func (m *streamsMap) DeleteStream(id protocol.StreamID) error { panic("") } -func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStream, error) { - str, err := m.getOrOpenReceiveStream(id) - if err != nil { - return nil, &qerr.TransportError{ - ErrorCode: qerr.StreamStateError, - ErrorMessage: err.Error(), - } - } - return str, nil -} - -func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStream, error) { - switch id.Type() { - case protocol.StreamTypeUni: - if id.InitiatedBy() == m.perspective { - // an outgoing unidirectional stream is a send stream, not a receive stream - return nil, fmt.Errorf("peer attempted to open receive stream %d", id) - } - return m.incomingUniStreams.GetOrOpenStream(id) - case protocol.StreamTypeBidi: - if id.InitiatedBy() == m.perspective { - str, err := m.outgoingBidiStreams.GetStream(id) - if str == nil && err == nil { - return nil, nil - } - return str.ReceiveStream, err - } else { - str, err := m.incomingBidiStreams.GetOrOpenStream(id) - if str == nil && err == nil { - return nil, nil - } - return str.ReceiveStream, err - } - } - panic("") -} - -func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (*SendStream, error) { - str, err := m.getOrOpenSendStream(id) - if err != nil { - return nil, &qerr.TransportError{ - ErrorCode: qerr.StreamStateError, - ErrorMessage: err.Error(), - } - } - return str, nil -} - -func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (*SendStream, error) { - switch id.Type() { - case protocol.StreamTypeUni: - if id.InitiatedBy() == m.perspective { - str, err := m.outgoingUniStreams.GetStream(id) - if str == nil && err == nil { - return nil, nil - } - return str, err - } - // an incoming unidirectional stream is a receive stream, not a send stream - return nil, fmt.Errorf("peer attempted to open send stream %d", id) - case protocol.StreamTypeBidi: - if id.InitiatedBy() == m.perspective { - str, err := m.outgoingBidiStreams.GetStream(id) - if str == nil && err == nil { - return nil, nil - } - if err != nil { - return nil, err - } - return str.SendStream, nil - } else { - str, err := m.incomingBidiStreams.GetOrOpenStream(id) - if str == nil && err == nil { - return nil, nil - } - if err != nil { - return nil, err - } - return str.SendStream, nil - } - } - panic("") -} - func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) { switch f.Type { case protocol.StreamTypeUni: @@ -273,6 +188,134 @@ func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) { } } +type sendStreamFrameHandler interface { + updateSendWindow(protocol.ByteCount) + handleStopSendingFrame(*wire.StopSendingFrame) +} + +func (m *streamsMap) getSendStream(id protocol.StreamID) (sendStreamFrameHandler, error) { + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() != m.perspective { + // an outgoing unidirectional stream is a send stream, not a receive stream + return nil, &qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("invalid frame for send stream %d", id), + } + } + str, err := m.outgoingUniStreams.GetStream(id) + if str == nil || err != nil { + return nil, err + } + return str, nil + case protocol.StreamTypeBidi: + if id.InitiatedBy() == m.perspective { + str, err := m.outgoingBidiStreams.GetStream(id) + if str == nil || err != nil { + return nil, err + } + return str, nil + } + str, err := m.incomingBidiStreams.GetOrOpenStream(id) + if str == nil || err != nil { + return nil, err + } + return str, nil + } + panic("unreachable") +} + +func (m *streamsMap) HandleMaxStreamDataFrame(f *wire.MaxStreamDataFrame) error { + str, err := m.getSendStream(f.StreamID) + if err != nil { + return err + } + if str == nil { // stream already deleted + return nil + } + str.updateSendWindow(f.MaximumStreamData) + return nil +} + +func (m *streamsMap) HandleStopSendingFrame(f *wire.StopSendingFrame) error { + str, err := m.getSendStream(f.StreamID) + if err != nil { + return err + } + if str == nil { // stream already deleted + return nil + } + str.handleStopSendingFrame(f) + return nil +} + +type receiveStreamFrameHandler interface { + handleResetStreamFrame(*wire.ResetStreamFrame, time.Time) error + handleStreamFrame(*wire.StreamFrame, time.Time) error +} + +func (m *streamsMap) getReceiveStream(id protocol.StreamID) (receiveStreamFrameHandler, error) { + switch id.Type() { + case protocol.StreamTypeUni: + // an outgoing unidirectional stream is a send stream, not a receive stream + if id.InitiatedBy() == m.perspective { + return nil, &qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("invalid frame for receive stream %d", id), + } + } + str, err := m.incomingUniStreams.GetOrOpenStream(id) + if err != nil || str == nil { + return nil, err + } + return str, nil + case protocol.StreamTypeBidi: + var str *Stream + var err error + if id.InitiatedBy() == m.perspective { + str, err = m.outgoingBidiStreams.GetStream(id) + } else { + str, err = m.incomingBidiStreams.GetOrOpenStream(id) + } + if str == nil || err != nil { + return nil, err + } + return str, nil + } + panic("unreachable") +} + +func (m *streamsMap) HandleStreamDataBlockedFrame(f *wire.StreamDataBlockedFrame) error { + if _, err := m.getReceiveStream(f.StreamID); err != nil { + return err + } + // We don't need to do anything in response to a STREAM_DATA_BLOCKED frame, + // but we need to make sure that the stream ID is valid. + return nil // we don't need to do anything in response to a STREAM_DATA_BLOCKED frame +} + +func (m *streamsMap) HandleResetStreamFrame(f *wire.ResetStreamFrame, rcvTime time.Time) error { + str, err := m.getReceiveStream(f.StreamID) + if err != nil { + return err + } + if str == nil { // stream already deleted + return nil + } + return str.handleResetStreamFrame(f, rcvTime) +} + +func (m *streamsMap) HandleStreamFrame(f *wire.StreamFrame, rcvTime time.Time) error { + str, err := m.getReceiveStream(f.StreamID) + if err != nil { + return err + } + if str == nil { // stream already deleted + return nil + } + return str.handleStreamFrame(f, rcvTime) +} + func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) { m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote) m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective)) diff --git a/streams_map_incoming.go b/streams_map_incoming.go index 28637798..c714eaf1 100644 --- a/streams_map_incoming.go +++ b/streams_map_incoming.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" ) @@ -115,7 +116,10 @@ func (m *incomingStreamsMap[T]) GetOrOpenStream(id protocol.StreamID) (T, error) m.mutex.RLock() if id > m.maxStream { m.mutex.RUnlock() - return *new(T), fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + return *new(T), &qerr.TransportError{ + ErrorCode: qerr.StreamLimitError, + ErrorMessage: fmt.Sprintf("peer tried to open stream %d (current limit: %d)", id, m.maxStream), + } } // if the num is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or @@ -152,7 +156,13 @@ func (m *incomingStreamsMap[T]) DeleteStream(id protocol.StreamID) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.deleteStream(id) + if err := m.deleteStream(id); err != nil { + return &qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: err.Error(), + } + } + return nil } func (m *incomingStreamsMap[T]) deleteStream(id protocol.StreamID) error { diff --git a/streams_map_incoming_test.go b/streams_map_incoming_test.go index c62e9ca6..643a43aa 100644 --- a/streams_map_incoming_test.go +++ b/streams_map_incoming_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/assert" @@ -80,6 +81,7 @@ func testStreamsMapIncomingGettingStreams(t *testing.T, perspective protocol.Per _, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams - 4) require.NoError(t, err) _, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamLimitError}) require.ErrorContains(t, err, "peer tried to open stream") require.Equal(t, maxNumStreams, newStreamCounter) } @@ -151,6 +153,7 @@ func testStreamsMapIncomingDeletingStreams(t *testing.T, perspective protocol.Pe perspective, ) err := m.DeleteStream(firstStream + 1337*4) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, "tried to delete unknown incoming stream") s, err := m.GetOrOpenStream(firstStream + 4) diff --git a/streams_map_outgoing.go b/streams_map_outgoing.go index 248390e3..28b2179f 100644 --- a/streams_map_outgoing.go +++ b/streams_map_outgoing.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" ) @@ -153,7 +154,10 @@ func (m *outgoingStreamsMap[T]) GetStream(id protocol.StreamID) (T, error) { m.mutex.RLock() if id >= m.nextStream { m.mutex.RUnlock() - return *new(T), fmt.Errorf("peer attempted to open stream %d", id) + return *new(T), &qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), + } } s := m.streams[id] m.mutex.RUnlock() @@ -165,7 +169,10 @@ func (m *outgoingStreamsMap[T]) DeleteStream(id protocol.StreamID) error { defer m.mutex.Unlock() if _, ok := m.streams[id]; !ok { - return fmt.Errorf("tried to delete unknown outgoing stream %d", id) + return &qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("tried to delete unknown outgoing stream %d", id), + } } delete(m.streams, id) return nil diff --git a/streams_map_outgoing_test.go b/streams_map_outgoing_test.go index 3009a0af..20330257 100644 --- a/streams_map_outgoing_test.go +++ b/streams_map_outgoing_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/assert" @@ -34,6 +35,7 @@ func testStreamsMapOutgoingOpenAndDelete(t *testing.T, perspective protocol.Pers m.SetMaxStream(protocol.MaxStreamID) _, err := m.GetStream(firstStream) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstStream)) str1, err := m.OpenStream() @@ -54,14 +56,14 @@ func testStreamsMapOutgoingOpenAndDelete(t *testing.T, perspective protocol.Pers err = m.DeleteStream(firstStream + 1337*4) require.Error(t, err) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, "tried to delete unknown outgoing stream") require.NoError(t, m.DeleteStream(firstStream)) // deleting the same stream twice will fail - require.ErrorContains(t, - m.DeleteStream(firstStream), - "tried to delete unknown outgoing stream", - ) + err = m.DeleteStream(firstStream) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + require.ErrorContains(t, err, "tried to delete unknown outgoing stream") // after deleting the stream it's not available anymore str, err := m.GetStream(firstStream) require.NoError(t, err) @@ -98,7 +100,6 @@ func testStreamsMapOutgoingLimits(t *testing.T, perspective protocol.Perspective ctx, cancel := context.WithCancel(context.Background()) cancel() _, err = m.OpenStreamSync(ctx) - require.Error(t, err) require.ErrorIs(t, err, context.Canceled) // OpenStreamSync blocks until the context is canceled... diff --git a/streams_map_test.go b/streams_map_test.go index 80f58dee..07274e8e 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -19,7 +19,7 @@ import ( func TestStreamsMapCreatingStreams(t *testing.T) { t.Run("client", func(t *testing.T) { - testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveClient, + testStreamsMapCreatingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingBidiStreamClient, protocol.FirstOutgoingBidiStreamClient, protocol.FirstIncomingUniStreamClient, @@ -27,7 +27,7 @@ func TestStreamsMapCreatingStreams(t *testing.T) { ) }) t.Run("server", func(t *testing.T) { - testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveServer, + testStreamsMapCreatingStreams(t, protocol.PerspectiveServer, protocol.FirstIncomingBidiStreamServer, protocol.FirstOutgoingBidiStreamServer, protocol.FirstIncomingUniStreamServer, @@ -36,7 +36,7 @@ func TestStreamsMapCreatingStreams(t *testing.T) { }) } -func testStreamsMapCreatingAndDeletingStreams(t *testing.T, +func testStreamsMapCreatingStreams(t *testing.T, perspective protocol.Perspective, firstIncomingBidiStream protocol.StreamID, firstOutgoingBidiStream protocol.StreamID, @@ -50,7 +50,9 @@ func testStreamsMapCreatingAndDeletingStreams(t *testing.T, mockSender, func(wire.Frame) {}, func(protocol.StreamID) flowcontrol.StreamFlowController { - return mocks.NewMockStreamFlowController(mockCtrl) + fc := mocks.NewMockStreamFlowController(mockCtrl) + fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + return fc }, 1, 1, @@ -76,25 +78,9 @@ func testStreamsMapCreatingAndDeletingStreams(t *testing.T, assert.Equal(t, ustr1.StreamID(), firstOutgoingUniStream) assert.Equal(t, ustr2.StreamID(), firstOutgoingUniStream+4) - // accepting streams: - // This function is called when a frame referencing this stream is received. - // The peer may open a peer-initiated stream... - _, err = m.GetOrOpenReceiveStream(firstIncomingBidiStream) - require.NoError(t, err) - _, err = m.GetOrOpenReceiveStream(firstIncomingUniStream) - require.NoError(t, err) - - // ... but not a stream that is initiated by us. - _, err = m.GetOrOpenSendStream(firstOutgoingBidiStream + 8) - require.ErrorIs(t, err, &qerr.TransportError{ - ErrorCode: qerr.StreamStateError, - ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream+8), - }) - _, err = m.GetOrOpenSendStream(firstOutgoingUniStream + 8) - require.ErrorIs(t, err, &qerr.TransportError{ - ErrorCode: qerr.StreamStateError, - ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", firstOutgoingUniStream+8), - }) + // accepting streams is triggered by receiving a frame referencing this stream + require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingBidiStream}, time.Now())) + require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingUniStream}, time.Now())) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -141,7 +127,9 @@ func testStreamsMapDeletingStreams(t *testing.T, mockSender, func(frame wire.Frame) { frameQueue = append(frameQueue, frame) }, func(protocol.StreamID) flowcontrol.StreamFlowController { - return mocks.NewMockStreamFlowController(mockCtrl) + fc := mocks.NewMockStreamFlowController(mockCtrl) + fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + return fc }, 100, 100, @@ -155,37 +143,25 @@ func testStreamsMapDeletingStreams(t *testing.T, _, err := m.OpenStream() require.NoError(t, err) require.NoError(t, m.DeleteStream(firstOutgoingBidiStream)) - sstr, err := m.GetOrOpenSendStream(firstOutgoingBidiStream) - require.NoError(t, err) - require.Nil(t, sstr) - require.ErrorContains(t, - m.DeleteStream(firstOutgoingBidiStream+400), - fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingBidiStream+400), - ) + err = m.DeleteStream(firstOutgoingBidiStream + 400) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingBidiStream+400)) _, err = m.OpenUniStream() require.NoError(t, err) require.NoError(t, m.DeleteStream(firstOutgoingUniStream)) - sstr, err = m.GetOrOpenSendStream(firstOutgoingUniStream) - require.NoError(t, err) - require.Nil(t, sstr) - require.ErrorContains(t, - m.DeleteStream(firstOutgoingUniStream+400), - fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingUniStream+400), - ) + err = m.DeleteStream(firstOutgoingUniStream + 400) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingUniStream+400)) require.Empty(t, frameQueue) // deleting incoming bidirectional streams - _, err = m.GetOrOpenReceiveStream(firstIncomingBidiStream) - require.NoError(t, err) + require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingBidiStream}, time.Now())) require.NoError(t, m.DeleteStream(firstIncomingBidiStream)) - sstr, err = m.GetOrOpenSendStream(firstIncomingBidiStream) - require.NoError(t, err) - require.Nil(t, sstr) - require.ErrorContains(t, - m.DeleteStream(firstIncomingBidiStream+400), - fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingBidiStream+400), - ) + err = m.DeleteStream(firstIncomingBidiStream + 400) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingBidiStream+400)) + // the MAX_STREAMS frame is only queued once the stream is accepted require.Empty(t, frameQueue) _, err = m.AcceptStream(context.Background()) @@ -200,16 +176,12 @@ func testStreamsMapDeletingStreams(t *testing.T, frameQueue = frameQueue[:0] // deleting incoming unidirectional streams - _, err = m.GetOrOpenReceiveStream(firstIncomingUniStream) - require.NoError(t, err) + require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingUniStream}, time.Now())) require.NoError(t, m.DeleteStream(firstIncomingUniStream)) - rstr, err := m.GetOrOpenReceiveStream(firstIncomingUniStream) - require.NoError(t, err) - require.Nil(t, rstr) - require.ErrorContains(t, - m.DeleteStream(firstIncomingUniStream+400), - fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingUniStream+400), - ) + err = m.DeleteStream(firstIncomingUniStream + 400) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingUniStream+400)) + // the MAX_STREAMS frame is only queued once the stream is accepted require.Empty(t, frameQueue) _, err = m.AcceptUniStream(context.Background()) @@ -254,6 +226,7 @@ func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective) // increase via transport parameters _, err := m.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) + require.ErrorContains(t, err, "too many open streams") m.UpdateLimits(&wire.TransportParameters{MaxBidiStreamNum: 1}) _, err = m.OpenStream() require.NoError(t, err) @@ -293,6 +266,239 @@ func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective) require.ErrorIs(t, err, &StreamLimitReachedError{}) } +func TestStreamsMapHandleReceiveStreamFrames(t *testing.T) { + for _, pers := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} { + t.Run(pers.String(), func(t *testing.T) { + t.Run("STREAM frame", func(t *testing.T) { + testStreamsMapHandleReceiveStreamFrames(t, + pers, + func(m *streamsMap, id protocol.StreamID) error { + return m.HandleStreamFrame(&wire.StreamFrame{StreamID: id}, time.Now()) + }, + ) + }) + + t.Run("STREAM_DATA_BLOCKED frame", func(t *testing.T) { + testStreamsMapHandleReceiveStreamFrames(t, + pers, + func(m *streamsMap, id protocol.StreamID) error { + return m.HandleStreamDataBlockedFrame(&wire.StreamDataBlockedFrame{StreamID: id}) + }, + ) + }) + + t.Run("RESET_STREAM frame", func(t *testing.T) { + testStreamsMapHandleReceiveStreamFrames(t, + pers, + func(m *streamsMap, id protocol.StreamID) error { + return m.HandleResetStreamFrame(&wire.ResetStreamFrame{StreamID: id}, time.Now()) + }, + ) + }) + }) + } +} + +func testStreamsMapHandleReceiveStreamFrames(t *testing.T, pers protocol.Perspective, handleFrame func(*streamsMap, protocol.StreamID) error) { + mockCtrl := gomock.NewController(t) + mockSender := NewMockStreamSender(mockCtrl) + var streamsCreated []protocol.StreamID + m := newStreamsMap( + context.Background(), + mockSender, + func(frame wire.Frame) {}, + func(id protocol.StreamID) flowcontrol.StreamFlowController { + streamsCreated = append(streamsCreated, id) + fc := mocks.NewMockStreamFlowController(mockCtrl) + fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + fc.EXPECT().Abandon().AnyTimes() + return fc + }, + 100, + 100, + pers, + ) + m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: protocol.MaxStreamCount}) + m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount}) + + var firstOutgoingUniStream, firstOutgoingBidiStream, firstIncomingUniStream, firstIncomingBidiStream protocol.StreamID + if pers == protocol.PerspectiveClient { + firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamClient + firstOutgoingUniStream = protocol.FirstOutgoingUniStreamClient + firstIncomingUniStream = protocol.FirstIncomingUniStreamClient + firstIncomingBidiStream = protocol.FirstIncomingBidiStreamClient + } else { + firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamServer + firstOutgoingUniStream = protocol.FirstOutgoingUniStreamServer + firstIncomingUniStream = protocol.FirstIncomingUniStreamServer + firstIncomingBidiStream = protocol.FirstIncomingBidiStreamServer + } + + // 1. The peer can't open a unidirectional send stream... + err := handleFrame(m, firstOutgoingUniStream) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + require.ErrorContains(t, err, fmt.Sprintf("invalid frame for receive stream %d", firstOutgoingUniStream)) + require.Empty(t, streamsCreated) + // ... and a STREAM frame for a unidirectional send stream is invalid even if the stream is open. + _, err = m.OpenUniStream() + require.NoError(t, err) + err = handleFrame(m, firstOutgoingUniStream) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + require.ErrorContains(t, err, fmt.Sprintf("invalid frame for receive stream %d", firstOutgoingUniStream)) + streamsCreated = streamsCreated[:0] + + // 2. The peer can't open a bidirectional stream initiated by us... + err = handleFrame(m, firstOutgoingBidiStream) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream)) + require.Empty(t, streamsCreated) + // ... but it's valid once we have opened the stream. + _, err = m.OpenStream() + require.NoError(t, err) + require.NoError(t, handleFrame(m, firstOutgoingBidiStream)) + streamsCreated = streamsCreated[:0] + // Delayed frames for deleted streams are absorbed. + require.NoError(t, m.DeleteStream(firstOutgoingBidiStream)) + require.NoError(t, handleFrame(m, firstOutgoingBidiStream)) + require.Empty(t, streamsCreated) + + // 3. The peer can send STREAM frames for unidirectional receive streams, + // as long as they're below the stream limit. + require.ErrorIs(t, + handleFrame(m, firstIncomingUniStream+400), + &qerr.TransportError{ErrorCode: qerr.StreamLimitError}, + ) + require.Empty(t, streamsCreated) + require.NoError(t, handleFrame(m, firstIncomingUniStream)) + require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingUniStream}) + streamsCreated = streamsCreated[:0] + // Delayed frames for deleted streams are absorbed. + require.NoError(t, m.DeleteStream(firstIncomingUniStream)) + require.NoError(t, handleFrame(m, firstIncomingUniStream)) + require.Empty(t, streamsCreated) + + // 4. The peer can send STREAM frames for bidirectional receive streams, + // as long as they're below the stream limit. + require.ErrorIs(t, + handleFrame(m, firstIncomingBidiStream+400), + &qerr.TransportError{ErrorCode: qerr.StreamLimitError}, + ) + require.Empty(t, streamsCreated) + require.NoError(t, handleFrame(m, firstIncomingBidiStream)) + require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingBidiStream}) +} + +func TestStreamsMapHandleSendStreamFrames(t *testing.T) { + for _, pers := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} { + t.Run(pers.String(), func(t *testing.T) { + t.Run("STOP_SENDING frame", func(t *testing.T) { + testStreamsMapHandleSendStreamFrames(t, + pers, + func(m *streamsMap, id protocol.StreamID) error { + return m.HandleStopSendingFrame(&wire.StopSendingFrame{StreamID: id}) + }, + ) + }) + + t.Run("MAX_STREAM_DATA frame", func(t *testing.T) { + testStreamsMapHandleSendStreamFrames(t, + pers, + func(m *streamsMap, id protocol.StreamID) error { + return m.HandleMaxStreamDataFrame(&wire.MaxStreamDataFrame{StreamID: id, MaximumStreamData: 1000}) + }, + ) + }) + }) + } +} + +func testStreamsMapHandleSendStreamFrames(t *testing.T, pers protocol.Perspective, handleFrame func(m *streamsMap, id protocol.StreamID) error) { + mockCtrl := gomock.NewController(t) + mockSender := NewMockStreamSender(mockCtrl) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()).AnyTimes() + var streamsCreated []protocol.StreamID + m := newStreamsMap( + context.Background(), + mockSender, + func(frame wire.Frame) {}, + func(id protocol.StreamID) flowcontrol.StreamFlowController { + streamsCreated = append(streamsCreated, id) + fc := mocks.NewMockStreamFlowController(mockCtrl) + fc.EXPECT().UpdateSendWindow(gomock.Any()).AnyTimes() + return fc + }, + 100, + 100, + pers, + ) + m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: protocol.MaxStreamCount}) + m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount}) + + var firstOutgoingUniStream, firstOutgoingBidiStream, firstIncomingUniStream, firstIncomingBidiStream protocol.StreamID + if pers == protocol.PerspectiveClient { + firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamClient + firstOutgoingUniStream = protocol.FirstOutgoingUniStreamClient + firstIncomingUniStream = protocol.FirstIncomingUniStreamClient + firstIncomingBidiStream = protocol.FirstIncomingBidiStreamClient + } else { + firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamServer + firstOutgoingUniStream = protocol.FirstOutgoingUniStreamServer + firstIncomingUniStream = protocol.FirstIncomingUniStreamServer + firstIncomingBidiStream = protocol.FirstIncomingBidiStreamServer + } + + // 1. The peer can't open a unidirectional send stream... + err := handleFrame(m, firstOutgoingUniStream) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingUniStream)) + require.Empty(t, streamsCreated) + // ... but once we have opened the stream, it's valid. + _, err = m.OpenUniStream() + require.NoError(t, err) + require.NoError(t, handleFrame(m, firstOutgoingUniStream)) + streamsCreated = streamsCreated[:0] + // Delayed frames for deleted streams are absorbed. + require.NoError(t, m.DeleteStream(firstOutgoingUniStream)) + require.NoError(t, handleFrame(m, firstOutgoingUniStream)) + require.Empty(t, streamsCreated) + + // 2. The peer can't open a bidirectional stream initiated by us... + err = handleFrame(m, firstOutgoingBidiStream) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream)) + require.Empty(t, streamsCreated) + // ... but once we have opened the stream, it's valid. + _, err = m.OpenStream() + require.NoError(t, err) + require.NoError(t, handleFrame(m, firstOutgoingBidiStream)) + streamsCreated = streamsCreated[:0] + // Delayed frames for deleted streams are absorbed. + require.NoError(t, m.DeleteStream(firstOutgoingBidiStream)) + require.NoError(t, handleFrame(m, firstOutgoingBidiStream)) + require.Empty(t, streamsCreated) + + // 3. The peer can't send STOP_SENDING frames for unidirectional send streams + err = handleFrame(m, firstIncomingUniStream) + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) + require.ErrorContains(t, err, fmt.Sprintf("invalid frame for send stream %d", firstIncomingUniStream)) + require.Empty(t, streamsCreated) + + // 4. The peer can send STOP_SENDING frames for bidirectional receive streams iniated by itself, + // as long as they're below the stream limit. + require.ErrorIs(t, + handleFrame(m, firstIncomingBidiStream+400), + &qerr.TransportError{ErrorCode: qerr.StreamLimitError}, + ) + require.Empty(t, streamsCreated) + require.NoError(t, handleFrame(m, firstIncomingBidiStream)) + require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingBidiStream}) + streamsCreated = streamsCreated[:0] + // Delayed frames for deleted streams are absorbed. + require.NoError(t, m.DeleteStream(firstIncomingBidiStream)) + require.NoError(t, handleFrame(m, firstIncomingBidiStream)) + require.Empty(t, streamsCreated) +} + func TestStreamsMapClosing(t *testing.T) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) @@ -366,7 +572,9 @@ func TestStreamsMap0RTTRejection(t *testing.T) { mockSender, func(wire.Frame) {}, func(protocol.StreamID) flowcontrol.StreamFlowController { - return mocks.NewMockStreamFlowController(mockCtrl) + fc := mocks.NewMockStreamFlowController(mockCtrl) + fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + return fc }, 1, 1, @@ -380,10 +588,11 @@ func TestStreamsMap0RTTRejection(t *testing.T) { require.ErrorIs(t, err, Err0RTTRejected) _, err = m.AcceptStream(context.Background()) require.ErrorIs(t, err, Err0RTTRejected) + _, err = m.AcceptUniStream(context.Background()) + require.ErrorIs(t, err, Err0RTTRejected) + // make sure that we can still get new streams, as the server might be sending us data - str, err := m.GetOrOpenReceiveStream(3) - require.NoError(t, err) - require.NotNil(t, str) + require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: 3}, time.Now())) // now switch to using the new streams map m.UseResetMaps()