forked from quic-go/quic-go
use synctest to make streams map tests fully deterministic (#5296)
* use synctest to make streams map tests fully deterministic * ignore synctest helper package for code coverage
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/synctest"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -92,51 +93,67 @@ func testStreamsMapIncomingGettingStreams(t *testing.T, perspective protocol.Per
|
||||
}
|
||||
|
||||
func TestStreamsMapIncomingAcceptingStreams(t *testing.T) {
|
||||
m := newIncomingStreamsMap(
|
||||
protocol.StreamTypeUni,
|
||||
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
||||
5,
|
||||
func(f wire.Frame) {},
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
m := newIncomingStreamsMap(
|
||||
protocol.StreamTypeUni,
|
||||
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
||||
5,
|
||||
func(f wire.Frame) {},
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
|
||||
// AcceptStream should respect the context
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(10*time.Millisecond))
|
||||
// AcceptStream should respect the context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := m.AcceptStream(ctx)
|
||||
errChan <- err
|
||||
}()
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := m.AcceptStream(ctx)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.Equal(t, context.DeadlineExceeded, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
synctest.Wait()
|
||||
|
||||
// AcceptStream should block if there are no streams available
|
||||
go func() {
|
||||
_, err := m.AcceptStream(context.Background())
|
||||
errChan <- err
|
||||
}()
|
||||
select {
|
||||
case <-errChan:
|
||||
t.Fatal("AcceptStream should not return")
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-errChan:
|
||||
t.Fatal("AcceptStream should block")
|
||||
case <-time.After(scaleDuration(10 * time.Millisecond)):
|
||||
}
|
||||
cancel()
|
||||
synctest.Wait()
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.Equal(t, context.Canceled, err)
|
||||
default:
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamClient)
|
||||
require.NoError(t, err)
|
||||
// AcceptStream should block if there are no streams available
|
||||
go func() {
|
||||
_, err := m.AcceptStream(context.Background())
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
synctest.Wait()
|
||||
|
||||
select {
|
||||
case <-errChan:
|
||||
t.Fatal("AcceptStream should block")
|
||||
default:
|
||||
}
|
||||
|
||||
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamClient)
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
synctest.Wait()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
default:
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamsMapIncomingDeletingStreams(t *testing.T) {
|
||||
@@ -241,103 +258,107 @@ func testStreamsMapIncomingDeletingStreamsWithHighLimits(t *testing.T, pers prot
|
||||
}
|
||||
|
||||
func TestStreamsMapIncomingClosing(t *testing.T) {
|
||||
m := newIncomingStreamsMap(
|
||||
protocol.StreamTypeUni,
|
||||
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
||||
5,
|
||||
func(f wire.Frame) {},
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
m := newIncomingStreamsMap(
|
||||
protocol.StreamTypeUni,
|
||||
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
||||
5,
|
||||
func(f wire.Frame) {},
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
|
||||
var streams []*mockStream
|
||||
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamServer + 8)
|
||||
require.NoError(t, err)
|
||||
for range 3 {
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
var streams []*mockStream
|
||||
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamServer + 8)
|
||||
require.NoError(t, err)
|
||||
streams = append(streams, str)
|
||||
}
|
||||
for range 3 {
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
streams = append(streams, str)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := m.AcceptStream(context.Background())
|
||||
errChan <- err
|
||||
}()
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := m.AcceptStream(context.Background())
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
m.CloseWithError(assert.AnError)
|
||||
m.CloseWithError(assert.AnError)
|
||||
synctest.Wait()
|
||||
|
||||
// accepted streams should be closed
|
||||
for _, str := range streams {
|
||||
require.True(t, str.closed)
|
||||
require.ErrorIs(t, str.closeErr, assert.AnError)
|
||||
}
|
||||
// AcceptStream should return the error
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
// accepted streams should be closed
|
||||
for _, str := range streams {
|
||||
require.True(t, str.closed)
|
||||
require.ErrorIs(t, str.closeErr, assert.AnError)
|
||||
}
|
||||
// AcceptStream should return the error
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
default:
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamsMapIncomingRandomized(t *testing.T) {
|
||||
const num = 1000
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
const num = 1000
|
||||
|
||||
streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)]
|
||||
firstStream := protocol.FirstIncomingUniStreamServer
|
||||
if streamType == protocol.StreamTypeBidi {
|
||||
firstStream = protocol.FirstIncomingBidiStreamServer
|
||||
}
|
||||
|
||||
m := newIncomingStreamsMap(
|
||||
streamType,
|
||||
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
||||
num,
|
||||
func(f wire.Frame) {},
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
|
||||
ids := make([]protocol.StreamID, num)
|
||||
for i := range num {
|
||||
ids[i] = firstStream + 4*protocol.StreamID(i)
|
||||
}
|
||||
rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
|
||||
|
||||
timeout := scaleDuration(time.Second)
|
||||
errChan1 := make(chan error, 1)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
for range num {
|
||||
if _, err := m.AcceptStream(ctx); err != nil {
|
||||
errChan1 <- err
|
||||
return
|
||||
}
|
||||
streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)]
|
||||
firstStream := protocol.FirstIncomingUniStreamServer
|
||||
if streamType == protocol.StreamTypeBidi {
|
||||
firstStream = protocol.FirstIncomingBidiStreamServer
|
||||
}
|
||||
close(errChan1)
|
||||
}()
|
||||
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
m := newIncomingStreamsMap(
|
||||
streamType,
|
||||
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
|
||||
num,
|
||||
func(f wire.Frame) {},
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
|
||||
ids := make([]protocol.StreamID, num)
|
||||
for i := range num {
|
||||
if _, err := m.GetOrOpenStream(ids[i]); err != nil {
|
||||
errChan2 <- err
|
||||
return
|
||||
}
|
||||
ids[i] = firstStream + 4*protocol.StreamID(i)
|
||||
}
|
||||
close(errChan2)
|
||||
}()
|
||||
rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
|
||||
|
||||
select {
|
||||
case err := <-errChan1:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(timeout * 3 / 2):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
select {
|
||||
case err := <-errChan2:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(timeout * 3 / 2):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
errChan1 := make(chan error, 1)
|
||||
go func() {
|
||||
for range num {
|
||||
if _, err := m.AcceptStream(context.Background()); err != nil {
|
||||
errChan1 <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
close(errChan1)
|
||||
}()
|
||||
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
for i := range num {
|
||||
if _, err := m.GetOrOpenStream(ids[i]); err != nil {
|
||||
errChan2 <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
close(errChan2)
|
||||
}()
|
||||
|
||||
synctest.Wait()
|
||||
|
||||
select {
|
||||
case err := <-errChan1:
|
||||
require.NoError(t, err)
|
||||
default:
|
||||
t.Fatal("should have accepted all streams")
|
||||
}
|
||||
select {
|
||||
case err := <-errChan2:
|
||||
require.NoError(t, err)
|
||||
default:
|
||||
t.Fatal("should have opened all streams")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user