diff --git a/send_stream.go b/send_stream.go index 13c0a149..a588cc8a 100644 --- a/send_stream.go +++ b/send_stream.go @@ -37,10 +37,10 @@ type sendStream struct { writeOffset protocol.ByteCount - cancelWriteErr *StreamError - closeForShutdownErr error - - queuedResetStreamFrame bool + // finalError is the error that is returned by Write. + // It can either be a cancellation error or the shutdown error. + finalError error + queuedResetStreamFrame *wire.ResetStreamFrame finishedWriting bool // set once Close() is called finSent bool // set when a STREAM_FRAME with FIN bit has been sent @@ -48,6 +48,8 @@ type sendStream struct { // This can happen because the application called CancelWrite, // or because Write returned the error (for remote cancellations). cancellationFlagged bool + cancelled bool // both local and remote cancellations + closedForShutdown bool // set by closeForShutdown completed bool // set when this stream has been reported to the streamSender as completed dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out @@ -105,16 +107,15 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error) s.mutex.Lock() defer s.mutex.Unlock() + if s.finalError != nil { + if s.cancelled { + s.cancellationFlagged = true + } + return s.isNewlyCompleted(), 0, s.finalError + } if s.finishedWriting { return false, 0, fmt.Errorf("write on closed stream %d", s.streamID) } - if s.cancelWriteErr != nil { - s.cancellationFlagged = true - return s.isNewlyCompleted(), 0, s.cancelWriteErr - } - if s.closeForShutdownErr != nil { - return false, 0, s.closeForShutdownErr - } if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { return false, 0, errDeadline } @@ -168,7 +169,7 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error) } deadlineTimer.Reset(deadline) } - if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil { + if s.dataForWriting == nil || s.finalError != nil { break } } @@ -197,11 +198,11 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error) if bytesWritten == len(p) { return false, bytesWritten, nil } - if s.closeForShutdownErr != nil { - return false, bytesWritten, s.closeForShutdownErr - } else if s.cancelWriteErr != nil { - s.cancellationFlagged = true - return s.isNewlyCompleted(), bytesWritten, s.cancelWriteErr + if s.finalError != nil { + if s.cancelled { + s.cancellationFlagged = true + } + return s.isNewlyCompleted(), bytesWritten, s.finalError } return false, bytesWritten, nil } @@ -234,7 +235,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers } func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) { - if s.cancelWriteErr != nil || s.closeForShutdownErr != nil { + if s.finalError != nil { return nil, nil, false } @@ -374,7 +375,7 @@ func (s *sendStream) isNewlyCompleted() bool { return false } // We need to keep the stream around until all frames have been sent and acknowledged. - if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame { + if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame != nil { return false } // The stream is completed if we sent the FIN. @@ -387,7 +388,7 @@ func (s *sendStream) isNewlyCompleted() bool { // 2. we received a STOP_SENDING, and // * the application consumed the error via Write, or // * the application called Close - if s.cancelWriteErr != nil && (s.cancellationFlagged || s.finishedWriting) { + if s.cancelled && (s.cancellationFlagged || s.finishedWriting) { s.completed = true return true } @@ -396,13 +397,13 @@ func (s *sendStream) isNewlyCompleted() bool { func (s *sendStream) Close() error { s.mutex.Lock() - if s.closeForShutdownErr != nil || s.finishedWriting { + if s.closedForShutdown || s.finishedWriting { s.mutex.Unlock() return nil } s.finishedWriting = true - cancelWriteErr := s.cancelWriteErr - if cancelWriteErr != nil { + cancelled := s.cancelled + if cancelled { s.cancellationFlagged = true } completed := s.isNewlyCompleted() @@ -411,7 +412,7 @@ func (s *sendStream) Close() error { if completed { s.sender.onStreamCompleted(s.streamID) } - if cancelWriteErr != nil { + if cancelled { return fmt.Errorf("close called for canceled stream %d", s.streamID) } s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex @@ -421,18 +422,21 @@ func (s *sendStream) Close() error { } func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { - s.cancelWriteImpl(errorCode, false) + s.cancelWrite(errorCode, false) } -func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) { +// cancelWrite cancels the stream +// It is possible to cancel a stream after it has been closed, both locally and remotely. +// This is useful to prevent the retransmission of outstanding stream data. +func (s *sendStream) cancelWrite(errorCode qerr.StreamErrorCode, remote bool) { s.mutex.Lock() - if s.closeForShutdownErr != nil { + if s.closedForShutdown { s.mutex.Unlock() return } if !remote { s.cancellationFlagged = true - if s.cancelWriteErr != nil { + if s.cancelled { completed := s.isNewlyCompleted() s.mutex.Unlock() // The user has called CancelWrite. If the previous cancellation was @@ -444,15 +448,20 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool return } } - if s.cancelWriteErr != nil { + if s.cancelled { s.mutex.Unlock() return } - s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote} - s.ctxCancel(s.cancelWriteErr) + s.cancelled = true + s.finalError = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote} + s.ctxCancel(s.finalError) s.numOutstandingFrames = 0 s.retransmissionQueue = nil - s.queuedResetStreamFrame = true + s.queuedResetStreamFrame = &wire.ResetStreamFrame{ + StreamID: s.streamID, + FinalSize: s.writeOffset, + ErrorCode: errorCode, + } s.mutex.Unlock() s.signalWrite() @@ -473,26 +482,23 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) { } func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { - s.cancelWriteImpl(frame.ErrorCode, true) + s.cancelWrite(frame.ErrorCode, true) } func (s *sendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) { s.mutex.Lock() defer s.mutex.Unlock() - if !s.queuedResetStreamFrame { + if s.queuedResetStreamFrame == nil { return ackhandler.Frame{}, false, false } - s.queuedResetStreamFrame = false s.numOutstandingFrames++ - return ackhandler.Frame{ - Frame: &wire.ResetStreamFrame{ - StreamID: s.streamID, - FinalSize: s.writeOffset, - ErrorCode: s.cancelWriteErr.ErrorCode, - }, + f := ackhandler.Frame{ + Frame: s.queuedResetStreamFrame, Handler: (*sendStreamResetStreamHandler)(s), - }, true, false + } + s.queuedResetStreamFrame = nil + return f, true, false } func (s *sendStream) Context() context.Context { @@ -512,7 +518,10 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error { // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. func (s *sendStream) closeForShutdown(err error) { s.mutex.Lock() - s.closeForShutdownErr = err + s.closedForShutdown = true + if s.finalError == nil && !s.finishedWriting { + s.finalError = err + } s.mutex.Unlock() s.signalWrite() } @@ -533,7 +542,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) { sf := f.(*wire.StreamFrame) sf.PutBack() s.mutex.Lock() - if s.cancelWriteErr != nil { + if s.cancelled { s.mutex.Unlock() return } @@ -552,7 +561,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) { func (s *sendStreamAckHandler) OnLost(f wire.Frame) { sf := f.(*wire.StreamFrame) s.mutex.Lock() - if s.cancelWriteErr != nil { + if s.cancelled { s.mutex.Unlock() return } @@ -585,9 +594,9 @@ func (s *sendStreamResetStreamHandler) OnAcked(wire.Frame) { } } -func (s *sendStreamResetStreamHandler) OnLost(wire.Frame) { +func (s *sendStreamResetStreamHandler) OnLost(f wire.Frame) { s.mutex.Lock() - s.queuedResetStreamFrame = true + s.queuedResetStreamFrame = f.(*wire.ResetStreamFrame) s.numOutstandingFrames-- s.mutex.Unlock() s.sender.onHasStreamControlFrame(s.streamID, (*sendStream)(s)) diff --git a/send_stream_test.go b/send_stream_test.go index f169db7e..9e743e93 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -432,6 +432,11 @@ func TestSendStreamClose(t *testing.T) { require.Nil(t, frame.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) + + // shutting down has no effect + str.closeForShutdown(errors.New("goodbye")) + _, err = strWithTimeout.Write([]byte("foobar")) + require.ErrorContains(t, err, "write on closed stream 1234") } func TestSendStreamImmediateClose(t *testing.T) { @@ -630,21 +635,28 @@ func TestSendStreamCancellation(t *testing.T) { require.ErrorContains(t, str.Close(), "close called for canceled stream") frame, _, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) - _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + _, err = strWithTimeout.Write([]byte("foobar")) require.Error(t, err) - // TODO(#4808):error code and remote flag are unchanged - // require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) + + // shutting down has no effect + str.closeForShutdown(errors.New("goodbyte")) + _, err = strWithTimeout.Write([]byte("foobar")) + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) } +// It is possible to cancel a stream after it has been closed. +// This is useful if the applications wants to prevent the retransmission of outstanding stream data. func TestSendStreamCancellationAfterClose(t *testing.T) { const streamID protocol.StreamID = 1234 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC) + strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) - _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + _, err := strWithTimeout.Write([]byte("foobar")) require.NoError(t, err) require.NoError(t, str.Close()) @@ -659,6 +671,10 @@ func TestSendStreamCancellationAfterClose(t *testing.T) { require.True(t, ok) require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 0, ErrorCode: 1337}, cf.Frame) require.False(t, hasMore) + + _, err = strWithTimeout.Write([]byte("foobar")) + require.Error(t, err) + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: false}) } func TestSendStreamCancellationStreamRetransmission(t *testing.T) { @@ -809,8 +825,7 @@ func TestSendStreamStopSending(t *testing.T) { require.Nil(t, frame.Frame) _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.Error(t, err) - // TODO(#4808):error code and remote flag are unchanged - // require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) } // This test is inherently racy, as it tests a concurrent call to Write() and CancelRead().