From 475b4f02cb9ab3eb137fc124e940a5a18da9c472 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 7 May 2024 10:03:41 +0800 Subject: [PATCH] http3: ignore deadline errors when tracking QUIC stream states (#4495) --- http3/state_tracking_stream.go | 5 +-- http3/state_tracking_stream_test.go | 52 ++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/http3/state_tracking_stream.go b/http3/state_tracking_stream.go index 4a73460e..a5cd834c 100644 --- a/http3/state_tracking_stream.go +++ b/http3/state_tracking_stream.go @@ -3,6 +3,7 @@ package http3 import ( "context" "errors" + "os" "sync" "github.com/quic-go/quic-go" @@ -75,7 +76,7 @@ func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) { func (s *stateTrackingStream) Write(b []byte) (int, error) { n, err := s.Stream.Write(b) - if err != nil { + if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { s.closeSend(err) } return n, err @@ -88,7 +89,7 @@ func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) { func (s *stateTrackingStream) Read(b []byte) (int, error) { n, err := s.Stream.Read(b) - if err != nil { + if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { s.closeReceive(err) } return n, err diff --git a/http3/state_tracking_stream_test.go b/http3/state_tracking_stream_test.go index 4bf1e227..e900cf1a 100644 --- a/http3/state_tracking_stream_test.go +++ b/http3/state_tracking_stream_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "io" + "os" "github.com/quic-go/quic-go" mockquic "github.com/quic-go/quic-go/internal/mocks/quic" @@ -43,7 +44,7 @@ var _ = Describe("State Tracking Stream", func() { Expect(states[0].err).To(Equal(io.EOF)) }) - It("recognizes read cancellations", func() { + It("recognizes local read cancellations", func() { qstr := mockquic.NewMockStream(mockCtrl) qstr.EXPECT().StreamID().AnyTimes() qstr.EXPECT().Context().Return(context.Background()).AnyTimes() @@ -64,6 +65,39 @@ var _ = Describe("State Tracking Stream", func() { Expect(states[0].err).To(Equal(&quic.StreamError{ErrorCode: 1337})) }) + It("recognizes remote cancellations", func() { + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().Context().Return(context.Background()).AnyTimes() + var states []stateTransition + str := newStateTrackingStream(qstr, func(state streamState, err error) { + states = append(states, stateTransition{state, err}) + }) + + testErr := errors.New("test error") + qstr.EXPECT().Read(gomock.Any()).Return(0, testErr) + _, err := str.Read(make([]byte, 3)) + Expect(err).To(MatchError(testErr)) + Expect(states).To(HaveLen(1)) + Expect(states[0].state).To(Equal(streamStateReceiveClosed)) + Expect(states[0].err).To(MatchError(testErr)) + }) + + It("doesn't misinterpret read deadline errors", func() { + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().Context().Return(context.Background()).AnyTimes() + var states []stateTransition + str := newStateTrackingStream(qstr, func(state streamState, err error) { + states = append(states, stateTransition{state, err}) + }) + + qstr.EXPECT().Read(gomock.Any()).Return(0, os.ErrDeadlineExceeded) + _, err := str.Read(make([]byte, 3)) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + Expect(states).To(BeEmpty()) + }) + It("recognizes when the send side is closed, when write errors", func() { qstr := mockquic.NewMockStream(mockCtrl) qstr.EXPECT().StreamID().AnyTimes() @@ -86,6 +120,22 @@ var _ = Describe("State Tracking Stream", func() { Expect(states[0].err).To(Equal(testErr)) }) + It("recognizes when the send side is closed, when write errors", func() { + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().Context().Return(context.Background()).AnyTimes() + var states []stateTransition + str := newStateTrackingStream(qstr, func(state streamState, err error) { + states = append(states, stateTransition{state, err}) + }) + + qstr.EXPECT().Write([]byte("foo")).Return(0, os.ErrDeadlineExceeded) + Expect(states).To(BeEmpty()) + _, err := str.Write([]byte("foo")) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + Expect(states).To(BeEmpty()) + }) + It("recognizes when the send side is closed, when CancelWrite is called", func() { qstr := mockquic.NewMockStream(mockCtrl) qstr.EXPECT().StreamID().AnyTimes()