handle stream-related frame in the streams map (#5212)

* handle stream-related frame in the streams map

* remove stream manager interface and mock
This commit is contained in:
Marten Seemann
2025-06-09 16:00:46 +08:00
committed by GitHub
parent 4f23ac2752
commit 1b07674b19
10 changed files with 529 additions and 1036 deletions

View File

@@ -30,23 +30,6 @@ type unpacker interface {
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
}
type streamManager interface {
GetOrOpenSendStream(protocol.StreamID) (*SendStream, error)
GetOrOpenReceiveStream(protocol.StreamID) (*ReceiveStream, error)
OpenStream() (*Stream, error)
OpenUniStream() (*SendStream, error)
OpenStreamSync(context.Context) (*Stream, error)
OpenUniStreamSync(context.Context) (*SendStream, error)
AcceptStream(context.Context) (*Stream, error)
AcceptUniStream(context.Context) (*ReceiveStream, error)
DeleteStream(protocol.StreamID) error
UpdateLimits(*wire.TransportParameters)
HandleMaxStreamsFrame(*wire.MaxStreamsFrame)
CloseWithError(error)
ResetFor0RTT()
UseResetMaps()
}
type cryptoStreamHandler interface {
StartHandshake(context.Context) error
ChangeConnectionID(protocol.ConnectionID)
@@ -133,7 +116,7 @@ type connection struct {
largestRcvdAppData protocol.PacketNumber
pathManagerOutgoing atomic.Pointer[pathManagerOutgoing]
streamsMap streamManager
streamsMap *streamsMap
connIDManager *connIDManager
connIDGenerator *connIDGenerator
@@ -1498,25 +1481,25 @@ func (s *connection) handleFrame(
case *wire.CryptoFrame:
err = s.handleCryptoFrame(frame, encLevel, rcvTime)
case *wire.StreamFrame:
err = s.handleStreamFrame(frame, rcvTime)
err = s.streamsMap.HandleStreamFrame(frame, rcvTime)
case *wire.AckFrame:
err = s.handleAckFrame(frame, encLevel, rcvTime)
case *wire.ConnectionCloseFrame:
err = s.handleConnectionCloseFrame(frame)
case *wire.ResetStreamFrame:
err = s.handleResetStreamFrame(frame, rcvTime)
err = s.streamsMap.HandleResetStreamFrame(frame, rcvTime)
case *wire.MaxDataFrame:
s.handleMaxDataFrame(frame)
s.connFlowController.UpdateSendWindow(frame.MaximumData)
case *wire.MaxStreamDataFrame:
err = s.handleMaxStreamDataFrame(frame)
err = s.streamsMap.HandleMaxStreamDataFrame(frame)
case *wire.MaxStreamsFrame:
s.handleMaxStreamsFrame(frame)
s.streamsMap.HandleMaxStreamsFrame(frame)
case *wire.DataBlockedFrame:
case *wire.StreamDataBlockedFrame:
err = s.handleStreamDataBlockedFrame(frame)
err = s.streamsMap.HandleStreamDataBlockedFrame(frame)
case *wire.StreamsBlockedFrame:
case *wire.StopSendingFrame:
err = s.handleStopSendingFrame(frame)
err = s.streamsMap.HandleStopSendingFrame(frame)
case *wire.PingFrame:
case *wire.PathChallengeFrame:
s.handlePathChallengeFrame(frame)
@@ -1526,9 +1509,9 @@ func (s *connection) handleFrame(
case *wire.NewTokenFrame:
err = s.handleNewTokenFrame(frame)
case *wire.NewConnectionIDFrame:
err = s.handleNewConnectionIDFrame(frame)
err = s.connIDManager.Add(frame)
case *wire.RetireConnectionIDFrame:
err = s.handleRetireConnectionIDFrame(rcvTime, frame, destConnID)
err = s.connIDGenerator.Retire(frame.SequenceNumber, destConnID, rcvTime.Add(3*s.rttStats.PTO(false)))
case *wire.HandshakeDoneFrame:
err = s.handleHandshakeDoneFrame(rcvTime)
case *wire.DatagramFrame:
@@ -1625,70 +1608,6 @@ func (s *connection) handleHandshakeEvents(now time.Time) error {
}
}
func (s *connection) handleStreamFrame(frame *wire.StreamFrame, rcvTime time.Time) error {
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
if err != nil {
return err
}
if str == nil { // stream was already closed and garbage collected
return nil
}
return str.handleStreamFrame(frame, rcvTime)
}
func (s *connection) handleMaxDataFrame(frame *wire.MaxDataFrame) {
s.connFlowController.UpdateSendWindow(frame.MaximumData)
}
func (s *connection) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error {
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
// stream is closed and already garbage collected
return nil
}
str.updateSendWindow(frame.MaximumStreamData)
return nil
}
func (s *connection) handleStreamDataBlockedFrame(frame *wire.StreamDataBlockedFrame) error {
// We don't need to do anything in response to a STREAM_DATA_BLOCKED frame,
// but we need to make sure that the stream ID is valid.
_, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
return err
}
func (s *connection) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) {
s.streamsMap.HandleMaxStreamsFrame(frame)
}
func (s *connection) handleResetStreamFrame(frame *wire.ResetStreamFrame, rcvTime time.Time) error {
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
// stream is closed and already garbage collected
return nil
}
return str.handleResetStreamFrame(frame, rcvTime)
}
func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error {
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
// stream is closed and already garbage collected
return nil
}
str.handleStopSendingFrame(frame)
return nil
}
func (s *connection) handlePathChallengeFrame(f *wire.PathChallengeFrame) {
if s.perspective == protocol.PerspectiveClient {
s.queueControlFrame(&wire.PathResponseFrame{Data: f.Data})
@@ -1743,14 +1662,6 @@ func (s *connection) handleNewTokenFrame(frame *wire.NewTokenFrame) error {
return nil
}
func (s *connection) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) error {
return s.connIDManager.Add(f)
}
func (s *connection) handleRetireConnectionIDFrame(now time.Time, f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error {
return s.connIDGenerator.Retire(f.SequenceNumber, destConnID, now.Add(3*s.rttStats.PTO(false)))
}
func (s *connection) handleHandshakeDoneFrame(rcvTime time.Time) error {
if s.perspective == protocol.PerspectiveServer {
return &qerr.TransportError{

View File

@@ -35,10 +35,6 @@ func connectionOptCryptoSetup(cs *mocks.MockCryptoSetup) testConnectionOpt {
return func(conn *connection) { conn.cryptoStreamHandler = cs }
}
func connectionOptStreamManager(sm *MockStreamManager) testConnectionOpt {
return func(conn *connection) { conn.streamsMap = sm }
}
func connectionOptConnFlowController(cfc flowcontrol.ConnectionFlowController) testConnectionOpt {
return func(conn *connection) { conn.connFlowController = cfc }
}
@@ -208,145 +204,28 @@ func newClientTestConnection(
}
}
func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
const streamID protocol.StreamID = 5
now := time.Now()
func TestConnectionHandleStreamRelatedFrames(t *testing.T) {
const id protocol.StreamID = 5
connID := protocol.ConnectionID{}
f := &wire.StreamFrame{StreamID: streamID, Data: []byte("foobar")}
rsf := &wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 42, FinalSize: 1337}
sdbf := &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: 1337}
t.Run("for existing and new streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
mockSender := NewMockStreamSender(mockCtrl)
mockSender.EXPECT().onHasStreamData(streamID, gomock.Any()).AnyTimes()
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
str := newReceiveStream(streamID, mockSender, mockFC)
// STREAM frame
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(len(f.Data)), false, gomock.Any())
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// RESET_STREAM frame
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1337), true, gomock.Any())
mockFC.EXPECT().Abandon()
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
})
tests := []struct {
name string
frame wire.Frame
}{
{name: "STREAM", frame: &wire.StreamFrame{StreamID: id, Data: []byte("foobar")}},
{name: "RESET_STREAM", frame: &wire.ResetStreamFrame{StreamID: id, ErrorCode: 42, FinalSize: 1337}},
{name: "STOP_SENDING", frame: &wire.StopSendingFrame{StreamID: id, ErrorCode: 42}},
{name: "MAX_STREAM_DATA", frame: &wire.MaxStreamDataFrame{StreamID: id, MaximumStreamData: 1337}},
{name: "STREAM_DATA_BLOCKED", frame: &wire.StreamDataBlockedFrame{StreamID: id, MaximumStreamData: 42}},
}
t.Run("for closed streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// RESET_STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
})
t.Run("for invalid streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, assert.AnError)
_, err := tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now)
require.ErrorIs(t, err, assert.AnError)
// RESET_STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, assert.AnError)
_, err = tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)
require.ErrorIs(t, err, assert.AnError)
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, assert.AnError)
_, err = tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)
require.ErrorIs(t, err, assert.AnError)
})
}
func TestConnectionHandleSendStreamFrames(t *testing.T) {
const streamID protocol.StreamID = 3
now := time.Now()
connID := protocol.ConnectionID{}
ss := &wire.StopSendingFrame{StreamID: streamID, ErrorCode: 42}
msd := &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337}
t.Run("for existing and new streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
mockSender := NewMockStreamSender(mockCtrl)
mockSender.EXPECT().onHasStreamControlFrame(streamID, gomock.Any()).AnyTimes()
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
str := newSendStream(context.Background(), streamID, mockSender, mockFC)
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// MAX_STREAM_DATA frame
mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(1337))
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
})
t.Run("for closed streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// MAX_STREAM_DATA frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
})
t.Run("for invalid streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, assert.AnError)
_, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
require.ErrorIs(t, err, assert.AnError)
// MAX_STREAM_DATA frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, assert.AnError)
_, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
require.ErrorIs(t, err, assert.AnError)
})
}
func TestConnectionHandleStreamNumFrames(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
now := time.Now()
connID := protocol.ConnectionID{}
// MAX_STREAMS frame
msf := &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}
streamsMap.EXPECT().HandleMaxStreamsFrame(msf)
_, err := tc.conn.handleFrame(msf, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
// STREAMS_BLOCKED frame
_, err = tc.conn.handleFrame(&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 1}, protocol.Encryption1RTT, connID, now)
require.NoError(t, err)
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tc := newServerTestConnection(t, gomock.NewController(t), nil, false)
_, err := tc.conn.handleFrame(test.frame, protocol.Encryption1RTT, connID, time.Now())
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
})
}
}
func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
@@ -365,60 +244,6 @@ func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
require.NoError(t, err)
}
func TestConnectionOpenStreams(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// using OpenStream
str1 := &Stream{}
streamsMap.EXPECT().OpenStream().Return(str1, nil)
str, err := tc.conn.OpenStream()
require.NoError(t, err)
require.Equal(t, str1, str)
// using OpenStreamSync
streamsMap.EXPECT().OpenStreamSync(context.Background()).Return(str1, nil)
str, err = tc.conn.OpenStreamSync(context.Background())
require.NoError(t, err)
require.Equal(t, str1, str)
// using OpenUniStream
str2 := &SendStream{}
streamsMap.EXPECT().OpenUniStream().Return(str2, nil)
ustr, err := tc.conn.OpenUniStream()
require.NoError(t, err)
require.Equal(t, str2, ustr)
// using OpenUniStreamSync
streamsMap.EXPECT().OpenUniStreamSync(context.Background()).Return(str2, nil)
ustr, err = tc.conn.OpenUniStreamSync(context.Background())
require.NoError(t, err)
require.Equal(t, str2, ustr)
}
func TestConnectionAcceptStreams(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// bidirectional streams
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
str1 := &Stream{}
streamsMap.EXPECT().AcceptStream(ctx).Return(str1, nil)
str, err := tc.conn.AcceptStream(ctx)
require.NoError(t, err)
require.Equal(t, str1, str)
// unidirectional streams
str2 := &ReceiveStream{}
streamsMap.EXPECT().AcceptUniStream(ctx).Return(str2, nil)
ustr, err := tc.conn.AcceptUniStream(ctx)
require.NoError(t, err)
require.Equal(t, str2, ustr)
}
func TestConnectionServerInvalidFrames(t *testing.T) {
mockCtrl := gomock.NewController(t)
tc := newServerTestConnection(t, mockCtrl, nil, false)
@@ -943,14 +768,12 @@ func TestConnectionMaxUnprocessedPackets(t *testing.T) {
func TestConnectionRemoteClose(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockStreamManager := NewMockStreamManager(mockCtrl)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptStreamManager(mockStreamManager),
connectionOptTracer(tr),
connectionOptUnpacker(unpacker),
)
@@ -964,8 +787,6 @@ func TestConnectionRemoteClose(t *testing.T) {
expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, Remote: true}
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any())
streamErrChan := make(chan error, 1)
mockStreamManager.EXPECT().CloseWithError(gomock.Any()).Do(func(e error) { streamErrChan <- e })
tracerErrChan := make(chan error, 1)
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { tracerErrChan <- e })
tracer.EXPECT().Close()
@@ -988,12 +809,6 @@ func TestConnectionRemoteClose(t *testing.T) {
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case err := <-streamErrChan:
require.ErrorIs(t, err, expectedErr)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionIdleTimeoutDuringHandshake(t *testing.T) {
@@ -1061,6 +876,8 @@ func TestConnectionTransportParameters(t *testing.T) {
)
_, err := tc.conn.OpenStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
_, err = tc.conn.OpenUniStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
tracer.EXPECT().ReceivedTransportParameters(gomock.Any())
params := &wire.TransportParameters{
MaxIdleTimeout: 90 * time.Second,
@@ -1071,11 +888,78 @@ func TestConnectionTransportParameters(t *testing.T) {
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
OriginalDestinationConnectionID: tc.destConnID,
MaxBidiStreamNum: 1,
MaxUniStreamNum: 1,
}
require.NoError(t, tc.conn.handleTransportParameters(params))
require.Equal(t, protocol.ByteCount(1337), connFC.SendWindowSize())
_, err = tc.conn.OpenStream()
require.NoError(t, err)
_, err = tc.conn.OpenUniStream()
require.NoError(t, err)
}
func TestConnectionHandleMaxStreamsFrame(t *testing.T) {
mockCtrl := gomock.NewController(t)
connFC := flowcontrol.NewConnectionFlowController(0, 0, nil, &utils.RTTStats{}, utils.DefaultLogger)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC))
tc.conn.handleTransportParameters(&wire.TransportParameters{})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
uniStreamChan := make(chan error)
go func() {
_, err := tc.conn.OpenUniStreamSync(ctx)
uniStreamChan <- err
}()
bidiStreamChan := make(chan error)
go func() {
_, err := tc.conn.OpenStreamSync(ctx)
bidiStreamChan <- err
}()
select {
case <-uniStreamChan:
t.Fatal("uni stream should be blocked")
case <-bidiStreamChan:
t.Fatal("bidi stream should be blocked")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
// MAX_STREAMS frame for bidirectional stream
_, err := tc.conn.handleFrame(
&wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10},
protocol.Encryption1RTT,
protocol.ConnectionID{},
time.Now(),
)
require.NoError(t, err)
select {
case <-uniStreamChan:
t.Fatal("uni stream should be blocked")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
select {
case err := <-bidiStreamChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// MAX_STREAMS frame for bidirectional stream
_, err = tc.conn.handleFrame(
&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 10},
protocol.Encryption1RTT,
protocol.ConnectionID{},
time.Now(),
)
require.NoError(t, err)
select {
case err := <-uniStreamChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionTransportParameterValidationFailureServer(t *testing.T) {
@@ -3075,10 +2959,11 @@ func testConnectionMigration(t *testing.T, enabled bool) {
).AnyTimes()
tc.connRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any())
// add a new connection ID, so the path can be probed
require.NoError(t, tc.conn.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{
_, err = tc.conn.handleFrame(&wire.NewConnectionIDFrame{
SequenceNumber: 1,
ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
}))
}, protocol.EncryptionInitial, tc.destConnID, time.Now())
require.NoError(t, err)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()

View File

@@ -1,573 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/quic-go/quic-go (interfaces: StreamManager)
//
// Generated by this command:
//
// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager
//
// Package quic is a generated GoMock package.
package quic
import (
context "context"
reflect "reflect"
protocol "github.com/quic-go/quic-go/internal/protocol"
wire "github.com/quic-go/quic-go/internal/wire"
gomock "go.uber.org/mock/gomock"
)
// MockStreamManager is a mock of StreamManager interface.
type MockStreamManager struct {
ctrl *gomock.Controller
recorder *MockStreamManagerMockRecorder
isgomock struct{}
}
// MockStreamManagerMockRecorder is the mock recorder for MockStreamManager.
type MockStreamManagerMockRecorder struct {
mock *MockStreamManager
}
// NewMockStreamManager creates a new mock instance.
func NewMockStreamManager(ctrl *gomock.Controller) *MockStreamManager {
mock := &MockStreamManager{ctrl: ctrl}
mock.recorder = &MockStreamManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder {
return m.recorder
}
// AcceptStream mocks base method.
func (m *MockStreamManager) AcceptStream(arg0 context.Context) (*Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptStream", arg0)
ret0, _ := ret[0].(*Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptStream indicates an expected call of AcceptStream.
func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 any) *MockStreamManagerAcceptStreamCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0)
return &MockStreamManagerAcceptStreamCall{Call: call}
}
// MockStreamManagerAcceptStreamCall wrap *gomock.Call
type MockStreamManagerAcceptStreamCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerAcceptStreamCall) Return(arg0 *Stream, arg1 error) *MockStreamManagerAcceptStreamCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerAcceptStreamCall) Do(f func(context.Context) (*Stream, error)) *MockStreamManagerAcceptStreamCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerAcceptStreamCall) DoAndReturn(f func(context.Context) (*Stream, error)) *MockStreamManagerAcceptStreamCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// AcceptUniStream mocks base method.
func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (*ReceiveStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
ret0, _ := ret[0].(*ReceiveStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptUniStream indicates an expected call of AcceptUniStream.
func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 any) *MockStreamManagerAcceptUniStreamCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0)
return &MockStreamManagerAcceptUniStreamCall{Call: call}
}
// MockStreamManagerAcceptUniStreamCall wrap *gomock.Call
type MockStreamManagerAcceptUniStreamCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerAcceptUniStreamCall) Return(arg0 *ReceiveStream, arg1 error) *MockStreamManagerAcceptUniStreamCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerAcceptUniStreamCall) Do(f func(context.Context) (*ReceiveStream, error)) *MockStreamManagerAcceptUniStreamCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerAcceptUniStreamCall) DoAndReturn(f func(context.Context) (*ReceiveStream, error)) *MockStreamManagerAcceptUniStreamCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// CloseWithError mocks base method.
func (m *MockStreamManager) CloseWithError(arg0 error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "CloseWithError", arg0)
}
// CloseWithError indicates an expected call of CloseWithError.
func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 any) *MockStreamManagerCloseWithErrorCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0)
return &MockStreamManagerCloseWithErrorCall{Call: call}
}
// MockStreamManagerCloseWithErrorCall wrap *gomock.Call
type MockStreamManagerCloseWithErrorCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerCloseWithErrorCall) Return() *MockStreamManagerCloseWithErrorCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerCloseWithErrorCall) Do(f func(error)) *MockStreamManagerCloseWithErrorCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerCloseWithErrorCall) DoAndReturn(f func(error)) *MockStreamManagerCloseWithErrorCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// DeleteStream mocks base method.
func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteStream", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteStream indicates an expected call of DeleteStream.
func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 any) *MockStreamManagerDeleteStreamCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0)
return &MockStreamManagerDeleteStreamCall{Call: call}
}
// MockStreamManagerDeleteStreamCall wrap *gomock.Call
type MockStreamManagerDeleteStreamCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerDeleteStreamCall) Return(arg0 error) *MockStreamManagerDeleteStreamCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerDeleteStreamCall) Do(f func(protocol.StreamID) error) *MockStreamManagerDeleteStreamCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerDeleteStreamCall) DoAndReturn(f func(protocol.StreamID) error) *MockStreamManagerDeleteStreamCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// GetOrOpenReceiveStream mocks base method.
func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (*ReceiveStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0)
ret0, _ := ret[0].(*ReceiveStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream.
func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 any) *MockStreamManagerGetOrOpenReceiveStreamCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0)
return &MockStreamManagerGetOrOpenReceiveStreamCall{Call: call}
}
// MockStreamManagerGetOrOpenReceiveStreamCall wrap *gomock.Call
type MockStreamManagerGetOrOpenReceiveStreamCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerGetOrOpenReceiveStreamCall) Return(arg0 *ReceiveStream, arg1 error) *MockStreamManagerGetOrOpenReceiveStreamCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerGetOrOpenReceiveStreamCall) Do(f func(protocol.StreamID) (*ReceiveStream, error)) *MockStreamManagerGetOrOpenReceiveStreamCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerGetOrOpenReceiveStreamCall) DoAndReturn(f func(protocol.StreamID) (*ReceiveStream, error)) *MockStreamManagerGetOrOpenReceiveStreamCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// GetOrOpenSendStream mocks base method.
func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (*SendStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
ret0, _ := ret[0].(*SendStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream.
func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 any) *MockStreamManagerGetOrOpenSendStreamCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0)
return &MockStreamManagerGetOrOpenSendStreamCall{Call: call}
}
// MockStreamManagerGetOrOpenSendStreamCall wrap *gomock.Call
type MockStreamManagerGetOrOpenSendStreamCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerGetOrOpenSendStreamCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerGetOrOpenSendStreamCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerGetOrOpenSendStreamCall) Do(f func(protocol.StreamID) (*SendStream, error)) *MockStreamManagerGetOrOpenSendStreamCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerGetOrOpenSendStreamCall) DoAndReturn(f func(protocol.StreamID) (*SendStream, error)) *MockStreamManagerGetOrOpenSendStreamCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// HandleMaxStreamsFrame mocks base method.
func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "HandleMaxStreamsFrame", arg0)
}
// HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame.
func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 any) *MockStreamManagerHandleMaxStreamsFrameCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0)
return &MockStreamManagerHandleMaxStreamsFrameCall{Call: call}
}
// MockStreamManagerHandleMaxStreamsFrameCall wrap *gomock.Call
type MockStreamManagerHandleMaxStreamsFrameCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerHandleMaxStreamsFrameCall) Return() *MockStreamManagerHandleMaxStreamsFrameCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerHandleMaxStreamsFrameCall) Do(f func(*wire.MaxStreamsFrame)) *MockStreamManagerHandleMaxStreamsFrameCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerHandleMaxStreamsFrameCall) DoAndReturn(f func(*wire.MaxStreamsFrame)) *MockStreamManagerHandleMaxStreamsFrameCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// OpenStream mocks base method.
func (m *MockStreamManager) OpenStream() (*Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStream")
ret0, _ := ret[0].(*Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStream indicates an expected call of OpenStream.
func (mr *MockStreamManagerMockRecorder) OpenStream() *MockStreamManagerOpenStreamCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamManager)(nil).OpenStream))
return &MockStreamManagerOpenStreamCall{Call: call}
}
// MockStreamManagerOpenStreamCall wrap *gomock.Call
type MockStreamManagerOpenStreamCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerOpenStreamCall) Return(arg0 *Stream, arg1 error) *MockStreamManagerOpenStreamCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerOpenStreamCall) Do(f func() (*Stream, error)) *MockStreamManagerOpenStreamCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerOpenStreamCall) DoAndReturn(f func() (*Stream, error)) *MockStreamManagerOpenStreamCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// OpenStreamSync mocks base method.
func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (*Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStreamSync", arg0)
ret0, _ := ret[0].(*Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStreamSync indicates an expected call of OpenStreamSync.
func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 any) *MockStreamManagerOpenStreamSyncCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0)
return &MockStreamManagerOpenStreamSyncCall{Call: call}
}
// MockStreamManagerOpenStreamSyncCall wrap *gomock.Call
type MockStreamManagerOpenStreamSyncCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerOpenStreamSyncCall) Return(arg0 *Stream, arg1 error) *MockStreamManagerOpenStreamSyncCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerOpenStreamSyncCall) Do(f func(context.Context) (*Stream, error)) *MockStreamManagerOpenStreamSyncCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerOpenStreamSyncCall) DoAndReturn(f func(context.Context) (*Stream, error)) *MockStreamManagerOpenStreamSyncCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// OpenUniStream mocks base method.
func (m *MockStreamManager) OpenUniStream() (*SendStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenUniStream")
ret0, _ := ret[0].(*SendStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenUniStream indicates an expected call of OpenUniStream.
func (mr *MockStreamManagerMockRecorder) OpenUniStream() *MockStreamManagerOpenUniStreamCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStream))
return &MockStreamManagerOpenUniStreamCall{Call: call}
}
// MockStreamManagerOpenUniStreamCall wrap *gomock.Call
type MockStreamManagerOpenUniStreamCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerOpenUniStreamCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerOpenUniStreamCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerOpenUniStreamCall) Do(f func() (*SendStream, error)) *MockStreamManagerOpenUniStreamCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerOpenUniStreamCall) DoAndReturn(f func() (*SendStream, error)) *MockStreamManagerOpenUniStreamCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// OpenUniStreamSync mocks base method.
func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (*SendStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
ret0, _ := ret[0].(*SendStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync.
func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 any) *MockStreamManagerOpenUniStreamSyncCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0)
return &MockStreamManagerOpenUniStreamSyncCall{Call: call}
}
// MockStreamManagerOpenUniStreamSyncCall wrap *gomock.Call
type MockStreamManagerOpenUniStreamSyncCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerOpenUniStreamSyncCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerOpenUniStreamSyncCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerOpenUniStreamSyncCall) Do(f func(context.Context) (*SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerOpenUniStreamSyncCall) DoAndReturn(f func(context.Context) (*SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// ResetFor0RTT mocks base method.
func (m *MockStreamManager) ResetFor0RTT() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ResetFor0RTT")
}
// ResetFor0RTT indicates an expected call of ResetFor0RTT.
func (mr *MockStreamManagerMockRecorder) ResetFor0RTT() *MockStreamManagerResetFor0RTTCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetFor0RTT", reflect.TypeOf((*MockStreamManager)(nil).ResetFor0RTT))
return &MockStreamManagerResetFor0RTTCall{Call: call}
}
// MockStreamManagerResetFor0RTTCall wrap *gomock.Call
type MockStreamManagerResetFor0RTTCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerResetFor0RTTCall) Return() *MockStreamManagerResetFor0RTTCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerResetFor0RTTCall) Do(f func()) *MockStreamManagerResetFor0RTTCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerResetFor0RTTCall) DoAndReturn(f func()) *MockStreamManagerResetFor0RTTCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// UpdateLimits mocks base method.
func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "UpdateLimits", arg0)
}
// UpdateLimits indicates an expected call of UpdateLimits.
func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 any) *MockStreamManagerUpdateLimitsCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0)
return &MockStreamManagerUpdateLimitsCall{Call: call}
}
// MockStreamManagerUpdateLimitsCall wrap *gomock.Call
type MockStreamManagerUpdateLimitsCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerUpdateLimitsCall) Return() *MockStreamManagerUpdateLimitsCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerUpdateLimitsCall) Do(f func(*wire.TransportParameters)) *MockStreamManagerUpdateLimitsCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerUpdateLimitsCall) DoAndReturn(f func(*wire.TransportParameters)) *MockStreamManagerUpdateLimitsCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// UseResetMaps mocks base method.
func (m *MockStreamManager) UseResetMaps() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "UseResetMaps")
}
// UseResetMaps indicates an expected call of UseResetMaps.
func (mr *MockStreamManagerMockRecorder) UseResetMaps() *MockStreamManagerUseResetMapsCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UseResetMaps", reflect.TypeOf((*MockStreamManager)(nil).UseResetMaps))
return &MockStreamManagerUseResetMapsCall{Call: call}
}
// MockStreamManagerUseResetMapsCall wrap *gomock.Call
type MockStreamManagerUseResetMapsCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamManagerUseResetMapsCall) Return() *MockStreamManagerUseResetMapsCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamManagerUseResetMapsCall) Do(f func()) *MockStreamManagerUseResetMapsCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamManagerUseResetMapsCall) DoAndReturn(f func()) *MockStreamManagerUseResetMapsCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

