diff --git a/send_stream.go b/send_stream.go index 7f9b0eb40..10235c477 100644 --- a/send_stream.go +++ b/send_stream.go @@ -42,7 +42,11 @@ type sendStream struct { finishedWriting bool // set once Close() is called finSent bool // set when a STREAM_FRAME with FIN bit has been sent - completed bool // set when this stream has been reported to the streamSender as completed + // Set when the application knows about the cancellation. + // This can happen because the application called CancelWrite, + // or because Write returned the error (for remote cancellations). + cancellationFlagged bool + completed bool // set when this stream has been reported to the streamSender as completed dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out nextFrame *wire.StreamFrame @@ -87,23 +91,32 @@ func (s *sendStream) Write(p []byte) (int, error) { s.writeOnce <- struct{}{} defer func() { <-s.writeOnce }() + isNewlyCompleted, n, err := s.write(p) + if isNewlyCompleted { + s.sender.onStreamCompleted(s.streamID) + } + return n, err +} + +func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error) { s.mutex.Lock() defer s.mutex.Unlock() if s.finishedWriting { - return 0, fmt.Errorf("write on closed stream %d", s.streamID) + return false, 0, fmt.Errorf("write on closed stream %d", s.streamID) } if s.cancelWriteErr != nil { - return 0, s.cancelWriteErr + s.cancellationFlagged = true + return s.isNewlyCompleted(), 0, s.cancelWriteErr } if s.closeForShutdownErr != nil { - return 0, s.closeForShutdownErr + return false, 0, s.closeForShutdownErr } if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { - return 0, errDeadline + return false, 0, errDeadline } if len(p) == 0 { - return 0, nil + return false, 0, nil } s.dataForWriting = p @@ -144,7 +157,7 @@ func (s *sendStream) Write(p []byte) (int, error) { if !deadline.IsZero() { if !time.Now().Before(deadline) { s.dataForWriting = nil - return bytesWritten, errDeadline + return false, bytesWritten, errDeadline } if deadlineTimer == nil { deadlineTimer = utils.NewTimer() @@ -179,14 +192,15 @@ func (s *sendStream) Write(p []byte) (int, error) { } if bytesWritten == len(p) { - return bytesWritten, nil + return false, bytesWritten, nil } if s.closeForShutdownErr != nil { - return bytesWritten, s.closeForShutdownErr + return false, bytesWritten, s.closeForShutdownErr } else if s.cancelWriteErr != nil { - return bytesWritten, s.cancelWriteErr + s.cancellationFlagged = true + return s.isNewlyCompleted(), bytesWritten, s.cancelWriteErr } - return bytesWritten, nil + return false, bytesWritten, nil } func (s *sendStream) canBufferStreamFrame() bool { @@ -349,8 +363,24 @@ func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.By } func (s *sendStream) isNewlyCompleted() bool { - completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 - if completed && !s.completed { + if s.completed { + return false + } + // We need to keep the stream around until all frames have been sent and acknowledged. + if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 { + return false + } + // The stream is completed if we sent the FIN. + if s.finSent { + s.completed = true + return true + } + // The stream is also completed if: + // 1. the application called CancelWrite, or + // 2. we received a STOP_SENDING, and + // * the application consumed the error via Write, or + // * the application called CLsoe + if s.cancelWriteErr != nil && (s.cancellationFlagged || s.finishedWriting) { s.completed = true return true } @@ -363,15 +393,23 @@ func (s *sendStream) Close() error { s.mutex.Unlock() return nil } - if s.cancelWriteErr != nil { - s.mutex.Unlock() - return fmt.Errorf("close called for canceled stream %d", s.streamID) - } - s.ctxCancel(nil) s.finishedWriting = true + cancelWriteErr := s.cancelWriteErr + if cancelWriteErr != nil { + s.cancellationFlagged = true + } + completed := s.isNewlyCompleted() s.mutex.Unlock() + if completed { + s.sender.onStreamCompleted(s.streamID) + } + if cancelWriteErr != nil { + return fmt.Errorf("close called for canceled stream %d", s.streamID) + } s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex + + s.ctxCancel(nil) return nil } @@ -379,9 +417,11 @@ func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { s.cancelWriteImpl(errorCode, false) } -// must be called after locking the mutex func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) { s.mutex.Lock() + if !remote { + s.cancellationFlagged = true + } if s.cancelWriteErr != nil { s.mutex.Unlock() return diff --git a/send_stream_test.go b/send_stream_test.go index 1f9308baa..204a0f14e 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -908,8 +908,8 @@ var _ = Describe("Send Stream", func() { StreamID: streamID, ErrorCode: 101, }) - mockSender.EXPECT().onStreamCompleted(gomock.Any()) - + // Don't EXPECT calls to onStreamCompleted. + // The application needs to learn about the cancellation first. str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 101, @@ -919,10 +919,10 @@ var _ = Describe("Send Stream", func() { It("unblocks Write", func() { mockSender.EXPECT().onHasStreamData(streamID) mockSender.EXPECT().queueControlFrame(gomock.Any()) - mockSender.EXPECT().onStreamCompleted(gomock.Any()) done := make(chan struct{}) go func() { defer GinkgoRecover() + mockSender.EXPECT().onStreamCompleted(gomock.Any()) _, err := str.Write(getData(5000)) Expect(err).To(Equal(&StreamError{ StreamID: streamID, @@ -941,11 +941,11 @@ var _ = Describe("Send Stream", func() { It("doesn't allow further calls to Write", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) - mockSender.EXPECT().onStreamCompleted(gomock.Any()) str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 123, }) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) _, err := str.Write([]byte("foobar")) Expect(err).To(Equal(&StreamError{ StreamID: streamID, @@ -953,6 +953,44 @@ var _ = Describe("Send Stream", func() { Remote: true, })) }) + + It("handles Close after STOP_SENDING", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + str.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 123, + }) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) + str.Close() + }) + + It("handles STOP_SENDING after sending the FIN", func() { + mockSender.EXPECT().onHasStreamData(gomock.Any()) + str.Close() + _, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) + gomock.InOrder( + mockSender.EXPECT().queueControlFrame(gomock.Any()), + mockSender.EXPECT().onStreamCompleted(gomock.Any()), + ) + str.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 123, + }) + }) + + It("handles STOP_SENDING after Close, but before sending the FIN", func() { + mockSender.EXPECT().onHasStreamData(gomock.Any()) + str.Close() + gomock.InOrder( + mockSender.EXPECT().queueControlFrame(gomock.Any()), + mockSender.EXPECT().onStreamCompleted(gomock.Any()), + ) + str.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 123, + }) + }) }) })