use synctest to make streams map tests fully deterministic (#5296)

* use synctest to make streams map tests fully deterministic

* ignore synctest helper package for code coverage
This commit is contained in:
Marten Seemann
2025-08-24 00:11:59 +08:00
committed by GitHub
parent 257281b6e2
commit 258d6b1352
5 changed files with 569 additions and 555 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/synctest"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/assert"
@@ -92,51 +93,67 @@ func testStreamsMapIncomingGettingStreams(t *testing.T, perspective protocol.Per
}
func TestStreamsMapIncomingAcceptingStreams(t *testing.T) {
m := newIncomingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
5,
func(f wire.Frame) {},
protocol.PerspectiveClient,
)
synctest.Test(t, func(t *testing.T) {
m := newIncomingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
5,
func(f wire.Frame) {},
protocol.PerspectiveClient,
)
// AcceptStream should respect the context
errChan := make(chan error, 1)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(10*time.Millisecond))
// AcceptStream should respect the context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := m.AcceptStream(ctx)
errChan <- err
}()
errChan := make(chan error, 1)
go func() {
_, err := m.AcceptStream(ctx)
errChan <- err
}()
select {
case err := <-errChan:
require.Equal(t, context.DeadlineExceeded, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
synctest.Wait()
// AcceptStream should block if there are no streams available
go func() {
_, err := m.AcceptStream(context.Background())
errChan <- err
}()
select {
case <-errChan:
t.Fatal("AcceptStream should not return")
default:
}
select {
case <-errChan:
t.Fatal("AcceptStream should block")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
cancel()
synctest.Wait()
select {
case err := <-errChan:
require.Equal(t, context.Canceled, err)
default:
t.Fatal("timeout")
}
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamClient)
require.NoError(t, err)
// AcceptStream should block if there are no streams available
go func() {
_, err := m.AcceptStream(context.Background())
errChan <- err
}()
select {
case err := <-errChan:
synctest.Wait()
select {
case <-errChan:
t.Fatal("AcceptStream should block")
default:
}
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamClient)
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
synctest.Wait()
select {
case err := <-errChan:
require.NoError(t, err)
default:
t.Fatal("timeout")
}
})
}
func TestStreamsMapIncomingDeletingStreams(t *testing.T) {
@@ -241,103 +258,107 @@ func testStreamsMapIncomingDeletingStreamsWithHighLimits(t *testing.T, pers prot
}
func TestStreamsMapIncomingClosing(t *testing.T) {
m := newIncomingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
5,
func(f wire.Frame) {},
protocol.PerspectiveServer,
)
synctest.Test(t, func(t *testing.T) {
m := newIncomingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
5,
func(f wire.Frame) {},
protocol.PerspectiveServer,
)
var streams []*mockStream
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamServer + 8)
require.NoError(t, err)
for range 3 {
str, err := m.AcceptStream(context.Background())
var streams []*mockStream
_, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamServer + 8)
require.NoError(t, err)
streams = append(streams, str)
}
for range 3 {
str, err := m.AcceptStream(context.Background())
require.NoError(t, err)
streams = append(streams, str)
}
errChan := make(chan error, 1)
go func() {
_, err := m.AcceptStream(context.Background())
errChan <- err
}()
errChan := make(chan error, 1)
go func() {
_, err := m.AcceptStream(context.Background())
errChan <- err
}()
m.CloseWithError(assert.AnError)
m.CloseWithError(assert.AnError)
synctest.Wait()
// accepted streams should be closed
for _, str := range streams {
require.True(t, str.closed)
require.ErrorIs(t, str.closeErr, assert.AnError)
}
// AcceptStream should return the error
select {
case err := <-errChan:
require.ErrorIs(t, err, assert.AnError)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// accepted streams should be closed
for _, str := range streams {
require.True(t, str.closed)
require.ErrorIs(t, str.closeErr, assert.AnError)
}
// AcceptStream should return the error
select {
case err := <-errChan:
require.ErrorIs(t, err, assert.AnError)
default:
t.Fatal("timeout")
}
})
}
func TestStreamsMapIncomingRandomized(t *testing.T) {
const num = 1000
synctest.Test(t, func(t *testing.T) {
const num = 1000
streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)]
firstStream := protocol.FirstIncomingUniStreamServer
if streamType == protocol.StreamTypeBidi {
firstStream = protocol.FirstIncomingBidiStreamServer
}
m := newIncomingStreamsMap(
streamType,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
num,
func(f wire.Frame) {},
protocol.PerspectiveServer,
)
ids := make([]protocol.StreamID, num)
for i := range num {
ids[i] = firstStream + 4*protocol.StreamID(i)
}
rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
timeout := scaleDuration(time.Second)
errChan1 := make(chan error, 1)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for range num {
if _, err := m.AcceptStream(ctx); err != nil {
errChan1 <- err
return
}
streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)]
firstStream := protocol.FirstIncomingUniStreamServer
if streamType == protocol.StreamTypeBidi {
firstStream = protocol.FirstIncomingBidiStreamServer
}
close(errChan1)
}()
errChan2 := make(chan error, 1)
go func() {
m := newIncomingStreamsMap(
streamType,
func(id protocol.StreamID) *mockStream { return &mockStream{id: id} },
num,
func(f wire.Frame) {},
protocol.PerspectiveServer,
)
ids := make([]protocol.StreamID, num)
for i := range num {
if _, err := m.GetOrOpenStream(ids[i]); err != nil {
errChan2 <- err
return
}
ids[i] = firstStream + 4*protocol.StreamID(i)
}
close(errChan2)
}()
rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
select {
case err := <-errChan1:
require.NoError(t, err)
case <-time.After(timeout * 3 / 2):
t.Fatal("timeout")
}
select {
case err := <-errChan2:
require.NoError(t, err)
case <-time.After(timeout * 3 / 2):
t.Fatal("timeout")
}
errChan1 := make(chan error, 1)
go func() {
for range num {
if _, err := m.AcceptStream(context.Background()); err != nil {
errChan1 <- err
return
}
}
close(errChan1)
}()
errChan2 := make(chan error, 1)
go func() {
for i := range num {
if _, err := m.GetOrOpenStream(ids[i]); err != nil {
errChan2 <- err
return
}
}
close(errChan2)
}()
synctest.Wait()
select {
case err := <-errChan1:
require.NoError(t, err)
default:
t.Fatal("should have accepted all streams")
}
select {
case err := <-errChan2:
require.NoError(t, err)
default:
t.Fatal("should have opened all streams")
}
})
}