View File

@@ -26,9 +26,6 @@ type FrameSource = frameSource
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource"
type AckFrameSource = ackFrameSource
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager"
type StreamManager = streamManager
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager"
type SealingManager = sealingManager

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"sync"
"time"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
@@ -37,8 +38,6 @@ type streamsMap struct {
reset bool
}
var _ streamManager = &streamsMap{}
func newStreamsMap(
ctx context.Context,
sender streamSender,
@@ -180,90 +179,6 @@ func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
panic("")
}
func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStream, error) {
str, err := m.getOrOpenReceiveStream(id)
if err != nil {
return nil, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: err.Error(),
}
}
return str, nil
}
func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStream, error) {
switch id.Type() {
case protocol.StreamTypeUni:
if id.InitiatedBy() == m.perspective {
// an outgoing unidirectional stream is a send stream, not a receive stream
return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
}
return m.incomingUniStreams.GetOrOpenStream(id)
case protocol.StreamTypeBidi:
if id.InitiatedBy() == m.perspective {
str, err := m.outgoingBidiStreams.GetStream(id)
if str == nil && err == nil {
return nil, nil
}
return str.ReceiveStream, err
} else {
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
if str == nil && err == nil {
return nil, nil
}
return str.ReceiveStream, err
}
}
panic("")
}
func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (*SendStream, error) {
str, err := m.getOrOpenSendStream(id)
if err != nil {
return nil, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: err.Error(),
}
}
return str, nil
}
func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (*SendStream, error) {
switch id.Type() {
case protocol.StreamTypeUni:
if id.InitiatedBy() == m.perspective {
str, err := m.outgoingUniStreams.GetStream(id)
if str == nil && err == nil {
return nil, nil
}
return str, err
}
// an incoming unidirectional stream is a receive stream, not a send stream
return nil, fmt.Errorf("peer attempted to open send stream %d", id)
case protocol.StreamTypeBidi:
if id.InitiatedBy() == m.perspective {
str, err := m.outgoingBidiStreams.GetStream(id)
if str == nil && err == nil {
return nil, nil
}
if err != nil {
return nil, err
}
return str.SendStream, nil
} else {
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
if str == nil && err == nil {
return nil, nil
}
if err != nil {
return nil, err
}
return str.SendStream, nil
}
}
panic("")
}
func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
switch f.Type {
case protocol.StreamTypeUni:
@@ -273,6 +188,134 @@ func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
}
}
type sendStreamFrameHandler interface {
updateSendWindow(protocol.ByteCount)
handleStopSendingFrame(*wire.StopSendingFrame)
}
func (m *streamsMap) getSendStream(id protocol.StreamID) (sendStreamFrameHandler, error) {
switch id.Type() {
case protocol.StreamTypeUni:
if id.InitiatedBy() != m.perspective {
// an outgoing unidirectional stream is a send stream, not a receive stream
return nil, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("invalid frame for send stream %d", id),
}
}
str, err := m.outgoingUniStreams.GetStream(id)
if str == nil || err != nil {
return nil, err
}
return str, nil
case protocol.StreamTypeBidi:
if id.InitiatedBy() == m.perspective {
str, err := m.outgoingBidiStreams.GetStream(id)
if str == nil || err != nil {
return nil, err
}
return str, nil
}
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
if str == nil || err != nil {
return nil, err
}
return str, nil
}
panic("unreachable")
}
func (m *streamsMap) HandleMaxStreamDataFrame(f *wire.MaxStreamDataFrame) error {
str, err := m.getSendStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
str.updateSendWindow(f.MaximumStreamData)
return nil
}
func (m *streamsMap) HandleStopSendingFrame(f *wire.StopSendingFrame) error {
str, err := m.getSendStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
str.handleStopSendingFrame(f)
return nil
}
type receiveStreamFrameHandler interface {
handleResetStreamFrame(*wire.ResetStreamFrame, time.Time) error
handleStreamFrame(*wire.StreamFrame, time.Time) error
}
func (m *streamsMap) getReceiveStream(id protocol.StreamID) (receiveStreamFrameHandler, error) {
switch id.Type() {
case protocol.StreamTypeUni:
// an outgoing unidirectional stream is a send stream, not a receive stream
if id.InitiatedBy() == m.perspective {
return nil, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("invalid frame for receive stream %d", id),
}
}
str, err := m.incomingUniStreams.GetOrOpenStream(id)
if err != nil || str == nil {
return nil, err
}
return str, nil
case protocol.StreamTypeBidi:
var str *Stream
var err error
if id.InitiatedBy() == m.perspective {
str, err = m.outgoingBidiStreams.GetStream(id)
} else {
str, err = m.incomingBidiStreams.GetOrOpenStream(id)
}
if str == nil || err != nil {
return nil, err
}
return str, nil
}
panic("unreachable")
}
func (m *streamsMap) HandleStreamDataBlockedFrame(f *wire.StreamDataBlockedFrame) error {
if _, err := m.getReceiveStream(f.StreamID); err != nil {
return err
}
// We don't need to do anything in response to a STREAM_DATA_BLOCKED frame,
// but we need to make sure that the stream ID is valid.
return nil // we don't need to do anything in response to a STREAM_DATA_BLOCKED frame
}
func (m *streamsMap) HandleResetStreamFrame(f *wire.ResetStreamFrame, rcvTime time.Time) error {
str, err := m.getReceiveStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
return str.handleResetStreamFrame(f, rcvTime)
}
func (m *streamsMap) HandleStreamFrame(f *wire.StreamFrame, rcvTime time.Time) error {
str, err := m.getReceiveStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
return str.handleStreamFrame(f, rcvTime)
}
func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective))

