forked from quic-go/quic-go
365 lines
10 KiB
Go
365 lines
10 KiB
Go
package quic
|
|
|
|
import (
|
|
"context"
|
|
"math/rand/v2"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.geeks-team.ru/gr1ffon/quic-go/internal/protocol"
|
|
"git.geeks-team.ru/gr1ffon/quic-go/internal/qerr"
|
|
"git.geeks-team.ru/gr1ffon/quic-go/internal/synctest"
|
|
"git.geeks-team.ru/gr1ffon/quic-go/internal/wire"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type mockStream struct {
|
|
id protocol.StreamID
|
|
|
|
closed bool
|
|
closeErr error
|
|
sendWindow protocol.ByteCount
|
|
supportsResetStreamAt bool
|
|
}
|
|
|
|
func (s *mockStream) closeForShutdown(err error) {
|
|
s.closed = true
|
|
s.closeErr = err
|
|
}
|
|
|
|
func (s *mockStream) updateSendWindow(limit protocol.ByteCount) {
|
|
s.sendWindow = limit
|
|
}
|
|
|
|
func (s *mockStream) enableResetStreamAt() {
|
|
s.supportsResetStreamAt = true
|
|
}
|
|
|
|
func TestStreamsMapIncomingGettingStreams(t *testing.T) {
|
|
t.Run("client", func(t *testing.T) {
|
|
testStreamsMapIncomingGettingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient)
|
|
})
|
|
t.Run("server", func(t *testing.T) {
|
|
testStreamsMapIncomingGettingStreams(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer)
|
|
})
|
|
}
|
|
|
|
func testStreamsMapIncomingGettingStreams(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) {
|
|
var newStreamCounter int
|
|
const maxNumStreams = 10
|
|
m := newIncomingStreamsMap(
|
|
protocol.StreamTypeUni,
|
|
func(id protocol.StreamID) *mockStream {
|
|
newStreamCounter++
|
|
return &mockStream{id: id}
|
|
},
|
|
maxNumStreams,
|
|
func(f wire.Frame) {},
|
|
perspective,
|
|
)
|
|
|
|
// all streams up to the id on GetOrOpenStream are opened
|
|
str, err := m.GetOrOpenStream(firstStream + 4)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, str)
|
|
require.Equal(t, 2, newStreamCounter)
|
|
require.Equal(t, firstStream+4, str.id)
|
|
// accept one of the streams
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
str, err = m.AcceptStream(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, firstStream, str.id)
|
|
// open some more streams
|
|
str, err = m.GetOrOpenStream(firstStream + 16)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 5, newStreamCounter)
|
|
require.Equal(t, firstStream+16, str.id)
|
|
// and accept all of them
|
|
for i := 1; i < 5; i++ {
|
|
str, err := m.AcceptStream(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, firstStream+4*protocol.StreamID(i), str.id)
|
|
}
|
|
|
|
_, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams - 4)
|
|
require.NoError(t, err)
|
|
_, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams)
|
|
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamLimitError})
|
|
require.ErrorContains(t, err, "peer tried to open stream")
|
|
require.Equal(t, maxNumStreams, newStreamCounter)
|
|
}
|
|
|
|
func TestStreamsMapIncomingAcceptingStreams(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
m := newIncomingStreamsMap(
|
|
protocol.StreamTypeUni,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
5,
|
|
func(f wire.Frame) {},
|
|
protocol.PerspectiveClient,
|
|
)
|
|
|
|
// AcceptStream should respect the context
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
_, err := m.AcceptStream(ctx)
|
|
errChan <- err
|
|
}()
|
|
|
|
synctest.Wait()
|
|
|
|
select {
|
|
case <-errChan:
|
|
t.Fatal("AcceptStream should not return")
|
|
default:
|
|
}
|
|
|
|
cancel()
|
|
synctest.Wait()
|
|
select {
|
|
case err := <-errChan:
|
|
require.Equal(t, context.Canceled, err)
|
|
default:
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// AcceptStream should block if there are no streams available
|
|
go func() {
|
|
_, err := m.AcceptStream(context.Background())
|
|
errChan <- err
|
|
}()
|
|
|
|
synctest.Wait()
|
|
|
|
select {
|
|
case <-errChan:
|
|
t.Fatal("AcceptStream should block")
|
|
default:
|
|
}
|
|
|
|
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamClient)
|
|
require.NoError(t, err)
|
|
|
|
synctest.Wait()
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
default:
|
|
t.Fatal("timeout")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestStreamsMapIncomingDeletingStreams(t *testing.T) {
|
|
t.Run("client", func(t *testing.T) {
|
|
testStreamsMapIncomingDeletingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient)
|
|
})
|
|
t.Run("server", func(t *testing.T) {
|
|
testStreamsMapIncomingDeletingStreams(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer)
|
|
})
|
|
}
|
|
|
|
func testStreamsMapIncomingDeletingStreams(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) {
|
|
var frameQueue []wire.Frame
|
|
m := newIncomingStreamsMap(
|
|
protocol.StreamTypeUni,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
5,
|
|
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
|
|
perspective,
|
|
)
|
|
err := m.DeleteStream(firstStream + 1337*4)
|
|
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
|
require.ErrorContains(t, err, "tried to delete unknown incoming stream")
|
|
|
|
s, err := m.GetOrOpenStream(firstStream + 4)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, s)
|
|
// delete the stream
|
|
require.NoError(t, m.DeleteStream(firstStream+4))
|
|
require.Empty(t, frameQueue)
|
|
// it's not returned by GetOrOpenStream anymore
|
|
s, err = m.GetOrOpenStream(firstStream + 4)
|
|
require.NoError(t, err)
|
|
require.Nil(t, s)
|
|
|
|
// AcceptStream still returns this stream
|
|
str, err := m.AcceptStream(context.Background())
|
|
require.NoError(t, err)
|
|
require.Equal(t, firstStream, str.id)
|
|
require.Empty(t, frameQueue)
|
|
|
|
str, err = m.AcceptStream(context.Background())
|
|
require.NoError(t, err)
|
|
require.Equal(t, firstStream+4, str.id)
|
|
// now the stream is deleted and new stream credit is issued
|
|
require.Len(t, frameQueue, 1)
|
|
require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 6}, frameQueue[0])
|
|
frameQueue = frameQueue[:0]
|
|
|
|
require.NoError(t, m.DeleteStream(firstStream))
|
|
require.Len(t, frameQueue, 1)
|
|
require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 7}, frameQueue[0])
|
|
}
|
|
|
|
// There's a maximum number that can be encoded in a MAX_STREAMS frame.
|
|
// Since the stream limit is configurable by the user, we can't rely on this number
|
|
// being high enough that it will never be reached in practice.
|
|
func TestStreamsMapIncomingDeletingStreamsWithHighLimits(t *testing.T) {
|
|
t.Run("client", func(t *testing.T) {
|
|
testStreamsMapIncomingDeletingStreamsWithHighLimits(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient)
|
|
})
|
|
t.Run("server", func(t *testing.T) {
|
|
testStreamsMapIncomingDeletingStreamsWithHighLimits(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer)
|
|
})
|
|
}
|
|
|
|
func testStreamsMapIncomingDeletingStreamsWithHighLimits(t *testing.T, pers protocol.Perspective, firstStream protocol.StreamID) {
|
|
var frameQueue []wire.Frame
|
|
m := newIncomingStreamsMap(
|
|
protocol.StreamTypeUni,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
uint64(protocol.MaxStreamCount-2),
|
|
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
|
|
pers,
|
|
)
|
|
|
|
// open a bunch of streams
|
|
_, err := m.GetOrOpenStream(firstStream + 16)
|
|
require.NoError(t, err)
|
|
// accept all streams
|
|
for range 5 {
|
|
_, err := m.AcceptStream(context.Background())
|
|
require.NoError(t, err)
|
|
}
|
|
require.Empty(t, frameQueue)
|
|
require.NoError(t, m.DeleteStream(firstStream+12))
|
|
require.Len(t, frameQueue, 1)
|
|
require.Equal(t,
|
|
&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount - 1},
|
|
frameQueue[0],
|
|
)
|
|
require.NoError(t, m.DeleteStream(firstStream+8))
|
|
require.Len(t, frameQueue, 2)
|
|
require.Equal(t,
|
|
&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount},
|
|
frameQueue[1],
|
|
)
|
|
// at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent
|
|
require.NoError(t, m.DeleteStream(firstStream+4))
|
|
require.NoError(t, m.DeleteStream(firstStream))
|
|
require.Len(t, frameQueue, 2)
|
|
}
|
|
|
|
func TestStreamsMapIncomingClosing(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
m := newIncomingStreamsMap(
|
|
protocol.StreamTypeUni,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
5,
|
|
func(f wire.Frame) {},
|
|
protocol.PerspectiveServer,
|
|
)
|
|
|
|
var streams []*mockStream
|
|
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamServer + 8)
|
|
require.NoError(t, err)
|
|
for range 3 {
|
|
str, err := m.AcceptStream(context.Background())
|
|
require.NoError(t, err)
|
|
streams = append(streams, str)
|
|
}
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
_, err := m.AcceptStream(context.Background())
|
|
errChan <- err
|
|
}()
|
|
|
|
m.CloseWithError(assert.AnError)
|
|
synctest.Wait()
|
|
|
|
// accepted streams should be closed
|
|
for _, str := range streams {
|
|
require.True(t, str.closed)
|
|
require.ErrorIs(t, str.closeErr, assert.AnError)
|
|
}
|
|
// AcceptStream should return the error
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, assert.AnError)
|
|
default:
|
|
t.Fatal("timeout")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestStreamsMapIncomingRandomized(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
const num = 1000
|
|
|
|
streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)]
|
|
firstStream := protocol.FirstIncomingUniStreamServer
|
|
if streamType == protocol.StreamTypeBidi {
|
|
firstStream = protocol.FirstIncomingBidiStreamServer
|
|
}
|
|
|
|
m := newIncomingStreamsMap(
|
|
streamType,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
num,
|
|
func(f wire.Frame) {},
|
|
protocol.PerspectiveServer,
|
|
)
|
|
|
|
ids := make([]protocol.StreamID, num)
|
|
for i := range num {
|
|
ids[i] = firstStream + 4*protocol.StreamID(i)
|
|
}
|
|
rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
|
|
|
|
errChan1 := make(chan error, 1)
|
|
go func() {
|
|
for range num {
|
|
if _, err := m.AcceptStream(context.Background()); err != nil {
|
|
errChan1 <- err
|
|
return
|
|
}
|
|
}
|
|
close(errChan1)
|
|
}()
|
|
|
|
errChan2 := make(chan error, 1)
|
|
go func() {
|
|
for i := range num {
|
|
if _, err := m.GetOrOpenStream(ids[i]); err != nil {
|
|
errChan2 <- err
|
|
return
|
|
}
|
|
}
|
|
close(errChan2)
|
|
}()
|
|
|
|
synctest.Wait()
|
|
|
|
select {
|
|
case err := <-errChan1:
|
|
require.NoError(t, err)
|
|
default:
|
|
t.Fatal("should have accepted all streams")
|
|
}
|
|
select {
|
|
case err := <-errChan2:
|
|
require.NoError(t, err)
|
|
default:
|
|
t.Fatal("should have opened all streams")
|
|
}
|
|
})
|
|
}
|