From 56bebc4e61c68535f5d11f22fefa1ff53024dabb Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 28 Dec 2024 15:06:24 +0800 Subject: [PATCH] migrate the stream tests away from Ginkgo (#4799) * translate the receive stream tests * translate the send stream tests * translate the stream tests --- receive_stream_test.go | 1223 ++++++++++------------ send_stream_test.go | 2215 ++++++++++++++++++---------------------- stream_test.go | 154 +-- 3 files changed, 1602 insertions(+), 1990 deletions(-) diff --git a/receive_stream_test.go b/receive_stream_test.go index 95a4c847b..b4ce411dc 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -2,735 +2,614 @@ package quic import ( "errors" + "fmt" "io" - "runtime" - "sync" + "net" + "os" "sync/atomic" + "testing" "time" "github.com/quic-go/quic-go/internal/mocks" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "github.com/onsi/gomega/gbytes" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) -var _ = Describe("Receive Stream", func() { - const streamID protocol.StreamID = 1337 +type readerWithTimeout struct { + io.Reader + Timeout time.Duration +} - var ( - str *receiveStream - strWithTimeout io.Reader // str wrapped with gbytes.TimeoutReader - mockFC *mocks.MockStreamFlowController - mockSender *MockStreamSender +func (r *readerWithTimeout) Read(p []byte) (n int, err error) { + done := make(chan struct{}) + go func() { + defer close(done) + n, err = r.Reader.Read(p) + }() + + select { + case <-done: + return n, err + case <-time.After(r.Timeout): + return 0, fmt.Errorf("read timeout after %s", r.Timeout) + } +} + +func TestReceiveStreamReadData(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + str := newReceiveStream(42, nil, mockFC) + + // read an entire frame + now := time.Now() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, now) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte{0xde, 0xad, 0xbe, 0xef}}, now)) + b := make([]byte, 4) + n, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) + require.NoError(t, err) + require.Equal(t, 4, n) + require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, b) + + // split a frame across multiple reads + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false, now) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 4, Data: []byte{0xca, 0xfe, 0xba, 0xbe}}, now)) + b = make([]byte, 2) + n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte{0xca, 0xfe}, b) + n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte{0xba, 0xbe}, b) + + // combine two frames + gomock.InOrder( + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(11), false, now), + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(14), false, now), + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)).Times(2), ) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 8, Data: []byte{'f', 'o', 'o'}}, now)) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 11, Data: []byte{'b', 'a', 'r'}}, now)) + b = make([]byte, 6) + n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) + require.NoError(t, err) + require.Equal(t, 6, n) + require.Equal(t, []byte{'f', 'o', 'o', 'b', 'a', 'r'}, b) - BeforeEach(func() { - mockSender = NewMockStreamSender(mockCtrl) - mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newReceiveStream(streamID, mockSender, mockFC) + // reordered frames + gomock.InOrder( + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(20), false, now), + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(17), false, now), + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)).Times(2), + ) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 17, Data: []byte{'b', 'a', 'z'}}, now)) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 14, Data: []byte{'f', 'o', 'o'}}, now)) + b = make([]byte, 6) + n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) + require.NoError(t, err) + require.Equal(t, 6, n) + require.Equal(t, []byte{'f', 'o', 'o', 'b', 'a', 'z'}, b) +} - timeout := scaleDuration(250 * time.Millisecond) - strWithTimeout = gbytes.TimeoutReader(str, timeout) - }) +func TestReceiveStreamBlockRead(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) - It("gets stream id", func() { - Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) - }) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false, gomock.Any()) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + errChan := make(chan error, 1) + go func() { + frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}} + time.Sleep(scaleDuration(5 * time.Millisecond)) + errChan <- str.handleStreamFrame(&frame, time.Now()) + }() - Context("reading", func() { - It("reads a single STREAM frame", func() { - now := time.Now() - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, now) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - Expect(str.handleStreamFrame(&frame, now)).To(Succeed()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) + n, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(make([]byte, 2)) + require.NoError(t, err) + require.Equal(t, 2, n) + require.NoError(t, <-errChan) +} - It("reads a single STREAM frame in multiple goes", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - Expect(str.handleStreamFrame(&frame, time.Now())).To(Succeed()) - b := make([]byte, 2) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - Expect(b).To(Equal([]byte{0xDE, 0xAD})) - n, err = strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - Expect(b).To(Equal([]byte{0xBE, 0xEF})) - }) +func TestReceiveStreamReadOverlappingData(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + str := newReceiveStream(42, nil, mockFC) - It("queues a flow control update", func() { - now := time.Now() - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, now) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)).Return(true) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xde, 0xad, 0xbe, 0xef}, - } - Expect(str.handleStreamFrame(&frame, now)).To(Succeed()) - mockSender.EXPECT().onHasStreamControlFrame(streamID, str) - n, err := strWithTimeout.Read(make([]byte, 3)) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - now = now.Add(time.Second) - mockFC.EXPECT().GetWindowUpdate(now).Return(protocol.ByteCount(1337)) - f, ok, hasMore := str.getControlFrame(now) - Expect(ok).To(BeTrue()) - Expect(f.Frame).To(Equal(&wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337})) - Expect(hasMore).To(BeFalse()) - }) + // receive the same frame multiple times + now := time.Now() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, now).Times(3) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) + for i := 0; i < 3; i++ { + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte{0xde, 0xad, 0xbe, 0xef}}, now)) + } + b := make([]byte, 4) + n, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) + require.NoError(t, err) + require.Equal(t, 4, n) + require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, b) - It("reads all data available", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false, gomock.Any()) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - frame2 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - Expect(str.handleStreamFrame(&frame1, time.Now())).To(Succeed()) - Expect(str.handleStreamFrame(&frame2, time.Now())).To(Succeed()) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x00})) - }) + // receive overlapping data + gomock.InOrder( + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false, now), + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), false, now), + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)), + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)), + ) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 4, Data: []byte{'f', 'o', 'o', 'b'}}, now)) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 6, Data: []byte{'o', 'b', 'a', 'r'}}, now)) + b = make([]byte, 6) + n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) + require.NoError(t, err) + require.Equal(t, 6, n) + require.Equal(t, []byte{'f', 'o', 'o', 'b', 'a', 'r'}, b) +} - It("assembles multiple STREAM frames", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false, gomock.Any()) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - frame2 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - Expect(str.handleStreamFrame(&frame1, time.Now())).To(Succeed()) - Expect(str.handleStreamFrame(&frame2, time.Now())).To(Succeed()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) +func TestReceiveStreamMaxStreamData(t *testing.T) { + const streamID protocol.StreamID = 42 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(streamID, mockSender, mockFC) - It("waits until data is available", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - go func() { - defer GinkgoRecover() - frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}} - time.Sleep(10 * time.Millisecond) - Expect(str.handleStreamFrame(&frame, time.Now())).To(Succeed()) - }() - b := make([]byte, 2) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - }) + now := time.Now() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, now) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte{0xde, 0xad, 0xbe, 0xef}}, now)) - It("handles STREAM frames in wrong order", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false, gomock.Any()) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - frame2 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - Expect(str.handleStreamFrame(&frame1, time.Now())).To(Succeed()) - Expect(str.handleStreamFrame(&frame2, time.Now())).To(Succeed()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)).Return(true) + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + n, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(make([]byte, 3)) + require.NoError(t, err) + require.Equal(t, 3, n) + require.True(t, mockCtrl.Satisfied()) - It("ignores duplicate STREAM frames", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false, gomock.Any()) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false, gomock.Any()) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - frame2 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0x13, 0x37}, - } - frame3 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - Expect(str.handleStreamFrame(&frame1, time.Now())).To(Succeed()) - Expect(str.handleStreamFrame(&frame2, time.Now())).To(Succeed()) - Expect(str.handleStreamFrame(&frame3, time.Now())).To(Succeed()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) + now = now.Add(time.Second) + mockFC.EXPECT().GetWindowUpdate(now).Return(protocol.ByteCount(1337)) + f, ok, hasMore := str.getControlFrame(now) + require.True(t, ok) + require.Equal(t, &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337}, f.Frame) + require.False(t, hasMore) +} - It("doesn't rejects a STREAM frames with an overlapping data range", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, gomock.Any()) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte("foob"), - } - frame2 := wire.StreamFrame{ - Offset: 2, - Data: []byte("obar"), - } - Expect(str.handleStreamFrame(&frame1, time.Now())).To(Succeed()) - Expect(str.handleStreamFrame(&frame2, time.Now())).To(Succeed()) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b).To(Equal([]byte("foobar"))) - }) +func TestReceiveStreamDeadlineInThePast(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + str := newReceiveStream(42, nil, mockFC) - Context("deadlines", func() { - It("the deadline error has the right net.Error properties", func() { - Expect(errDeadline.Timeout()).To(BeTrue()) - Expect(errDeadline).To(MatchError("deadline exceeded")) - }) + // no data is read when the deadline is in the past + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()).AnyTimes() + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())) + require.NoError(t, str.SetReadDeadline(time.Now().Add(-time.Second))) + b := make([]byte, 6) + n, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) + require.Error(t, err) + require.Zero(t, n) + var nerr net.Error + require.ErrorAs(t, err, &nerr) + require.True(t, nerr.Timeout()) - It("returns an error when Read is called after the deadline", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()).AnyTimes() - Expect(str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())).To(Succeed()) - str.SetReadDeadline(time.Now().Add(-time.Second)) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) + // data is read when the deadline is in the future + require.NoError(t, str.SetReadDeadline(time.Now().Add(time.Second))) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) + require.NoError(t, err) + require.Equal(t, 6, n) +} - It("unblocks when the deadline is changed to the past", func() { - str.SetReadDeadline(time.Now().Add(time.Hour)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Read(make([]byte, 6)) - Expect(err).To(MatchError(errDeadline)) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - str.SetReadDeadline(time.Now().Add(-time.Hour)) - Eventually(done).Should(BeClosed()) - }) +func TestReceiveStreamDeadlineRemoval(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + str := newReceiveStream(42, nil, mockFC) - It("unblocks after the deadline", func() { - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetReadDeadline(deadline) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) - }) + deadline := scaleDuration(20 * time.Millisecond) + require.NoError(t, str.SetReadDeadline(time.Now().Add(deadline))) + errChan := make(chan error, 1) + go func() { + _, err := (&readerWithTimeout{Reader: str, Timeout: 5 * time.Second}).Read([]byte{0}) + errChan <- err + }() + select { + case err := <-errChan: + t.Fatalf("read should not have returned yet: %v", err) + case <-time.After(deadline / 2): + } - It("doesn't unblock if the deadline is changed before the first one expires", func() { - deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) - str.SetReadDeadline(deadline1) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(20 * time.Millisecond)) - str.SetReadDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline1)) - }() - runtime.Gosched() - b := make([]byte, 10) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) - }) + // remove the deadline after a while (but before it expires) + require.NoError(t, str.SetReadDeadline(time.Time{})) - It("unblocks earlier, when a new deadline is set", func() { - deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(10 * time.Millisecond)) - str.SetReadDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline2)) - }() - str.SetReadDeadline(deadline1) - runtime.Gosched() - b := make([]byte, 10) - _, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(25*time.Millisecond))) - }) + select { + case err := <-errChan: + t.Fatalf("read should not have returned yet: %v", err) + case <-time.After(deadline): + } - It("doesn't unblock if the deadline is removed", func() { - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetReadDeadline(deadline) - deadlineUnset := make(chan struct{}) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(20 * time.Millisecond)) - str.SetReadDeadline(time.Time{}) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline)) - close(deadlineUnset) - }() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Read(make([]byte, 1)) - Expect(err).To(MatchError("test done")) - close(done) - }() - runtime.Gosched() - Eventually(deadlineUnset).Should(BeClosed()) - Consistently(done, scaleDuration(100*time.Millisecond)).ShouldNot(BeClosed()) - // make the go routine return - str.closeForShutdown(errors.New("test done")) - Eventually(done).Should(BeClosed()) - }) - }) + // now set the deadline to the past to make Read return immediately + require.NoError(t, str.SetReadDeadline(time.Now().Add(-time.Second))) + select { + case err := <-errChan: + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + case <-time.After(time.Second): + t.Fatal("timeout") + } +} - Context("closing", func() { - Context("with FIN bit", func() { - It("returns EOFs", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - Fin: true, - }, time.Now())).To(Succeed()) - mockSender.EXPECT().onStreamCompleted(streamID) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - n, err = strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) +func TestReceiveStreamDeadlineExtension(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + str := newReceiveStream(42, nil, mockFC) - It("handles out-of-order frames", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false, gomock.Any()) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - Fin: true, - } - frame2 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - Expect(str.handleStreamFrame(&frame1, time.Now())).To(Succeed()) - Expect(str.handleStreamFrame(&frame2, time.Now())).To(Succeed()) - mockSender.EXPECT().onStreamCompleted(streamID) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - n, err = strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) + deadline := scaleDuration(20 * time.Millisecond) + require.NoError(t, str.SetReadDeadline(time.Now().Add(deadline))) + errChan := make(chan error, 1) + go func() { + _, err := (&readerWithTimeout{Reader: str, Timeout: 5 * time.Second}).Read([]byte{0}) + errChan <- err + }() + select { + case err := <-errChan: + t.Fatalf("read should not have returned yet: %v", err) + case <-time.After(deadline / 2): + } - It("returns EOFs with partial read", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 0, - Data: []byte{0xde, 0xad}, - Fin: true, - }, time.Now())).To(Succeed()) - mockSender.EXPECT().onStreamCompleted(streamID) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(2)) - Expect(b[:n]).To(Equal([]byte{0xde, 0xad})) - }) + // extend the deadline + require.NoError(t, str.SetReadDeadline(time.Now().Add(deadline))) + select { + case err := <-errChan: + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + case <-time.After(deadline * 3 / 2): + t.Fatal("timeout") + } +} - It("handles immediate FINs", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 0, - Fin: true, - }, time.Now())).To(Succeed()) - mockSender.EXPECT().onStreamCompleted(streamID) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) +func TestReceiveStreamEOFWithData(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) - // Calling Read concurrently doesn't make any sense (and is forbidden), - // but we still want to make sure that we don't complete the stream more than once - // if the user misuses our API. - // This would lead to an INTERNAL_ERROR ("tried to delete unknown outgoing stream"), - // which can be hard to debug. - // Note that even without the protection built into the receiveStream, this test - // is very timing-dependent, and would need to run a few hundred times to trigger the failure. - It("handles concurrent reads", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any(), gomock.Any()).AnyTimes() - var bytesRead protocol.ByteCount - mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) bool { - bytesRead += n - return false - }).AnyTimes() + now := time.Now() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true, now) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false, now) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 2, Data: []byte{0xbe, 0xef}, Fin: true}, now)) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte{0xde, 0xad}}, now)) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) - var numCompleted int32 - mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { - atomic.AddInt32(&numCompleted, 1) - }).AnyTimes() - const num = 3 - var wg sync.WaitGroup - wg.Add(num) - for i := 0; i < num; i++ { - go func() { - defer wg.Done() - defer GinkgoRecover() - _, err := str.Read(make([]byte, 8)) - Expect(err).To(MatchError(io.EOF)) - }() - } - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 0, - Data: []byte("foobar"), - Fin: true, - }, time.Now())).To(Succeed()) - wg.Wait() - Expect(bytesRead).To(BeEquivalentTo(6)) - Expect(atomic.LoadInt32(&numCompleted)).To(BeEquivalentTo(1)) - }) - }) - }) + strWithTimeout := &readerWithTimeout{Reader: str, Timeout: time.Second} + b := make([]byte, 6) + n, err := strWithTimeout.Read(b) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 4, n) + require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, b[:n]) + n, err = strWithTimeout.Read(b) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) +} - Context("closing for shutdown", func() { - testErr := errors.New("test error") +func TestReceiveStreamImmediateFINs(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true, gomock.Any()) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Fin: true}, time.Now())) + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + n, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(make([]byte, 4)) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) +} - It("immediately returns all reads", func() { - done := make(chan struct{}) - b := make([]byte, 4) - go func() { - defer GinkgoRecover() - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - str.closeForShutdown(testErr) - Eventually(done).Should(BeClosed()) - }) +func TestReceiveStreamCloseForShutdown(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + strWithTimeout := &readerWithTimeout{Reader: str, Timeout: time.Second} - It("errors for all following reads", func() { - str.closeForShutdown(testErr) - b := make([]byte, 1) - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - }) - }) - }) + // Test immediate return of reads + errChan := make(chan error, 1) + go func() { + _, err := strWithTimeout.Read([]byte{0}) + errChan <- err + }() - Context("stream cancellations", func() { - Context("canceling read", func() { - It("unblocks Read", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(Equal(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - Remote: false, - })) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - str.CancelRead(1234) - Eventually(done).Should(BeClosed()) - }) + select { + case err := <-errChan: + t.Fatalf("read returned before closeForShutdown: %v", err) + case <-time.After(scaleDuration(5 * time.Millisecond)): // short wait to ensure read is blocked + } - It("doesn't allow further calls to Read", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - str.CancelRead(1234) - _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(Equal(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - Remote: false, - })) - }) + testErr := errors.New("test error") + str.closeForShutdown(testErr) + select { + case err := <-errChan: + require.ErrorIs(t, err, testErr) + case <-time.After(time.Second): + t.Fatal("read did not return after closeForShutdown") + } - It("does nothing when CancelRead is called twice", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - str.CancelRead(1234) - str.CancelRead(1234) - _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(Equal(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - Remote: false, - })) - }) + // following calls to Read should return the error + n, err := strWithTimeout.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, testErr) - It("queues a STOP_SENDING frame", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - str.CancelRead(1234) - f, ok, hasMore := str.getControlFrame(time.Now()) - Expect(ok).To(BeTrue()) - Expect(f.Frame).To(Equal(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 1234, - })) - Expect(hasMore).To(BeFalse()) - }) + // receiving a RESET_STREAM frame after closeForShutdown does nothing + require.NoError(t, str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1234, FinalSize: 42}, time.Now())) + n, err = strWithTimeout.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, testErr) - It("doesn't send a STOP_SENDING frame, if the FIN was already read", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) - // no calls to mockSender.queueControlFrame - Expect(str.handleStreamFrame(&wire.StreamFrame{ - StreamID: streamID, - Data: []byte("foobar"), - Fin: true, - }, time.Now())).To(Succeed()) - mockSender.EXPECT().onStreamCompleted(streamID) - n, err := strWithTimeout.Read(make([]byte, 100)) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(6)) - str.CancelRead(1234) - }) + // calling CancelRead after closeForShutdown does nothing + str.CancelRead(1234) + n, err = strWithTimeout.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, testErr) +} - It("doesn't send a STOP_SENDING frame, if the stream was already reset", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()) - mockFC.EXPECT().Abandon().MinTimes(1) - Expect(str.handleResetStreamFrame(&wire.ResetStreamFrame{ - ErrorCode: 1337, - StreamID: streamID, - FinalSize: 42, - }, time.Now())).To(Succeed()) - mockSender.EXPECT().onStreamCompleted(gomock.Any()) - str.CancelRead(1234) - // check that the error indicates a remote reset - n, err := str.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - Expect(n).To(BeZero()) - var streamErr *StreamError - Expect(errors.As(err, &streamErr)).To(BeTrue()) - Expect(streamErr.ErrorCode).To(BeEquivalentTo(1337)) - Expect(streamErr.Remote).To(BeTrue()) - }) +func TestReceiveStreamCancellation(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + strWithTimeout := &readerWithTimeout{Reader: str, Timeout: time.Second} - It("sends a STOP_SENDING after receiving the final offset", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Data: []byte("foobar"), - Fin: true, - }, time.Now())).To(Succeed()) - mockFC.EXPECT().Abandon() - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - mockSender.EXPECT().onStreamCompleted(streamID) - str.CancelRead(1234) - // read the error - n, err := str.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - Expect(n).To(BeZero()) - }) + mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), gomock.Any()) + errChan := make(chan error, 1) + go func() { + _, err := strWithTimeout.Read([]byte{0}) + errChan <- err + }() - It("completes the stream when receiving the Fin after the stream was canceled", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - str.CancelRead(1234) - gomock.InOrder( - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true, gomock.Any()), - mockFC.EXPECT().Abandon(), - ) - mockSender.EXPECT().onStreamCompleted(streamID) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 1000, - Fin: true, - }, time.Now())).To(Succeed()) - }) + select { + case err := <-errChan: + t.Fatalf("read returned before CancelRead: %v", err) + case <-time.After(scaleDuration(5 * time.Millisecond)): + } - It("handles duplicate FinBits after the stream was canceled", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - str.CancelRead(1234) - gomock.InOrder( - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true, gomock.Any()), - mockFC.EXPECT().Abandon(), - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true, gomock.Any()), - ) - mockSender.EXPECT().onStreamCompleted(streamID) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 1000, - Fin: true, - }, time.Now())).To(Succeed()) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 1000, - Fin: true, - }, time.Now())).To(Succeed()) - }) + str.CancelRead(1234) + // this queues a STOP_SENDING frame + f, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1234}, f.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) - It("ignores cancellations after closeForShutdown", func() { - closeErr := errors.New("closed for shutdown") - str.closeForShutdown(closeErr) - buf := make([]byte, 100) - _, err := str.Read(buf) - Expect(err).To(Equal(closeErr)) - str.CancelRead(42) - _, err = str.Read(buf) - Expect(err).To(Equal(closeErr)) - }) - }) + select { + case err := <-errChan: + var streamErr *StreamError + require.ErrorAs(t, err, &streamErr) + require.Equal(t, StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}, *streamErr) + case <-time.After(time.Second): + t.Fatal("Read was not unblocked") + } - Context("receiving RESET_STREAM frames", func() { - rst := &wire.ResetStreamFrame{ - StreamID: streamID, - FinalSize: 42, - ErrorCode: 1234, - } + // further Read calls return the error + n, err := strWithTimeout.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}) - It("unblocks Read", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(Equal(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - Remote: true, - })) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - mockSender.EXPECT().onStreamCompleted(streamID) - gomock.InOrder( - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()), - mockFC.EXPECT().Abandon(), - ) - Expect(str.handleResetStreamFrame(rst, time.Now())).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) + // calling CancelRead again does nothing + // especially: + // 1. no more calls to onHasStreamControlFrame + // 2. no changes of the error code returned by Read + str.CancelRead(1234) + str.CancelRead(4321) + n, err = strWithTimeout.Read([]byte{0}) + require.Zero(t, n) + // error code unchanged + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}) + require.True(t, mockCtrl.Satisfied()) - It("doesn't allow further calls to Read", func() { - mockSender.EXPECT().onStreamCompleted(streamID) - gomock.InOrder( - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()), - mockFC.EXPECT().Abandon(), - ) - Expect(str.handleResetStreamFrame(rst, time.Now())).To(Succeed()) - _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - })) - }) + // receiving the FIN bit has no effect + mockFC.EXPECT().Abandon() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()).Times(2) + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + // receive two of them, to make sure onStreamCompleted is not called twice + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) + require.True(t, mockCtrl.Satisfied()) - It("errors when receiving a RESET_STREAM with an inconsistent offset", func() { - testErr := errors.New("already received a different final offset before") - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()).Return(testErr) - Expect(str.handleResetStreamFrame(rst, time.Now())).To(MatchError(testErr)) - }) + // receiving a RESET_STREAM frame after CancelRead has no effect + mockFC.EXPECT().Abandon() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()) + require.NoError(t, str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 4321, FinalSize: 42}, time.Now())) + n, err = strWithTimeout.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}) +} - It("ignores duplicate RESET_STREAM frames", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()).Times(2) - mockFC.EXPECT().Abandon() - Expect(str.handleResetStreamFrame(rst, time.Now())).To(Succeed()) - Expect(str.handleResetStreamFrame(rst, time.Now())).To(Succeed()) - }) +func TestReceiveStreamCancelReadAfterFINReceived(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) - It("doesn't call onStreamCompleted again when the final offset was already received via Fin", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()).Times(2) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - StreamID: streamID, - Offset: rst.FinalSize, - Fin: true, - }, time.Now())).To(Succeed()) - mockFC.EXPECT().Abandon().MinTimes(1) - mockSender.EXPECT().onStreamCompleted(streamID) - Expect(str.handleResetStreamFrame(rst, time.Now())).To(Succeed()) - // now read the error - n, err := str.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - Expect(n).To(BeZero()) - }) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) - It("doesn't do anything when it was closed for shutdown", func() { - str.closeForShutdown(errors.New("shutdown")) - Expect(str.handleResetStreamFrame(rst, time.Now())).To(Succeed()) - }) + // if the FIN was received, but not read yet, a STOP_SENDING frame is queued + mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), str) + mockFC.EXPECT().Abandon() + str.CancelRead(1337) + f, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1337}, f.Frame) + require.False(t, hasMore) - It("handles RESET_STREAM after CancelRead", func() { - mockFC.EXPECT().Abandon() - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - str.CancelRead(1234) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()) - mockSender.EXPECT().onStreamCompleted(streamID) - Expect(str.handleResetStreamFrame(rst, time.Now())).To(Succeed()) - // check that the error indicates a local reset - n, err := str.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - Expect(n).To(BeZero()) - var streamErr *StreamError - Expect(errors.As(err, &streamErr)).To(BeTrue()) - Expect(streamErr.Remote).To(BeFalse()) - }) - }) - }) + // Read returns the error + n, err := str.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: false}) +} - It("errors when a STREAM frame causes a flow control violation", func() { - testErr := errors.New("flow control violation") - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false, gomock.Any()).Return(testErr) - Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 2, - Data: []byte("foobar"), - }, time.Now())).To(MatchError(testErr)) - }) -}) +func TestReceiveStreamCancelReadAfterFINRead(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) + n, err := str.Read(make([]byte, 10)) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 6, n) + + // if the EOF was already read, no STOP_SENDING frame is queued + str.CancelRead(1234) + _, ok, hasMore := str.getControlFrame(time.Now()) + require.False(t, ok) + require.False(t, hasMore) + + // Read returns the error + n, err = str.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) +} + +func TestReceiveStreamReset(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + strWithTimeout := &readerWithTimeout{Reader: str, Timeout: time.Second} + + errChan := make(chan error, 1) + go func() { + _, err := strWithTimeout.Read([]byte{0}) + errChan <- err + }() + + select { + case err := <-errChan: + t.Fatalf("read returned before reset: %v", err) + case <-time.After(scaleDuration(5 * time.Millisecond)): + } + + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + gomock.InOrder( + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()), + mockFC.EXPECT().Abandon(), + ) + require.NoError(t, str.handleResetStreamFrame( + &wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1234, FinalSize: 42}, + time.Now(), + )) + + select { + case err := <-errChan: + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true}) + case <-time.After(time.Second): + t.Fatal("Read was not unblocked") + } + + // Test that further calls to Read return the error + _, err := strWithTimeout.Read([]byte{0}) + require.Equal(t, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true}, err) + + // further RESET_STREAM frames have no effect + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()) + require.NoError(t, str.handleResetStreamFrame( + &wire.ResetStreamFrame{StreamID: 42, ErrorCode: 4321, FinalSize: 42}, + time.Now(), + )) + n, err := str.Read([]byte{0}) + require.Zero(t, n) + // error code unchanged + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true}) + + // CancelRead after a RESET_STREAM frame has no effect + str.CancelRead(100) + n, err = str.Read([]byte{0}) + require.Zero(t, n) + // error code and remote flag unchanged + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true}) +} + +func TestReceiveStreamResetAfterFINRead(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + require.NoError(t, str.handleStreamFrame( + &wire.StreamFrame{StreamID: 42, Data: []byte("foobar"), Fin: true}, + time.Now(), + )) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + n, err := str.Read(make([]byte, 6)) + require.Equal(t, 6, n) + require.ErrorIs(t, err, io.EOF) + // make sure that onStreamCompleted was called due to the EOF + require.True(t, mockCtrl.Satisfied()) + + // Now receive a RESET_STREAM frame. + // We don't expect any more calls to onStreamCompleted. + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) + mockFC.EXPECT().Abandon() + require.NoError(t, str.handleResetStreamFrame( + &wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1234, FinalSize: 6}, + time.Now(), + )) + // now read the error + n, err = str.Read([]byte{0}) + require.Error(t, err) + require.Zero(t, n) +} + +// Calling Read concurrently doesn't make any sense (and is forbidden), +// but we still want to make sure that we don't complete the stream more than once +// if the user misuses our API. +// This would lead to an INTERNAL_ERROR ("tried to delete unknown outgoing stream"), +// which can be hard to debug. +// Note that even without the protection built into the receiveStream, this test +// is very timing-dependent, and would need to run a few hundred times to trigger the failure. +func TestReceiveStreamConcurrentReads(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any(), gomock.Any()).AnyTimes() + var bytesRead protocol.ByteCount + mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) bool { + bytesRead += n + return false + }).AnyTimes() + + var numCompleted atomic.Int32 + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)).Do(func(protocol.StreamID) { + numCompleted.Add(1) + }).AnyTimes() + + const num = 3 + errChan := make(chan error, num) + for i := 0; i < num; i++ { + go func() { + _, err := str.Read(make([]byte, 8)) + errChan <- err + }() + } + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) + for i := 0; i < num; i++ { + select { + case err := <-errChan: + require.ErrorIs(t, err, io.EOF) + case <-time.After(time.Second): + t.Fatal("timeout") + } + } + require.Equal(t, protocol.ByteCount(6), bytesRead) + require.Equal(t, int32(1), numCompleted.Load()) +} diff --git a/send_stream_test.go b/send_stream_test.go index 81717502c..f4d3de047 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -4,1294 +4,1019 @@ import ( "bytes" "context" "errors" + "fmt" "io" mrand "math/rand" - "runtime" + "net" + "os" + "testing" "time" "golang.org/x/exp/rand" - "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/mocks" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "github.com/onsi/gomega/gbytes" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) -var _ = Describe("Send Stream", func() { - const streamID protocol.StreamID = 1337 +type writerWithTimeout struct { + io.Writer + Timeout time.Duration +} - var ( - str *sendStream - strWithTimeout io.Writer // str wrapped with gbytes.TimeoutWriter - mockFC *mocks.MockStreamFlowController - mockSender *MockStreamSender +func (w *writerWithTimeout) Write(p []byte) (n int, err error) { + done := make(chan struct{}) + go func() { + defer close(done) + n, err = w.Writer.Write(p) + }() + + select { + case <-done: + return n, err + case <-time.After(w.Timeout): + return 0, fmt.Errorf("write timeout after %s", w.Timeout) + } +} + +func expectedFrameHeaderLen(strID protocol.StreamID, offset protocol.ByteCount) protocol.ByteCount { + return (&wire.StreamFrame{StreamID: strID, Offset: offset, DataLenPresent: true}).Length(protocol.Version1) +} + +func TestSendStreamSetup(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + ctx := context.WithValue(context.Background(), "foo", "bar") + str := newSendStream(ctx, 1337, nil, mockFC) + require.NotNil(t, str.Context()) + require.Equal(t, "bar", str.Context().Value("foo")) + require.Equal(t, protocol.StreamID(1337), str.StreamID()) +} + +func TestSendStreamWriteData(t *testing.T) { + const streamID protocol.StreamID = 42 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} + + mockSender.EXPECT().onHasStreamData(streamID, str) + n, err := strWithTimeout.Write([]byte("foobar")) + require.NoError(t, err) + require.Equal(t, 6, n) + + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + frame, ok, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) + require.False(t, hasMore) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: streamID, Data: []byte("foobar"), DataLenPresent: true}, + frame.Frame, + ) + require.True(t, mockCtrl.Satisfied()) + + // nothing more to send at this point + _, ok, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + // nil writes don't do anything + n, err = strWithTimeout.Write(nil) + require.NoError(t, err) + require.Zero(t, n) + require.True(t, mockCtrl.Satisfied()) + + // empty slices writes don't do anything + n, err = strWithTimeout.Write([]byte{}) + require.NoError(t, err) + require.Zero(t, n) + require.True(t, mockCtrl.Satisfied()) + + // multiple writes are bundled into a single frame + mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) + n, err = strWithTimeout.Write([]byte{0xde, 0xad}) + require.NoError(t, err) + require.Equal(t, 2, n) + n, err = strWithTimeout.Write([]byte{0xbe, 0xef}) + require.NoError(t, err) + require.Equal(t, 2, n) + + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(4)) + frame, ok, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) + require.False(t, hasMore) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 42, Offset: 6, Data: []byte{0xde, 0xad, 0xbe, 0xef}, DataLenPresent: true}, + frame.Frame, ) - BeforeEach(func() { - mockSender = NewMockStreamSender(mockCtrl) - mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newSendStream(context.Background(), streamID, mockSender, mockFC) + // a single write is split up into smaller frames + mockSender.EXPECT().onHasStreamData(streamID, str) + n, err = strWithTimeout.Write([]byte("foobaz")) + require.NoError(t, err) + require.Equal(t, 6, n) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) + // TODO(#4807): check that no empty frames are popped + // frame, ok, hasMore = str.popStreamFrame(expectedFrameHeaderLen(streamID, 10), protocol.Version1) + // require.False(t, ok) + // require.True(t, hasMore) + frame, ok, hasMore = str.popStreamFrame(expectedFrameHeaderLen(streamID, 10)+3, protocol.Version1) + require.True(t, ok) + require.True(t, hasMore) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: streamID, Offset: 10, Data: []byte("foo"), DataLenPresent: true}, + frame.Frame, + ) + frame, ok, hasMore = str.popStreamFrame(expectedFrameHeaderLen(streamID, 13)+3, protocol.Version1) + require.True(t, ok) + require.False(t, hasMore) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: streamID, Offset: 13, Data: []byte("baz"), DataLenPresent: true}, + frame.Frame, + ) +} - timeout := scaleDuration(250 * time.Millisecond) - strWithTimeout = gbytes.TimeoutWriter(str, timeout) +func TestSendStreamLargeWrites(t *testing.T) { + const streamID protocol.StreamID = 1337 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + + mockSender.EXPECT().onHasStreamData(streamID, str) + data := make([]byte, 5000) + rand.Read(data) + errChan := make(chan error, 1) + go func() { + defer str.Close() + _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write(data) + errChan <- err + }() + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(scaleDuration(5 * time.Millisecond)): // short wait to ensure write is blocked + } + + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxPacketBufferSize).AnyTimes() + mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() + var offset protocol.ByteCount + const size = 40 + for offset+size < protocol.ByteCount(len(data))-protocol.MaxPacketBufferSize { + frame, ok, hasMore := str.popStreamFrame(size+expectedFrameHeaderLen(streamID, offset), protocol.Version1) + require.True(t, ok) + require.True(t, hasMore) + require.Equal(t, offset, frame.Frame.Offset) + require.Equal(t, data[offset:offset+size], frame.Frame.Data) + offset += size + require.True(t, mockCtrl.Satisfied()) + } + // Write should still be blocked, since there's more than protocol.MaxPacketBufferSize left to send + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(scaleDuration(5 * time.Millisecond)): // short wait to ensure write is blocked + } + + mockSender.EXPECT().onHasStreamData(streamID, str) // from the Close call + frame, ok, hasMore := str.popStreamFrame(size+expectedFrameHeaderLen(streamID, offset), protocol.Version1) + require.True(t, ok) + require.True(t, hasMore) + require.Equal(t, data[offset:offset+size], frame.Frame.Data) + require.Equal(t, offset, frame.Frame.Offset) + offset += size + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + frame, ok, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) + require.False(t, hasMore) + require.Equal(t, data[offset:], frame.Frame.Data) + require.True(t, frame.Frame.Fin) +} + +func TestSendStreamLargeWriteBlocking(t *testing.T) { + const streamID protocol.StreamID = 1337 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + + mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) + _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.NoError(t, err) + errChan := make(chan error, 1) + go func() { + _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write(make([]byte, protocol.MaxPacketBufferSize)) + errChan <- err + }() + + select { + case err := <-errChan: + t.Fatalf("write should not have returned yet: %v", err) + case <-time.After(scaleDuration(5 * time.Millisecond)): + } + + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) + frame, ok, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(streamID, 0)+3, protocol.Version1) + require.True(t, ok) + require.True(t, hasMoreData) + require.Equal(t, []byte("foo"), frame.Frame.Data) + + select { + case err := <-errChan: + t.Fatalf("write should not have returned yet: %v", err) + case <-time.After(scaleDuration(5 * time.Millisecond)): + } + + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) + frame, ok, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(streamID, 3)+3, protocol.Version1) + require.True(t, ok) + require.True(t, hasMoreData) + require.Equal(t, []byte("bar"), frame.Frame.Data) + + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + +func TestSendStreamCopyData(t *testing.T) { + const streamID protocol.StreamID = 42 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} + + // for small writes + data := []byte("foobar") + mockSender.EXPECT().onHasStreamData(streamID, str) + _, err := strWithTimeout.Write(data) + require.NoError(t, err) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + frame, ok, _ := str.popStreamFrame(protocol.MaxPacketBufferSize, protocol.Version1) + require.True(t, ok) + data[1] = 'e' // modify the data after it has been written + require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Data: []byte("foobar"), DataLenPresent: true}, frame.Frame) +} + +func TestSendStreamDeadlineInThePast(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), 42, mockSender, mockFC) + + // no data is written when the deadline is in the past + require.NoError(t, str.SetWriteDeadline(time.Now().Add(-time.Second))) + n, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + require.Zero(t, n) + var nerr net.Error + require.ErrorAs(t, err, &nerr) + require.True(t, nerr.Timeout()) + + // data is written when the deadline is in the future + mockSender.EXPECT().onHasStreamData(gomock.Any(), str) + require.NoError(t, str.SetWriteDeadline(time.Now().Add(time.Second))) + n, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.NoError(t, err) + require.Equal(t, 6, n) +} + +func TestSendStreamDeadlineRemoval(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), 42, mockSender, mockFC) + + deadline := scaleDuration(20 * time.Millisecond) + require.NoError(t, str.SetWriteDeadline(time.Now().Add(deadline))) + mockSender.EXPECT().onHasStreamData(gomock.Any(), str).Times(2) + + // small writes are written immediately + _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.NoError(t, err) + + // large writes might block, and therefore subject to the deadline + errChan := make(chan error, 1) + go func() { + _, err := (&writerWithTimeout{Writer: str, Timeout: 5 * time.Second}).Write(make([]byte, 2000)) + errChan <- err + }() + select { + case err := <-errChan: + t.Fatalf("write should not have returned yet: %v", err) + case <-time.After(deadline / 2): + } + + // remove the deadline after a while (but before it expires) + require.NoError(t, str.SetWriteDeadline(time.Time{})) + + select { + case err := <-errChan: + t.Fatalf("write should not have returned yet: %v", err) + case <-time.After(deadline): + } + + // now set the deadline to the past to make Write return immediately + require.NoError(t, str.SetWriteDeadline(time.Now().Add(-time.Second))) + select { + case err := <-errChan: + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + frame, ok, hasMoreData := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) + require.False(t, hasMoreData) + require.Equal(t, []byte("foobar"), frame.Frame.Data) +} + +func TestSendStreamDeadlineExtension(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), 42, mockSender, mockFC) + + deadline := scaleDuration(20 * time.Millisecond) + require.NoError(t, str.SetWriteDeadline(time.Now().Add(deadline))) + + mockSender.EXPECT().onHasStreamData(gomock.Any(), str) + errChan := make(chan error, 1) + go func() { + _, err := (&writerWithTimeout{Writer: str, Timeout: 5 * time.Second}).Write(make([]byte, 2000)) + errChan <- err + }() + select { + case err := <-errChan: + t.Fatalf("write should not have returned yet: %v", err) + case <-time.After(deadline / 2): + } + + // extend the deadline + require.NoError(t, str.SetWriteDeadline(time.Now().Add(deadline))) + select { + case err := <-errChan: + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + case <-time.After(deadline * 3 / 2): + t.Fatal("timeout") + } + + _, ok, hasMoreData := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + require.False(t, hasMoreData) +} + +func TestSendStreamClose(t *testing.T) { + const streamID protocol.StreamID = 1234 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} + + mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) + _, err := strWithTimeout.Write([]byte("foobar")) + require.NoError(t, err) + require.NoError(t, str.Close()) + + select { + case <-str.Context().Done(): + default: + t.Fatal("stream context should have been canceled") + } + + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) + frame, ok, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, 0)+3, protocol.Version1) + require.True(t, ok) + require.True(t, hasMore) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: streamID, Offset: 0, Data: []byte("foo"), DataLenPresent: true}, // no FIN yet + frame.Frame, + ) + frame, ok, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) + require.False(t, hasMore) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: streamID, Offset: 3, Fin: true, Data: []byte("bar"), DataLenPresent: true}, + frame.Frame, + ) + require.True(t, mockCtrl.Satisfied()) + + // further calls to Write return an error + _, err = strWithTimeout.Write([]byte("foobar")) + require.ErrorContains(t, err, "write on closed stream 1234") + _, ok, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + require.False(t, hasMore) + + // further calls to Close don't do anything + // TODO(#4800): there shouldn't be any calls to mockSender + mockSender.EXPECT().onHasStreamData(streamID, str) + require.NoError(t, str.Close()) + _, ok, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) +} + +func TestSendStreamImmediateClose(t *testing.T) { + const streamID protocol.StreamID = 1337 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + mockSender.EXPECT().onHasStreamData(streamID, str) + require.NoError(t, str.Close()) + frame, ok, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, 13)+3, protocol.Version1) + require.True(t, ok) + require.False(t, hasMore) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: streamID, Fin: true, DataLenPresent: true}, + frame.Frame, + ) +} + +func TestSendStreamFlowControlBlocked(t *testing.T) { + const streamID protocol.StreamID = 42 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + + mockSender.EXPECT().onHasStreamData(streamID, str) + _, err := str.Write([]byte("foobar")) + require.NoError(t, err) + + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(3)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(0)) + mockFC.EXPECT().IsNewlyBlocked().Return(true) + frame, ok, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) + require.True(t, hasMore) + require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Data: []byte("foo"), DataLenPresent: true}, frame.Frame) + + // TODO(#4771): the STREAM_DATA_BLOCKED frame should be sent immediately + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + frame, ok, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + require.False(t, hasMore) + + cf, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.EqualExportedValues(t, &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: 3}, cf.Frame) + require.False(t, hasMore) +} + +func TestSendStreamCloseForShutdown(t *testing.T) { + const streamID protocol.StreamID = 1337 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} + + mockSender.EXPECT().onHasStreamData(streamID, str) + errChan := make(chan error, 1) + go func() { + _, err := strWithTimeout.Write(bytes.Repeat([]byte("foobar"), 1000)) + errChan <- err + }() + + select { + case err := <-errChan: + t.Fatalf("write returned before closeForShutdown: %v", err) + case <-time.After(scaleDuration(5 * time.Millisecond)): // short wait to ensure write is blocked + } + + testErr := errors.New("test error") + str.closeForShutdown(testErr) + require.True(t, mockCtrl.Satisfied()) + + select { + case err := <-errChan: + require.ErrorIs(t, err, testErr) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + // future calls to Write should return the error + _, err := strWithTimeout.Write([]byte("foobar")) + require.ErrorIs(t, err, testErr) + + // closing the stream doesn't do anything + require.NoError(t, str.Close()) + + // no STREAM frames popped + _, ok, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + require.False(t, hasMore) + + // canceling the stream doesn't do anything + str.CancelWrite(1234) + _, err = strWithTimeout.Write([]byte("foobar")) + require.ErrorIs(t, err, testErr) // error unchanged +} + +func TestSendStreamUpdateSendWindow(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), 42, mockSender, mockFC) + + mockSender.EXPECT().onHasStreamData(gomock.Any(), str) + _, err := str.Write([]byte("foobar")) + require.NoError(t, err) + require.True(t, mockCtrl.Satisfied()) + + // no calls to onHasStreamData if the window size wasn't increased + mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(41)).Return(false) + str.updateSendWindow(41) +} + +func TestSendStreamCancellation(t *testing.T) { + const streamID protocol.StreamID = 42 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} + + mockSender.EXPECT().onHasStreamData(streamID, str) + _, err := strWithTimeout.Write([]byte("foobar")) + require.NoError(t, err) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) + frame, ok, hasMore := str.popStreamFrame(3+expectedFrameHeaderLen(streamID, 0), protocol.Version1) + require.True(t, ok) + require.True(t, hasMore) + require.Equal(t, []byte("foo"), frame.Frame.Data) + require.True(t, mockCtrl.Satisfied()) + + wrote := make(chan struct{}) + mockSender.EXPECT().onHasStreamData(streamID, str).Do(func(protocol.StreamID, sendStreamI) { close(wrote) }) + errChan := make(chan error, 1) + go func() { + _, err := strWithTimeout.Write(make([]byte, 2000)) + errChan <- err + }() + + select { + case <-wrote: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + // cancel the stream + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + str.CancelWrite(1234) + require.True(t, mockCtrl.Satisfied()) + + cf, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + // only the "foo" was sent out, so the final size is 3 + require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 3, ErrorCode: 1234}, cf.Frame) + require.False(t, hasMore) + + // the context was canceled + select { + case <-str.Context().Done(): + default: + t.Fatal("stream context should have been canceled") + } + require.ErrorIs(t, context.Cause(str.Context()), &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) + + // duplicate calls to CancelWrite don't do anything + str.CancelWrite(1234) + _, ok, _ = str.getControlFrame(time.Now()) + require.False(t, ok) + + // the Write call should return an error + select { + case err := <-errChan: + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + // no data to send + _, ok, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + require.False(t, hasMore) + + // future calls to Write should return an error + _, err = strWithTimeout.Write([]byte("foo")) + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) + _, ok, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + require.False(t, hasMore) + + // Close has no effect + require.ErrorContains(t, str.Close(), "close called for canceled stream") + _, ok, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.Error(t, err) + // TODO(#4808):error code and remote flag are unchanged + // require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) +} + +func TestSendStreamCancellationAfterClose(t *testing.T) { + const streamID protocol.StreamID = 1234 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + + mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) + _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.NoError(t, err) + require.NoError(t, str.Close()) + + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + str.CancelWrite(1337) + + _, ok, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + require.False(t, hasMore) + + cf, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 0, ErrorCode: 1337}, cf.Frame) + require.False(t, hasMore) +} + +func TestSendStreamCancellationStreamRetransmission(t *testing.T) { + t.Run("local", func(t *testing.T) { + testSendStreamCancellationStreamRetransmission(t, false) }) + t.Run("remote", func(t *testing.T) { + testSendStreamCancellationStreamRetransmission(t, true) + }) +} - expectedFrameHeaderLen := func(offset protocol.ByteCount) protocol.ByteCount { - return (&wire.StreamFrame{ - StreamID: streamID, - Offset: offset, - DataLenPresent: true, - }).Length(protocol.Version1) +func testSendStreamCancellationStreamRetransmission(t *testing.T, remote bool) { + const streamID protocol.StreamID = 1000 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + + mockSender.EXPECT().onHasStreamData(streamID, str) + _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.NoError(t, err) + + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) + f1, ok, hasMore := str.popStreamFrame(3+expectedFrameHeaderLen(streamID, 0), protocol.Version1) + require.True(t, ok) + require.True(t, hasMore) + f2, ok, hasMore := str.popStreamFrame(3+expectedFrameHeaderLen(streamID, 3), protocol.Version1) + require.True(t, ok) + require.False(t, hasMore) + + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + if remote { + str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID, ErrorCode: 1337}) + } else { + str.CancelWrite(1337) + } + cf, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.IsType(t, &wire.ResetStreamFrame{}, cf.Frame) + require.False(t, hasMore) + + // it doesn't matter if the STREAM frames are acked or lost + f1.Handler.OnAcked(f1.Frame) + f2.Handler.OnLost(f2.Frame) + _, ok, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + require.False(t, hasMore) + // if CancelWrite was called, the stream is completed as soon as the RESET_STREAM frame is acked + if !remote { + mockSender.EXPECT().onStreamCompleted(streamID) + } + cf.Handler.OnAcked(cf.Frame) + + // but if it's a remote cancellation, the application has to consume the error first + if remote { + mockSender.EXPECT().onStreamCompleted(streamID) + _, err := str.Write([]byte("foobar")) + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) + } +} + +func TestSendStreamCancellationResetStreamRetransmission(t *testing.T) { + const streamID protocol.StreamID = 1000 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + str.CancelWrite(1337) + + f1, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 0, ErrorCode: 1337}, f1.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + // lose the frame + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + f1.Handler.OnLost(f1.Frame) + // get the retransmission + f2, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 0, ErrorCode: 1337}, f2.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + // acknowledge the frame + // TODO(#4803): this should complete the stream, if we correctly accounted for lost RESET_STREAM frames + // mockSender.EXPECT().onStreamCompleted(streamID) + f2.Handler.OnAcked(f2.Frame) +} + +func TestSendStreamStopSending(t *testing.T) { + const streamID protocol.StreamID = 1000 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + + mockSender.EXPECT().onHasStreamData(streamID, str) + _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.NoError(t, err) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + _, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) + require.True(t, mockCtrl.Satisfied()) + + errChan := make(chan error, 1) + go func() { + _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write(make([]byte, 2000)) + errChan <- err + }() + + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID, ErrorCode: 1337}) + + select { + case err := <-errChan: + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) + case <-time.After(time.Second): + t.Fatal("timeout") } - waitForWrite := func() { - EventuallyWithOffset(0, func() bool { - str.mutex.Lock() - hasData := str.dataForWriting != nil || str.nextFrame != nil - str.mutex.Unlock() - return hasData - }).Should(BeTrue()) - } + cf, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 6, ErrorCode: 1337}, cf.Frame) + require.False(t, hasMore) - getDataAtOffset := func(offset, length protocol.ByteCount) []byte { - b := make([]byte, length) - for i := protocol.ByteCount(0); i < length; i++ { - b[i] = uint8(offset + i) + // calls to Write should return an error + _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) + _, ok, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + + // calls to CancelWrite have no effect + str.CancelWrite(1234) + _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + // error code and remote flag are unchanged + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) + _, ok, _ = str.getControlFrame(time.Now()) + require.False(t, ok) + + // Close has no effect + require.ErrorContains(t, str.Close(), "close called for canceled stream") + _, ok, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.False(t, ok) + _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.Error(t, err) + // TODO(#4808):error code and remote flag are unchanged + // require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) +} + +// This test is inherently racy, as it tests a concurrent call to Write() and CancelRead(). +// A single successful run of this test therefore doesn't mean a lot, +// for reliable results it has to be run many times. +func TestSendStreamConcurrentWriteAndCancel(t *testing.T) { + const streamID protocol.StreamID = 1000 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) + + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()).MaxTimes(1) + mockSender.EXPECT().onHasStreamData(streamID, str).MaxTimes(1) + mockSender.EXPECT().onStreamCompleted(streamID).MaxTimes(1) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).MaxTimes(1) + mockFC.EXPECT().AddBytesSent(gomock.Any()).MaxTimes(1) + + errChan := make(chan error, 1) + go func() { + n, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write(make([]byte, 100)) + if n == 0 { + errChan <- nil + return } - return b + errChan <- err + }() + + done := make(chan struct{}, 2) + go func() { + str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + done <- struct{}{} + }() + go func() { + str.CancelWrite(1234) + done <- struct{}{} + }() + + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout waiting for write to complete") } - getData := func(length protocol.ByteCount) []byte { - return getDataAtOffset(0, length) + for i := 0; i < 2; i++ { + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for cancel to complete") + } } +} - It("gets stream id", func() { - Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) - }) +func TestSendStreamRetransmissions(t *testing.T) { + const streamID protocol.StreamID = 1000 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) - Context("writing", func() { - It("writes and gets all data at once", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID, str) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - }() - waitForWrite() - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - f := frame.Frame - Expect(f.Data).To(Equal([]byte("foobar"))) - Expect(f.Fin).To(BeFalse()) - Expect(f.Offset).To(BeZero()) - Expect(f.DataLenPresent).To(BeTrue()) - Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) - Expect(str.dataForWriting).To(BeNil()) - Eventually(done).Should(BeClosed()) - }) + mockSender.EXPECT().onHasStreamData(streamID, str) + _, err := str.Write([]byte("foo")) + require.NoError(t, err) - It("writes and gets data in two turns", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - mockSender.EXPECT().onHasStreamData(streamID, str) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - close(done) - }() - waitForWrite() - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) - frame, ok, _ := str.popStreamFrame(expectedFrameHeaderLen(0)+3, protocol.Version1) - Expect(ok).To(BeTrue()) - f := frame.Frame - Expect(f.Offset).To(BeZero()) - Expect(f.Fin).To(BeFalse()) - Expect(f.Data).To(Equal([]byte("foo"))) - Expect(f.DataLenPresent).To(BeTrue()) - frame, ok, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - f = frame.Frame - Expect(f.Data).To(Equal([]byte("bar"))) - Expect(f.Fin).To(BeFalse()) - Expect(f.Offset).To(Equal(protocol.ByteCount(3))) - Expect(f.DataLenPresent).To(BeTrue()) - _, ok, _ = str.popStreamFrame(1000, protocol.Version1) - Expect(ok).To(BeFalse()) - Eventually(done).Should(BeClosed()) - }) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) + f1, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) + require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Data: []byte("foo"), DataLenPresent: true}, f1.Frame) + require.True(t, mockCtrl.Satisfied()) - It("bundles small writes", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) - n, err := strWithTimeout.Write([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - n, err = strWithTimeout.Write([]byte("bar")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - close(done) - }() - Eventually(done).Should(BeClosed()) // both Write calls returned without any data having been dequeued yet - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - f := frame.Frame - Expect(f.Offset).To(BeZero()) - Expect(f.Fin).To(BeFalse()) - Expect(f.Data).To(Equal([]byte("foobar"))) - }) + // write some more data + mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) + _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("bar")) + require.NoError(t, err) + require.NoError(t, str.Close()) + require.True(t, mockCtrl.Satisfied()) - It("writes and gets data in multiple turns, for large writes", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(5) - var totalBytesSent protocol.ByteCount - mockFC.EXPECT().AddBytesSent(gomock.Any()).Do(func(l protocol.ByteCount) { totalBytesSent += l }).Times(5) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - mockSender.EXPECT().onHasStreamData(streamID, str) - n, err := strWithTimeout.Write(getData(5000)) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(5000)) - close(done) - }() - waitForWrite() - for i := 0; i < 5; i++ { - frame, ok, _ := str.popStreamFrame(1100, protocol.Version1) - Expect(ok).To(BeTrue()) - f := frame.Frame - Expect(f.Offset).To(BeNumerically("~", 1100*i, 10*i)) - Expect(f.Fin).To(BeFalse()) - Expect(f.Data).To(Equal(getDataAtOffset(f.Offset, f.DataLen()))) - Expect(f.DataLenPresent).To(BeTrue()) - } - Expect(totalBytesSent).To(Equal(protocol.ByteCount(5000))) - Eventually(done).Should(BeClosed()) - }) + // lose the frame + mockSender.EXPECT().onHasStreamData(streamID, str) + f1.Handler.OnLost(f1.Frame) + require.True(t, mockCtrl.Satisfied()) - It("unblocks Write as soon as a STREAM frame can be buffered", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID, str) - _, err := strWithTimeout.Write(getData(protocol.MaxPacketBufferSize + 3)) - Expect(err).ToNot(HaveOccurred()) - }() - waitForWrite() - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) - frame, ok, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0)+2, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(hasMoreData).To(BeTrue()) - f := frame.Frame - Expect(f.DataLen()).To(Equal(protocol.ByteCount(2))) - Consistently(done).ShouldNot(BeClosed()) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) - frame, ok, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(1)+1, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(hasMoreData).To(BeTrue()) - f = frame.Frame - Expect(f.DataLen()).To(Equal(protocol.ByteCount(1))) - Eventually(done).Should(BeClosed()) - }) + // when popping a new frame, we first get the retransmission... + f2, ok, hasMoreData := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) + require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Data: []byte("foo"), DataLenPresent: true}, f2.Frame) + require.True(t, hasMoreData) + require.True(t, mockCtrl.Satisfied()) - It("only unblocks Write once a previously buffered STREAM frame has been fully dequeued", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID, str) - _, err := str.Write(getData(protocol.MaxPacketBufferSize)) - Expect(err).ToNot(HaveOccurred()) - }() - waitForWrite() - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) - frame, ok, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0)+2, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(hasMoreData).To(BeTrue()) - f := frame.Frame - Expect(f.Data).To(Equal([]byte("fo"))) - Consistently(done).ShouldNot(BeClosed()) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(4)) - frame, ok, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(2)+4, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(hasMoreData).To(BeTrue()) - f = frame.Frame - Expect(f.Data).To(Equal([]byte("obar"))) - Eventually(done).Should(BeClosed()) - }) + // ... then we get the new data + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) + f3, ok, hasMoreData := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) + require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Offset: 3, Fin: true, Data: []byte("bar"), DataLenPresent: true}, f3.Frame) + require.False(t, hasMoreData) + require.True(t, mockCtrl.Satisfied()) - It("popStreamFrame returns nil if no data is available", func() { - _, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1) - Expect(ok).To(BeFalse()) - Expect(hasMoreData).To(BeFalse()) - }) + // acknowledge the retransmission... + f2.Handler.OnAcked(f2.Frame) + // ... and the last frame, which concludes this stream + mockSender.EXPECT().onStreamCompleted(streamID) + f3.Handler.OnAcked(f3.Frame) +} - It("says if it has more data for writing", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID, str) - n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(100)) - }() - waitForWrite() - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) - frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame.Frame.Fin).To(BeFalse()) - Expect(hasMoreData).To(BeTrue()) - frame, ok, hasMoreData = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame.Frame.Fin).To(BeFalse()) - Expect(hasMoreData).To(BeFalse()) - _, ok, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeFalse()) - Eventually(done).Should(BeClosed()) - }) +func TestSendStreamRetransmissionFraming(t *testing.T) { + const streamID protocol.StreamID = 1000 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) - It("copies the slice while writing", func() { - frameHeaderSize := protocol.ByteCount(4) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) - s := []byte("foo") - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID, str) - n, err := strWithTimeout.Write(s) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - }() - waitForWrite() - frame, ok, _ := str.popStreamFrame(frameHeaderSize+1, protocol.Version1) - Expect(ok).To(BeTrue()) - f := frame.Frame - Expect(f.Data).To(Equal([]byte("f"))) - frame, ok, _ = str.popStreamFrame(100, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - f = frame.Frame - Expect(f.Data).To(Equal([]byte("oo"))) - s[1] = 'e' - Expect(f.Data).To(Equal([]byte("oo"))) - Eventually(done).Should(BeClosed()) - }) + mockSender.EXPECT().onHasStreamData(streamID, str) + _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.NoError(t, err) - It("returns when given a nil input", func() { - n, err := strWithTimeout.Write(nil) - Expect(n).To(BeZero()) - Expect(err).ToNot(HaveOccurred()) - }) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + f, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) - It("returns when given an empty slice", func() { - n, err := strWithTimeout.Write([]byte("")) - Expect(n).To(BeZero()) - Expect(err).ToNot(HaveOccurred()) - }) + // lose the frame + mockSender.EXPECT().onHasStreamData(streamID, str) + f.Handler.OnLost(f.Frame) - It("cancels the context when Close is called", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - Expect(str.Context().Done()).ToNot(BeClosed()) - Expect(str.Close()).To(Succeed()) - Expect(str.Context().Done()).To(BeClosed()) - Expect(context.Cause(str.Context())).To(MatchError(context.Canceled)) - }) + // retransmission doesn't fit + _, ok, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, 0), protocol.Version1) + require.False(t, ok) + require.True(t, hasMore) - Context("flow control blocking", func() { - It("queues a BLOCKED frame if the stream is flow control blocked", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(3)) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(0)) - mockFC.EXPECT().IsNewlyBlocked().Return(true) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID, str) - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - }() - waitForWrite() - f, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(hasMoreData).To(BeTrue()) - Expect(f.Frame.Data).To(HaveLen(3)) - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - _, ok, hasMoreData = str.popStreamFrame(1000, protocol.Version1) - Expect(ok).To(BeFalse()) - Expect(hasMoreData).To(BeFalse()) - cf, ok, hasMore := str.getControlFrame(time.Now()) - Expect(ok).To(BeTrue()) - Expect(cf.Frame).To(Equal(&wire.StreamDataBlockedFrame{ - StreamID: streamID, - MaximumStreamData: 3, - })) - Expect(hasMore).To(BeFalse()) - // make the Write go routine return - str.closeForShutdown(nil) - Eventually(done).Should(BeClosed()) - }) - }) + // split the retransmission + r1, ok, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, 0)+3, protocol.Version1) + require.True(t, ok) + require.True(t, hasMore) + require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Data: []byte("foo"), DataLenPresent: true}, r1.Frame) + r2, ok, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, 3)+3, protocol.Version1) + require.True(t, ok) + // When popping a retransmission, we always claim that there's more data to send. + // We accept that this might be incorrect. + require.True(t, hasMore) + require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Offset: 3, Data: []byte("bar"), DataLenPresent: true}, r2.Frame) + _, ok, hasMore = str.popStreamFrame(expectedFrameHeaderLen(streamID, 3)+3, protocol.Version1) + require.False(t, ok) + require.False(t, hasMore) +} - Context("deadlines", func() { - It("returns an error when Write is called after the deadline", func() { - str.SetWriteDeadline(time.Now().Add(-time.Second)) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) +// This test is kind of an integration test. +// It writes 4 MB of data, and pops STREAM frames that sometimes are and sometimes aren't limited by flow control. +// Half of these STREAM frames are then received and their content saved, while the other half is reported lost +// and has to be retransmitted. +func TestSendStreamRetransmitDataUntilAcknowledged(t *testing.T) { + const streamID protocol.StreamID = 123456 + const dataLen = 1 << 22 // 4 MB + mockCtrl := gomock.NewController(t) + mockSender := NewMockStreamSender(mockCtrl) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC) - It("unblocks after the deadline", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetWriteDeadline(deadline) - n, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) - }) + mockSender.EXPECT().onHasStreamData(streamID, str).AnyTimes() + mockFC.EXPECT().SendWindowSize().DoAndReturn(func() protocol.ByteCount { + return protocol.ByteCount(mrand.Intn(500)) + 50 + }).AnyTimes() + mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() - It("unblocks when the deadline is changed to the past", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - str.SetWriteDeadline(time.Now().Add(time.Hour)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - str.SetWriteDeadline(time.Now().Add(-time.Hour)) - Eventually(done).Should(BeClosed()) - }) + data := make([]byte, dataLen) + _, err := rand.Read(data) + require.NoError(t, err) + done := make(chan struct{}) + go func() { + defer close(done) + _, err := str.Write(data) + require.NoError(t, err) + str.Close() + }() - It("returns the number of bytes written, when the deadline expires", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() - mockFC.EXPECT().AddBytesSent(gomock.Any()) - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetWriteDeadline(deadline) - var n int - writeReturned := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(writeReturned) - mockSender.EXPECT().onHasStreamData(streamID, str) - var err error - n, err = strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) - }() - waitForWrite() - frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - Expect(hasMoreData).To(BeTrue()) - Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) - Expect(n).To(BeEquivalentTo(frame.Frame.DataLen())) - }) + var completed bool + mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { completed = true }) - It("doesn't pop any data after the deadline expired", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() - mockFC.EXPECT().AddBytesSent(gomock.Any()) - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetWriteDeadline(deadline) - writeReturned := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(writeReturned) - mockSender.EXPECT().onHasStreamData(streamID, str) - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - }() - waitForWrite() - frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - Expect(hasMoreData).To(BeTrue()) - Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) - _, ok, hasMoreData = str.popStreamFrame(50, protocol.Version1) - Expect(ok).To(BeFalse()) - Expect(hasMoreData).To(BeFalse()) - }) - - It("doesn't unblock if the deadline is changed before the first one expires", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) - str.SetWriteDeadline(deadline1) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(20 * time.Millisecond)) - str.SetWriteDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline1)) - close(done) - }() - runtime.Gosched() - n, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) - Eventually(done).Should(BeClosed()) - }) - - It("unblocks earlier, when a new deadline is set", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(10 * time.Millisecond)) - str.SetWriteDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline2)) - close(done) - }() - str.SetWriteDeadline(deadline1) - runtime.Gosched() - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(errDeadline)) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't unblock if the deadline is removed", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetWriteDeadline(deadline) - deadlineUnset := make(chan struct{}) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(20 * time.Millisecond)) - str.SetWriteDeadline(time.Time{}) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline)) - close(deadlineUnset) - }() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError("test done")) - close(done) - }() - runtime.Gosched() - Eventually(deadlineUnset).Should(BeClosed()) - Consistently(done, scaleDuration(100*time.Millisecond)).ShouldNot(BeClosed()) - // make the go routine return - str.closeForShutdown(errors.New("test done")) - Eventually(done).Should(BeClosed()) - }) - }) - - Context("closing", func() { - It("doesn't allow writes after it has been closed", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - str.Close() - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError("write on closed stream 1337")) - }) - - It("allows FIN", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - str.Close() - frame, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - f := frame.Frame - Expect(f.Data).To(BeEmpty()) - Expect(f.Fin).To(BeTrue()) - Expect(f.DataLenPresent).To(BeTrue()) - Expect(hasMoreData).To(BeFalse()) - }) - - It("doesn't send a FIN when there's still data", func() { - const frameHeaderLen protocol.ByteCount = 4 - mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) - mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) - frame, ok, _ := str.popStreamFrame(3+frameHeaderLen, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - f := frame.Frame - Expect(f.Data).To(Equal([]byte("foo"))) - Expect(f.Fin).To(BeFalse()) - frame, ok, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - f = frame.Frame - Expect(f.Data).To(Equal([]byte("bar"))) - Expect(f.Fin).To(BeTrue()) - }) - - It("doesn't send a FIN when there's still data, for long writes", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - mockSender.EXPECT().onHasStreamData(streamID, str) - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().onHasStreamData(streamID, str) - Expect(str.Close()).To(Succeed()) - }() - waitForWrite() - for i := 1; i <= 5; i++ { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - if i == 5 { - Eventually(done).Should(BeClosed()) - } - frame, ok, _ := str.popStreamFrame(1100, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - f := frame.Frame - Expect(f.Data).To(Equal(getDataAtOffset(f.Offset, f.DataLen()))) - Expect(f.Fin).To(Equal(i == 5)) // the last frame should have the FIN bit set - } - }) - - It("doesn't allow FIN after it is closed for shutdown", func() { - str.closeForShutdown(errors.New("test")) - _, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1) - Expect(ok).To(BeFalse()) - Expect(hasMoreData).To(BeFalse()) - - Expect(str.Close()).To(Succeed()) - _, ok, hasMoreData = str.popStreamFrame(1000, protocol.Version1) - Expect(ok).To(BeFalse()) - Expect(hasMoreData).To(BeFalse()) - }) - - It("doesn't allow FIN twice", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - str.Close() - frame, ok, _ := str.popStreamFrame(1000, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - f := frame.Frame - Expect(f.Data).To(BeEmpty()) - Expect(f.Fin).To(BeTrue()) - _, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1) - Expect(ok).To(BeFalse()) - Expect(hasMoreData).To(BeFalse()) - }) - }) - - Context("closing for shutdown", func() { - testErr := errors.New("test") - - It("returns errors when the stream is cancelled", func() { - str.closeForShutdown(testErr) - n, err := strWithTimeout.Write([]byte("foo")) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - }) - - It("doesn't get data for writing if an error occurred", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID, str) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError(testErr)) - close(done) - }() - waitForWrite() - frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) // get a STREAM frame containing some data, but not all - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - Expect(hasMoreData).To(BeTrue()) - str.closeForShutdown(testErr) - _, ok, hasMoreData = str.popStreamFrame(1000, protocol.Version1) - Expect(ok).To(BeFalse()) - Expect(hasMoreData).To(BeFalse()) - Eventually(done).Should(BeClosed()) - }) - }) - }) - - Context("handling MAX_STREAM_DATA frames", func() { - It("informs the flow controller", func() { - mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(0x1337)) - str.updateSendWindow(0x1337) - }) - - It("says when it has data for sending", func() { - mockFC.EXPECT().UpdateSendWindow(gomock.Any()).Return(true) - mockSender.EXPECT().onHasStreamData(streamID, str) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - mockSender.EXPECT().onHasStreamData(streamID, str) - str.updateSendWindow(42) - // make sure the Write go routine returns - str.closeForShutdown(nil) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't say it has data for sending if the MAX_STREAM_DATA frame was reordered", func() { - mockFC.EXPECT().UpdateSendWindow(gomock.Any()).Return(false) // reordered frame - mockSender.EXPECT().onHasStreamData(streamID, str) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - // don't expect any calls to onHasStreamData - str.updateSendWindow(42) - // make sure the Write go routine returns - str.closeForShutdown(nil) - Eventually(done).Should(BeClosed()) - }) - }) - - Context("stream cancellations", func() { - Context("canceling writing", func() { - It("queues a RESET_STREAM frame", func() { - mockSender.EXPECT().onHasStreamControlFrame(streamID, gomock.Any()) - str.writeOffset = 1234 - str.CancelWrite(9876) - cf, ok, hasMore := str.getControlFrame(time.Now()) - Expect(ok).To(BeTrue()) - Expect(cf.Frame).To(Equal(&wire.ResetStreamFrame{ - StreamID: streamID, - FinalSize: 1234, - ErrorCode: 9876, - })) - Expect(hasMore).To(BeFalse()) - }) - - It("retransmits a RESET_STREAM frame", func() { - mockSender.EXPECT().onHasStreamControlFrame(streamID, gomock.Any()) - str.CancelWrite(9876) - cf, ok, _ := str.getControlFrame(time.Now()) - Expect(ok).To(BeTrue()) - Expect(cf.Frame).To(BeAssignableToTypeOf(&wire.ResetStreamFrame{})) - - mockSender.EXPECT().onHasStreamControlFrame(streamID, gomock.Any()) - cf.Handler.OnLost(cf.Frame) - cf2, ok, _ := str.getControlFrame(time.Now()) - Expect(ok).To(BeTrue()) - Expect(cf2.Frame).To(Equal(cf.Frame)) - }) - - // This test is inherently racy, as it tests a concurrent call to Write() and CancelRead(). - // A single successful run of this test therefore doesn't mean a lot, - // for reliable results it has to be run many times. - It("returns a nil error when the whole slice has been sent out", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()).MaxTimes(1) - mockSender.EXPECT().onHasStreamData(streamID, str).MaxTimes(1) - mockSender.EXPECT().onStreamCompleted(streamID).MaxTimes(1) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).MaxTimes(1) - mockFC.EXPECT().AddBytesSent(gomock.Any()).MaxTimes(1) - errChan := make(chan error) - go func() { - defer GinkgoRecover() - n, err := strWithTimeout.Write(getData(100)) - if n == 0 { - errChan <- nil - return - } - errChan <- err - }() - - runtime.Gosched() - go str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - go str.CancelWrite(1234) - Eventually(errChan).Should(Receive(Not(HaveOccurred()))) - }) - - It("unblocks Write", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID, str) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - writeReturned := make(chan struct{}) - var n int - go func() { - defer GinkgoRecover() - var err error - n, err = strWithTimeout.Write(getData(5000)) - Expect(err).To(Equal(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - Remote: false, - })) - close(writeReturned) - }() - waitForWrite() - frame, ok, _ := str.popStreamFrame(50, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - str.CancelWrite(1234) - Eventually(writeReturned).Should(BeClosed()) - Expect(n).To(BeEquivalentTo(frame.Frame.DataLen())) - }) - - It("doesn't pop STREAM frames after being canceled", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID, str) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - writeReturned := make(chan struct{}) - go func() { - defer GinkgoRecover() - strWithTimeout.Write(getData(100)) - close(writeReturned) - }() - waitForWrite() - frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(hasMoreData).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - str.CancelWrite(1234) - _, ok, hasMoreData = str.popStreamFrame(10, protocol.Version1) - Expect(ok).To(BeFalse()) - Expect(hasMoreData).To(BeFalse()) - Eventually(writeReturned).Should(BeClosed()) - }) - - It("doesn't pop STREAM frames after being canceled, for large writes", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID, str) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - writeReturned := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(Equal(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - Remote: false, - })) - close(writeReturned) - }() - waitForWrite() - frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(hasMoreData).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - str.CancelWrite(1234) - _, ok, hasMoreData = str.popStreamFrame(10, protocol.Version1) - Expect(ok).To(BeFalse()) - Expect(hasMoreData).To(BeFalse()) - Eventually(writeReturned).Should(BeClosed()) - }) - - It("ignores acknowledgements for STREAM frames after it was cancelled", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID, str) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(gomock.Any()) - writeReturned := make(chan struct{}) - go func() { - defer GinkgoRecover() - strWithTimeout.Write(getData(100)) - close(writeReturned) - }() - waitForWrite() - frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(hasMoreData).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - str.CancelWrite(1234) - frame.Handler.OnAcked(frame.Frame) - }) - - It("cancels the context", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - Expect(str.Context().Done()).ToNot(BeClosed()) - str.CancelWrite(1234) - Expect(str.Context().Done()).To(BeClosed()) - Expect(context.Cause(str.Context())).To(BeAssignableToTypeOf(&StreamError{})) - Expect(context.Cause(str.Context()).(*StreamError).ErrorCode).To(Equal(StreamErrorCode(1234))) - }) - - It("doesn't allow further calls to Write", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - str.CancelWrite(1234) - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(&StreamError{ - StreamID: streamID, - ErrorCode: 1234, - Remote: false, - })) - }) - - It("only cancels once", func() { - mockSender.EXPECT().onHasStreamControlFrame(streamID, gomock.Any()) - str.CancelWrite(1234) - str.CancelWrite(4321) - cf, ok, hasMore := str.getControlFrame(time.Now()) - Expect(ok).To(BeTrue()) - Expect(cf.Frame).To(Equal(&wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 1234})) - Expect(hasMore).To(BeFalse()) - cf, ok, hasMore = str.getControlFrame(time.Now()) - Expect(ok).To(BeFalse()) - Expect(cf.Frame).To(BeNil()) - Expect(hasMore).To(BeFalse()) - }) - - It("queues a RESET_STREAM frame, even if the stream was already closed", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - mockSender.EXPECT().onHasStreamControlFrame(streamID, str) - Expect(str.Close()).To(Succeed()) - // don't EXPECT any calls to queueControlFrame - str.CancelWrite(123) - f, ok, hasMore := str.getControlFrame(time.Now()) - Expect(ok).To(BeTrue()) - Expect(f.Frame).To(BeAssignableToTypeOf(&wire.ResetStreamFrame{})) - Expect(hasMore).To(BeFalse()) - }) - }) - - Context("receiving STOP_SENDING frames", func() { - It("queues a RESET_STREAM frames, and copies the error code from the STOP_SENDING frame", func() { - mockSender.EXPECT().onHasStreamControlFrame(streamID, str) - // Don't EXPECT calls to onStreamCompleted. - // The application needs to learn about the cancellation first. - str.handleStopSendingFrame(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 101, - }) - f, ok, hasMore := str.getControlFrame(time.Now()) - Expect(ok).To(BeTrue()) - Expect(f.Frame).To(Equal(&wire.ResetStreamFrame{ - StreamID: streamID, - ErrorCode: 101, - })) - Expect(hasMore).To(BeFalse()) - }) - - It("discards the stream when CancelWrite is called after receiving STOP_SENDING", func() { - mockSender.EXPECT().onHasStreamControlFrame(streamID, str) - str.handleStopSendingFrame(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 101, - }) - - str.CancelWrite(101) - }) - - It("unblocks Write", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Write(getData(5000)) - Expect(err).To(Equal(&StreamError{ - StreamID: streamID, - ErrorCode: 123, - Remote: true, - })) - close(done) - }() - waitForWrite() - str.handleStopSendingFrame(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 123, - }) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't allow further calls to Write", func() { - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - str.handleStopSendingFrame(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 123, - }) - _, err := str.Write([]byte("foobar")) - Expect(err).To(Equal(&StreamError{ - StreamID: streamID, - ErrorCode: 123, - Remote: true, - })) - }) - - It("handles Close after STOP_SENDING", func() { - mockSender.EXPECT().onHasStreamControlFrame(streamID, str) - str.handleStopSendingFrame(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 123, - }) - str.Close() - }) - - It("handles STOP_SENDING after sending the FIN", func() { - mockSender.EXPECT().onHasStreamData(gomock.Any(), gomock.Any()) - str.Close() - _, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - str.handleStopSendingFrame(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 123, - }) - }) - - It("handles STOP_SENDING after Close, but before sending the FIN", func() { - mockSender.EXPECT().onHasStreamData(gomock.Any(), gomock.Any()) - str.Close() - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - str.handleStopSendingFrame(&wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 123, - }) - }) - It("ignores cancellations after closeForShutdown", func() { - closeErr := errors.New("closed for shutdown") - str.closeForShutdown(closeErr) - _, err := str.Write([]byte("hello")) - Expect(err).To(Equal(closeErr)) - str.CancelWrite(42) - _, err = str.Write([]byte("hello")) - Expect(err).To(Equal(closeErr)) - }) - }) - }) - - Context("retransmissions", func() { - It("queues and retrieves frames", func() { - str.numOutstandingFrames = 1 - f := &wire.StreamFrame{ - Data: []byte("foobar"), - Offset: 0x42, - DataLenPresent: false, - } - mockSender.EXPECT().onHasStreamData(streamID, str) - (*sendStreamAckHandler)(str).OnLost(f) - frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - f = frame.Frame - Expect(f.Offset).To(Equal(protocol.ByteCount(0x42))) - Expect(f.Data).To(Equal([]byte("foobar"))) - Expect(f.DataLenPresent).To(BeTrue()) - }) - - It("splits a retransmission", func() { - str.numOutstandingFrames = 1 - sf := &wire.StreamFrame{ - Data: []byte("foobar"), - Offset: 0x42, - DataLenPresent: false, - } - mockSender.EXPECT().onHasStreamData(streamID, str) - (*sendStreamAckHandler)(str).OnLost(sf) - frame, ok, hasMoreData := str.popStreamFrame(sf.Length(protocol.Version1)-3, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - f := frame.Frame - Expect(hasMoreData).To(BeTrue()) - Expect(f.Offset).To(Equal(protocol.ByteCount(0x42))) - Expect(f.Data).To(Equal([]byte("foo"))) - Expect(f.DataLenPresent).To(BeTrue()) - frame, ok, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - f = frame.Frame - Expect(f.Offset).To(Equal(protocol.ByteCount(0x45))) - Expect(f.Data).To(Equal([]byte("bar"))) - Expect(f.DataLenPresent).To(BeTrue()) - }) - - It("returns nil if the size is too small", func() { - str.numOutstandingFrames = 1 - f := &wire.StreamFrame{ - Data: []byte("foobar"), - Offset: 0x42, - DataLenPresent: false, - } - mockSender.EXPECT().onHasStreamData(streamID, str) - (*sendStreamAckHandler)(str).OnLost(f) - _, ok, hasMoreData := str.popStreamFrame(2, protocol.Version1) - Expect(ok).To(BeFalse()) - Expect(hasMoreData).To(BeTrue()) - }) - - It("queues lost STREAM frames", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - Eventually(done).Should(BeClosed()) - Expect(frame).ToNot(BeNil()) - Expect(frame.Frame.Data).To(Equal([]byte("foobar"))) - - // now lose the frame - mockSender.EXPECT().onHasStreamData(streamID, str) - frame.Handler.OnLost(frame.Frame) - newFrame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(newFrame).ToNot(BeNil()) - Expect(newFrame.Frame.Data).To(Equal([]byte("foobar"))) - }) - - It("doesn't queue retransmissions for a stream that was canceled", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - f, ok, _ := str.popStreamFrame(100, protocol.Version1) - Expect(ok).To(BeTrue()) - Eventually(done).Should(BeClosed()) - Expect(f).ToNot(BeNil()) - mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) - str.CancelWrite(9876) - // don't EXPECT any calls to onHasStreamData + received := make([]byte, dataLen) + for { + if completed { + break + } + f, ok, _ := str.popStreamFrame(protocol.ByteCount(mrand.Intn(300)+100), protocol.Version1) + if !ok { + continue + } + sf := f.Frame + // 50%: acknowledge the frame and save the data + // 50%: lose the frame + if mrand.Intn(100) < 50 { + copy(received[sf.Offset:sf.Offset+sf.DataLen()], sf.Data) + f.Handler.OnAcked(f.Frame) + } else { f.Handler.OnLost(f.Frame) - Expect(str.retransmissionQueue).To(BeEmpty()) - }) - }) - - Context("determining when a stream is completed", func() { - BeforeEach(func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() - mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() - }) - - It("says when a stream is completed", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(make([]byte, 100)) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - - // get a bunch of small frames (max. 20 bytes) - var frames []ackhandler.StreamFrame - for { - frame, ok, hasMoreData := str.popStreamFrame(20, protocol.Version1) - if !ok { - continue - } - frames = append(frames, frame) - if !hasMoreData { - break - } - } - Eventually(done).Should(BeClosed()) - - // Acknowledge all frames. - // We don't expect the stream to be completed, since we still need to send the FIN. - for _, f := range frames { - f.Handler.OnAcked(f.Frame) - } - - // Now close the stream and acknowledge the FIN. - mockSender.EXPECT().onHasStreamData(streamID, str) - Expect(str.Close()).To(Succeed()) - frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame).ToNot(BeNil()) - mockSender.EXPECT().onStreamCompleted(streamID) - frame.Handler.OnAcked(frame.Frame) - }) - - It("waits until a RESET_STREAM is acknowledged", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - - frame, ok, hasMoreData := str.popStreamFrame(20, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(frame.Frame.DataLen()).To(BeEquivalentTo(6)) - Expect(hasMoreData).To(BeFalse()) - - mockSender.EXPECT().onHasStreamControlFrame(streamID, str) - str.CancelWrite(1234) - cf, ok, _ := str.getControlFrame(time.Now()) - Expect(ok).To(BeTrue()) - - // Acknowledge the STREAM frame. - // This doesn't complete the stream, since we're still waiting for - // the acknowledgment of the RESET_STREAM frame. - frame.Handler.OnAcked(frame.Frame) - - mockSender.EXPECT().onStreamCompleted(streamID) - cf.Handler.OnAcked(cf.Frame) - }) - - It("says when a stream is completed, if Close() is called before popping the frame", func() { - mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(make([]byte, 100)) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - waitForWrite() - Eventually(done).Should(BeClosed()) - Expect(str.Close()).To(Succeed()) - - frame, ok, hasMoreData := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(hasMoreData).To(BeFalse()) - Expect(frame).ToNot(BeNil()) - Expect(frame.Frame.Fin).To(BeTrue()) - - mockSender.EXPECT().onStreamCompleted(streamID) - frame.Handler.OnAcked(frame.Frame) - }) - - It("doesn't say it's completed when there are frames waiting to be retransmitted", func() { - mockSender.EXPECT().onHasStreamData(streamID, str) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := strWithTimeout.Write(getData(100)) - Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().onHasStreamData(streamID, str) - Expect(str.Close()).To(Succeed()) - close(done) - }() - waitForWrite() - - // get a bunch of small frames (max. 20 bytes) - var frames []ackhandler.StreamFrame - for { - frame, ok, _ := str.popStreamFrame(20, protocol.Version1) - if !ok { - continue - } - frames = append(frames, frame) - if frame.Frame.Fin { - break - } - } - Eventually(done).Should(BeClosed()) - - // lose the first frame, acknowledge all others - for _, f := range frames[1:] { - f.Handler.OnAcked(f.Frame) - } - mockSender.EXPECT().onHasStreamData(streamID, str) - frames[0].Handler.OnLost(frames[0].Frame) - - // get the retransmission and acknowledge it - ret, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(ok).To(BeTrue()) - Expect(ret).ToNot(BeNil()) - mockSender.EXPECT().onStreamCompleted(streamID) - ret.Handler.OnAcked(ret.Frame) - }) - - // This test is kind of an integration test. - // It writes 4 MB of data, and pops STREAM frames that sometimes are and sometimes aren't limited by flow control. - // Half of these STREAM frames are then received and their content saved, while the other half is reported lost - // and has to be retransmitted. - It("retransmits data until everything has been acknowledged", func() { - const dataLen = 1 << 22 // 4 MB - mockSender.EXPECT().onHasStreamData(streamID, str).AnyTimes() - mockFC.EXPECT().SendWindowSize().DoAndReturn(func() protocol.ByteCount { - return protocol.ByteCount(mrand.Intn(500)) + 50 - }).AnyTimes() - mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() - - data := make([]byte, dataLen) - _, err := rand.Read(data) - Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - _, err := str.Write(data) - Expect(err).ToNot(HaveOccurred()) - str.Close() - }() - - var completed bool - mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { completed = true }) - - received := make([]byte, dataLen) - for { - if completed { - break - } - f, ok, _ := str.popStreamFrame(protocol.ByteCount(mrand.Intn(300)+100), protocol.Version1) - if !ok { - continue - } - sf := f.Frame - // 50%: acknowledge the frame and save the data - // 50%: lose the frame - if mrand.Intn(100) < 50 { - copy(received[sf.Offset:sf.Offset+sf.DataLen()], sf.Data) - f.Handler.OnAcked(f.Frame) - } else { - f.Handler.OnLost(f.Frame) - } - } - Expect(received).To(Equal(data)) - }) - }) -}) + } + } + require.Equal(t, data, received) +} diff --git a/stream_test.go b/stream_test.go index f01293938..44c103103 100644 --- a/stream_test.go +++ b/stream_test.go @@ -2,93 +2,101 @@ package quic import ( "context" - "errors" "io" "os" + "testing" "time" "github.com/quic-go/quic-go/internal/mocks" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "github.com/onsi/gomega/gbytes" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) -var _ = Describe("Stream", func() { +func TestStreamDeadlines(t *testing.T) { + const streamID protocol.StreamID = 1337 + mockCtrl := gomock.NewController(t) + mockSender := NewMockStreamSender(mockCtrl) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + str := newStream(context.Background(), streamID, mockSender, mockFC) + + // SetDeadline sets both read and write deadlines + str.SetDeadline(time.Now().Add(-time.Second)) + n, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + require.Zero(t, n) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()).AnyTimes() + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())) + n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(make([]byte, 6)) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + require.Zero(t, n) +} + +func TestStreamCompletion(t *testing.T) { + completeReadSide := func( + t *testing.T, + str *stream, + mockCtrl *gomock.Controller, + mockFC *mocks.MockStreamFlowController, + ) { + t.Helper() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{ + StreamID: str.StreamID(), + Data: []byte("foobar"), + Fin: true, + }, time.Now())) + _, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(make([]byte, 6)) + require.ErrorIs(t, err, io.EOF) + require.True(t, mockCtrl.Satisfied()) + } + + completeWriteSide := func( + t *testing.T, + str *stream, + mockCtrl *gomock.Controller, + mockFC *mocks.MockStreamFlowController, + mockSender *MockStreamSender, + ) { + t.Helper() + mockSender.EXPECT().onHasStreamData(str.StreamID(), gomock.Any()).Times(2) + _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.NoError(t, err) + require.NoError(t, str.Close()) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + f, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.True(t, ok) + require.True(t, f.Frame.Fin) + f.Handler.OnAcked(f.Frame) + require.True(t, mockCtrl.Satisfied()) + } + const streamID protocol.StreamID = 1337 - var ( - str *stream - strWithTimeout io.ReadWriter // str wrapped with gbytes.Timeout{Reader,Writer} - mockFC *mocks.MockStreamFlowController - mockSender *MockStreamSender - ) + t.Run("first read, then write", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockSender := NewMockStreamSender(mockCtrl) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + str := newStream(context.Background(), streamID, mockSender, mockFC) - BeforeEach(func() { - mockSender = NewMockStreamSender(mockCtrl) - mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newStream(context.Background(), streamID, mockSender, mockFC) - - timeout := scaleDuration(250 * time.Millisecond) - strWithTimeout = struct { - io.Reader - io.Writer - }{ - gbytes.TimeoutReader(str, timeout), - gbytes.TimeoutWriter(str, timeout), - } + completeReadSide(t, str, mockCtrl, mockFC) + mockSender.EXPECT().onStreamCompleted(streamID) + completeWriteSide(t, str, mockCtrl, mockFC, mockSender) }) - It("gets stream id", func() { - Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) + t.Run("first write, then read", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockSender := NewMockStreamSender(mockCtrl) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + str := newStream(context.Background(), streamID, mockSender, mockFC) + + completeWriteSide(t, str, mockCtrl, mockFC, mockSender) + mockSender.EXPECT().onStreamCompleted(streamID) + completeReadSide(t, str, mockCtrl, mockFC) }) - - Context("deadlines", func() { - It("sets a write deadline, when SetDeadline is called", func() { - str.SetDeadline(time.Now().Add(-time.Second)) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) - - It("sets a read deadline, when SetDeadline is called", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()).AnyTimes() - Expect(str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())).To(Succeed()) - str.SetDeadline(time.Now().Add(-time.Second)) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) - }) - - Context("completing", func() { - It("is not completed when only the receive side is completed", func() { - // don't EXPECT a call to mockSender.onStreamCompleted() - str.receiveStream.sender.onStreamCompleted(streamID) - }) - - It("is not completed when only the send side is completed", func() { - // don't EXPECT a call to mockSender.onStreamCompleted() - str.sendStream.sender.onStreamCompleted(streamID) - }) - - It("is completed when both sides are completed", func() { - mockSender.EXPECT().onStreamCompleted(streamID) - str.sendStream.sender.onStreamCompleted(streamID) - str.receiveStream.sender.onStreamCompleted(streamID) - }) - }) -}) - -var _ = Describe("Deadline Error", func() { - It("is a net.Error that wraps os.ErrDeadlineError", func() { - err := deadlineError{} - Expect(err.Timeout()).To(BeTrue()) - Expect(errors.Is(err, os.ErrDeadlineExceeded)).To(BeTrue()) - Expect(errors.Unwrap(err)).To(Equal(os.ErrDeadlineExceeded)) - }) -}) +}