View File

@@ -6,6 +6,7 @@ import (
"sync"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
@@ -115,7 +116,10 @@ func (m *incomingStreamsMap[T]) GetOrOpenStream(id protocol.StreamID) (T, error)
m.mutex.RLock()
if id > m.maxStream {
m.mutex.RUnlock()
return *new(T), fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream)
return *new(T), &qerr.TransportError{
ErrorCode: qerr.StreamLimitError,
ErrorMessage: fmt.Sprintf("peer tried to open stream %d (current limit: %d)", id, m.maxStream),
}
}
// if the num is smaller than the highest we accepted
// * this stream exists in the map, and we can return it, or
@@ -152,7 +156,13 @@ func (m *incomingStreamsMap[T]) DeleteStream(id protocol.StreamID) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteStream(id)
if err := m.deleteStream(id); err != nil {
return &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: err.Error(),
}
}
return nil
}
func (m *incomingStreamsMap[T]) deleteStream(id protocol.StreamID) error {

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/assert"
@@ -80,6 +81,7 @@ func testStreamsMapIncomingGettingStreams(t *testing.T, perspective protocol.Per
_, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams - 4)
require.NoError(t, err)
_, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamLimitError})
require.ErrorContains(t, err, "peer tried to open stream")
require.Equal(t, maxNumStreams, newStreamCounter)
}
@@ -151,6 +153,7 @@ func testStreamsMapIncomingDeletingStreams(t *testing.T, perspective protocol.Pe
perspective,
)
err := m.DeleteStream(firstStream + 1337*4)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, "tried to delete unknown incoming stream")
s, err := m.GetOrOpenStream(firstStream + 4)

