Files
quic-go/streams_map_incoming_test.go
2025-11-14 04:04:40 +03:00

365 lines
10 KiB
Go

package quic
import (
"context"
"math/rand/v2"
"testing"
"time"
"git.geeks-team.ru/gr1ffon/quic-go/internal/protocol"
"git.geeks-team.ru/gr1ffon/quic-go/internal/qerr"
"git.geeks-team.ru/gr1ffon/quic-go/internal/synctest"
"git.geeks-team.ru/gr1ffon/quic-go/internal/wire"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockStream struct {
id protocol.StreamID
closed bool
closeErr error
sendWindow protocol.ByteCount
supportsResetStreamAt bool
}
func (s *mockStream) closeForShutdown(err error) {
s.closed = true
s.closeErr = err
}
func (s *mockStream) updateSendWindow(limit protocol.ByteCount) {
s.sendWindow = limit
}
func (s *mockStream) enableResetStreamAt() {
s.supportsResetStreamAt = true
}
func TestStreamsMapIncomingGettingStreams(t *testing.T) {
t.Run("client", func(t *testing.T) {
testStreamsMapIncomingGettingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient)
})
t.Run("server", func(t *testing.T) {
testStreamsMapIncomingGettingStreams(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer)
})
}
func testStreamsMapIncomingGettingStreams(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) {
var newStreamCounter int
const maxNumStreams = 10
m := newIncomingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream {
newStreamCounter++
return &mockStream{id: id}
},
maxNumStreams,
func(f wire.Frame) {},
perspective,
)
// all streams up to the id on GetOrOpenStream are opened
str, err := m.GetOrOpenStream(firstStream + 4)
require.NoError(t, err)
require.NotNil(t, str)
require.Equal(t, 2, newStreamCounter)
require.Equal(t, firstStream+4, str.id)
// accept one of the streams
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
str, err = m.AcceptStream(ctx)
require.NoError(t, err)
require.Equal(t, firstStream, str.id)
// open some more streams
str, err = m.GetOrOpenStream(firstStream + 16)
require.NoError(t, err)
require.Equal(t, 5, newStreamCounter)
require.Equal(t, firstStream+16, str.id)
// and accept all of them
for i := 1; i < 5; i++ {
str, err := m.AcceptStream(ctx)
require.NoError(t, err)
require.Equal(t, firstStream+4*protocol.StreamID(i), str.id)
}
_, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams - 4)
require.NoError(t, err)
_, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamLimitError})
require.ErrorContains(t, err, "peer tried to open stream")
require.Equal(t, maxNumStreams, newStreamCounter)
}
func TestStreamsMapIncomingAcceptingStreams(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
m := newIncomingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
5,
func(f wire.Frame) {},
protocol.PerspectiveClient,
)
// AcceptStream should respect the context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errChan := make(chan error, 1)
go func() {
_, err := m.AcceptStream(ctx)
errChan <- err
}()
synctest.Wait()
select {
case <-errChan:
t.Fatal("AcceptStream should not return")
default:
}
cancel()
synctest.Wait()
select {
case err := <-errChan:
require.Equal(t, context.Canceled, err)
default:
t.Fatal("timeout")
}
// AcceptStream should block if there are no streams available
go func() {
_, err := m.AcceptStream(context.Background())
errChan <- err
}()
synctest.Wait()
select {
case <-errChan:
t.Fatal("AcceptStream should block")
default:
}
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamClient)
require.NoError(t, err)
synctest.Wait()
select {
case err := <-errChan:
require.NoError(t, err)
default:
t.Fatal("timeout")
}
})
}
func TestStreamsMapIncomingDeletingStreams(t *testing.T) {
t.Run("client", func(t *testing.T) {
testStreamsMapIncomingDeletingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient)
})
t.Run("server", func(t *testing.T) {
testStreamsMapIncomingDeletingStreams(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer)
})
}
func testStreamsMapIncomingDeletingStreams(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) {
var frameQueue []wire.Frame
m := newIncomingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
5,
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
perspective,
)
err := m.DeleteStream(firstStream + 1337*4)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, "tried to delete unknown incoming stream")
s, err := m.GetOrOpenStream(firstStream + 4)
require.NoError(t, err)
require.NotNil(t, s)
// delete the stream
require.NoError(t, m.DeleteStream(firstStream+4))
require.Empty(t, frameQueue)
// it's not returned by GetOrOpenStream anymore
s, err = m.GetOrOpenStream(firstStream + 4)
require.NoError(t, err)
require.Nil(t, s)
// AcceptStream still returns this stream
str, err := m.AcceptStream(context.Background())
require.NoError(t, err)
require.Equal(t, firstStream, str.id)
require.Empty(t, frameQueue)
str, err = m.AcceptStream(context.Background())
require.NoError(t, err)
require.Equal(t, firstStream+4, str.id)
// now the stream is deleted and new stream credit is issued
require.Len(t, frameQueue, 1)
require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 6}, frameQueue[0])
frameQueue = frameQueue[:0]
require.NoError(t, m.DeleteStream(firstStream))
require.Len(t, frameQueue, 1)
require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 7}, frameQueue[0])
}
// There's a maximum number that can be encoded in a MAX_STREAMS frame.
// Since the stream limit is configurable by the user, we can't rely on this number
// being high enough that it will never be reached in practice.
func TestStreamsMapIncomingDeletingStreamsWithHighLimits(t *testing.T) {
t.Run("client", func(t *testing.T) {
testStreamsMapIncomingDeletingStreamsWithHighLimits(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient)
})
t.Run("server", func(t *testing.T) {
testStreamsMapIncomingDeletingStreamsWithHighLimits(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer)
})
}
func testStreamsMapIncomingDeletingStreamsWithHighLimits(t *testing.T, pers protocol.Perspective, firstStream protocol.StreamID) {
var frameQueue []wire.Frame
m := newIncomingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
uint64(protocol.MaxStreamCount-2),
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
pers,
)
// open a bunch of streams
_, err := m.GetOrOpenStream(firstStream + 16)
require.NoError(t, err)
// accept all streams
for range 5 {
_, err := m.AcceptStream(context.Background())
require.NoError(t, err)
}
require.Empty(t, frameQueue)
require.NoError(t, m.DeleteStream(firstStream+12))
require.Len(t, frameQueue, 1)
require.Equal(t,
&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount - 1},
frameQueue[0],
)
require.NoError(t, m.DeleteStream(firstStream+8))
require.Len(t, frameQueue, 2)
require.Equal(t,
&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount},
frameQueue[1],
)
// at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent
require.NoError(t, m.DeleteStream(firstStream+4))
require.NoError(t, m.DeleteStream(firstStream))
require.Len(t, frameQueue, 2)
}
func TestStreamsMapIncomingClosing(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
m := newIncomingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
5,
func(f wire.Frame) {},
protocol.PerspectiveServer,
)
var streams []*mockStream
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamServer + 8)
require.NoError(t, err)
for range 3 {
str, err := m.AcceptStream(context.Background())
require.NoError(t, err)
streams = append(streams, str)
}
errChan := make(chan error, 1)
go func() {
_, err := m.AcceptStream(context.Background())
errChan <- err
}()
m.CloseWithError(assert.AnError)
synctest.Wait()
// accepted streams should be closed
for _, str := range streams {
require.True(t, str.closed)
require.ErrorIs(t, str.closeErr, assert.AnError)
}
// AcceptStream should return the error
select {
case err := <-errChan:
require.ErrorIs(t, err, assert.AnError)
default:
t.Fatal("timeout")
}
})
}
func TestStreamsMapIncomingRandomized(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
const num = 1000
streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)]
firstStream := protocol.FirstIncomingUniStreamServer
if streamType == protocol.StreamTypeBidi {
firstStream = protocol.FirstIncomingBidiStreamServer
}
m := newIncomingStreamsMap(
streamType,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
num,
func(f wire.Frame) {},
protocol.PerspectiveServer,
)
ids := make([]protocol.StreamID, num)
for i := range num {
ids[i] = firstStream + 4*protocol.StreamID(i)
}
rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
errChan1 := make(chan error, 1)
go func() {
for range num {
if _, err := m.AcceptStream(context.Background()); err != nil {
errChan1 <- err
return
}
}
close(errChan1)
}()
errChan2 := make(chan error, 1)
go func() {
for i := range num {
if _, err := m.GetOrOpenStream(ids[i]); err != nil {
errChan2 <- err
return
}
}
close(errChan2)
}()
synctest.Wait()
select {
case err := <-errChan1:
require.NoError(t, err)
default:
t.Fatal("should have accepted all streams")
}
select {
case err := <-errChan2:
require.NoError(t, err)
default:
t.Fatal("should have opened all streams")
}
})
}