forked from quic-go/quic-go
416 lines
12 KiB
Go
416 lines
12 KiB
Go
package quic
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go/internal/flowcontrol"
|
|
"github.com/quic-go/quic-go/internal/mocks"
|
|
"github.com/quic-go/quic-go/internal/protocol"
|
|
"github.com/quic-go/quic-go/internal/qerr"
|
|
"github.com/quic-go/quic-go/internal/wire"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
const (
|
|
firstIncomingBidiStreamServer protocol.StreamID = 0
|
|
firstOutgoingBidiStreamServer protocol.StreamID = 1
|
|
firstIncomingUniStreamServer protocol.StreamID = 2
|
|
firstOutgoingUniStreamServer protocol.StreamID = 3
|
|
)
|
|
|
|
const (
|
|
firstIncomingBidiStreamClient protocol.StreamID = 1
|
|
firstOutgoingBidiStreamClient protocol.StreamID = 0
|
|
firstIncomingUniStreamClient protocol.StreamID = 3
|
|
firstOutgoingUniStreamClient protocol.StreamID = 2
|
|
)
|
|
|
|
func (e streamError) TestError() error {
|
|
nums := make([]interface{}, len(e.nums))
|
|
for i, num := range e.nums {
|
|
nums[i] = num
|
|
}
|
|
return fmt.Errorf(e.message, nums...)
|
|
}
|
|
|
|
func TestStreamsMapCreatingStreams(t *testing.T) {
|
|
t.Run("client", func(t *testing.T) {
|
|
testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveClient,
|
|
firstIncomingBidiStreamClient,
|
|
firstOutgoingBidiStreamClient,
|
|
firstIncomingUniStreamClient,
|
|
firstOutgoingUniStreamClient,
|
|
)
|
|
})
|
|
t.Run("server", func(t *testing.T) {
|
|
testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveServer,
|
|
firstIncomingBidiStreamServer,
|
|
firstOutgoingBidiStreamServer,
|
|
firstIncomingUniStreamServer,
|
|
firstOutgoingUniStreamServer,
|
|
)
|
|
})
|
|
}
|
|
|
|
func testStreamsMapCreatingAndDeletingStreams(t *testing.T,
|
|
perspective protocol.Perspective,
|
|
firstIncomingBidiStream protocol.StreamID,
|
|
firstOutgoingBidiStream protocol.StreamID,
|
|
firstIncomingUniStream protocol.StreamID,
|
|
firstOutgoingUniStream protocol.StreamID,
|
|
) {
|
|
mockCtrl := gomock.NewController(t)
|
|
mockSender := NewMockStreamSender(mockCtrl)
|
|
m := newStreamsMap(
|
|
context.Background(),
|
|
mockSender,
|
|
func(wire.Frame) {},
|
|
func(protocol.StreamID) flowcontrol.StreamFlowController {
|
|
return mocks.NewMockStreamFlowController(mockCtrl)
|
|
},
|
|
1,
|
|
1,
|
|
perspective,
|
|
)
|
|
m.UpdateLimits(&wire.TransportParameters{
|
|
MaxBidiStreamNum: protocol.MaxStreamCount,
|
|
MaxUniStreamNum: protocol.MaxStreamCount,
|
|
})
|
|
|
|
// opening streams
|
|
str1, err := m.OpenStream()
|
|
require.NoError(t, err)
|
|
str2, err := m.OpenStream()
|
|
require.NoError(t, err)
|
|
ustr1, err := m.OpenUniStream()
|
|
require.NoError(t, err)
|
|
ustr2, err := m.OpenUniStream()
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, str1.StreamID(), firstOutgoingBidiStream)
|
|
assert.Equal(t, str2.StreamID(), firstOutgoingBidiStream+4)
|
|
assert.Equal(t, ustr1.StreamID(), firstOutgoingUniStream)
|
|
assert.Equal(t, ustr2.StreamID(), firstOutgoingUniStream+4)
|
|
|
|
// accepting streams:
|
|
// This function is called when a frame referencing this stream is received.
|
|
// The peer may open a peer-initiated stream...
|
|
_, err = m.GetOrOpenReceiveStream(firstIncomingBidiStream)
|
|
require.NoError(t, err)
|
|
_, err = m.GetOrOpenReceiveStream(firstIncomingUniStream)
|
|
require.NoError(t, err)
|
|
|
|
// ... but not a stream that is initiated by us.
|
|
_, err = m.GetOrOpenSendStream(firstOutgoingBidiStream + 8)
|
|
require.ErrorIs(t, err, &qerr.TransportError{
|
|
ErrorCode: qerr.StreamStateError,
|
|
ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream+8),
|
|
})
|
|
_, err = m.GetOrOpenSendStream(firstOutgoingUniStream + 8)
|
|
require.ErrorIs(t, err, &qerr.TransportError{
|
|
ErrorCode: qerr.StreamStateError,
|
|
ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", firstOutgoingUniStream+8),
|
|
})
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
str, err := m.AcceptStream(ctx)
|
|
require.NoError(t, err)
|
|
ustr, err := m.AcceptUniStream(ctx)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, str.StreamID(), firstIncomingBidiStream)
|
|
assert.Equal(t, ustr.StreamID(), firstIncomingUniStream)
|
|
}
|
|
|
|
func TestStreamsMapDeletingStreams(t *testing.T) {
|
|
t.Run("client", func(t *testing.T) {
|
|
testStreamsMapDeletingStreams(t, protocol.PerspectiveClient,
|
|
firstIncomingBidiStreamClient,
|
|
firstOutgoingBidiStreamClient,
|
|
firstIncomingUniStreamClient,
|
|
firstOutgoingUniStreamClient,
|
|
)
|
|
})
|
|
t.Run("server", func(t *testing.T) {
|
|
testStreamsMapDeletingStreams(t, protocol.PerspectiveServer,
|
|
firstIncomingBidiStreamServer,
|
|
firstOutgoingBidiStreamServer,
|
|
firstIncomingUniStreamServer,
|
|
firstOutgoingUniStreamServer,
|
|
)
|
|
})
|
|
}
|
|
|
|
func testStreamsMapDeletingStreams(t *testing.T,
|
|
perspective protocol.Perspective,
|
|
firstIncomingBidiStream protocol.StreamID,
|
|
firstOutgoingBidiStream protocol.StreamID,
|
|
firstIncomingUniStream protocol.StreamID,
|
|
firstOutgoingUniStream protocol.StreamID,
|
|
) {
|
|
mockCtrl := gomock.NewController(t)
|
|
mockSender := NewMockStreamSender(mockCtrl)
|
|
var frameQueue []wire.Frame
|
|
m := newStreamsMap(
|
|
context.Background(),
|
|
mockSender,
|
|
func(frame wire.Frame) { frameQueue = append(frameQueue, frame) },
|
|
func(protocol.StreamID) flowcontrol.StreamFlowController {
|
|
return mocks.NewMockStreamFlowController(mockCtrl)
|
|
},
|
|
100,
|
|
100,
|
|
perspective,
|
|
)
|
|
m.UpdateLimits(&wire.TransportParameters{
|
|
MaxBidiStreamNum: 10,
|
|
MaxUniStreamNum: 10,
|
|
})
|
|
|
|
_, err := m.OpenStream()
|
|
require.NoError(t, err)
|
|
require.NoError(t, m.DeleteStream(firstOutgoingBidiStream))
|
|
sstr, err := m.GetOrOpenSendStream(firstOutgoingBidiStream)
|
|
require.NoError(t, err)
|
|
require.Nil(t, sstr)
|
|
require.ErrorContains(t,
|
|
m.DeleteStream(firstOutgoingBidiStream+400),
|
|
fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingBidiStream+400),
|
|
)
|
|
|
|
_, err = m.OpenUniStream()
|
|
require.NoError(t, err)
|
|
require.NoError(t, m.DeleteStream(firstOutgoingUniStream))
|
|
sstr, err = m.GetOrOpenSendStream(firstOutgoingUniStream)
|
|
require.NoError(t, err)
|
|
require.Nil(t, sstr)
|
|
require.ErrorContains(t,
|
|
m.DeleteStream(firstOutgoingUniStream+400),
|
|
fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingUniStream+400),
|
|
)
|
|
|
|
require.Empty(t, frameQueue)
|
|
// deleting incoming bidirectional streams
|
|
_, err = m.GetOrOpenReceiveStream(firstIncomingBidiStream)
|
|
require.NoError(t, err)
|
|
require.NoError(t, m.DeleteStream(firstIncomingBidiStream))
|
|
sstr, err = m.GetOrOpenSendStream(firstIncomingBidiStream)
|
|
require.NoError(t, err)
|
|
require.Nil(t, sstr)
|
|
require.ErrorContains(t,
|
|
m.DeleteStream(firstIncomingBidiStream+400),
|
|
fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingBidiStream+400),
|
|
)
|
|
// the MAX_STREAMS frame is only queued once the stream is accepted
|
|
require.Empty(t, frameQueue)
|
|
_, err = m.AcceptStream(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, frameQueue, []wire.Frame{
|
|
&wire.MaxStreamsFrame{
|
|
Type: protocol.StreamTypeBidi,
|
|
MaxStreamNum: 101,
|
|
},
|
|
})
|
|
frameQueue = frameQueue[:0]
|
|
|
|
// deleting incoming unidirectional streams
|
|
_, err = m.GetOrOpenReceiveStream(firstIncomingUniStream)
|
|
require.NoError(t, err)
|
|
require.NoError(t, m.DeleteStream(firstIncomingUniStream))
|
|
rstr, err := m.GetOrOpenReceiveStream(firstIncomingUniStream)
|
|
require.NoError(t, err)
|
|
require.Nil(t, rstr)
|
|
require.ErrorContains(t,
|
|
m.DeleteStream(firstIncomingUniStream+400),
|
|
fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingUniStream+400),
|
|
)
|
|
// the MAX_STREAMS frame is only queued once the stream is accepted
|
|
require.Empty(t, frameQueue)
|
|
_, err = m.AcceptUniStream(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, frameQueue, []wire.Frame{
|
|
&wire.MaxStreamsFrame{
|
|
Type: protocol.StreamTypeUni,
|
|
MaxStreamNum: 101,
|
|
},
|
|
})
|
|
frameQueue = frameQueue[:0]
|
|
}
|
|
|
|
func TestStreamsMapStreamLimits(t *testing.T) {
|
|
t.Run("client", func(t *testing.T) {
|
|
testStreamsMapStreamLimits(t, protocol.PerspectiveClient)
|
|
})
|
|
t.Run("server", func(t *testing.T) {
|
|
testStreamsMapStreamLimits(t, protocol.PerspectiveServer)
|
|
})
|
|
}
|
|
|
|
func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective) {
|
|
mockCtrl := gomock.NewController(t)
|
|
mockSender := NewMockStreamSender(mockCtrl)
|
|
var frameQueue []wire.Frame
|
|
m := newStreamsMap(
|
|
context.Background(),
|
|
mockSender,
|
|
func(frame wire.Frame) { frameQueue = append(frameQueue, frame) },
|
|
func(protocol.StreamID) flowcontrol.StreamFlowController {
|
|
fc := mocks.NewMockStreamFlowController(mockCtrl)
|
|
fc.EXPECT().UpdateSendWindow(gomock.Any()).AnyTimes()
|
|
return fc
|
|
},
|
|
100,
|
|
100,
|
|
perspective,
|
|
)
|
|
|
|
// increase via transport parameters
|
|
_, err := m.OpenStream()
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
m.UpdateLimits(&wire.TransportParameters{MaxBidiStreamNum: 1})
|
|
_, err = m.OpenStream()
|
|
require.NoError(t, err)
|
|
_, err = m.OpenStream()
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
|
|
_, err = m.OpenUniStream()
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
m.UpdateLimits(&wire.TransportParameters{MaxUniStreamNum: 1})
|
|
_, err = m.OpenUniStream()
|
|
require.NoError(t, err)
|
|
_, err = m.OpenUniStream()
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
|
|
// increase via MAX_STREAMS frames
|
|
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
|
|
Type: protocol.StreamTypeBidi,
|
|
MaxStreamNum: 2,
|
|
})
|
|
_, err = m.OpenStream()
|
|
require.NoError(t, err)
|
|
_, err = m.OpenStream()
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
|
|
m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
|
|
Type: protocol.StreamTypeUni,
|
|
MaxStreamNum: 2,
|
|
})
|
|
_, err = m.OpenUniStream()
|
|
require.NoError(t, err)
|
|
_, err = m.OpenUniStream()
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
|
|
// decrease via transport parameters
|
|
m.UpdateLimits(&wire.TransportParameters{MaxBidiStreamNum: 0})
|
|
_, err = m.OpenStream()
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
}
|
|
|
|
func TestStreamsMapClosing(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
mockSender := NewMockStreamSender(mockCtrl)
|
|
m := newStreamsMap(
|
|
context.Background(),
|
|
mockSender,
|
|
func(wire.Frame) {},
|
|
func(protocol.StreamID) flowcontrol.StreamFlowController {
|
|
return mocks.NewMockStreamFlowController(mockCtrl)
|
|
},
|
|
1,
|
|
1,
|
|
protocol.PerspectiveClient,
|
|
)
|
|
m.CloseWithError(assert.AnError)
|
|
_, err := m.OpenStream()
|
|
require.ErrorIs(t, err, assert.AnError)
|
|
_, err = m.OpenUniStream()
|
|
require.ErrorIs(t, err, assert.AnError)
|
|
_, err = m.AcceptStream(context.Background())
|
|
require.ErrorIs(t, err, assert.AnError)
|
|
_, err = m.AcceptUniStream(context.Background())
|
|
require.ErrorIs(t, err, assert.AnError)
|
|
}
|
|
|
|
func TestStreamsMap0RTT(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
mockSender := NewMockStreamSender(mockCtrl)
|
|
fcBidi := mocks.NewMockStreamFlowController(mockCtrl)
|
|
fcUni := mocks.NewMockStreamFlowController(mockCtrl)
|
|
fcs := []flowcontrol.StreamFlowController{fcBidi, fcUni}
|
|
m := newStreamsMap(
|
|
context.Background(),
|
|
mockSender,
|
|
func(wire.Frame) {},
|
|
func(protocol.StreamID) flowcontrol.StreamFlowController {
|
|
fc := fcs[0]
|
|
fcs = fcs[1:]
|
|
return fc
|
|
},
|
|
1,
|
|
1,
|
|
protocol.PerspectiveClient,
|
|
)
|
|
// restored transport parameters
|
|
m.UpdateLimits(&wire.TransportParameters{
|
|
MaxBidiStreamNum: 1,
|
|
MaxUniStreamNum: 1,
|
|
})
|
|
_, err := m.OpenStream()
|
|
require.NoError(t, err)
|
|
_, err = m.OpenUniStream()
|
|
require.NoError(t, err)
|
|
|
|
fcBidi.EXPECT().UpdateSendWindow(protocol.ByteCount(1234))
|
|
fcUni.EXPECT().UpdateSendWindow(protocol.ByteCount(4321))
|
|
// new transport parameters
|
|
m.UpdateLimits(&wire.TransportParameters{
|
|
MaxBidiStreamNum: 1000,
|
|
InitialMaxStreamDataBidiRemote: 1234,
|
|
MaxUniStreamNum: 1000,
|
|
InitialMaxStreamDataUni: 4321,
|
|
})
|
|
}
|
|
|
|
func TestStreamsMap0RTTRejection(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
mockSender := NewMockStreamSender(mockCtrl)
|
|
m := newStreamsMap(
|
|
context.Background(),
|
|
mockSender,
|
|
func(wire.Frame) {},
|
|
func(protocol.StreamID) flowcontrol.StreamFlowController {
|
|
return mocks.NewMockStreamFlowController(mockCtrl)
|
|
},
|
|
1,
|
|
1,
|
|
protocol.PerspectiveClient,
|
|
)
|
|
|
|
m.ResetFor0RTT()
|
|
_, err := m.OpenStream()
|
|
require.ErrorIs(t, err, Err0RTTRejected)
|
|
_, err = m.OpenUniStream()
|
|
require.ErrorIs(t, err, Err0RTTRejected)
|
|
_, err = m.AcceptStream(context.Background())
|
|
require.ErrorIs(t, err, Err0RTTRejected)
|
|
// make sure that we can still get new streams, as the server might be sending us data
|
|
str, err := m.GetOrOpenReceiveStream(3)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, str)
|
|
|
|
// now switch to using the new streams map
|
|
m.UseResetMaps()
|
|
_, err = m.OpenStream()
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
}
|