diff --git a/send_stream.go b/send_stream.go index 94883392..6be4b898 100644 --- a/send_stream.go +++ b/send_stream.go @@ -157,7 +157,7 @@ func (s *sendStream) popStreamFrameImpl(maxBytes protocol.ByteCount) (bool /* co s.mutex.Lock() defer s.mutex.Unlock() - if s.closeForShutdownErr != nil { + if s.canceledWrite || s.closeForShutdownErr != nil { return false, nil, false } @@ -273,12 +273,6 @@ func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, wr return true, nil } -func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { - if completed := s.handleStopSendingFrameImpl(frame); completed { - s.sender.onStreamCompleted(s.streamID) - } -} - func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { s.mutex.Lock() hasStreamData := s.dataForWriting != nil @@ -289,6 +283,12 @@ func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { } } +func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { + if completed := s.handleStopSendingFrameImpl(frame); completed { + s.sender.onStreamCompleted(s.streamID) + } +} + // must be called after locking the mutex func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) bool /*completed*/ { s.mutex.Lock() diff --git a/send_stream_test.go b/send_stream_test.go index c2955c29..01814d84 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -549,12 +549,35 @@ var _ = Describe("Send Stream", func() { waitForWrite() frame, _ := str.popStreamFrame(50) Expect(frame).ToNot(BeNil()) - err := str.CancelWrite(1234) - Expect(err).ToNot(HaveOccurred()) + Expect(str.CancelWrite(1234)).To(Succeed()) Eventually(writeReturned).Should(BeClosed()) Expect(n).To(BeEquivalentTo(frame.DataLen())) }) + It("doesn't pop STREAM frames after being canceled", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().onStreamCompleted(streamID) + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) + Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) + close(writeReturned) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) + Expect(hasMoreData).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + Expect(str.CancelWrite(1234)).To(Succeed()) + frame, hasMoreData = str.popStreamFrame(10) + Expect(hasMoreData).To(BeFalse()) + Expect(frame).To(BeNil()) + Eventually(writeReturned).Should(BeClosed()) + }) + It("cancels the context", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) mockSender.EXPECT().onStreamCompleted(streamID)