diff --git a/session.go b/session.go index 83e10564..692d750d 100644 --- a/session.go +++ b/session.go @@ -849,14 +849,6 @@ func (s *session) WaitUntilHandshakeComplete() error { return <-s.handshakeCompleteChan } -func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { - s.packer.QueueControlFrame(&wire.RstStreamFrame{ - StreamID: id, - ByteOffset: offset, - }) - s.scheduleSending() -} - func (s *session) newStream(id protocol.StreamID) streamI { var initialSendWindow protocol.ByteCount if s.peerParams != nil { @@ -871,7 +863,7 @@ func (s *session) newStream(id protocol.StreamID) streamI { initialSendWindow, s.rttStats, ) - return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version) + return newStream(id, s.scheduleSending, s.packer.QueueControlFrame, flowController, s.version) } func (s *session) newCryptoStream() cryptoStreamI { diff --git a/session_test.go b/session_test.go index d697097b..e36f510c 100644 --- a/session_test.go +++ b/session_test.go @@ -357,15 +357,6 @@ var _ = Describe("Session", func() { Expect(err).ToNot(HaveOccurred()) }) - It("queues a RST_STERAM frame", func() { - sess.queueResetStreamFrame(5, 0x1337) - Expect(sess.packer.controlFrames).To(HaveLen(1)) - Expect(sess.packer.controlFrames[0].(*wire.RstStreamFrame)).To(Equal(&wire.RstStreamFrame{ - StreamID: 5, - ByteOffset: 0x1337, - })) - }) - It("returns errors", func() { testErr := errors.New("flow control violation") str, err := sess.GetOrOpenStream(5) diff --git a/stream.go b/stream.go index cb24f6d7..c1150c09 100644 --- a/stream.go +++ b/stream.go @@ -38,9 +38,11 @@ type stream struct { ctxCancel context.CancelFunc streamID protocol.StreamID - onData func() - // onReset is a callback that should send a RST_STREAM - onReset func(protocol.StreamID, protocol.ByteCount) + // onData tells the session that there's stuff to pack into a new packet + onData func() + // queueControlFrame queues a new control frame for sending + // it does not call onData + queueControlFrame func(wire.Frame) readPosInFrame int writeOffset protocol.ByteCount @@ -88,19 +90,19 @@ var errDeadline net.Error = &deadlineError{} // newStream creates a new Stream func newStream(StreamID protocol.StreamID, onData func(), - onReset func(protocol.StreamID, protocol.ByteCount), + queueControlFrame func(wire.Frame), flowController flowcontrol.StreamFlowController, version protocol.VersionNumber, ) *stream { s := &stream{ - onData: onData, - onReset: onReset, - streamID: StreamID, - flowController: flowController, - frameQueue: newStreamFrameSorter(), - readChan: make(chan struct{}, 1), - writeChan: make(chan struct{}, 1), - version: version, + onData: onData, + queueControlFrame: queueControlFrame, + streamID: StreamID, + flowController: flowController, + frameQueue: newStreamFrameSorter(), + readChan: make(chan struct{}, 1), + writeChan: make(chan struct{}, 1), + version: version, } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) return s @@ -421,7 +423,11 @@ func (s *stream) Reset(err error) { s.signalWrite() } if s.shouldSendReset() { - s.onReset(s.streamID, s.writeOffset) + s.queueControlFrame(&wire.RstStreamFrame{ + StreamID: s.streamID, + ByteOffset: s.writeOffset, + }) + s.onData() s.rstSent.Set(true) } s.mutex.Unlock() @@ -444,7 +450,11 @@ func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error return err } if s.shouldSendReset() { - s.onReset(s.streamID, s.writeOffset) + s.queueControlFrame(&wire.RstStreamFrame{ + StreamID: s.streamID, + ByteOffset: s.writeOffset, + }) + s.onData() s.rstSent.Set(true) } s.mutex.Unlock() diff --git a/stream_test.go b/stream_test.go index e5e1cfe5..4efe2127 100644 --- a/stream_test.go +++ b/stream_test.go @@ -28,9 +28,7 @@ var _ = Describe("Stream", func() { strWithTimeout io.ReadWriter // str wrapped with gbytes.Timeout{Reader,Writer} onDataCalled bool - resetCalled bool - resetCalledForStream protocol.StreamID - resetCalledAtOffset protocol.ByteCount + queuedControlFrames []wire.Frame mockFC *mocks.MockStreamFlowController ) @@ -51,17 +49,15 @@ var _ = Describe("Stream", func() { onDataCalled = true } - onReset := func(id protocol.StreamID, offset protocol.ByteCount) { - resetCalled = true - resetCalledForStream = id - resetCalledAtOffset = offset + queueControlFrame := func(f wire.Frame) { + queuedControlFrames = append(queuedControlFrames, f) } BeforeEach(func() { + queuedControlFrames = queuedControlFrames[:0] onDataCalled = false - resetCalled = false mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newStream(streamID, onData, onReset, mockFC, protocol.VersionWhatever) + str = newStream(streamID, onData, queueControlFrame, mockFC, protocol.VersionWhatever) timeout := scaleDuration(250 * time.Millisecond) strWithTimeout = struct { @@ -631,9 +627,12 @@ var _ = Describe("Stream", func() { close(done) }() str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeTrue()) - Expect(resetCalledForStream).To(Equal(protocol.StreamID(1337))) - Expect(resetCalledAtOffset).To(Equal(protocol.ByteCount(0x1000))) + Expect(queuedControlFrames).To(Equal([]wire.Frame{ + &wire.RstStreamFrame{ + StreamID: 1337, + ByteOffset: 0x1000, + }, + })) Eventually(done).Should(BeClosed()) }) @@ -643,25 +642,23 @@ var _ = Describe("Stream", func() { f := str.PopStreamFrame(100) Expect(f.FinBit).To(BeTrue()) str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeFalse()) + Expect(queuedControlFrames).To(BeEmpty()) }) - It("doesn't call onReset if the stream was reset locally before", func() { + It("doesn't call queue a RST_STREAM if the stream was reset locally before", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) str.Reset(testErr) - Expect(resetCalled).To(BeTrue()) - resetCalled = false + Expect(queuedControlFrames).To(HaveLen(1)) str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeFalse()) + Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame }) - It("doesn't call onReset twice, when it gets two remote errors", func() { + It("doesn't queue two RST_STREAMs twice, when it gets two remote errors", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeTrue()) - resetCalled = false + Expect(queuedControlFrames).To(HaveLen(1)) str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeFalse()) + Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame }) }) @@ -716,37 +713,38 @@ var _ = Describe("Stream", func() { Expect(err).To(MatchError(testErr)) }) - It("calls onReset", func() { + It("queues a RST_STREAM frame", func() { str.writeOffset = 0x1000 str.Reset(testErr) - Expect(resetCalled).To(BeTrue()) - Expect(resetCalledForStream).To(Equal(protocol.StreamID(1337))) - Expect(resetCalledAtOffset).To(Equal(protocol.ByteCount(0x1000))) + Expect(queuedControlFrames).To(Equal([]wire.Frame{ + &wire.RstStreamFrame{ + StreamID: 1337, + ByteOffset: 0x1000, + }, + })) }) - It("doesn't call onReset if it already sent a FIN", func() { + It("doesn't queue a RST_STREAM if it already sent a FIN", func() { str.Close() f := str.PopStreamFrame(1000) Expect(f.FinBit).To(BeTrue()) str.Reset(testErr) - Expect(resetCalled).To(BeFalse()) + Expect(queuedControlFrames).To(BeEmpty()) }) - It("doesn't call onReset if the stream was reset remotely before", func() { + It("doesn't queue a new RST_STREAM, if the stream was reset remotely before", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeTrue()) - resetCalled = false + Expect(queuedControlFrames).To(HaveLen(1)) str.Reset(testErr) - Expect(resetCalled).To(BeFalse()) + Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame }) It("doesn't call onReset twice", func() { str.Reset(testErr) - Expect(resetCalled).To(BeTrue()) - resetCalled = false + Expect(queuedControlFrames).To(HaveLen(1)) str.Reset(testErr) - Expect(resetCalled).To(BeFalse()) + Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame }) It("cancels the context", func() {