View File

@@ -7,6 +7,7 @@ import (
"sync"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
@@ -153,7 +154,10 @@ func (m *outgoingStreamsMap[T]) GetStream(id protocol.StreamID) (T, error) {
m.mutex.RLock()
if id >= m.nextStream {
m.mutex.RUnlock()
return *new(T), fmt.Errorf("peer attempted to open stream %d", id)
return *new(T), &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
}
}
s := m.streams[id]
m.mutex.RUnlock()
@@ -165,7 +169,10 @@ func (m *outgoingStreamsMap[T]) DeleteStream(id protocol.StreamID) error {
defer m.mutex.Unlock()
if _, ok := m.streams[id]; !ok {
return fmt.Errorf("tried to delete unknown outgoing stream %d", id)
return &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("tried to delete unknown outgoing stream %d", id),
}
}
delete(m.streams, id)
return nil

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/assert"
@@ -34,6 +35,7 @@ func testStreamsMapOutgoingOpenAndDelete(t *testing.T, perspective protocol.Pers
m.SetMaxStream(protocol.MaxStreamID)
_, err := m.GetStream(firstStream)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstStream))
str1, err := m.OpenStream()
@@ -54,14 +56,14 @@ func testStreamsMapOutgoingOpenAndDelete(t *testing.T, perspective protocol.Pers
err = m.DeleteStream(firstStream + 1337*4)
require.Error(t, err)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, "tried to delete unknown outgoing stream")
require.NoError(t, m.DeleteStream(firstStream))
// deleting the same stream twice will fail
require.ErrorContains(t,
m.DeleteStream(firstStream),
"tried to delete unknown outgoing stream",
)
err = m.DeleteStream(firstStream)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, "tried to delete unknown outgoing stream")
// after deleting the stream it's not available anymore
str, err := m.GetStream(firstStream)
require.NoError(t, err)
@@ -98,7 +100,6 @@ func testStreamsMapOutgoingLimits(t *testing.T, perspective protocol.Perspective
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = m.OpenStreamSync(ctx)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
// OpenStreamSync blocks until the context is canceled...

View File

@@ -19,7 +19,7 @@ import (
func TestStreamsMapCreatingStreams(t *testing.T) {
t.Run("client", func(t *testing.T) {
testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveClient,
testStreamsMapCreatingStreams(t, protocol.PerspectiveClient,
protocol.FirstIncomingBidiStreamClient,
protocol.FirstOutgoingBidiStreamClient,
protocol.FirstIncomingUniStreamClient,
@@ -27,7 +27,7 @@ func TestStreamsMapCreatingStreams(t *testing.T) {
)
})
t.Run("server", func(t *testing.T) {
testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveServer,
testStreamsMapCreatingStreams(t, protocol.PerspectiveServer,
protocol.FirstIncomingBidiStreamServer,
protocol.FirstOutgoingBidiStreamServer,
protocol.FirstIncomingUniStreamServer,
@@ -36,7 +36,7 @@ func TestStreamsMapCreatingStreams(t *testing.T) {
})
}
func testStreamsMapCreatingAndDeletingStreams(t *testing.T,
func testStreamsMapCreatingStreams(t *testing.T,
perspective protocol.Perspective,
firstIncomingBidiStream protocol.StreamID,
firstOutgoingBidiStream protocol.StreamID,
@@ -50,7 +50,9 @@ func testStreamsMapCreatingAndDeletingStreams(t *testing.T,
mockSender,
func(wire.Frame) {},
func(protocol.StreamID) flowcontrol.StreamFlowController {
return mocks.NewMockStreamFlowController(mockCtrl)
fc := mocks.NewMockStreamFlowController(mockCtrl)
fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
return fc
},
1,
1,
@@ -76,25 +78,9 @@ func testStreamsMapCreatingAndDeletingStreams(t *testing.T,
assert.Equal(t, ustr1.StreamID(), firstOutgoingUniStream)
assert.Equal(t, ustr2.StreamID(), firstOutgoingUniStream+4)
// accepting streams:
// This function is called when a frame referencing this stream is received.
// The peer may open a peer-initiated stream...
_, err = m.GetOrOpenReceiveStream(firstIncomingBidiStream)
require.NoError(t, err)
_, err = m.GetOrOpenReceiveStream(firstIncomingUniStream)
require.NoError(t, err)
// ... but not a stream that is initiated by us.
_, err = m.GetOrOpenSendStream(firstOutgoingBidiStream + 8)
require.ErrorIs(t, err, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream+8),
})
_, err = m.GetOrOpenSendStream(firstOutgoingUniStream + 8)
require.ErrorIs(t, err, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", firstOutgoingUniStream+8),
})
// accepting streams is triggered by receiving a frame referencing this stream
require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingBidiStream}, time.Now()))
require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingUniStream}, time.Now()))
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
@@ -141,7 +127,9 @@ func testStreamsMapDeletingStreams(t *testing.T,
mockSender,
func(frame wire.Frame) { frameQueue = append(frameQueue, frame) },
func(protocol.StreamID) flowcontrol.StreamFlowController {
return mocks.NewMockStreamFlowController(mockCtrl)
fc := mocks.NewMockStreamFlowController(mockCtrl)
fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
return fc
},
100,
100,
@@ -155,37 +143,25 @@ func testStreamsMapDeletingStreams(t *testing.T,
_, err := m.OpenStream()
require.NoError(t, err)
require.NoError(t, m.DeleteStream(firstOutgoingBidiStream))
sstr, err := m.GetOrOpenSendStream(firstOutgoingBidiStream)
require.NoError(t, err)
require.Nil(t, sstr)
require.ErrorContains(t,
m.DeleteStream(firstOutgoingBidiStream+400),
fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingBidiStream+400),
)
err = m.DeleteStream(firstOutgoingBidiStream + 400)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingBidiStream+400))
_, err = m.OpenUniStream()
require.NoError(t, err)
require.NoError(t, m.DeleteStream(firstOutgoingUniStream))
sstr, err = m.GetOrOpenSendStream(firstOutgoingUniStream)
require.NoError(t, err)
require.Nil(t, sstr)
require.ErrorContains(t,
m.DeleteStream(firstOutgoingUniStream+400),
fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingUniStream+400),
)
err = m.DeleteStream(firstOutgoingUniStream + 400)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingUniStream+400))
require.Empty(t, frameQueue)
// deleting incoming bidirectional streams
_, err = m.GetOrOpenReceiveStream(firstIncomingBidiStream)
require.NoError(t, err)
require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingBidiStream}, time.Now()))
require.NoError(t, m.DeleteStream(firstIncomingBidiStream))
sstr, err = m.GetOrOpenSendStream(firstIncomingBidiStream)
require.NoError(t, err)
require.Nil(t, sstr)
require.ErrorContains(t,
m.DeleteStream(firstIncomingBidiStream+400),
fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingBidiStream+400),
)
err = m.DeleteStream(firstIncomingBidiStream + 400)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingBidiStream+400))
// the MAX_STREAMS frame is only queued once the stream is accepted
require.Empty(t, frameQueue)
_, err = m.AcceptStream(context.Background())
@@ -200,16 +176,12 @@ func testStreamsMapDeletingStreams(t *testing.T,
frameQueue = frameQueue[:0]
// deleting incoming unidirectional streams
_, err = m.GetOrOpenReceiveStream(firstIncomingUniStream)
require.NoError(t, err)
require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingUniStream}, time.Now()))
require.NoError(t, m.DeleteStream(firstIncomingUniStream))
rstr, err := m.GetOrOpenReceiveStream(firstIncomingUniStream)
require.NoError(t, err)
require.Nil(t, rstr)
require.ErrorContains(t,
m.DeleteStream(firstIncomingUniStream+400),
fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingUniStream+400),
)
err = m.DeleteStream(firstIncomingUniStream + 400)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingUniStream+400))
// the MAX_STREAMS frame is only queued once the stream is accepted
require.Empty(t, frameQueue)
_, err = m.AcceptUniStream(context.Background())
@@ -254,6 +226,7 @@ func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective)
// increase via transport parameters
_, err := m.OpenStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
require.ErrorContains(t, err, "too many open streams")
m.UpdateLimits(&wire.TransportParameters{MaxBidiStreamNum: 1})
_, err = m.OpenStream()
require.NoError(t, err)
@@ -293,6 +266,239 @@ func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective)
require.ErrorIs(t, err, &StreamLimitReachedError{})
}
func TestStreamsMapHandleReceiveStreamFrames(t *testing.T) {
for _, pers := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} {
t.Run(pers.String(), func(t *testing.T) {
t.Run("STREAM frame", func(t *testing.T) {
testStreamsMapHandleReceiveStreamFrames(t,
pers,
func(m *streamsMap, id protocol.StreamID) error {
return m.HandleStreamFrame(&wire.StreamFrame{StreamID: id}, time.Now())
},
)
})
t.Run("STREAM_DATA_BLOCKED frame", func(t *testing.T) {
testStreamsMapHandleReceiveStreamFrames(t,
pers,
func(m *streamsMap, id protocol.StreamID) error {
return m.HandleStreamDataBlockedFrame(&wire.StreamDataBlockedFrame{StreamID: id})
},
)
})
t.Run("RESET_STREAM frame", func(t *testing.T) {
testStreamsMapHandleReceiveStreamFrames(t,
pers,
func(m *streamsMap, id protocol.StreamID) error {
return m.HandleResetStreamFrame(&wire.ResetStreamFrame{StreamID: id}, time.Now())
},
)
})
})
}
}
func testStreamsMapHandleReceiveStreamFrames(t *testing.T, pers protocol.Perspective, handleFrame func(*streamsMap, protocol.StreamID) error) {
mockCtrl := gomock.NewController(t)
mockSender := NewMockStreamSender(mockCtrl)
var streamsCreated []protocol.StreamID
m := newStreamsMap(
context.Background(),
mockSender,
func(frame wire.Frame) {},
func(id protocol.StreamID) flowcontrol.StreamFlowController {
streamsCreated = append(streamsCreated, id)
fc := mocks.NewMockStreamFlowController(mockCtrl)
fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
fc.EXPECT().Abandon().AnyTimes()
return fc
},
100,
100,
pers,
)
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: protocol.MaxStreamCount})
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount})
var firstOutgoingUniStream, firstOutgoingBidiStream, firstIncomingUniStream, firstIncomingBidiStream protocol.StreamID
if pers == protocol.PerspectiveClient {
firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamClient
firstOutgoingUniStream = protocol.FirstOutgoingUniStreamClient
firstIncomingUniStream = protocol.FirstIncomingUniStreamClient
firstIncomingBidiStream = protocol.FirstIncomingBidiStreamClient
} else {
firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamServer
firstOutgoingUniStream = protocol.FirstOutgoingUniStreamServer
firstIncomingUniStream = protocol.FirstIncomingUniStreamServer
firstIncomingBidiStream = protocol.FirstIncomingBidiStreamServer
}
// 1. The peer can't open a unidirectional send stream...
err := handleFrame(m, firstOutgoingUniStream)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("invalid frame for receive stream %d", firstOutgoingUniStream))
require.Empty(t, streamsCreated)
// ... and a STREAM frame for a unidirectional send stream is invalid even if the stream is open.
_, err = m.OpenUniStream()
require.NoError(t, err)
err = handleFrame(m, firstOutgoingUniStream)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("invalid frame for receive stream %d", firstOutgoingUniStream))
streamsCreated = streamsCreated[:0]
// 2. The peer can't open a bidirectional stream initiated by us...
err = handleFrame(m, firstOutgoingBidiStream)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream))
require.Empty(t, streamsCreated)
// ... but it's valid once we have opened the stream.
_, err = m.OpenStream()
require.NoError(t, err)
require.NoError(t, handleFrame(m, firstOutgoingBidiStream))
streamsCreated = streamsCreated[:0]
// Delayed frames for deleted streams are absorbed.
require.NoError(t, m.DeleteStream(firstOutgoingBidiStream))
require.NoError(t, handleFrame(m, firstOutgoingBidiStream))
require.Empty(t, streamsCreated)
// 3. The peer can send STREAM frames for unidirectional receive streams,
// as long as they're below the stream limit.
require.ErrorIs(t,
handleFrame(m, firstIncomingUniStream+400),
&qerr.TransportError{ErrorCode: qerr.StreamLimitError},
)
require.Empty(t, streamsCreated)
require.NoError(t, handleFrame(m, firstIncomingUniStream))
require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingUniStream})
streamsCreated = streamsCreated[:0]
// Delayed frames for deleted streams are absorbed.
require.NoError(t, m.DeleteStream(firstIncomingUniStream))
require.NoError(t, handleFrame(m, firstIncomingUniStream))
require.Empty(t, streamsCreated)
// 4. The peer can send STREAM frames for bidirectional receive streams,
// as long as they're below the stream limit.
require.ErrorIs(t,
handleFrame(m, firstIncomingBidiStream+400),
&qerr.TransportError{ErrorCode: qerr.StreamLimitError},
)
require.Empty(t, streamsCreated)
require.NoError(t, handleFrame(m, firstIncomingBidiStream))
require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingBidiStream})
}
func TestStreamsMapHandleSendStreamFrames(t *testing.T) {
for _, pers := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} {
t.Run(pers.String(), func(t *testing.T) {
t.Run("STOP_SENDING frame", func(t *testing.T) {
testStreamsMapHandleSendStreamFrames(t,
pers,
func(m *streamsMap, id protocol.StreamID) error {
return m.HandleStopSendingFrame(&wire.StopSendingFrame{StreamID: id})
},
)
})
t.Run("MAX_STREAM_DATA frame", func(t *testing.T) {
testStreamsMapHandleSendStreamFrames(t,
pers,
func(m *streamsMap, id protocol.StreamID) error {
return m.HandleMaxStreamDataFrame(&wire.MaxStreamDataFrame{StreamID: id, MaximumStreamData: 1000})
},
)
})
})
}
}
func testStreamsMapHandleSendStreamFrames(t *testing.T, pers protocol.Perspective, handleFrame func(m *streamsMap, id protocol.StreamID) error) {
mockCtrl := gomock.NewController(t)
mockSender := NewMockStreamSender(mockCtrl)
mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()).AnyTimes()
var streamsCreated []protocol.StreamID
m := newStreamsMap(
context.Background(),
mockSender,
func(frame wire.Frame) {},
func(id protocol.StreamID) flowcontrol.StreamFlowController {
streamsCreated = append(streamsCreated, id)
fc := mocks.NewMockStreamFlowController(mockCtrl)
fc.EXPECT().UpdateSendWindow(gomock.Any()).AnyTimes()
return fc
},
100,
100,
pers,
)
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: protocol.MaxStreamCount})
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount})
var firstOutgoingUniStream, firstOutgoingBidiStream, firstIncomingUniStream, firstIncomingBidiStream protocol.StreamID
if pers == protocol.PerspectiveClient {
firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamClient
firstOutgoingUniStream = protocol.FirstOutgoingUniStreamClient
firstIncomingUniStream = protocol.FirstIncomingUniStreamClient
firstIncomingBidiStream = protocol.FirstIncomingBidiStreamClient
} else {
firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamServer
firstOutgoingUniStream = protocol.FirstOutgoingUniStreamServer
firstIncomingUniStream = protocol.FirstIncomingUniStreamServer
firstIncomingBidiStream = protocol.FirstIncomingBidiStreamServer
}
// 1. The peer can't open a unidirectional send stream...
err := handleFrame(m, firstOutgoingUniStream)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingUniStream))
require.Empty(t, streamsCreated)
// ... but once we have opened the stream, it's valid.
_, err = m.OpenUniStream()
require.NoError(t, err)
require.NoError(t, handleFrame(m, firstOutgoingUniStream))
streamsCreated = streamsCreated[:0]
// Delayed frames for deleted streams are absorbed.
require.NoError(t, m.DeleteStream(firstOutgoingUniStream))
require.NoError(t, handleFrame(m, firstOutgoingUniStream))
require.Empty(t, streamsCreated)
// 2. The peer can't open a bidirectional stream initiated by us...
err = handleFrame(m, firstOutgoingBidiStream)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream))
require.Empty(t, streamsCreated)
// ... but once we have opened the stream, it's valid.
_, err = m.OpenStream()
require.NoError(t, err)
require.NoError(t, handleFrame(m, firstOutgoingBidiStream))
streamsCreated = streamsCreated[:0]
// Delayed frames for deleted streams are absorbed.
require.NoError(t, m.DeleteStream(firstOutgoingBidiStream))
require.NoError(t, handleFrame(m, firstOutgoingBidiStream))
require.Empty(t, streamsCreated)
// 3. The peer can't send STOP_SENDING frames for unidirectional send streams
err = handleFrame(m, firstIncomingUniStream)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("invalid frame for send stream %d", firstIncomingUniStream))
require.Empty(t, streamsCreated)
// 4. The peer can send STOP_SENDING frames for bidirectional receive streams iniated by itself,
// as long as they're below the stream limit.
require.ErrorIs(t,
handleFrame(m, firstIncomingBidiStream+400),
&qerr.TransportError{ErrorCode: qerr.StreamLimitError},
)
require.Empty(t, streamsCreated)
require.NoError(t, handleFrame(m, firstIncomingBidiStream))
require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingBidiStream})
streamsCreated = streamsCreated[:0]
// Delayed frames for deleted streams are absorbed.
require.NoError(t, m.DeleteStream(firstIncomingBidiStream))
require.NoError(t, handleFrame(m, firstIncomingBidiStream))
require.Empty(t, streamsCreated)
}
func TestStreamsMapClosing(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockSender := NewMockStreamSender(mockCtrl)
@@ -366,7 +572,9 @@ func TestStreamsMap0RTTRejection(t *testing.T) {
mockSender,
func(wire.Frame) {},
func(protocol.StreamID) flowcontrol.StreamFlowController {
return mocks.NewMockStreamFlowController(mockCtrl)
fc := mocks.NewMockStreamFlowController(mockCtrl)
fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
return fc
},
1,
1,
@@ -380,10 +588,11 @@ func TestStreamsMap0RTTRejection(t *testing.T) {
require.ErrorIs(t, err, Err0RTTRejected)
_, err = m.AcceptStream(context.Background())
require.ErrorIs(t, err, Err0RTTRejected)
_, err = m.AcceptUniStream(context.Background())
require.ErrorIs(t, err, Err0RTTRejected)
// make sure that we can still get new streams, as the server might be sending us data
str, err := m.GetOrOpenReceiveStream(3)
require.NoError(t, err)
require.NotNil(t, str)
require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: 3}, time.Now()))
// now switch to using the new streams map
m.UseResetMaps()