Files
quic-go/streams_map_test.go
2025-04-19 16:16:47 +02:00

416 lines
12 KiB
Go

package quic
import (
"context"
"fmt"
"testing"
"time"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/mocks"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)
const (
firstIncomingBidiStreamServer protocol.StreamID = 0
firstOutgoingBidiStreamServer protocol.StreamID = 1
firstIncomingUniStreamServer protocol.StreamID = 2
firstOutgoingUniStreamServer protocol.StreamID = 3
)
const (
firstIncomingBidiStreamClient protocol.StreamID = 1
firstOutgoingBidiStreamClient protocol.StreamID = 0
firstIncomingUniStreamClient protocol.StreamID = 3
firstOutgoingUniStreamClient protocol.StreamID = 2
)
func (e streamError) TestError() error {
nums := make([]interface{}, len(e.nums))
for i, num := range e.nums {
nums[i] = num
}
return fmt.Errorf(e.message, nums...)
}
func TestStreamsMapCreatingStreams(t *testing.T) {
t.Run("client", func(t *testing.T) {
testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveClient,
firstIncomingBidiStreamClient,
firstOutgoingBidiStreamClient,
firstIncomingUniStreamClient,
firstOutgoingUniStreamClient,
)
})
t.Run("server", func(t *testing.T) {
testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveServer,
firstIncomingBidiStreamServer,
firstOutgoingBidiStreamServer,
firstIncomingUniStreamServer,
firstOutgoingUniStreamServer,
)
})
}
func testStreamsMapCreatingAndDeletingStreams(t *testing.T,
perspective protocol.Perspective,
firstIncomingBidiStream protocol.StreamID,
firstOutgoingBidiStream protocol.StreamID,
firstIncomingUniStream protocol.StreamID,
firstOutgoingUniStream protocol.StreamID,
) {
mockCtrl := gomock.NewController(t)
mockSender := NewMockStreamSender(mockCtrl)
m := newStreamsMap(
context.Background(),
mockSender,
func(wire.Frame) {},
func(protocol.StreamID) flowcontrol.StreamFlowController {
return mocks.NewMockStreamFlowController(mockCtrl)
},
1,
1,
perspective,
)
m.UpdateLimits(&wire.TransportParameters{
MaxBidiStreamNum: protocol.MaxStreamCount,
MaxUniStreamNum: protocol.MaxStreamCount,
})
// opening streams
str1, err := m.OpenStream()
require.NoError(t, err)
str2, err := m.OpenStream()
require.NoError(t, err)
ustr1, err := m.OpenUniStream()
require.NoError(t, err)
ustr2, err := m.OpenUniStream()
require.NoError(t, err)
assert.Equal(t, str1.StreamID(), firstOutgoingBidiStream)
assert.Equal(t, str2.StreamID(), firstOutgoingBidiStream+4)
assert.Equal(t, ustr1.StreamID(), firstOutgoingUniStream)
assert.Equal(t, ustr2.StreamID(), firstOutgoingUniStream+4)
// accepting streams:
// This function is called when a frame referencing this stream is received.
// The peer may open a peer-initiated stream...
_, err = m.GetOrOpenReceiveStream(firstIncomingBidiStream)
require.NoError(t, err)
_, err = m.GetOrOpenReceiveStream(firstIncomingUniStream)
require.NoError(t, err)
// ... but not a stream that is initiated by us.
_, err = m.GetOrOpenSendStream(firstOutgoingBidiStream + 8)
require.ErrorIs(t, err, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream+8),
})
_, err = m.GetOrOpenSendStream(firstOutgoingUniStream + 8)
require.ErrorIs(t, err, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", firstOutgoingUniStream+8),
})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
str, err := m.AcceptStream(ctx)
require.NoError(t, err)
ustr, err := m.AcceptUniStream(ctx)
require.NoError(t, err)
assert.Equal(t, str.StreamID(), firstIncomingBidiStream)
assert.Equal(t, ustr.StreamID(), firstIncomingUniStream)
}
func TestStreamsMapDeletingStreams(t *testing.T) {
t.Run("client", func(t *testing.T) {
testStreamsMapDeletingStreams(t, protocol.PerspectiveClient,
firstIncomingBidiStreamClient,
firstOutgoingBidiStreamClient,
firstIncomingUniStreamClient,
firstOutgoingUniStreamClient,
)
})
t.Run("server", func(t *testing.T) {
testStreamsMapDeletingStreams(t, protocol.PerspectiveServer,
firstIncomingBidiStreamServer,
firstOutgoingBidiStreamServer,
firstIncomingUniStreamServer,
firstOutgoingUniStreamServer,
)
})
}
func testStreamsMapDeletingStreams(t *testing.T,
perspective protocol.Perspective,
firstIncomingBidiStream protocol.StreamID,
firstOutgoingBidiStream protocol.StreamID,
firstIncomingUniStream protocol.StreamID,
firstOutgoingUniStream protocol.StreamID,
) {
mockCtrl := gomock.NewController(t)
mockSender := NewMockStreamSender(mockCtrl)
var frameQueue []wire.Frame
m := newStreamsMap(
context.Background(),
mockSender,
func(frame wire.Frame) { frameQueue = append(frameQueue, frame) },
func(protocol.StreamID) flowcontrol.StreamFlowController {
return mocks.NewMockStreamFlowController(mockCtrl)
},
100,
100,
perspective,
)
m.UpdateLimits(&wire.TransportParameters{
MaxBidiStreamNum: 10,
MaxUniStreamNum: 10,
})
_, err := m.OpenStream()
require.NoError(t, err)
require.NoError(t, m.DeleteStream(firstOutgoingBidiStream))
sstr, err := m.GetOrOpenSendStream(firstOutgoingBidiStream)
require.NoError(t, err)
require.Nil(t, sstr)
require.ErrorContains(t,
m.DeleteStream(firstOutgoingBidiStream+400),
fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingBidiStream+400),
)
_, err = m.OpenUniStream()
require.NoError(t, err)
require.NoError(t, m.DeleteStream(firstOutgoingUniStream))
sstr, err = m.GetOrOpenSendStream(firstOutgoingUniStream)
require.NoError(t, err)
require.Nil(t, sstr)
require.ErrorContains(t,
m.DeleteStream(firstOutgoingUniStream+400),
fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingUniStream+400),
)
require.Empty(t, frameQueue)
// deleting incoming bidirectional streams
_, err = m.GetOrOpenReceiveStream(firstIncomingBidiStream)
require.NoError(t, err)
require.NoError(t, m.DeleteStream(firstIncomingBidiStream))
sstr, err = m.GetOrOpenSendStream(firstIncomingBidiStream)
require.NoError(t, err)
require.Nil(t, sstr)
require.ErrorContains(t,
m.DeleteStream(firstIncomingBidiStream+400),
fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingBidiStream+400),
)
// the MAX_STREAMS frame is only queued once the stream is accepted
require.Empty(t, frameQueue)
_, err = m.AcceptStream(context.Background())
require.NoError(t, err)
require.Equal(t, frameQueue, []wire.Frame{
&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreamNum: 101,
},
})
frameQueue = frameQueue[:0]
// deleting incoming unidirectional streams
_, err = m.GetOrOpenReceiveStream(firstIncomingUniStream)
require.NoError(t, err)
require.NoError(t, m.DeleteStream(firstIncomingUniStream))
rstr, err := m.GetOrOpenReceiveStream(firstIncomingUniStream)
require.NoError(t, err)
require.Nil(t, rstr)
require.ErrorContains(t,
m.DeleteStream(firstIncomingUniStream+400),
fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingUniStream+400),
)
// the MAX_STREAMS frame is only queued once the stream is accepted
require.Empty(t, frameQueue)
_, err = m.AcceptUniStream(context.Background())
require.NoError(t, err)
require.Equal(t, frameQueue, []wire.Frame{
&wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreamNum: 101,
},
})
frameQueue = frameQueue[:0]
}
func TestStreamsMapStreamLimits(t *testing.T) {
t.Run("client", func(t *testing.T) {
testStreamsMapStreamLimits(t, protocol.PerspectiveClient)
})
t.Run("server", func(t *testing.T) {
testStreamsMapStreamLimits(t, protocol.PerspectiveServer)
})
}
func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective) {
mockCtrl := gomock.NewController(t)
mockSender := NewMockStreamSender(mockCtrl)
var frameQueue []wire.Frame
m := newStreamsMap(
context.Background(),
mockSender,
func(frame wire.Frame) { frameQueue = append(frameQueue, frame) },
func(protocol.StreamID) flowcontrol.StreamFlowController {
fc := mocks.NewMockStreamFlowController(mockCtrl)
fc.EXPECT().UpdateSendWindow(gomock.Any()).AnyTimes()
return fc
},
100,
100,
perspective,
)
// increase via transport parameters
_, err := m.OpenStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
m.UpdateLimits(&wire.TransportParameters{MaxBidiStreamNum: 1})
_, err = m.OpenStream()
require.NoError(t, err)
_, err = m.OpenStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
_, err = m.OpenUniStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
m.UpdateLimits(&wire.TransportParameters{MaxUniStreamNum: 1})
_, err = m.OpenUniStream()
require.NoError(t, err)
_, err = m.OpenUniStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
// increase via MAX_STREAMS frames
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreamNum: 2,
})
_, err = m.OpenStream()
require.NoError(t, err)
_, err = m.OpenStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreamNum: 2,
})
_, err = m.OpenUniStream()
require.NoError(t, err)
_, err = m.OpenUniStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
// decrease via transport parameters
m.UpdateLimits(&wire.TransportParameters{MaxBidiStreamNum: 0})
_, err = m.OpenStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
}
func TestStreamsMapClosing(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockSender := NewMockStreamSender(mockCtrl)
m := newStreamsMap(
context.Background(),
mockSender,
func(wire.Frame) {},
func(protocol.StreamID) flowcontrol.StreamFlowController {
return mocks.NewMockStreamFlowController(mockCtrl)
},
1,
1,
protocol.PerspectiveClient,
)
m.CloseWithError(assert.AnError)
_, err := m.OpenStream()
require.ErrorIs(t, err, assert.AnError)
_, err = m.OpenUniStream()
require.ErrorIs(t, err, assert.AnError)
_, err = m.AcceptStream(context.Background())
require.ErrorIs(t, err, assert.AnError)
_, err = m.AcceptUniStream(context.Background())
require.ErrorIs(t, err, assert.AnError)
}
func TestStreamsMap0RTT(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockSender := NewMockStreamSender(mockCtrl)
fcBidi := mocks.NewMockStreamFlowController(mockCtrl)
fcUni := mocks.NewMockStreamFlowController(mockCtrl)
fcs := []flowcontrol.StreamFlowController{fcBidi, fcUni}
m := newStreamsMap(
context.Background(),
mockSender,
func(wire.Frame) {},
func(protocol.StreamID) flowcontrol.StreamFlowController {
fc := fcs[0]
fcs = fcs[1:]
return fc
},
1,
1,
protocol.PerspectiveClient,
)
// restored transport parameters
m.UpdateLimits(&wire.TransportParameters{
MaxBidiStreamNum: 1,
MaxUniStreamNum: 1,
})
_, err := m.OpenStream()
require.NoError(t, err)
_, err = m.OpenUniStream()
require.NoError(t, err)
fcBidi.EXPECT().UpdateSendWindow(protocol.ByteCount(1234))
fcUni.EXPECT().UpdateSendWindow(protocol.ByteCount(4321))
// new transport parameters
m.UpdateLimits(&wire.TransportParameters{
MaxBidiStreamNum: 1000,
InitialMaxStreamDataBidiRemote: 1234,
MaxUniStreamNum: 1000,
InitialMaxStreamDataUni: 4321,
})
}
func TestStreamsMap0RTTRejection(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockSender := NewMockStreamSender(mockCtrl)
m := newStreamsMap(
context.Background(),
mockSender,
func(wire.Frame) {},
func(protocol.StreamID) flowcontrol.StreamFlowController {
return mocks.NewMockStreamFlowController(mockCtrl)
},
1,
1,
protocol.PerspectiveClient,
)
m.ResetFor0RTT()
_, err := m.OpenStream()
require.ErrorIs(t, err, Err0RTTRejected)
_, err = m.OpenUniStream()
require.ErrorIs(t, err, Err0RTTRejected)
_, err = m.AcceptStream(context.Background())
require.ErrorIs(t, err, Err0RTTRejected)
// make sure that we can still get new streams, as the server might be sending us data
str, err := m.GetOrOpenReceiveStream(3)
require.NoError(t, err)
require.NotNil(t, str)
// now switch to using the new streams map
m.UseResetMaps()
_, err = m.OpenStream()
require.Error(t, err)
require.ErrorIs(t, err, &StreamLimitReachedError{})
}