From 4b87539b1efd2ffd05996fb2b9a7519a9cdb9cd2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 26 Apr 2024 17:48:00 +0200 Subject: [PATCH] delay completion of the receive stream until the reset error was read (#4460) * delay completion of the receive stream until the reset error was read * fix handling of CancelRead after receiving a RESET_STREAM --- receive_stream.go | 136 ++++++++++++++++++++++++----------------- receive_stream_test.go | 65 ++++++++++++++------ 2 files changed, 127 insertions(+), 74 deletions(-) diff --git a/receive_stream.go b/receive_stream.go index 1235ff0e..19759ad9 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -37,10 +37,14 @@ type receiveStream struct { readPosInFrame int currentFrameIsLast bool // is the currentFrame the last frame on this stream - finRead bool // set once we read a frame with a Fin + // Set once we read the io.EOF or the cancellation error. + // Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet. + errorRead bool + completed bool // set once we've called streamSender.onStreamCompleted + cancelledRemotely bool + cancelledLocally bool + cancelErr *StreamError closeForShutdownErr error - cancelReadErr error - resetRemotelyErr *StreamError readChan chan struct{} readOnce chan struct{} // cap: 1, to protect against concurrent use of Read @@ -83,7 +87,8 @@ func (s *receiveStream) Read(p []byte) (int, error) { defer func() { <-s.readOnce }() s.mutex.Lock() - completed, n, err := s.readImpl(p) + n, err := s.readImpl(p) + completed := s.isNewlyCompleted() s.mutex.Unlock() if completed { @@ -92,18 +97,38 @@ func (s *receiveStream) Read(p []byte) (int, error) { return n, err } -func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, error) { - if s.finRead { - return false, 0, io.EOF +func (s *receiveStream) isNewlyCompleted() bool { + if s.completed { + return false } - if s.cancelReadErr != nil { - return false, 0, s.cancelReadErr + // We need to know the final offset (either via FIN or RESET_STREAM) for flow control accounting. + if s.finalOffset == protocol.MaxByteCount { + return false } - if s.resetRemotelyErr != nil { - return false, 0, s.resetRemotelyErr + // We're done with the stream if it was cancelled locally... + if s.cancelledLocally { + s.completed = true + return true + } + // ... or if the error (either io.EOF or the reset error) was read + if s.errorRead { + s.completed = true + return true + } + return false +} + +func (s *receiveStream) readImpl(p []byte) (int, error) { + if s.currentFrameIsLast && s.currentFrame == nil { + s.errorRead = true + return 0, io.EOF + } + if s.cancelledRemotely || s.cancelledLocally { + s.errorRead = true + return 0, s.cancelErr } if s.closeForShutdownErr != nil { - return false, 0, s.closeForShutdownErr + return 0, s.closeForShutdownErr } var bytesRead int @@ -113,25 +138,23 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err s.dequeueNextFrame() } if s.currentFrame == nil && bytesRead > 0 { - return false, bytesRead, s.closeForShutdownErr + return bytesRead, s.closeForShutdownErr } for { // Stop waiting on errors if s.closeForShutdownErr != nil { - return false, bytesRead, s.closeForShutdownErr + return bytesRead, s.closeForShutdownErr } - if s.cancelReadErr != nil { - return false, bytesRead, s.cancelReadErr - } - if s.resetRemotelyErr != nil { - return false, bytesRead, s.resetRemotelyErr + if s.cancelledRemotely || s.cancelledLocally { + s.errorRead = true + return 0, s.cancelErr } deadline := s.deadline if !deadline.IsZero() { if !time.Now().Before(deadline) { - return false, bytesRead, errDeadline + return bytesRead, errDeadline } if deadlineTimer == nil { deadlineTimer = utils.NewTimer() @@ -161,10 +184,10 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err } if bytesRead > len(p) { - return false, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) + return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) } if s.readPosInFrame > len(s.currentFrame) { - return false, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) + return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) } m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:]) @@ -173,20 +196,20 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err // when a RESET_STREAM was received, the flow controller was already // informed about the final byteOffset for this stream - if s.resetRemotelyErr == nil { + if !s.cancelledRemotely { s.flowController.AddBytesRead(protocol.ByteCount(m)) } if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast { - s.finRead = true s.currentFrame = nil if s.currentFrameDone != nil { s.currentFrameDone() } - return true, bytesRead, io.EOF + s.errorRead = true + return bytesRead, io.EOF } } - return false, bytesRead, nil + return bytesRead, nil } func (s *receiveStream) dequeueNextFrame() { @@ -202,7 +225,8 @@ func (s *receiveStream) dequeueNextFrame() { func (s *receiveStream) CancelRead(errorCode StreamErrorCode) { s.mutex.Lock() - completed := s.cancelReadImpl(errorCode) + s.cancelReadImpl(errorCode) + completed := s.isNewlyCompleted() s.mutex.Unlock() if completed { @@ -211,23 +235,26 @@ func (s *receiveStream) CancelRead(errorCode StreamErrorCode) { } } -func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ { - if s.finRead || s.cancelReadErr != nil || s.resetRemotelyErr != nil { - return false +func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) { + if s.cancelledLocally { // duplicate call to CancelRead + return } - s.cancelReadErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false} + s.cancelledLocally = true + if s.errorRead || s.cancelledRemotely { + return + } + s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false} s.signalRead() s.sender.queueControlFrame(&wire.StopSendingFrame{ StreamID: s.streamID, ErrorCode: errorCode, }) - // We're done with this stream if the final offset was already received. - return s.finalOffset != protocol.MaxByteCount } func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { s.mutex.Lock() - completed, err := s.handleStreamFrameImpl(frame) + err := s.handleStreamFrameImpl(frame) + completed := s.isNewlyCompleted() s.mutex.Unlock() if completed { @@ -237,59 +264,58 @@ func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { return err } -func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* completed */, error) { +func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) error { maxOffset := frame.Offset + frame.DataLen() if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil { - return false, err + return err } - var newlyRcvdFinalOffset bool if frame.Fin { - newlyRcvdFinalOffset = s.finalOffset == protocol.MaxByteCount s.finalOffset = maxOffset } - if s.cancelReadErr != nil { - return newlyRcvdFinalOffset, nil + if s.cancelledLocally { + return nil } if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil { - return false, err + return err } s.signalRead() - return false, nil + return nil } func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { s.mutex.Lock() - completed, err := s.handleResetStreamFrameImpl(frame) + err := s.handleResetStreamFrameImpl(frame) + completed := s.isNewlyCompleted() s.mutex.Unlock() if completed { - s.flowController.Abandon() s.sender.onStreamCompleted(s.streamID) } return err } -func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) (bool /*completed */, error) { +func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) error { if s.closeForShutdownErr != nil { - return false, nil + return nil } if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil { - return false, err + return err } - newlyRcvdFinalOffset := s.finalOffset == protocol.MaxByteCount s.finalOffset = frame.FinalSize // ignore duplicate RESET_STREAM frames for this stream (after checking their final offset) - if s.resetRemotelyErr != nil { - return false, nil + if s.cancelledRemotely { + return nil } - s.resetRemotelyErr = &StreamError{ - StreamID: s.streamID, - ErrorCode: frame.ErrorCode, - Remote: true, + s.flowController.Abandon() + // don't save the error if the RESET_STREAM frames was received after CancelRead was called + if s.cancelledLocally { + return nil } + s.cancelledRemotely = true + s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: frame.ErrorCode, Remote: true} s.signalRead() - return newlyRcvdFinalOffset, nil + return nil } func (s *receiveStream) SetReadDeadline(t time.Time) error { diff --git a/receive_stream_test.go b/receive_stream_test.go index 5438a1a0..d0ef7f31 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -226,9 +226,7 @@ var _ = Describe("Receive Stream", func() { It("returns an error when Read is called after the deadline", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes() - f := &wire.StreamFrame{Data: []byte("foobar")} - err := str.handleStreamFrame(f) - Expect(err).ToNot(HaveOccurred()) + Expect(str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")})).To(Succeed()) str.SetReadDeadline(time.Now().Add(-time.Second)) b := make([]byte, 6) n, err := strWithTimeout.Read(b) @@ -534,34 +532,46 @@ var _ = Describe("Receive Stream", func() { Fin: true, })).To(Succeed()) mockSender.EXPECT().onStreamCompleted(streamID) - _, err := strWithTimeout.Read(make([]byte, 100)) + n, err := strWithTimeout.Read(make([]byte, 100)) Expect(err).To(MatchError(io.EOF)) + Expect(n).To(Equal(6)) str.CancelRead(1234) }) It("doesn't send a STOP_SENDING frame, if the stream was already reset", func() { - gomock.InOrder( - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true), - mockFC.EXPECT().Abandon(), - ) - mockSender.EXPECT().onStreamCompleted(streamID) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true) + mockFC.EXPECT().Abandon().MinTimes(1) Expect(str.handleResetStreamFrame(&wire.ResetStreamFrame{ + ErrorCode: 1337, StreamID: streamID, FinalSize: 42, })).To(Succeed()) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) str.CancelRead(1234) + // check that the error indicates a remote reset + n, err := str.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + Expect(n).To(BeZero()) + var streamErr *StreamError + Expect(errors.As(err, &streamErr)).To(BeTrue()) + Expect(streamErr.ErrorCode).To(BeEquivalentTo(1337)) + Expect(streamErr.Remote).To(BeTrue()) }) - It("sends a STOP_SENDING and completes the stream after receiving the final offset", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true) + It("sends a STOP_SENDING after receiving the final offset", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) Expect(str.handleStreamFrame(&wire.StreamFrame{ - Offset: 1000, - Fin: true, + Data: []byte("foobar"), + Fin: true, })).To(Succeed()) mockFC.EXPECT().Abandon() mockSender.EXPECT().queueControlFrame(gomock.Any()) mockSender.EXPECT().onStreamCompleted(streamID) str.CancelRead(1234) + // read the error + n, err := str.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + Expect(n).To(BeZero()) }) It("completes the stream when receiving the Fin after the stream was canceled", func() { @@ -649,25 +659,26 @@ var _ = Describe("Receive Stream", func() { }) It("ignores duplicate RESET_STREAM frames", func() { - mockSender.EXPECT().onStreamCompleted(streamID) - mockFC.EXPECT().Abandon() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2) + mockFC.EXPECT().Abandon() Expect(str.handleResetStreamFrame(rst)).To(Succeed()) Expect(str.handleResetStreamFrame(rst)).To(Succeed()) }) It("doesn't call onStreamCompleted again when the final offset was already received via Fin", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - str.CancelRead(1234) - mockSender.EXPECT().onStreamCompleted(streamID) - mockFC.EXPECT().Abandon() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2) Expect(str.handleStreamFrame(&wire.StreamFrame{ StreamID: streamID, Offset: rst.FinalSize, Fin: true, })).To(Succeed()) + mockFC.EXPECT().Abandon().MinTimes(1) + mockSender.EXPECT().onStreamCompleted(streamID) Expect(str.handleResetStreamFrame(rst)).To(Succeed()) + // now read the error + n, err := str.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + Expect(n).To(BeZero()) }) It("doesn't do anything when it was closed for shutdown", func() { @@ -675,6 +686,22 @@ var _ = Describe("Receive Stream", func() { err := str.handleResetStreamFrame(rst) Expect(err).ToNot(HaveOccurred()) }) + + It("handles RESET_STREAM after CancelRead", func() { + mockFC.EXPECT().Abandon() + mockSender.EXPECT().queueControlFrame(gomock.Any()) + str.CancelRead(1234) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true) + mockSender.EXPECT().onStreamCompleted(streamID) + Expect(str.handleResetStreamFrame(rst)).To(Succeed()) + // check that the error indicates a local reset + n, err := str.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + Expect(n).To(BeZero()) + var streamErr *StreamError + Expect(errors.As(err, &streamErr)).To(BeTrue()) + Expect(streamErr.Remote).To(BeFalse()) + }) }) })