use synctest.Wait in receive stream tests (#5299)

This commit is contained in:
Marten Seemann
2025-08-24 12:25:04 +08:00
parent 92e7eca419
commit 98b84a3523

View File

@@ -261,10 +261,12 @@ func TestReceiveStreamDeadlineRemoval(t *testing.T) {
// now set the deadline to the past to make Read return immediately
require.NoError(t, str.SetReadDeadline(time.Now().Add(-time.Second)))
synctest.Wait()
select {
case err := <-errChan:
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
case <-time.After(time.Second):
default:
t.Fatal("timeout")
}
})
@@ -356,18 +358,22 @@ func TestReceiveStreamCloseForShutdown(t *testing.T) {
errChan <- err
}()
synctest.Wait()
select {
case err := <-errChan:
t.Fatalf("read returned before closeForShutdown: %v", err)
case <-time.After(time.Second): // short wait to ensure read is blocked
default:
}
str.closeForShutdown(assert.AnError)
synctest.Wait()
select {
case err := <-errChan:
require.ErrorIs(t, err, assert.AnError)
case <-time.After(time.Second):
t.Fatal("read did not return after closeForShutdown")
default:
t.Fatal("read should have returned")
}
// following calls to Read should return the error
@@ -404,11 +410,7 @@ func TestReceiveStreamCancellation(t *testing.T) {
errChan <- err
}()
select {
case err := <-errChan:
t.Fatalf("read returned before CancelRead: %v", err)
case <-time.After(time.Second):
}
synctest.Wait()
str.CancelRead(1234)
// this queues a STOP_SENDING frame
@@ -418,12 +420,14 @@ func TestReceiveStreamCancellation(t *testing.T) {
require.False(t, hasMore)
require.True(t, mockCtrl.Satisfied())
synctest.Wait()
select {
case err := <-errChan:
var streamErr *StreamError
require.ErrorAs(t, err, &streamErr)
require.Equal(t, StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}, *streamErr)
case <-time.After(time.Second):
default:
t.Fatal("Read was not unblocked")
}
@@ -529,11 +533,7 @@ func TestReceiveStreamReset(t *testing.T) {
errChan <- err
}()
select {
case err := <-errChan:
t.Fatalf("read returned before reset: %v", err)
case <-time.After(time.Second):
}
synctest.Wait()
mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
gomock.InOrder(
@@ -545,10 +545,12 @@ func TestReceiveStreamReset(t *testing.T) {
time.Now(),
))
synctest.Wait()
select {
case err := <-errChan:
require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true})
case <-time.After(time.Second):
default:
t.Fatal("Read was not unblocked")
}
@@ -616,42 +618,46 @@ func TestReceiveStreamResetAfterFINRead(t *testing.T) {
// Note that even without the protection built into the receiveStream, this test
// is very timing-dependent, and would need to run a few hundred times to trigger the failure.
func TestReceiveStreamConcurrentReads(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
mockSender := NewMockStreamSender(mockCtrl)
str := newReceiveStream(42, mockSender, mockFC)
synctest.Test(t, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
mockSender := NewMockStreamSender(mockCtrl)
str := newReceiveStream(42, mockSender, mockFC)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any(), gomock.Any()).AnyTimes()
var bytesRead protocol.ByteCount
mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) (bool, bool) {
bytesRead += n
return false, false
}).AnyTimes()
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any(), gomock.Any()).AnyTimes()
var bytesRead protocol.ByteCount
mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) (bool, bool) {
bytesRead += n
return false, false
}).AnyTimes()
var numCompleted atomic.Int32
mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)).Do(func(protocol.StreamID) {
numCompleted.Add(1)
}).AnyTimes()
var numCompleted atomic.Int32
mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)).Do(func(protocol.StreamID) {
numCompleted.Add(1)
}).AnyTimes()
const num = 3
errChan := make(chan error, num)
for range num {
go func() {
_, err := str.Read(make([]byte, 8))
errChan <- err
}()
}
require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now()))
for range num {
select {
case err := <-errChan:
require.ErrorIs(t, err, io.EOF)
case <-time.After(time.Second):
t.Fatal("timeout")
const num = 3
errChan := make(chan error, num)
for range num {
go func() {
_, err := str.Read(make([]byte, 8))
errChan <- err
}()
}
}
require.Equal(t, protocol.ByteCount(6), bytesRead)
require.Equal(t, int32(1), numCompleted.Load())
require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now()))
synctest.Wait()
for range num {
select {
case err := <-errChan:
require.ErrorIs(t, err, io.EOF)
default:
t.Fatal("read should have returned")
}
}
require.Equal(t, protocol.ByteCount(6), bytesRead)
require.Equal(t, int32(1), numCompleted.Load())
})
}
func TestReceiveStreamResetStreamAtBeforeReadOffset(t *testing.T) {