forked from quic-go/quic-go
567 lines
15 KiB
Go
567 lines
15 KiB
Go
package quic
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"math/rand/v2"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.geeks-team.ru/gr1ffon/quic-go/internal/protocol"
|
|
"git.geeks-team.ru/gr1ffon/quic-go/internal/qerr"
|
|
"git.geeks-team.ru/gr1ffon/quic-go/internal/synctest"
|
|
"git.geeks-team.ru/gr1ffon/quic-go/internal/wire"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestStreamsMapOutgoingOpenAndDelete(t *testing.T) {
|
|
t.Run("client", func(t *testing.T) {
|
|
testStreamsMapOutgoingOpenAndDelete(t, protocol.PerspectiveClient, protocol.FirstOutgoingBidiStreamClient)
|
|
})
|
|
t.Run("server", func(t *testing.T) {
|
|
testStreamsMapOutgoingOpenAndDelete(t, protocol.PerspectiveServer, protocol.FirstOutgoingBidiStreamServer)
|
|
})
|
|
}
|
|
|
|
func testStreamsMapOutgoingOpenAndDelete(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) {
|
|
m := newOutgoingStreamsMap(
|
|
protocol.StreamTypeBidi,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
func(f wire.Frame) {},
|
|
perspective,
|
|
)
|
|
m.SetMaxStream(protocol.MaxStreamID)
|
|
|
|
_, err := m.GetStream(firstStream)
|
|
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
|
require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstStream))
|
|
|
|
str1, err := m.OpenStream()
|
|
require.NoError(t, err)
|
|
require.Equal(t, firstStream, str1.id)
|
|
s, err := m.GetStream(firstStream)
|
|
require.NoError(t, err)
|
|
require.Equal(t, s, str1)
|
|
|
|
str2, err := m.OpenStream()
|
|
require.NoError(t, err)
|
|
require.Equal(t, firstStream+4, str2.id)
|
|
|
|
// update send window
|
|
m.UpdateSendWindow(1000)
|
|
require.Equal(t, protocol.ByteCount(1000), str1.sendWindow)
|
|
require.Equal(t, protocol.ByteCount(1000), str2.sendWindow)
|
|
|
|
// enable reset stream at
|
|
m.EnableResetStreamAt()
|
|
require.True(t, str1.supportsResetStreamAt)
|
|
require.True(t, str2.supportsResetStreamAt)
|
|
|
|
err = m.DeleteStream(firstStream + 1337*4)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
|
require.ErrorContains(t, err, "tried to delete unknown outgoing stream")
|
|
|
|
require.NoError(t, m.DeleteStream(firstStream))
|
|
// deleting the same stream twice will fail
|
|
err = m.DeleteStream(firstStream)
|
|
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
|
|
require.ErrorContains(t, err, "tried to delete unknown outgoing stream")
|
|
// after deleting the stream it's not available anymore
|
|
str, err := m.GetStream(firstStream)
|
|
require.NoError(t, err)
|
|
require.Nil(t, str)
|
|
}
|
|
|
|
func TestStreamsMapOutgoingLimits(t *testing.T) {
|
|
t.Run("client", func(t *testing.T) {
|
|
testStreamsMapOutgoingLimits(t, protocol.PerspectiveClient, protocol.FirstOutgoingUniStreamClient)
|
|
})
|
|
t.Run("server", func(t *testing.T) {
|
|
testStreamsMapOutgoingLimits(t, protocol.PerspectiveServer, protocol.FirstOutgoingUniStreamServer)
|
|
})
|
|
}
|
|
|
|
func testStreamsMapOutgoingLimits(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
m := newOutgoingStreamsMap(
|
|
protocol.StreamTypeUni,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
func(f wire.Frame) {},
|
|
perspective,
|
|
)
|
|
m.SetMaxStream(firstStream)
|
|
|
|
str, err := m.OpenStream()
|
|
require.NoError(t, err)
|
|
require.Equal(t, firstStream, str.id)
|
|
|
|
// We've now reached the limit. OpenStream returns an error
|
|
_, err = m.OpenStream()
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
|
|
// OpenStreamSync with a canceled context will return an error immediately
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
_, err = m.OpenStreamSync(ctx)
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
|
|
// OpenStreamSync blocks until the context is canceled...
|
|
ctx, cancel = context.WithCancel(context.Background())
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
_, err := m.OpenStreamSync(ctx)
|
|
errChan <- err
|
|
}()
|
|
|
|
synctest.Wait()
|
|
select {
|
|
case <-errChan:
|
|
t.Fatal("didn't expect OpenStreamSync to return")
|
|
default:
|
|
}
|
|
// OpenStream still returns an error
|
|
_, err = m.OpenStream()
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
// cancelling the context unblocks OpenStreamSync
|
|
cancel()
|
|
synctest.Wait()
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
default:
|
|
t.Fatal("OpenStreamSync did not return after the context was canceled")
|
|
}
|
|
|
|
// ... or until it's possible to open a new stream
|
|
var openedStream *mockStream
|
|
go func() {
|
|
str, err := m.OpenStreamSync(context.Background())
|
|
openedStream = str
|
|
errChan <- err
|
|
}()
|
|
m.SetMaxStream(firstStream + 4)
|
|
synctest.Wait()
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
require.Equal(t, firstStream+4, openedStream.id)
|
|
default:
|
|
t.Fatal("OpenStreamSync did not return after the stream limit was increased")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestStreamsMapOutgoingConcurrentOpenStreamSync(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
m := newOutgoingStreamsMap(
|
|
protocol.StreamTypeUni,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
func(f wire.Frame) {},
|
|
protocol.PerspectiveClient,
|
|
)
|
|
|
|
type result struct {
|
|
index int
|
|
stream *mockStream
|
|
err error
|
|
}
|
|
results := make(chan result, 3)
|
|
for i := range 3 {
|
|
go func(i int) {
|
|
str, err := m.OpenStreamSync(context.Background())
|
|
results <- result{index: i, stream: str, err: err}
|
|
}(i)
|
|
time.Sleep(time.Minute)
|
|
}
|
|
|
|
m.SetMaxStream(protocol.FirstOutgoingUniStreamClient + 4)
|
|
synctest.Wait()
|
|
received := make(map[protocol.StreamID]struct{})
|
|
for range 2 {
|
|
select {
|
|
case res := <-results:
|
|
require.NoError(t, res.err)
|
|
require.Equal(t, protocol.FirstOutgoingUniStreamClient+4*protocol.StreamID(res.index), res.stream.id)
|
|
received[res.stream.id] = struct{}{}
|
|
default:
|
|
t.Fatal("OpenStreamSync did not return after the stream limit was increased")
|
|
}
|
|
}
|
|
require.Contains(t, received, protocol.FirstOutgoingUniStreamClient)
|
|
require.Contains(t, received, protocol.FirstOutgoingUniStreamClient+4)
|
|
|
|
// the call to stream 3 is still blocked
|
|
select {
|
|
case <-results:
|
|
t.Fatal("expected OpenStreamSync to be blocked")
|
|
default:
|
|
}
|
|
m.SetMaxStream(protocol.FirstOutgoingUniStreamClient + 8)
|
|
synctest.Wait()
|
|
select {
|
|
case res := <-results:
|
|
require.NoError(t, res.err)
|
|
require.Equal(t, protocol.FirstOutgoingUniStreamClient+8, res.stream.id)
|
|
default:
|
|
t.Fatal("OpenStreamSync did not return after the stream limit was increased")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestStreamsMapOutgoingClosing(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
m := newOutgoingStreamsMap(
|
|
protocol.StreamTypeUni,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
func(f wire.Frame) {},
|
|
protocol.PerspectiveServer,
|
|
)
|
|
|
|
m.SetMaxStream(protocol.FirstOutgoingUniStreamServer + 4)
|
|
str1, err := m.OpenStream()
|
|
require.NoError(t, err)
|
|
str2, err := m.OpenStream()
|
|
require.NoError(t, err)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
_, err := m.OpenStreamSync(context.Background())
|
|
errChan <- err
|
|
}()
|
|
|
|
m.CloseWithError(assert.AnError)
|
|
|
|
synctest.Wait()
|
|
|
|
// both stream should be closed
|
|
assert.True(t, str1.closed)
|
|
assert.Equal(t, assert.AnError, str1.closeErr)
|
|
assert.True(t, str2.closed)
|
|
assert.Equal(t, assert.AnError, str2.closeErr)
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.Error(t, err)
|
|
default:
|
|
t.Fatal("OpenStreamSync did not return after the stream was closed")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestStreamsMapOutgoingBlockedFrames(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
var frameQueue []wire.Frame
|
|
m := newOutgoingStreamsMap(
|
|
protocol.StreamTypeBidi,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
|
|
protocol.PerspectiveClient,
|
|
)
|
|
|
|
m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 8)
|
|
for range 3 {
|
|
_, err := m.OpenStream()
|
|
require.NoError(t, err)
|
|
}
|
|
require.Empty(t, frameQueue)
|
|
|
|
_, err := m.OpenStream()
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
require.Equal(t, []wire.Frame{
|
|
&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 3},
|
|
}, frameQueue)
|
|
frameQueue = frameQueue[:0]
|
|
|
|
// only a single STREAMS_BLOCKED frame is queued per offset
|
|
for range 5 {
|
|
_, err = m.OpenStream()
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
require.Empty(t, frameQueue)
|
|
}
|
|
|
|
errChan := make(chan error, 3)
|
|
for range 3 {
|
|
go func() {
|
|
_, err := m.OpenStreamSync(context.Background())
|
|
errChan <- err
|
|
}()
|
|
}
|
|
synctest.Wait()
|
|
|
|
// allow 2 more streams
|
|
m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 16)
|
|
synctest.Wait()
|
|
|
|
for range 2 {
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
default:
|
|
t.Fatal("OpenStreamSync did not return after the stream limit was increased")
|
|
}
|
|
}
|
|
require.Equal(t, []wire.Frame{
|
|
&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 5},
|
|
}, frameQueue)
|
|
frameQueue = frameQueue[:0]
|
|
|
|
// now accept the last stream
|
|
m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 20)
|
|
synctest.Wait()
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
default:
|
|
t.Fatal("OpenStreamSync did not return after the stream limit was increased")
|
|
}
|
|
require.Empty(t, frameQueue)
|
|
})
|
|
}
|
|
|
|
func TestStreamsMapOutgoingRandomizedOpenStreamSync(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)]
|
|
firstStream := protocol.FirstOutgoingUniStreamServer
|
|
if streamType == protocol.StreamTypeBidi {
|
|
firstStream = protocol.FirstOutgoingBidiStreamServer
|
|
}
|
|
|
|
const n = 100
|
|
|
|
frameQueue := make(chan wire.Frame, n)
|
|
m := newOutgoingStreamsMap(
|
|
streamType,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
func(f wire.Frame) { frameQueue <- f },
|
|
protocol.PerspectiveServer,
|
|
)
|
|
|
|
type result struct {
|
|
id protocol.StreamID
|
|
err error
|
|
}
|
|
resultChan := make(chan result, n)
|
|
for range n {
|
|
go func() {
|
|
str, err := m.OpenStreamSync(context.Background())
|
|
resultChan <- result{id: str.id, err: err}
|
|
}()
|
|
}
|
|
synctest.Wait()
|
|
|
|
select {
|
|
case f := <-frameQueue:
|
|
require.IsType(t, &wire.StreamsBlockedFrame{}, f)
|
|
require.Zero(t, f.(*wire.StreamsBlockedFrame).StreamLimit)
|
|
default:
|
|
t.Fatal("timed out waiting for STREAMS_BLOCKED frame")
|
|
}
|
|
|
|
limit := firstStream - 4
|
|
var limits []protocol.StreamID
|
|
seen := make(map[protocol.StreamID]struct{})
|
|
maxStream := firstStream + 4*(n-1)
|
|
for limit < maxStream {
|
|
add := 4 * protocol.StreamID(rand.IntN(n/5)+1)
|
|
limit += add
|
|
if limit <= maxStream {
|
|
limits = append(limits, limit)
|
|
}
|
|
t.Logf("setting stream limit to %d", limit)
|
|
m.SetMaxStream(limit)
|
|
synctest.Wait()
|
|
|
|
loop:
|
|
for {
|
|
select {
|
|
case res := <-resultChan:
|
|
require.NoError(t, res.err)
|
|
require.NotContains(t, seen, res.id)
|
|
require.LessOrEqual(t, res.id, limit)
|
|
seen[res.id] = struct{}{}
|
|
if len(seen) == int(limit.StreamNum()) || len(seen) == n {
|
|
break loop
|
|
}
|
|
default:
|
|
t.Fatalf("timed out waiting for stream to open")
|
|
}
|
|
}
|
|
|
|
str, err := m.OpenStream()
|
|
if limit <= maxStream {
|
|
require.ErrorIs(t, err, &StreamLimitReachedError{})
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.Equal(t, maxStream+4, str.id)
|
|
}
|
|
}
|
|
require.Len(t, seen, n)
|
|
|
|
close(frameQueue)
|
|
var blockedAt []protocol.StreamID
|
|
for f := range frameQueue {
|
|
if l := f.(*wire.StreamsBlockedFrame).StreamLimit; l <= n {
|
|
blockedAt = append(blockedAt, l.StreamID(streamType, protocol.PerspectiveServer))
|
|
}
|
|
}
|
|
require.Equal(t, limits, blockedAt)
|
|
})
|
|
}
|
|
|
|
func TestStreamsMapOutgoingRandomizedWithCancellation(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
const n = 100
|
|
|
|
streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)]
|
|
firstStream := protocol.FirstOutgoingUniStreamClient
|
|
if streamType == protocol.StreamTypeBidi {
|
|
firstStream = protocol.FirstOutgoingBidiStreamClient
|
|
}
|
|
|
|
frameQueue := make(chan wire.Frame, n)
|
|
m := newOutgoingStreamsMap(
|
|
streamType,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
func(f wire.Frame) { frameQueue <- f },
|
|
protocol.PerspectiveClient,
|
|
)
|
|
|
|
type result struct {
|
|
str *mockStream
|
|
err error
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
resultChan := make(chan result, 10*n)
|
|
var count int
|
|
var numCancelled int
|
|
for count < n {
|
|
shouldCancel := rand.IntN(n)%5 == 0
|
|
if shouldCancel {
|
|
numCancelled++
|
|
} else {
|
|
count++
|
|
}
|
|
go func() {
|
|
var str *mockStream
|
|
var err error
|
|
if shouldCancel {
|
|
str, err = m.OpenStreamSync(ctx)
|
|
} else {
|
|
str, err = m.OpenStreamSync(context.Background())
|
|
}
|
|
resultChan <- result{str: str, err: err}
|
|
}()
|
|
}
|
|
|
|
synctest.Wait()
|
|
|
|
select {
|
|
case f := <-frameQueue:
|
|
require.IsType(t, &wire.StreamsBlockedFrame{}, f)
|
|
require.Zero(t, f.(*wire.StreamsBlockedFrame).StreamLimit)
|
|
default:
|
|
t.Fatal("timed out waiting for STREAMS_BLOCKED frame")
|
|
}
|
|
|
|
synctest.Wait()
|
|
cancel()
|
|
|
|
limit := firstStream - 4
|
|
maxStream := firstStream + 4*(n-1)
|
|
var limits []protocol.StreamID
|
|
seen := make(map[protocol.StreamID]struct{})
|
|
var lastStreamSeen protocol.StreamID
|
|
var numCancelledSeen int
|
|
for limit < maxStream {
|
|
add := 4 * protocol.StreamID(rand.IntN(n/5)+1)
|
|
limit += add
|
|
if limit < maxStream {
|
|
limits = append(limits, limit)
|
|
}
|
|
t.Logf("setting stream limit to %d", limit)
|
|
m.SetMaxStream(limit)
|
|
|
|
for lastStreamSeen < min(maxStream, limit) {
|
|
select {
|
|
case res := <-resultChan:
|
|
if errors.Is(res.err, context.Canceled) {
|
|
numCancelledSeen++
|
|
} else {
|
|
require.NoError(t, res.err)
|
|
require.NotContains(t, seen, res.str.id)
|
|
seen[res.str.id] = struct{}{}
|
|
lastStreamSeen = res.str.id
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timed out waiting for stream to open")
|
|
}
|
|
}
|
|
}
|
|
require.Len(t, seen, n)
|
|
t.Logf("saw %d streams, %d cancelled", len(seen), numCancelledSeen)
|
|
require.Equal(t, numCancelled, numCancelledSeen)
|
|
|
|
close(frameQueue)
|
|
var blockedAt []protocol.StreamID
|
|
for f := range frameQueue {
|
|
sbf := f.(*wire.StreamsBlockedFrame)
|
|
require.Equal(t, streamType, sbf.Type)
|
|
blockedAt = append(blockedAt, sbf.StreamLimit.StreamID(streamType, protocol.PerspectiveClient))
|
|
}
|
|
require.Equal(t, limits, blockedAt)
|
|
})
|
|
}
|
|
|
|
func TestStreamsMapConcurrent(t *testing.T) {
|
|
for i := range 5 {
|
|
t.Run(fmt.Sprintf("iteration %d", i+1), func(t *testing.T) {
|
|
testStreamsMapConcurrent(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func testStreamsMapConcurrent(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
m := newOutgoingStreamsMap(
|
|
protocol.StreamTypeBidi,
|
|
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
|
func(f wire.Frame) {},
|
|
protocol.PerspectiveClient,
|
|
)
|
|
|
|
const num = 100
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
errChan := make(chan error, num)
|
|
for range num {
|
|
go func() {
|
|
_, err := m.OpenStreamSync(ctx)
|
|
errChan <- err
|
|
}()
|
|
}
|
|
|
|
go m.CloseWithError(assert.AnError)
|
|
go cancel()
|
|
go m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 4*num/2)
|
|
|
|
synctest.Wait()
|
|
|
|
for range num {
|
|
select {
|
|
case err := <-errChan:
|
|
if err != nil {
|
|
require.True(t, errors.Is(err, assert.AnError) || errors.Is(err, context.Canceled))
|
|
}
|
|
default:
|
|
t.Fatal("OpenStreamSync should have returned")
|
|
}
|
|
}
|
|
})
|
|
}
|