Files
quic-go/streams_map_outgoing_test.go
2025-11-14 04:04:40 +03:00

567 lines
15 KiB
Go

package quic
import (
"context"
"errors"
"fmt"
"math/rand/v2"
"testing"
"time"
"git.geeks-team.ru/gr1ffon/quic-go/internal/protocol"
"git.geeks-team.ru/gr1ffon/quic-go/internal/qerr"
"git.geeks-team.ru/gr1ffon/quic-go/internal/synctest"
"git.geeks-team.ru/gr1ffon/quic-go/internal/wire"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStreamsMapOutgoingOpenAndDelete(t *testing.T) {
t.Run("client", func(t *testing.T) {
testStreamsMapOutgoingOpenAndDelete(t, protocol.PerspectiveClient, protocol.FirstOutgoingBidiStreamClient)
})
t.Run("server", func(t *testing.T) {
testStreamsMapOutgoingOpenAndDelete(t, protocol.PerspectiveServer, protocol.FirstOutgoingBidiStreamServer)
})
}
func testStreamsMapOutgoingOpenAndDelete(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) {
m := newOutgoingStreamsMap(
protocol.StreamTypeBidi,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
func(f wire.Frame) {},
perspective,
)
m.SetMaxStream(protocol.MaxStreamID)
_, err := m.GetStream(firstStream)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstStream))
str1, err := m.OpenStream()
require.NoError(t, err)
require.Equal(t, firstStream, str1.id)
s, err := m.GetStream(firstStream)
require.NoError(t, err)
require.Equal(t, s, str1)
str2, err := m.OpenStream()
require.NoError(t, err)
require.Equal(t, firstStream+4, str2.id)
// update send window
m.UpdateSendWindow(1000)
require.Equal(t, protocol.ByteCount(1000), str1.sendWindow)
require.Equal(t, protocol.ByteCount(1000), str2.sendWindow)
// enable reset stream at
m.EnableResetStreamAt()
require.True(t, str1.supportsResetStreamAt)
require.True(t, str2.supportsResetStreamAt)
err = m.DeleteStream(firstStream + 1337*4)
require.Error(t, err)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, "tried to delete unknown outgoing stream")
require.NoError(t, m.DeleteStream(firstStream))
// deleting the same stream twice will fail
err = m.DeleteStream(firstStream)
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
require.ErrorContains(t, err, "tried to delete unknown outgoing stream")
// after deleting the stream it's not available anymore
str, err := m.GetStream(firstStream)
require.NoError(t, err)
require.Nil(t, str)
}
func TestStreamsMapOutgoingLimits(t *testing.T) {
t.Run("client", func(t *testing.T) {
testStreamsMapOutgoingLimits(t, protocol.PerspectiveClient, protocol.FirstOutgoingUniStreamClient)
})
t.Run("server", func(t *testing.T) {
testStreamsMapOutgoingLimits(t, protocol.PerspectiveServer, protocol.FirstOutgoingUniStreamServer)
})
}
func testStreamsMapOutgoingLimits(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) {
synctest.Test(t, func(t *testing.T) {
m := newOutgoingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
func(f wire.Frame) {},
perspective,
)
m.SetMaxStream(firstStream)
str, err := m.OpenStream()
require.NoError(t, err)
require.Equal(t, firstStream, str.id)
// We've now reached the limit. OpenStream returns an error
_, err = m.OpenStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
// OpenStreamSync with a canceled context will return an error immediately
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = m.OpenStreamSync(ctx)
require.ErrorIs(t, err, context.Canceled)
// OpenStreamSync blocks until the context is canceled...
ctx, cancel = context.WithCancel(context.Background())
errChan := make(chan error, 1)
go func() {
_, err := m.OpenStreamSync(ctx)
errChan <- err
}()
synctest.Wait()
select {
case <-errChan:
t.Fatal("didn't expect OpenStreamSync to return")
default:
}
// OpenStream still returns an error
_, err = m.OpenStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
// cancelling the context unblocks OpenStreamSync
cancel()
synctest.Wait()
select {
case err := <-errChan:
require.ErrorIs(t, err, context.Canceled)
default:
t.Fatal("OpenStreamSync did not return after the context was canceled")
}
// ... or until it's possible to open a new stream
var openedStream *mockStream
go func() {
str, err := m.OpenStreamSync(context.Background())
openedStream = str
errChan <- err
}()
m.SetMaxStream(firstStream + 4)
synctest.Wait()
select {
case err := <-errChan:
require.NoError(t, err)
require.Equal(t, firstStream+4, openedStream.id)
default:
t.Fatal("OpenStreamSync did not return after the stream limit was increased")
}
})
}
func TestStreamsMapOutgoingConcurrentOpenStreamSync(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
m := newOutgoingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
func(f wire.Frame) {},
protocol.PerspectiveClient,
)
type result struct {
index int
stream *mockStream
err error
}
results := make(chan result, 3)
for i := range 3 {
go func(i int) {
str, err := m.OpenStreamSync(context.Background())
results <- result{index: i, stream: str, err: err}
}(i)
time.Sleep(time.Minute)
}
m.SetMaxStream(protocol.FirstOutgoingUniStreamClient + 4)
synctest.Wait()
received := make(map[protocol.StreamID]struct{})
for range 2 {
select {
case res := <-results:
require.NoError(t, res.err)
require.Equal(t, protocol.FirstOutgoingUniStreamClient+4*protocol.StreamID(res.index), res.stream.id)
received[res.stream.id] = struct{}{}
default:
t.Fatal("OpenStreamSync did not return after the stream limit was increased")
}
}
require.Contains(t, received, protocol.FirstOutgoingUniStreamClient)
require.Contains(t, received, protocol.FirstOutgoingUniStreamClient+4)
// the call to stream 3 is still blocked
select {
case <-results:
t.Fatal("expected OpenStreamSync to be blocked")
default:
}
m.SetMaxStream(protocol.FirstOutgoingUniStreamClient + 8)
synctest.Wait()
select {
case res := <-results:
require.NoError(t, res.err)
require.Equal(t, protocol.FirstOutgoingUniStreamClient+8, res.stream.id)
default:
t.Fatal("OpenStreamSync did not return after the stream limit was increased")
}
})
}
func TestStreamsMapOutgoingClosing(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
m := newOutgoingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
func(f wire.Frame) {},
protocol.PerspectiveServer,
)
m.SetMaxStream(protocol.FirstOutgoingUniStreamServer + 4)
str1, err := m.OpenStream()
require.NoError(t, err)
str2, err := m.OpenStream()
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
_, err := m.OpenStreamSync(context.Background())
errChan <- err
}()
m.CloseWithError(assert.AnError)
synctest.Wait()
// both stream should be closed
assert.True(t, str1.closed)
assert.Equal(t, assert.AnError, str1.closeErr)
assert.True(t, str2.closed)
assert.Equal(t, assert.AnError, str2.closeErr)
select {
case err := <-errChan:
require.Error(t, err)
default:
t.Fatal("OpenStreamSync did not return after the stream was closed")
}
})
}
func TestStreamsMapOutgoingBlockedFrames(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
var frameQueue []wire.Frame
m := newOutgoingStreamsMap(
protocol.StreamTypeBidi,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
protocol.PerspectiveClient,
)
m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 8)
for range 3 {
_, err := m.OpenStream()
require.NoError(t, err)
}
require.Empty(t, frameQueue)
_, err := m.OpenStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
require.Equal(t, []wire.Frame{
&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 3},
}, frameQueue)
frameQueue = frameQueue[:0]
// only a single STREAMS_BLOCKED frame is queued per offset
for range 5 {
_, err = m.OpenStream()
require.ErrorIs(t, err, &StreamLimitReachedError{})
require.Empty(t, frameQueue)
}
errChan := make(chan error, 3)
for range 3 {
go func() {
_, err := m.OpenStreamSync(context.Background())
errChan <- err
}()
}
synctest.Wait()
// allow 2 more streams
m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 16)
synctest.Wait()
for range 2 {
select {
case err := <-errChan:
require.NoError(t, err)
default:
t.Fatal("OpenStreamSync did not return after the stream limit was increased")
}
}
require.Equal(t, []wire.Frame{
&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 5},
}, frameQueue)
frameQueue = frameQueue[:0]
// now accept the last stream
m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 20)
synctest.Wait()
select {
case err := <-errChan:
require.NoError(t, err)
default:
t.Fatal("OpenStreamSync did not return after the stream limit was increased")
}
require.Empty(t, frameQueue)
})
}
func TestStreamsMapOutgoingRandomizedOpenStreamSync(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)]
firstStream := protocol.FirstOutgoingUniStreamServer
if streamType == protocol.StreamTypeBidi {
firstStream = protocol.FirstOutgoingBidiStreamServer
}
const n = 100
frameQueue := make(chan wire.Frame, n)
m := newOutgoingStreamsMap(
streamType,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
func(f wire.Frame) { frameQueue <- f },
protocol.PerspectiveServer,
)
type result struct {
id protocol.StreamID
err error
}
resultChan := make(chan result, n)
for range n {
go func() {
str, err := m.OpenStreamSync(context.Background())
resultChan <- result{id: str.id, err: err}
}()
}
synctest.Wait()
select {
case f := <-frameQueue:
require.IsType(t, &wire.StreamsBlockedFrame{}, f)
require.Zero(t, f.(*wire.StreamsBlockedFrame).StreamLimit)
default:
t.Fatal("timed out waiting for STREAMS_BLOCKED frame")
}
limit := firstStream - 4
var limits []protocol.StreamID
seen := make(map[protocol.StreamID]struct{})
maxStream := firstStream + 4*(n-1)
for limit < maxStream {
add := 4 * protocol.StreamID(rand.IntN(n/5)+1)
limit += add
if limit <= maxStream {
limits = append(limits, limit)
}
t.Logf("setting stream limit to %d", limit)
m.SetMaxStream(limit)
synctest.Wait()
loop:
for {
select {
case res := <-resultChan:
require.NoError(t, res.err)
require.NotContains(t, seen, res.id)
require.LessOrEqual(t, res.id, limit)
seen[res.id] = struct{}{}
if len(seen) == int(limit.StreamNum()) || len(seen) == n {
break loop
}
default:
t.Fatalf("timed out waiting for stream to open")
}
}
str, err := m.OpenStream()
if limit <= maxStream {
require.ErrorIs(t, err, &StreamLimitReachedError{})
} else {
require.NoError(t, err)
require.Equal(t, maxStream+4, str.id)
}
}
require.Len(t, seen, n)
close(frameQueue)
var blockedAt []protocol.StreamID
for f := range frameQueue {
if l := f.(*wire.StreamsBlockedFrame).StreamLimit; l <= n {
blockedAt = append(blockedAt, l.StreamID(streamType, protocol.PerspectiveServer))
}
}
require.Equal(t, limits, blockedAt)
})
}
func TestStreamsMapOutgoingRandomizedWithCancellation(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
const n = 100
streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)]
firstStream := protocol.FirstOutgoingUniStreamClient
if streamType == protocol.StreamTypeBidi {
firstStream = protocol.FirstOutgoingBidiStreamClient
}
frameQueue := make(chan wire.Frame, n)
m := newOutgoingStreamsMap(
streamType,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
func(f wire.Frame) { frameQueue <- f },
protocol.PerspectiveClient,
)
type result struct {
str *mockStream
err error
}
ctx, cancel := context.WithCancel(context.Background())
resultChan := make(chan result, 10*n)
var count int
var numCancelled int
for count < n {
shouldCancel := rand.IntN(n)%5 == 0
if shouldCancel {
numCancelled++
} else {
count++
}
go func() {
var str *mockStream
var err error
if shouldCancel {
str, err = m.OpenStreamSync(ctx)
} else {
str, err = m.OpenStreamSync(context.Background())
}
resultChan <- result{str: str, err: err}
}()
}
synctest.Wait()
select {
case f := <-frameQueue:
require.IsType(t, &wire.StreamsBlockedFrame{}, f)
require.Zero(t, f.(*wire.StreamsBlockedFrame).StreamLimit)
default:
t.Fatal("timed out waiting for STREAMS_BLOCKED frame")
}
synctest.Wait()
cancel()
limit := firstStream - 4
maxStream := firstStream + 4*(n-1)
var limits []protocol.StreamID
seen := make(map[protocol.StreamID]struct{})
var lastStreamSeen protocol.StreamID
var numCancelledSeen int
for limit < maxStream {
add := 4 * protocol.StreamID(rand.IntN(n/5)+1)
limit += add
if limit < maxStream {
limits = append(limits, limit)
}
t.Logf("setting stream limit to %d", limit)
m.SetMaxStream(limit)
for lastStreamSeen < min(maxStream, limit) {
select {
case res := <-resultChan:
if errors.Is(res.err, context.Canceled) {
numCancelledSeen++
} else {
require.NoError(t, res.err)
require.NotContains(t, seen, res.str.id)
seen[res.str.id] = struct{}{}
lastStreamSeen = res.str.id
}
case <-time.After(time.Second):
t.Fatalf("timed out waiting for stream to open")
}
}
}
require.Len(t, seen, n)
t.Logf("saw %d streams, %d cancelled", len(seen), numCancelledSeen)
require.Equal(t, numCancelled, numCancelledSeen)
close(frameQueue)
var blockedAt []protocol.StreamID
for f := range frameQueue {
sbf := f.(*wire.StreamsBlockedFrame)
require.Equal(t, streamType, sbf.Type)
blockedAt = append(blockedAt, sbf.StreamLimit.StreamID(streamType, protocol.PerspectiveClient))
}
require.Equal(t, limits, blockedAt)
})
}
func TestStreamsMapConcurrent(t *testing.T) {
for i := range 5 {
t.Run(fmt.Sprintf("iteration %d", i+1), func(t *testing.T) {
testStreamsMapConcurrent(t)
})
}
}
func testStreamsMapConcurrent(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
m := newOutgoingStreamsMap(
protocol.StreamTypeBidi,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
func(f wire.Frame) {},
protocol.PerspectiveClient,
)
const num = 100
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errChan := make(chan error, num)
for range num {
go func() {
_, err := m.OpenStreamSync(ctx)
errChan <- err
}()
}
go m.CloseWithError(assert.AnError)
go cancel()
go m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 4*num/2)
synctest.Wait()
for range num {
select {
case err := <-errChan:
if err != nil {
require.True(t, errors.Is(err, assert.AnError) || errors.Is(err, context.Canceled))
}
default:
t.Fatal("OpenStreamSync should have returned")
}
}
})
}