From 085624be20d042b5994df51ca80fd734a48b7a6e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 7 Dec 2017 18:19:11 +0700 Subject: [PATCH 1/3] replace stream.LenOfDataForWriting by HasDataForWriting The return value (the length of data for writing) was only used to determine if the stream has data for writing. Therefore it's easier to just return a bool. No functional change expected. --- internal/mocks/stream.go | 24 ++++++++++++------------ stream.go | 12 +++++------- stream_framer.go | 4 ++-- stream_framer_test.go | 28 ++++++++++++++-------------- stream_test.go | 14 +++++++------- 5 files changed, 40 insertions(+), 42 deletions(-) diff --git a/internal/mocks/stream.go b/internal/mocks/stream.go index 5592282c..9d5230d1 100644 --- a/internal/mocks/stream.go +++ b/internal/mocks/stream.go @@ -130,6 +130,18 @@ func (_mr *MockStreamIMockRecorder) GetWriteOffset() *gomock.Call { return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWriteOffset", reflect.TypeOf((*MockStreamI)(nil).GetWriteOffset)) } +// HasDataForWriting mocks base method +func (_m *MockStreamI) HasDataForWriting() bool { + ret := _m.ctrl.Call(_m, "HasDataForWriting") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasDataForWriting indicates an expected call of HasDataForWriting +func (_mr *MockStreamIMockRecorder) HasDataForWriting() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "HasDataForWriting", reflect.TypeOf((*MockStreamI)(nil).HasDataForWriting)) +} + // IsFlowControlBlocked mocks base method func (_m *MockStreamI) IsFlowControlBlocked() bool { ret := _m.ctrl.Call(_m, "IsFlowControlBlocked") @@ -142,18 +154,6 @@ func (_mr *MockStreamIMockRecorder) IsFlowControlBlocked() *gomock.Call { return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "IsFlowControlBlocked", reflect.TypeOf((*MockStreamI)(nil).IsFlowControlBlocked)) } -// LenOfDataForWriting mocks base method -func (_m *MockStreamI) LenOfDataForWriting() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "LenOfDataForWriting") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// LenOfDataForWriting indicates an expected call of LenOfDataForWriting -func (_mr *MockStreamIMockRecorder) LenOfDataForWriting() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "LenOfDataForWriting", reflect.TypeOf((*MockStreamI)(nil).LenOfDataForWriting)) -} - // Read mocks base method func (_m *MockStreamI) Read(_param0 []byte) (int, error) { ret := _m.ctrl.Call(_m, "Read", _param0) diff --git a/stream.go b/stream.go index 3fe9a8d1..f8d9a629 100644 --- a/stream.go +++ b/stream.go @@ -19,7 +19,7 @@ type streamI interface { AddStreamFrame(*wire.StreamFrame) error RegisterRemoteError(error, protocol.ByteCount) error - LenOfDataForWriting() protocol.ByteCount + HasDataForWriting() bool GetDataForWriting(maxBytes protocol.ByteCount) []byte GetWriteOffset() protocol.ByteCount Finished() bool @@ -263,14 +263,12 @@ func (s *stream) GetWriteOffset() protocol.ByteCount { return s.writeOffset } -func (s *stream) LenOfDataForWriting() protocol.ByteCount { +// HasDataForWriting says if there's stream available to be dequeued for writing +func (s *stream) HasDataForWriting() bool { s.mutex.Lock() - var l protocol.ByteCount - if s.err == nil { - l = protocol.ByteCount(len(s.dataForWriting)) - } + hasData := s.err == nil && len(s.dataForWriting) > 0 s.mutex.Unlock() - return l + return hasData } func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte { diff --git a/stream_framer.go b/stream_framer.go index e16a01e9..8e886cac 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -54,7 +54,7 @@ func (f *streamFramer) HasFramesForRetransmission() bool { } func (f *streamFramer) HasCryptoStreamFrame() bool { - return f.cryptoStream.LenOfDataForWriting() > 0 + return f.cryptoStream.HasDataForWriting() } // TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. @@ -116,7 +116,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] maxLen := maxBytes - currentLen - frameHeaderBytes var data []byte - if s.LenOfDataForWriting() > 0 { + if s.HasDataForWriting() { data = s.GetDataForWriting(maxLen) } diff --git a/stream_framer_test.go b/stream_framer_test.go index 84897b9b..97dcb11e 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -50,7 +50,7 @@ var _ = Describe("Stream Framer", func() { }) setNoData := func(str *mocks.MockStreamI) { - str.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(0)).AnyTimes() + str.EXPECT().HasDataForWriting().Return(false).AnyTimes() str.EXPECT().GetDataForWriting(gomock.Any()).Return(nil).AnyTimes() str.EXPECT().ShouldSendFin().Return(false).AnyTimes() str.EXPECT().GetWriteOffset().AnyTimes() @@ -75,7 +75,7 @@ var _ = Describe("Stream Framer", func() { connFC.EXPECT().IsBlocked() setNoData(stream2) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(8)) + stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) stream1.EXPECT().IsFlowControlBlocked() stream1.EXPECT().ShouldSendFin() @@ -112,7 +112,7 @@ var _ = Describe("Stream Framer", func() { It("returns normal frames", func() { stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) - stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() stream1.EXPECT().ShouldSendFin() setNoData(stream2) @@ -125,11 +125,11 @@ var _ = Describe("Stream Framer", func() { It("returns multiple normal frames", func() { stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) - stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() stream1.EXPECT().ShouldSendFin() stream2.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobaz")) - stream2.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream2.EXPECT().HasDataForWriting().Return(true) stream2.EXPECT().GetWriteOffset() stream2.EXPECT().ShouldSendFin() fs := framer.PopStreamFrames(1000) @@ -146,7 +146,7 @@ var _ = Describe("Stream Framer", func() { It("returns retransmission frames before normal frames", func() { stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) - stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() stream1.EXPECT().ShouldSendFin() setNoData(stream2) @@ -158,7 +158,7 @@ var _ = Describe("Stream Framer", func() { }) It("does not pop empty frames", func() { - stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(0)) + stream1.EXPECT().HasDataForWriting().Return(false) stream1.EXPECT().ShouldSendFin() stream1.EXPECT().GetWriteOffset() setNoData(stream2) @@ -169,11 +169,11 @@ var _ = Describe("Stream Framer", func() { It("uses the round-robin scheduling", func() { streamFrameHeaderLen := protocol.ByteCount(4) stream1.EXPECT().GetDataForWriting(10 - streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(10-streamFrameHeaderLen))) - stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(100)) + stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() stream1.EXPECT().ShouldSendFin() stream2.EXPECT().GetDataForWriting(protocol.ByteCount(10 - streamFrameHeaderLen)).Return(bytes.Repeat([]byte("e"), int(10-streamFrameHeaderLen))) - stream2.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(100)) + stream2.EXPECT().HasDataForWriting().Return(true) stream2.EXPECT().GetWriteOffset() stream2.EXPECT().ShouldSendFin() fs := framer.PopStreamFrames(10) @@ -279,7 +279,7 @@ var _ = Describe("Stream Framer", func() { Context("sending FINs", func() { It("sends FINs when streams are closed", func() { offset := protocol.ByteCount(42) - stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(0)) + stream1.EXPECT().HasDataForWriting().Return(false) stream1.EXPECT().GetWriteOffset().Return(offset) stream1.EXPECT().ShouldSendFin().Return(true) stream1.EXPECT().SentFin() @@ -296,7 +296,7 @@ var _ = Describe("Stream Framer", func() { It("bundles FINs with data", func() { offset := protocol.ByteCount(42) stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) - stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset().Return(offset) stream1.EXPECT().ShouldSendFin().Return(true) stream1.EXPECT().SentFin() @@ -319,7 +319,7 @@ var _ = Describe("Stream Framer", func() { It("queues and pops BLOCKED frames for individually blocked streams", func() { connFC.EXPECT().IsBlocked() stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) - stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() stream1.EXPECT().ShouldSendFin() stream1.EXPECT().IsFlowControlBlocked().Return(true) @@ -336,7 +336,7 @@ var _ = Describe("Stream Framer", func() { It("does not queue a stream-level BLOCKED frame after sending the FinBit frame", func() { connFC.EXPECT().IsBlocked() stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo")) - stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(3)) + stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() stream1.EXPECT().ShouldSendFin().Return(true) stream1.EXPECT().SentFin() @@ -352,7 +352,7 @@ var _ = Describe("Stream Framer", func() { It("queues and pops BLOCKED frames for connection blocked streams", func() { connFC.EXPECT().IsBlocked().Return(true) stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo")) - stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(3)) + stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() stream1.EXPECT().ShouldSendFin() stream1.EXPECT().IsFlowControlBlocked().Return(false) diff --git a/stream_test.go b/stream_test.go index 0fec4f02..fbb67f1e 100644 --- a/stream_test.go +++ b/stream_test.go @@ -774,7 +774,7 @@ var _ = Describe("Stream", func() { }).Should(Equal([]byte("foobar"))) Consistently(done).ShouldNot(BeClosed()) Expect(onDataCalled).To(BeTrue()) - Expect(str.LenOfDataForWriting()).To(Equal(protocol.ByteCount(6))) + Expect(str.HasDataForWriting()).To(BeTrue()) data := str.GetDataForWriting(1000) Expect(data).To(Equal([]byte("foobar"))) Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) @@ -799,17 +799,17 @@ var _ = Describe("Stream", func() { return str.dataForWriting }).Should(Equal([]byte("foobar"))) Consistently(done).ShouldNot(BeClosed()) - Expect(str.LenOfDataForWriting()).To(Equal(protocol.ByteCount(6))) + Expect(str.HasDataForWriting()).To(BeTrue()) data := str.GetDataForWriting(3) Expect(data).To(Equal([]byte("foo"))) Expect(str.writeOffset).To(Equal(protocol.ByteCount(3))) Expect(str.dataForWriting).ToNot(BeNil()) - Expect(str.LenOfDataForWriting()).To(Equal(protocol.ByteCount(3))) + Expect(str.HasDataForWriting()).To(BeTrue()) data = str.GetDataForWriting(3) Expect(data).To(Equal([]byte("bar"))) Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) Expect(str.dataForWriting).To(BeNil()) - Expect(str.LenOfDataForWriting()).To(Equal(protocol.ByteCount(0))) + Expect(str.HasDataForWriting()).To(BeFalse()) Eventually(done).Should(BeClosed()) }) @@ -827,7 +827,7 @@ var _ = Describe("Stream", func() { Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(3)) }() - Eventually(func() protocol.ByteCount { return str.LenOfDataForWriting() }).ShouldNot(BeZero()) + Eventually(func() bool { return str.HasDataForWriting() }).Should(BeTrue()) s[0] = 'v' Expect(str.GetDataForWriting(3)).To(Equal([]byte("foo"))) }) @@ -961,11 +961,11 @@ var _ = Describe("Stream", func() { Expect(err).To(MatchError(testErr)) }() Eventually(func() []byte { return str.dataForWriting }).ShouldNot(BeNil()) - Expect(str.LenOfDataForWriting()).ToNot(BeZero()) + Expect(str.HasDataForWriting()).To(BeTrue()) str.Cancel(testErr) data := str.GetDataForWriting(6) Expect(data).To(BeNil()) - Expect(str.LenOfDataForWriting()).To(BeZero()) + Expect(str.HasDataForWriting()).To(BeFalse()) }) }) }) From 71af5758e2c373a2e2b07273df0fddf9ff7f3c73 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 7 Dec 2017 19:13:21 +0700 Subject: [PATCH 2/3] remove the ShouldSendFin method from the stream GetDataForWriting now has two return parameters: the data and if a FIN should be sent. --- internal/mocks/stream.go | 17 +++---------- stream.go | 22 ++++++---------- stream_framer.go | 16 +++--------- stream_framer_test.go | 41 +++++++++++------------------- stream_test.go | 54 ++++++++++++++++++++++++++++------------ 5 files changed, 67 insertions(+), 83 deletions(-) diff --git a/internal/mocks/stream.go b/internal/mocks/stream.go index 9d5230d1..a9a2b25e 100644 --- a/internal/mocks/stream.go +++ b/internal/mocks/stream.go @@ -95,10 +95,11 @@ func (_mr *MockStreamIMockRecorder) Finished() *gomock.Call { } // GetDataForWriting mocks base method -func (_m *MockStreamI) GetDataForWriting(_param0 protocol.ByteCount) []byte { +func (_m *MockStreamI) GetDataForWriting(_param0 protocol.ByteCount) ([]byte, bool) { ret := _m.ctrl.Call(_m, "GetDataForWriting", _param0) ret0, _ := ret[0].([]byte) - return ret0 + ret1, _ := ret[1].(bool) + return ret0, ret1 } // GetDataForWriting indicates an expected call of GetDataForWriting @@ -235,18 +236,6 @@ func (_mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.C return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), arg0) } -// ShouldSendFin mocks base method -func (_m *MockStreamI) ShouldSendFin() bool { - ret := _m.ctrl.Call(_m, "ShouldSendFin") - ret0, _ := ret[0].(bool) - return ret0 -} - -// ShouldSendFin indicates an expected call of ShouldSendFin -func (_mr *MockStreamIMockRecorder) ShouldSendFin() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ShouldSendFin", reflect.TypeOf((*MockStreamI)(nil).ShouldSendFin)) -} - // StreamID mocks base method func (_m *MockStreamI) StreamID() protocol.StreamID { ret := _m.ctrl.Call(_m, "StreamID") diff --git a/stream.go b/stream.go index f8d9a629..6e827d50 100644 --- a/stream.go +++ b/stream.go @@ -20,11 +20,10 @@ type streamI interface { AddStreamFrame(*wire.StreamFrame) error RegisterRemoteError(error, protocol.ByteCount) error HasDataForWriting() bool - GetDataForWriting(maxBytes protocol.ByteCount) []byte + GetDataForWriting(maxBytes protocol.ByteCount) (data []byte, shouldSendFin bool) GetWriteOffset() protocol.ByteCount Finished() bool Cancel(error) - ShouldSendFin() bool SentFin() // methods needed for flow control GetWindowUpdate() protocol.ByteCount @@ -266,17 +265,19 @@ func (s *stream) GetWriteOffset() protocol.ByteCount { // HasDataForWriting says if there's stream available to be dequeued for writing func (s *stream) HasDataForWriting() bool { s.mutex.Lock() - hasData := s.err == nil && len(s.dataForWriting) > 0 + hasData := s.err == nil && // nothing should be sent if an error occurred + (len(s.dataForWriting) > 0 || // there is data queued for sending + s.finishedWriting.Get() && !s.finSent.Get()) // if there is no data, but writing finished and the FIN hasn't been sent yet s.mutex.Unlock() return hasData } -func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte { +func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { s.mutex.Lock() defer s.mutex.Unlock() if s.err != nil || s.dataForWriting == nil { - return nil + return nil, s.finishedWriting.Get() && !s.finSent.Get() } // TODO(#657): Flow control for the crypto stream @@ -284,7 +285,7 @@ func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte { maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) } if maxBytes == 0 { - return nil + return nil, false } var ret []byte @@ -298,7 +299,7 @@ func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte { } s.writeOffset += protocol.ByteCount(len(ret)) s.flowController.AddBytesSent(protocol.ByteCount(len(ret))) - return ret + return ret, s.finishedWriting.Get() && s.dataForWriting == nil && !s.finSent.Get() } // Close implements io.Closer @@ -316,13 +317,6 @@ func (s *stream) shouldSendReset() bool { return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin() } -func (s *stream) ShouldSendFin() bool { - s.mutex.Lock() - res := s.finishedWriting.Get() && !s.finSent.Get() && s.err == nil && s.dataForWriting == nil - s.mutex.Unlock() - return res -} - func (s *stream) SentFin() { s.finSent.Set(true) } diff --git a/stream_framer.go b/stream_framer.go index 8e886cac..d4bcfaf5 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -67,7 +67,7 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str Offset: f.cryptoStream.GetWriteOffset(), } frameHeaderBytes, _ := frame.MinLength(f.version) // can never error - frame.Data = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes) + frame.Data, frame.FinBit = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes) return frame } @@ -115,24 +115,16 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] } maxLen := maxBytes - currentLen - frameHeaderBytes - var data []byte if s.HasDataForWriting() { - data = s.GetDataForWriting(maxLen) + frame.Data, frame.FinBit = s.GetDataForWriting(maxLen) } - - // This is unlikely, but check it nonetheless, the scheduler might have jumped in. Seems to happen in ~20% of cases in the tests. - shouldSendFin := s.ShouldSendFin() - if data == nil && !shouldSendFin { + if len(frame.Data) == 0 && !frame.FinBit { return true, nil } - - if shouldSendFin { - frame.FinBit = true + if frame.FinBit { s.SentFin() } - frame.Data = data - // Finally, check if we are now FC blocked and should queue a BLOCKED frame if !frame.FinBit && s.IsFlowControlBlocked() { f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.StreamBlockedFrame{StreamID: s.StreamID()}) diff --git a/stream_framer_test.go b/stream_framer_test.go index 97dcb11e..ea0669e2 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -51,8 +51,7 @@ var _ = Describe("Stream Framer", func() { setNoData := func(str *mocks.MockStreamI) { str.EXPECT().HasDataForWriting().Return(false).AnyTimes() - str.EXPECT().GetDataForWriting(gomock.Any()).Return(nil).AnyTimes() - str.EXPECT().ShouldSendFin().Return(false).AnyTimes() + str.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, false).AnyTimes() str.EXPECT().GetWriteOffset().AnyTimes() } @@ -76,9 +75,8 @@ var _ = Describe("Stream Framer", func() { setNoData(stream2) stream1.EXPECT().GetWriteOffset() stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) stream1.EXPECT().IsFlowControlBlocked() - stream1.EXPECT().ShouldSendFin() fs := framer.PopStreamFrames(protocol.MaxByteCount) Expect(fs).To(HaveLen(1)) Expect(fs[0].DataLenPresent).To(BeTrue()) @@ -111,10 +109,9 @@ var _ = Describe("Stream Framer", func() { }) It("returns normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() setNoData(stream2) fs := framer.PopStreamFrames(1000) Expect(fs).To(HaveLen(1)) @@ -124,14 +121,12 @@ var _ = Describe("Stream Framer", func() { }) It("returns multiple normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() - stream2.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobaz")) + stream2.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobaz"), false) stream2.EXPECT().HasDataForWriting().Return(true) stream2.EXPECT().GetWriteOffset() - stream2.EXPECT().ShouldSendFin() fs := framer.PopStreamFrames(1000) Expect(fs).To(HaveLen(2)) // Swap if we dequeued in other order @@ -145,10 +140,9 @@ var _ = Describe("Stream Framer", func() { }) It("returns retransmission frames before normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame1) fs := framer.PopStreamFrames(1000) @@ -159,7 +153,6 @@ var _ = Describe("Stream Framer", func() { It("does not pop empty frames", func() { stream1.EXPECT().HasDataForWriting().Return(false) - stream1.EXPECT().ShouldSendFin() stream1.EXPECT().GetWriteOffset() setNoData(stream2) fs := framer.PopStreamFrames(5) @@ -168,14 +161,12 @@ var _ = Describe("Stream Framer", func() { It("uses the round-robin scheduling", func() { streamFrameHeaderLen := protocol.ByteCount(4) - stream1.EXPECT().GetDataForWriting(10 - streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(10-streamFrameHeaderLen))) + stream1.EXPECT().GetDataForWriting(10-streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(10-streamFrameHeaderLen)), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() - stream2.EXPECT().GetDataForWriting(protocol.ByteCount(10 - streamFrameHeaderLen)).Return(bytes.Repeat([]byte("e"), int(10-streamFrameHeaderLen))) + stream2.EXPECT().GetDataForWriting(protocol.ByteCount(10-streamFrameHeaderLen)).Return(bytes.Repeat([]byte("e"), int(10-streamFrameHeaderLen)), false) stream2.EXPECT().HasDataForWriting().Return(true) stream2.EXPECT().GetWriteOffset() - stream2.EXPECT().ShouldSendFin() fs := framer.PopStreamFrames(10) Expect(fs).To(HaveLen(1)) // it doesn't matter here if this data is from stream1 or from stream2... @@ -279,9 +270,9 @@ var _ = Describe("Stream Framer", func() { Context("sending FINs", func() { It("sends FINs when streams are closed", func() { offset := protocol.ByteCount(42) - stream1.EXPECT().HasDataForWriting().Return(false) + stream1.EXPECT().HasDataForWriting().Return(true) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, true) stream1.EXPECT().GetWriteOffset().Return(offset) - stream1.EXPECT().ShouldSendFin().Return(true) stream1.EXPECT().SentFin() setNoData(stream2) @@ -295,10 +286,9 @@ var _ = Describe("Stream Framer", func() { It("bundles FINs with data", func() { offset := protocol.ByteCount(42) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), true) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset().Return(offset) - stream1.EXPECT().ShouldSendFin().Return(true) stream1.EXPECT().SentFin() setNoData(stream2) @@ -318,10 +308,9 @@ var _ = Describe("Stream Framer", func() { It("queues and pops BLOCKED frames for individually blocked streams", func() { connFC.EXPECT().IsBlocked() - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() stream1.EXPECT().IsFlowControlBlocked().Return(true) setNoData(stream2) frames := framer.PopStreamFrames(1000) @@ -335,10 +324,9 @@ var _ = Describe("Stream Framer", func() { It("does not queue a stream-level BLOCKED frame after sending the FinBit frame", func() { connFC.EXPECT().IsBlocked() - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), true) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin().Return(true) stream1.EXPECT().SentFin() setNoData(stream2) frames := framer.PopStreamFrames(1000) @@ -351,10 +339,9 @@ var _ = Describe("Stream Framer", func() { It("queues and pops BLOCKED frames for connection blocked streams", func() { connFC.EXPECT().IsBlocked().Return(true) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo")) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), false) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().ShouldSendFin() stream1.EXPECT().IsFlowControlBlocked().Return(false) setNoData(stream2) framer.PopStreamFrames(1000) diff --git a/stream_test.go b/stream_test.go index fbb67f1e..0ac723c1 100644 --- a/stream_test.go +++ b/stream_test.go @@ -9,6 +9,7 @@ import ( "os" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -617,7 +618,7 @@ var _ = Describe("Stream", func() { close(done) }() - Eventually(func() []byte { return str.GetDataForWriting(4) }).ShouldNot(BeEmpty()) + Eventually(func() []byte { data, _ := str.GetDataForWriting(4); return data }).ShouldNot(BeEmpty()) str.RegisterRemoteError(testErr, 10) Eventually(done).Should(BeClosed()) }) @@ -775,8 +776,9 @@ var _ = Describe("Stream", func() { Consistently(done).ShouldNot(BeClosed()) Expect(onDataCalled).To(BeTrue()) Expect(str.HasDataForWriting()).To(BeTrue()) - data := str.GetDataForWriting(1000) + data, sendFin := str.GetDataForWriting(1000) Expect(data).To(Equal([]byte("foobar"))) + Expect(sendFin).To(BeFalse()) Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) Expect(str.dataForWriting).To(BeNil()) Eventually(done).Should(BeClosed()) @@ -800,13 +802,15 @@ var _ = Describe("Stream", func() { }).Should(Equal([]byte("foobar"))) Consistently(done).ShouldNot(BeClosed()) Expect(str.HasDataForWriting()).To(BeTrue()) - data := str.GetDataForWriting(3) + data, sendFin := str.GetDataForWriting(3) Expect(data).To(Equal([]byte("foo"))) + Expect(sendFin).To(BeFalse()) Expect(str.writeOffset).To(Equal(protocol.ByteCount(3))) Expect(str.dataForWriting).ToNot(BeNil()) Expect(str.HasDataForWriting()).To(BeTrue()) - data = str.GetDataForWriting(3) + data, sendFin = str.GetDataForWriting(3) Expect(data).To(Equal([]byte("bar"))) + Expect(sendFin).To(BeFalse()) Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) Expect(str.dataForWriting).To(BeNil()) Expect(str.HasDataForWriting()).To(BeFalse()) @@ -905,11 +909,6 @@ var _ = Describe("Stream", func() { }) Context("closing", func() { - It("sets finishedWriting when calling Close", func() { - str.Close() - Expect(str.finishedWriting.Get()).To(BeTrue()) - }) - It("doesn't allow writes after it has been closed", func() { str.Close() _, err := strWithTimeout.Write([]byte("foobar")) @@ -918,29 +917,51 @@ var _ = Describe("Stream", func() { It("allows FIN", func() { str.Close() - Expect(str.ShouldSendFin()).To(BeTrue()) + Expect(str.HasDataForWriting()).To(BeTrue()) + data, sendFin := str.GetDataForWriting(1000) + Expect(data).To(BeEmpty()) + Expect(sendFin).To(BeTrue()) }) It("does not allow FIN when there's still data", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) + mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) str.dataForWriting = []byte("foobar") str.Close() - Expect(str.ShouldSendFin()).To(BeFalse()) + Expect(str.HasDataForWriting()).To(BeTrue()) + data, sendFin := str.GetDataForWriting(3) + Expect(data).To(Equal([]byte("foo"))) + Expect(sendFin).To(BeFalse()) + data, sendFin = str.GetDataForWriting(3) + Expect(data).To(Equal([]byte("bar"))) + Expect(sendFin).To(BeTrue()) }) It("does not allow FIN when the stream is not closed", func() { - Expect(str.ShouldSendFin()).To(BeFalse()) + Expect(str.HasDataForWriting()).To(BeFalse()) + _, sendFin := str.GetDataForWriting(3) + Expect(sendFin).To(BeFalse()) }) It("does not allow FIN after an error", func() { str.Cancel(errors.New("test")) - Expect(str.ShouldSendFin()).To(BeFalse()) + Expect(str.HasDataForWriting()).To(BeFalse()) + data, sendFin := str.GetDataForWriting(1000) + Expect(data).To(BeEmpty()) + Expect(sendFin).To(BeFalse()) }) It("does not allow FIN twice", func() { str.Close() - Expect(str.ShouldSendFin()).To(BeTrue()) + Expect(str.HasDataForWriting()).To(BeTrue()) + data, sendFin := str.GetDataForWriting(1000) + Expect(data).To(BeEmpty()) + Expect(sendFin).To(BeTrue()) str.SentFin() - Expect(str.ShouldSendFin()).To(BeFalse()) + Expect(str.HasDataForWriting()).To(BeFalse()) + data, sendFin = str.GetDataForWriting(1000) + Expect(data).To(BeEmpty()) + Expect(sendFin).To(BeFalse()) }) }) @@ -963,8 +984,9 @@ var _ = Describe("Stream", func() { Eventually(func() []byte { return str.dataForWriting }).ShouldNot(BeNil()) Expect(str.HasDataForWriting()).To(BeTrue()) str.Cancel(testErr) - data := str.GetDataForWriting(6) + data, sendFin := str.GetDataForWriting(6) Expect(data).To(BeNil()) + Expect(sendFin).To(BeFalse()) Expect(str.HasDataForWriting()).To(BeFalse()) }) }) From 8e8892b06439b53fe83367af1c46c8d899f15448 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 7 Dec 2017 19:26:01 +0700 Subject: [PATCH 3/3] remove the SentFin method from the stream When a FIN is dequeued from the stream by the streamFramer, it is guaranteed to be sent out. There's no need to explicitely signal that to the stream. --- internal/mocks/stream.go | 10 ---------- session_test.go | 6 +++--- stream.go | 13 ++++++++----- stream_framer.go | 3 --- stream_framer_test.go | 3 --- stream_test.go | 14 ++++++++------ 6 files changed, 19 insertions(+), 30 deletions(-) diff --git a/internal/mocks/stream.go b/internal/mocks/stream.go index a9a2b25e..1e6b4465 100644 --- a/internal/mocks/stream.go +++ b/internal/mocks/stream.go @@ -190,16 +190,6 @@ func (_mr *MockStreamIMockRecorder) Reset(arg0 interface{}) *gomock.Call { return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Reset", reflect.TypeOf((*MockStreamI)(nil).Reset), arg0) } -// SentFin mocks base method -func (_m *MockStreamI) SentFin() { - _m.ctrl.Call(_m, "SentFin") -} - -// SentFin indicates an expected call of SentFin -func (_mr *MockStreamIMockRecorder) SentFin() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SentFin", reflect.TypeOf((*MockStreamI)(nil).SentFin)) -} - // SetDeadline mocks base method func (_m *MockStreamI) SetDeadline(_param0 time.Time) error { ret := _m.ctrl.Call(_m, "SetDeadline", _param0) diff --git a/session_test.go b/session_test.go index ba731ac6..52c9ad9c 100644 --- a/session_test.go +++ b/session_test.go @@ -1434,9 +1434,9 @@ var _ = Describe("Session", func() { for i := 2; i <= 1000; i++ { s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) - err = s.Close() - Expect(err).NotTo(HaveOccurred()) - s.(*stream).SentFin() + Expect(s.Close()).To(Succeed()) + _, sentFin := s.(*stream).GetDataForWriting(1000) // trigger "sending" of the FIN bit + Expect(sentFin).To(BeTrue()) s.(*stream).CloseRemote(0) _, err = s.Read([]byte("a")) Expect(err).To(MatchError(io.EOF)) diff --git a/stream.go b/stream.go index 6e827d50..0e4f34e7 100644 --- a/stream.go +++ b/stream.go @@ -24,7 +24,6 @@ type streamI interface { GetWriteOffset() protocol.ByteCount Finished() bool Cancel(error) - SentFin() // methods needed for flow control GetWindowUpdate() protocol.ByteCount UpdateSendWindow(protocol.ByteCount) @@ -273,6 +272,14 @@ func (s *stream) HasDataForWriting() bool { } func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { + data, shouldSendFin := s.getDataForWritingImpl(maxBytes) + if shouldSendFin { + s.finSent.Set(true) + } + return data, shouldSendFin +} + +func (s *stream) getDataForWritingImpl(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { s.mutex.Lock() defer s.mutex.Unlock() @@ -317,10 +324,6 @@ func (s *stream) shouldSendReset() bool { return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin() } -func (s *stream) SentFin() { - s.finSent.Set(true) -} - // AddStreamFrame adds a new stream frame func (s *stream) AddStreamFrame(frame *wire.StreamFrame) error { maxOffset := frame.Offset + frame.DataLen() diff --git a/stream_framer.go b/stream_framer.go index d4bcfaf5..e275fccd 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -121,9 +121,6 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] if len(frame.Data) == 0 && !frame.FinBit { return true, nil } - if frame.FinBit { - s.SentFin() - } // Finally, check if we are now FC blocked and should queue a BLOCKED frame if !frame.FinBit && s.IsFlowControlBlocked() { diff --git a/stream_framer_test.go b/stream_framer_test.go index ea0669e2..0dabce53 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -273,7 +273,6 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, true) stream1.EXPECT().GetWriteOffset().Return(offset) - stream1.EXPECT().SentFin() setNoData(stream2) fs := framer.PopStreamFrames(1000) @@ -289,7 +288,6 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), true) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset().Return(offset) - stream1.EXPECT().SentFin() setNoData(stream2) fs := framer.PopStreamFrames(1000) @@ -327,7 +325,6 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), true) stream1.EXPECT().HasDataForWriting().Return(true) stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().SentFin() setNoData(stream2) frames := framer.PopStreamFrames(1000) Expect(frames).To(HaveLen(1)) diff --git a/stream_test.go b/stream_test.go index 0ac723c1..ce2ab483 100644 --- a/stream_test.go +++ b/stream_test.go @@ -641,7 +641,8 @@ var _ = Describe("Stream", func() { It("doesn't call onReset if it already sent a FIN", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) str.Close() - str.SentFin() + _, sentFin := str.GetDataForWriting(1000) + Expect(sentFin).To(BeTrue()) str.RegisterRemoteError(testErr, 0) Expect(resetCalled).To(BeFalse()) }) @@ -726,7 +727,8 @@ var _ = Describe("Stream", func() { It("doesn't call onReset if it already sent a FIN", func() { str.Close() - str.SentFin() + _, sentFin := str.GetDataForWriting(1000) + Expect(sentFin).To(BeTrue()) str.Reset(testErr) Expect(resetCalled).To(BeFalse()) }) @@ -957,7 +959,6 @@ var _ = Describe("Stream", func() { data, sendFin := str.GetDataForWriting(1000) Expect(data).To(BeEmpty()) Expect(sendFin).To(BeTrue()) - str.SentFin() Expect(str.HasDataForWriting()).To(BeFalse()) data, sendFin = str.GetDataForWriting(1000) Expect(data).To(BeEmpty()) @@ -1022,14 +1023,14 @@ var _ = Describe("Stream", func() { It("is not finished if it is only closed for writing", func() { str.Close() - str.SentFin() + _, sentFin := str.GetDataForWriting(1000) + Expect(sentFin).To(BeTrue()) Expect(str.Finished()).To(BeFalse()) }) It("cancels the context after it is closed", func() { Expect(str.Context().Done()).ToNot(BeClosed()) str.Close() - str.SentFin() Expect(str.Context().Done()).To(BeClosed()) }) @@ -1065,7 +1066,8 @@ var _ = Describe("Stream", func() { It("is finished after finishing writing and receiving a RST", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(13), true) str.Close() - str.SentFin() + _, sentFin := str.GetDataForWriting(1000) + Expect(sentFin).To(BeTrue()) str.RegisterRemoteError(testErr, 13) Expect(str.Finished()).To(BeTrue()) })