diff --git a/receive_stream_test.go b/receive_stream_test.go index fd73e6ae1..9a8238f49 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -261,10 +261,12 @@ func TestReceiveStreamDeadlineRemoval(t *testing.T) { // now set the deadline to the past to make Read return immediately require.NoError(t, str.SetReadDeadline(time.Now().Add(-time.Second))) + synctest.Wait() + select { case err := <-errChan: require.ErrorIs(t, err, os.ErrDeadlineExceeded) - case <-time.After(time.Second): + default: t.Fatal("timeout") } }) @@ -356,18 +358,22 @@ func TestReceiveStreamCloseForShutdown(t *testing.T) { errChan <- err }() + synctest.Wait() + select { case err := <-errChan: t.Fatalf("read returned before closeForShutdown: %v", err) - case <-time.After(time.Second): // short wait to ensure read is blocked + default: } str.closeForShutdown(assert.AnError) + synctest.Wait() + select { case err := <-errChan: require.ErrorIs(t, err, assert.AnError) - case <-time.After(time.Second): - t.Fatal("read did not return after closeForShutdown") + default: + t.Fatal("read should have returned") } // following calls to Read should return the error @@ -404,11 +410,7 @@ func TestReceiveStreamCancellation(t *testing.T) { errChan <- err }() - select { - case err := <-errChan: - t.Fatalf("read returned before CancelRead: %v", err) - case <-time.After(time.Second): - } + synctest.Wait() str.CancelRead(1234) // this queues a STOP_SENDING frame @@ -418,12 +420,14 @@ func TestReceiveStreamCancellation(t *testing.T) { require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) + synctest.Wait() + select { case err := <-errChan: var streamErr *StreamError require.ErrorAs(t, err, &streamErr) require.Equal(t, StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}, *streamErr) - case <-time.After(time.Second): + default: t.Fatal("Read was not unblocked") } @@ -529,11 +533,7 @@ func TestReceiveStreamReset(t *testing.T) { errChan <- err }() - select { - case err := <-errChan: - t.Fatalf("read returned before reset: %v", err) - case <-time.After(time.Second): - } + synctest.Wait() mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) gomock.InOrder( @@ -545,10 +545,12 @@ func TestReceiveStreamReset(t *testing.T) { time.Now(), )) + synctest.Wait() + select { case err := <-errChan: require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true}) - case <-time.After(time.Second): + default: t.Fatal("Read was not unblocked") } @@ -616,42 +618,46 @@ func TestReceiveStreamResetAfterFINRead(t *testing.T) { // Note that even without the protection built into the receiveStream, this test // is very timing-dependent, and would need to run a few hundred times to trigger the failure. func TestReceiveStreamConcurrentReads(t *testing.T) { - mockCtrl := gomock.NewController(t) - mockFC := mocks.NewMockStreamFlowController(mockCtrl) - mockSender := NewMockStreamSender(mockCtrl) - str := newReceiveStream(42, mockSender, mockFC) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any(), gomock.Any()).AnyTimes() - var bytesRead protocol.ByteCount - mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) (bool, bool) { - bytesRead += n - return false, false - }).AnyTimes() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any(), gomock.Any()).AnyTimes() + var bytesRead protocol.ByteCount + mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) (bool, bool) { + bytesRead += n + return false, false + }).AnyTimes() - var numCompleted atomic.Int32 - mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)).Do(func(protocol.StreamID) { - numCompleted.Add(1) - }).AnyTimes() + var numCompleted atomic.Int32 + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)).Do(func(protocol.StreamID) { + numCompleted.Add(1) + }).AnyTimes() - const num = 3 - errChan := make(chan error, num) - for range num { - go func() { - _, err := str.Read(make([]byte, 8)) - errChan <- err - }() - } - require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) - for range num { - select { - case err := <-errChan: - require.ErrorIs(t, err, io.EOF) - case <-time.After(time.Second): - t.Fatal("timeout") + const num = 3 + errChan := make(chan error, num) + for range num { + go func() { + _, err := str.Read(make([]byte, 8)) + errChan <- err + }() } - } - require.Equal(t, protocol.ByteCount(6), bytesRead) - require.Equal(t, int32(1), numCompleted.Load()) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) + synctest.Wait() + + for range num { + select { + case err := <-errChan: + require.ErrorIs(t, err, io.EOF) + default: + t.Fatal("read should have returned") + } + } + require.Equal(t, protocol.ByteCount(6), bytesRead) + require.Equal(t, int32(1), numCompleted.Load()) + }) } func TestReceiveStreamResetStreamAtBeforeReadOffset(t *testing.T) {