From d22d579733594ecb1b35851c223c08e6e39973d1 Mon Sep 17 00:00:00 2001 From: sukun Date: Sat, 14 Sep 2024 10:17:48 +0530 Subject: [PATCH] don't cancel streams after shutdown (#4673) This ensures that `stream.Write` and `stream.Read` return the error code from connection close, if the stream was closed as a result of connection close. --- receive_stream.go | 3 +++ receive_stream_test.go | 11 +++++++++++ send_stream.go | 4 ++++ send_stream_test.go | 9 +++++++++ 4 files changed, 27 insertions(+) diff --git a/receive_stream.go b/receive_stream.go index 803409235..b8535ef52 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -253,6 +253,9 @@ func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) (queuedNe if s.cancelledLocally { // duplicate call to CancelRead return false } + if s.closeForShutdownErr != nil { + return false + } s.cancelledLocally = true if s.errorRead || s.cancelledRemotely { return false diff --git a/receive_stream_test.go b/receive_stream_test.go index a01931aad..cb4eaa1aa 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -630,6 +630,17 @@ var _ = Describe("Receive Stream", func() { Fin: true, })).To(Succeed()) }) + + It("ignores cancellations after closeForShutdown", func() { + closeErr := errors.New("closed for shutdown") + str.closeForShutdown(closeErr) + buf := make([]byte, 100) + _, err := str.Read(buf) + Expect(err).To(Equal(closeErr)) + str.CancelRead(42) + _, err = str.Read(buf) + Expect(err).To(Equal(closeErr)) + }) }) Context("receiving RESET_STREAM frames", func() { diff --git a/send_stream.go b/send_stream.go index bcaf2abfd..699c40ef6 100644 --- a/send_stream.go +++ b/send_stream.go @@ -423,6 +423,10 @@ func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) { s.mutex.Lock() + if s.closeForShutdownErr != nil { + s.mutex.Unlock() + return + } if !remote { s.cancellationFlagged = true if s.cancelWriteErr != nil { diff --git a/send_stream_test.go b/send_stream_test.go index 0cd032cef..c5d62436c 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -981,6 +981,15 @@ var _ = Describe("Send Stream", func() { ErrorCode: 123, }) }) + It("ignores cancellations after closeForShutdown", func() { + closeErr := errors.New("closed for shutdown") + str.closeForShutdown(closeErr) + _, err := str.Write([]byte("hello")) + Expect(err).To(Equal(closeErr)) + str.CancelWrite(42) + _, err = str.Write([]byte("hello")) + Expect(err).To(Equal(closeErr)) + }) }) })