diff --git a/internal/mocks/stream.go b/internal/mocks/stream.go index 9d5230d1..a9a2b25e 100644 --- a/internal/mocks/stream.go +++ b/internal/mocks/stream.go @@ -95,10 +95,11 @@ func (_mr *MockStreamIMockRecorder) Finished() *gomock.Call { } // GetDataForWriting mocks base method -func (_m *MockStreamI) GetDataForWriting(_param0 protocol.ByteCount) []byte { +func (_m *MockStreamI) GetDataForWriting(_param0 protocol.ByteCount) ([]byte, bool) { ret := _m.ctrl.Call(_m, "GetDataForWriting", _param0) ret0, _ := ret[0].([]byte) - return ret0 + ret1, _ := ret[1].(bool) + return ret0, ret1 } // GetDataForWriting indicates an expected call of GetDataForWriting @@ -235,18 +236,6 @@ func (_mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.C return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), arg0) } -// ShouldSendFin mocks base method -func (_m *MockStreamI) ShouldSendFin() bool { - ret := _m.ctrl.Call(_m, "ShouldSendFin") - ret0, _ := ret[0].(bool) - return ret0 -} - -// ShouldSendFin indicates an expected call of ShouldSendFin -func (_mr *MockStreamIMockRecorder) ShouldSendFin() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ShouldSendFin", reflect.TypeOf((*MockStreamI)(nil).ShouldSendFin)) -} - // StreamID mocks base method func (_m *MockStreamI) StreamID() protocol.StreamID { ret := _m.ctrl.Call(_m, "StreamID") diff --git a/stream.go b/stream.go index f8d9a629..6e827d50 100644 --- a/stream.go +++ b/stream.go @@ -20,11 +20,10 @@ type streamI interface { AddStreamFrame(*wire.StreamFrame) error RegisterRemoteError(error, protocol.ByteCount) error HasDataForWriting() bool - GetDataForWriting(maxBytes protocol.ByteCount) []byte + GetDataForWriting(maxBytes protocol.ByteCount) (data []byte, shouldSendFin bool) GetWriteOffset() protocol.ByteCount Finished() bool Cancel(error) - ShouldSendFin() bool SentFin() // methods needed for flow control GetWindowUpdate() protocol.ByteCount @@ -266,17 +265,19 @@ func (s *stream) GetWriteOffset() protocol.ByteCount { // HasDataForWriting says if there's stream available to be dequeued for writing func (s *stream) HasDataForWriting() bool { s.mutex.Lock() - hasData := s.err == nil && len(s.dataForWriting) > 0 + hasData := s.err == nil && // nothing should be sent if an error occurred + (len(s.dataForWriting) > 0 || // there is data queued for sending + s.finishedWriting.Get() && !s.finSent.Get()) // if there is no data, but writing finished and the FIN hasn't been sent yet s.mutex.Unlock() return hasData } -func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte { +func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { s.mutex.Lock() defer s.mutex.Unlock() if s.err != nil || s.dataForWriting == nil { - return nil + return nil, s.finishedWriting.Get() && !s.finSent.Get() } // TODO(#657): Flow control for the crypto stream @@ -284,7 +285,7 @@ func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte { maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) } if maxBytes == 0 { - return nil + return nil, false } var ret []byte @@ -298,7 +299,7 @@ func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte { } s.writeOffset += protocol.ByteCount(len(ret)) s.flowController.AddBytesSent(protocol.ByteCount(len(ret))) - return ret + return ret, s.finishedWriting.Get() && s.dataForWriting == nil && !s.finSent.Get() } // Close implements io.Closer @@ -316,13 +317,6 @@ func (s *stream) shouldSendReset() bool { return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin() } -func (s *stream) ShouldSendFin() bool { - s.mutex.Lock() - res := s.finishedWriting.Get() && !s.finSent.Get() && s.err == nil && s.dataForWriting == nil - s.mutex.Unlock() - return res -} - func (s *stream) SentFin() { s.finSent.Set(true) } diff --git a/stream_framer.go b/stream_framer.go index 8e886cac..d4bcfaf5 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -67,7 +67,7 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str Offset: f.cryptoStream.GetWriteOffset(), } frameHeaderBytes, _ := frame.MinLength(f.version) // can never error - frame.Data = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes) + frame.Data, frame.FinBit = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes) return frame } @@ -115,24 +115,16 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] } maxLen := maxBytes - currentLen - frameHeaderBytes - var data []byte if s.HasDataForWriting() { - data = s.GetDataForWriting(maxLen) + frame.Data, frame.FinBit = s.GetDataForWriting(maxLen) } - - // This is unlikely, but check it nonetheless, the scheduler might have jumped in. Seems to happen in ~20% of cases in the tests. - shouldSendFin := s.ShouldSendFin() - if data == nil && !shouldSendFin { + if len(frame.Data) == 0 && !frame.FinBit { return true, nil } - - if shouldSendFin { - frame.FinBit = true + if frame.FinBit { s.SentFin() } - frame.Data = data - // Finally, check if we are now FC blocked and should queue a BLOCKED frame if !frame.FinBit && s.IsFlowControlBlocked() { f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.StreamBlockedFrame{StreamID: s.StreamID()}) diff --git a/stream_framer_test.go b/stream_framer_test.go index 97dcb11e..ea0669e2 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -51,8 +51,7 @@ var _ = Describe("Stream Framer", func() { setNoData := func(str *mocks.MockStreamI) { str.EXPECT().HasDataForWriting().Return(false).AnyTimes() - str.EXPECT().GetDataForWriting(gomock.Any()).Return(nil).AnyTimes() - str.EXPECT().ShouldSendFin().Return(false).AnyTimes() + str.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, false).AnyTimes() str.EXPECT().GetWriteOffset().AnyTimes() } @@ -76,9 +75,8 @@ var _ = Describe("Stream Framer", func() { setNoData(stream2) stream1.EXPECT().GetWriteOffset() stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) stream1.EXPECT().IsFlowControlBlocked() - stream1.EXPECT().ShouldSendFin() fs := framer.PopStreamFrames(protocol.MaxByteCount) Expect(fs).To(HaveLen(1)) Expect(fs[0].DataLenPresent).To(BeTrue()) @@ -111,10 +109,9 @@ var _ = Describe("Stream Framer", func() { }) It("returns normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() setNoData(stream2) fs := framer.PopStreamFrames(1000) Expect(fs).To(HaveLen(1)) @@ -124,14 +121,12 @@ var _ = Describe("Stream Framer", func() { }) It("returns multiple normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() - stream2.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobaz")) + stream2.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobaz"), false) stream2.EXPECT().HasDataForWriting().Return(true) stream2.EXPECT().GetWriteOffset() - stream2.EXPECT().ShouldSendFin() fs := framer.PopStreamFrames(1000) Expect(fs).To(HaveLen(2)) // Swap if we dequeued in other order @@ -145,10 +140,9 @@ var _ = Describe("Stream Framer", func() { }) It("returns retransmission frames before normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame1) fs := framer.PopStreamFrames(1000) @@ -159,7 +153,6 @@ var _ = Describe("Stream Framer", func() { It("does not pop empty frames", func() { stream1.EXPECT().HasDataForWriting().Return(false) - stream1.EXPECT().ShouldSendFin() stream1.EXPECT().GetWriteOffset() setNoData(stream2) fs := framer.PopStreamFrames(5) @@ -168,14 +161,12 @@ var _ = Describe("Stream Framer", func() { It("uses the round-robin scheduling", func() { streamFrameHeaderLen := protocol.ByteCount(4) - stream1.EXPECT().GetDataForWriting(10 - streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(10-streamFrameHeaderLen))) + stream1.EXPECT().GetDataForWriting(10-streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(10-streamFrameHeaderLen)), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() - stream2.EXPECT().GetDataForWriting(protocol.ByteCount(10 - streamFrameHeaderLen)).Return(bytes.Repeat([]byte("e"), int(10-streamFrameHeaderLen))) + stream2.EXPECT().GetDataForWriting(protocol.ByteCount(10-streamFrameHeaderLen)).Return(bytes.Repeat([]byte("e"), int(10-streamFrameHeaderLen)), false) stream2.EXPECT().HasDataForWriting().Return(true) stream2.EXPECT().GetWriteOffset() - stream2.EXPECT().ShouldSendFin() fs := framer.PopStreamFrames(10) Expect(fs).To(HaveLen(1)) // it doesn't matter here if this data is from stream1 or from stream2... @@ -279,9 +270,9 @@ var _ = Describe("Stream Framer", func() { Context("sending FINs", func() { It("sends FINs when streams are closed", func() { offset := protocol.ByteCount(42) - stream1.EXPECT().HasDataForWriting().Return(false) + stream1.EXPECT().HasDataForWriting().Return(true) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, true) stream1.EXPECT().GetWriteOffset().Return(offset) - stream1.EXPECT().ShouldSendFin().Return(true) stream1.EXPECT().SentFin() setNoData(stream2) @@ -295,10 +286,9 @@ var _ = Describe("Stream Framer", func() { It("bundles FINs with data", func() { offset := protocol.ByteCount(42) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), true) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset().Return(offset) - stream1.EXPECT().ShouldSendFin().Return(true) stream1.EXPECT().SentFin() setNoData(stream2) @@ -318,10 +308,9 @@ var _ = Describe("Stream Framer", func() { It("queues and pops BLOCKED frames for individually blocked streams", func() { connFC.EXPECT().IsBlocked() - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() stream1.EXPECT().IsFlowControlBlocked().Return(true) setNoData(stream2) frames := framer.PopStreamFrames(1000) @@ -335,10 +324,9 @@ var _ = Describe("Stream Framer", func() { It("does not queue a stream-level BLOCKED frame after sending the FinBit frame", func() { connFC.EXPECT().IsBlocked() - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), true) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin().Return(true) stream1.EXPECT().SentFin() setNoData(stream2) frames := framer.PopStreamFrames(1000) @@ -351,10 +339,9 @@ var _ = Describe("Stream Framer", func() { It("queues and pops BLOCKED frames for connection blocked streams", func() { connFC.EXPECT().IsBlocked().Return(true) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() stream1.EXPECT().IsFlowControlBlocked().Return(false) setNoData(stream2) framer.PopStreamFrames(1000) diff --git a/stream_test.go b/stream_test.go index fbb67f1e..0ac723c1 100644 --- a/stream_test.go +++ b/stream_test.go @@ -9,6 +9,7 @@ import ( "os" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -617,7 +618,7 @@ var _ = Describe("Stream", func() { close(done) }() - Eventually(func() []byte { return str.GetDataForWriting(4) }).ShouldNot(BeEmpty()) + Eventually(func() []byte { data, _ := str.GetDataForWriting(4); return data }).ShouldNot(BeEmpty()) str.RegisterRemoteError(testErr, 10) Eventually(done).Should(BeClosed()) }) @@ -775,8 +776,9 @@ var _ = Describe("Stream", func() { Consistently(done).ShouldNot(BeClosed()) Expect(onDataCalled).To(BeTrue()) Expect(str.HasDataForWriting()).To(BeTrue()) - data := str.GetDataForWriting(1000) + data, sendFin := str.GetDataForWriting(1000) Expect(data).To(Equal([]byte("foobar"))) + Expect(sendFin).To(BeFalse()) Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) Expect(str.dataForWriting).To(BeNil()) Eventually(done).Should(BeClosed()) @@ -800,13 +802,15 @@ var _ = Describe("Stream", func() { }).Should(Equal([]byte("foobar"))) Consistently(done).ShouldNot(BeClosed()) Expect(str.HasDataForWriting()).To(BeTrue()) - data := str.GetDataForWriting(3) + data, sendFin := str.GetDataForWriting(3) Expect(data).To(Equal([]byte("foo"))) + Expect(sendFin).To(BeFalse()) Expect(str.writeOffset).To(Equal(protocol.ByteCount(3))) Expect(str.dataForWriting).ToNot(BeNil()) Expect(str.HasDataForWriting()).To(BeTrue()) - data = str.GetDataForWriting(3) + data, sendFin = str.GetDataForWriting(3) Expect(data).To(Equal([]byte("bar"))) + Expect(sendFin).To(BeFalse()) Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) Expect(str.dataForWriting).To(BeNil()) Expect(str.HasDataForWriting()).To(BeFalse()) @@ -905,11 +909,6 @@ var _ = Describe("Stream", func() { }) Context("closing", func() { - It("sets finishedWriting when calling Close", func() { - str.Close() - Expect(str.finishedWriting.Get()).To(BeTrue()) - }) - It("doesn't allow writes after it has been closed", func() { str.Close() _, err := strWithTimeout.Write([]byte("foobar")) @@ -918,29 +917,51 @@ var _ = Describe("Stream", func() { It("allows FIN", func() { str.Close() - Expect(str.ShouldSendFin()).To(BeTrue()) + Expect(str.HasDataForWriting()).To(BeTrue()) + data, sendFin := str.GetDataForWriting(1000) + Expect(data).To(BeEmpty()) + Expect(sendFin).To(BeTrue()) }) It("does not allow FIN when there's still data", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) + mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) str.dataForWriting = []byte("foobar") str.Close() - Expect(str.ShouldSendFin()).To(BeFalse()) + Expect(str.HasDataForWriting()).To(BeTrue()) + data, sendFin := str.GetDataForWriting(3) + Expect(data).To(Equal([]byte("foo"))) + Expect(sendFin).To(BeFalse()) + data, sendFin = str.GetDataForWriting(3) + Expect(data).To(Equal([]byte("bar"))) + Expect(sendFin).To(BeTrue()) }) It("does not allow FIN when the stream is not closed", func() { - Expect(str.ShouldSendFin()).To(BeFalse()) + Expect(str.HasDataForWriting()).To(BeFalse()) + _, sendFin := str.GetDataForWriting(3) + Expect(sendFin).To(BeFalse()) }) It("does not allow FIN after an error", func() { str.Cancel(errors.New("test")) - Expect(str.ShouldSendFin()).To(BeFalse()) + Expect(str.HasDataForWriting()).To(BeFalse()) + data, sendFin := str.GetDataForWriting(1000) + Expect(data).To(BeEmpty()) + Expect(sendFin).To(BeFalse()) }) It("does not allow FIN twice", func() { str.Close() - Expect(str.ShouldSendFin()).To(BeTrue()) + Expect(str.HasDataForWriting()).To(BeTrue()) + data, sendFin := str.GetDataForWriting(1000) + Expect(data).To(BeEmpty()) + Expect(sendFin).To(BeTrue()) str.SentFin() - Expect(str.ShouldSendFin()).To(BeFalse()) + Expect(str.HasDataForWriting()).To(BeFalse()) + data, sendFin = str.GetDataForWriting(1000) + Expect(data).To(BeEmpty()) + Expect(sendFin).To(BeFalse()) }) }) @@ -963,8 +984,9 @@ var _ = Describe("Stream", func() { Eventually(func() []byte { return str.dataForWriting }).ShouldNot(BeNil()) Expect(str.HasDataForWriting()).To(BeTrue()) str.Cancel(testErr) - data := str.GetDataForWriting(6) + data, sendFin := str.GetDataForWriting(6) Expect(data).To(BeNil()) + Expect(sendFin).To(BeFalse()) Expect(str.HasDataForWriting()).To(BeFalse()) }) })