From bb6f066aa56f5d0c2d00466d5d1f67dd5a5d6226 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 5 May 2024 10:15:35 +0800 Subject: [PATCH] http3: use the stream context to detect when the send side is closed (#4489) --- http3/client_test.go | 1 + http3/conn_test.go | 2 + http3/datagram.go | 4 +- http3/server_test.go | 25 ++++------- http3/state_tracking_stream.go | 4 ++ http3/state_tracking_stream_test.go | 68 +++++++++++++++++++++++------ integrationtests/self/http_test.go | 40 +++++++++++++++++ 7 files changed, 111 insertions(+), 33 deletions(-) diff --git a/http3/client_test.go b/http3/client_test.go index e3863afe9..4c5330243 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -494,6 +494,7 @@ var _ = Describe("Client", func() { return len(b), nil }) // SETTINGS frame str = mockquic.NewMockStream(mockCtrl) + str.EXPECT().Context().Return(context.Background()).AnyTimes() str.EXPECT().StreamID().AnyTimes() conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) diff --git a/http3/conn_test.go b/http3/conn_test.go index 4aa6e423b..8af906cd0 100644 --- a/http3/conn_test.go +++ b/http3/conn_test.go @@ -382,6 +382,7 @@ var _ = Describe("Connection", func() { // ... then open the stream qstr := mockquic.NewMockStream(mockCtrl) qstr.EXPECT().StreamID().Return(strID).MinTimes(1) + qstr.EXPECT().Context().Return(context.Background()).AnyTimes() qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(qstr, nil) str, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000) Expect(err).ToNot(HaveOccurred()) @@ -397,6 +398,7 @@ var _ = Describe("Connection", func() { // first open the stream... qstr := mockquic.NewMockStream(mockCtrl) qstr.EXPECT().StreamID().Return(strID).MinTimes(1) + qstr.EXPECT().Context().Return(context.Background()).AnyTimes() qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(qstr, nil) str, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000) Expect(err).ToNot(HaveOccurred()) diff --git a/http3/datagram.go b/http3/datagram.go index 0a5028815..491e97ed7 100644 --- a/http3/datagram.go +++ b/http3/datagram.go @@ -85,9 +85,9 @@ start: d.mx.Unlock() return data, nil } - if d.receiveErr != nil { + if receiveErr := d.receiveErr; receiveErr != nil { d.mx.Unlock() - return nil, d.receiveErr + return nil, receiveErr } d.mx.Unlock() diff --git a/http3/server_test.go b/http3/server_test.go index 5597a2815..f6b635fed 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -90,7 +90,7 @@ var _ = Describe("Server", func() { exampleGetRequest *http.Request examplePostRequest *http.Request ) - reqContext := context.Background() + reqContext, reqContextCancel := context.WithCancel(context.Background()) decodeHeader := func(str io.Reader) map[string][]string { fields := make(map[string][]string) @@ -140,6 +140,7 @@ var _ = Describe("Server", func() { qpackDecoder = qpack.NewDecoder(nil) str = mockquic.NewMockStream(mockCtrl) + str.EXPECT().Context().Return(reqContext).AnyTimes() str.EXPECT().StreamID().AnyTimes() qconn := mockquic.NewMockEarlyConnection(mockCtrl) addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} @@ -157,7 +158,6 @@ var _ = Describe("Server", func() { }) setRequest(encodeRequest(exampleGetRequest)) - str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { return len(p), nil }).AnyTimes() @@ -178,7 +178,6 @@ var _ = Describe("Server", func() { responseBuf := &bytes.Buffer{} setRequest(encodeRequest(exampleGetRequest)) - str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() @@ -195,7 +194,6 @@ var _ = Describe("Server", func() { responseBuf := &bytes.Buffer{} setRequest(encodeRequest(exampleGetRequest)) - str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() @@ -216,7 +214,6 @@ var _ = Describe("Server", func() { responseBuf := &bytes.Buffer{} setRequest(encodeRequest(exampleGetRequest)) - str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() @@ -237,7 +234,6 @@ var _ = Describe("Server", func() { responseBuf := &bytes.Buffer{} setRequest(encodeRequest(exampleGetRequest)) - str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() @@ -258,7 +254,6 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) responseBuf := &bytes.Buffer{} setRequest(encodeRequest(headRequest)) - str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() @@ -278,7 +273,6 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) responseBuf := &bytes.Buffer{} setRequest(encodeRequest(headRequest)) - str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() @@ -297,7 +291,6 @@ var _ = Describe("Server", func() { responseBuf := &bytes.Buffer{} setRequest(encodeRequest(exampleGetRequest)) - str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError)) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError)) @@ -315,7 +308,6 @@ var _ = Describe("Server", func() { responseBuf := &bytes.Buffer{} setRequest(encodeRequest(exampleGetRequest)) - str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError)) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError)) @@ -355,6 +347,7 @@ var _ = Describe("Server", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes() unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() unknownStr.EXPECT().StreamID().AnyTimes() conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) @@ -380,6 +373,7 @@ var _ = Describe("Server", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes() unknownStr.EXPECT().StreamID().AnyTimes() unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) @@ -407,6 +401,7 @@ var _ = Describe("Server", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes() unknownStr.EXPECT().StreamID().AnyTimes() unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) @@ -428,7 +423,6 @@ var _ = Describe("Server", func() { const strID = protocol.StreamID(1234 * 4) testErr := errors.New("test error") done := make(chan struct{}) - unknownStr := mockquic.NewMockStream(mockCtrl) s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, err error) (bool, error) { defer close(done) Expect(ft).To(BeZero()) @@ -436,6 +430,8 @@ var _ = Describe("Server", func() { Expect(err).To(MatchError(testErr)) return true, nil } + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes() unknownStr.EXPECT().StreamID().Return(strID).AnyTimes() unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) @@ -586,7 +582,6 @@ var _ = Describe("Server", func() { responseBuf := &bytes.Buffer{} setRequest(append(requestData, b...)) done := make(chan struct{}) - str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) str.EXPECT().Close().Do(func() error { close(done); return nil }) @@ -610,7 +605,6 @@ var _ = Describe("Server", func() { b := (&dataFrame{Length: 6}).Append(nil) // add a body b = append(b, []byte("foobar")...) setRequest(append(requestData, b...)) - str.EXPECT().Context().Return(reqContext) var buf bytes.Buffer str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() @@ -726,7 +720,6 @@ var _ = Describe("Server", func() { }) setRequest(encodeRequest(examplePostRequest)) - str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { return len(p), nil }).AnyTimes() @@ -747,9 +740,7 @@ var _ = Describe("Server", func() { }) setRequest(encodeRequest(examplePostRequest)) - reqContext, cancel := context.WithCancel(context.Background()) - cancel() - str.EXPECT().Context().Return(reqContext) + reqContextCancel() str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { return len(p), nil }).AnyTimes() diff --git a/http3/state_tracking_stream.go b/http3/state_tracking_stream.go index 90914ebcd..4a73460e8 100644 --- a/http3/state_tracking_stream.go +++ b/http3/state_tracking_stream.go @@ -1,6 +1,7 @@ package http3 import ( + "context" "errors" "sync" @@ -26,6 +27,9 @@ type stateTrackingStream struct { } func newStateTrackingStream(s quic.Stream, onStateChange func(streamState, error)) *stateTrackingStream { + context.AfterFunc(s.Context(), func() { + onStateChange(streamStateSendClosed, context.Cause(s.Context())) + }) return &stateTrackingStream{ Stream: s, state: streamStateOpen, diff --git a/http3/state_tracking_stream_test.go b/http3/state_tracking_stream_test.go index 7526467a7..4bf1e227c 100644 --- a/http3/state_tracking_stream_test.go +++ b/http3/state_tracking_stream_test.go @@ -2,6 +2,7 @@ package http3 import ( "bytes" + "context" "errors" "io" @@ -19,22 +20,15 @@ type stateTransition struct { } var _ = Describe("State Tracking Stream", func() { - var ( - qstr *mockquic.MockStream - str *stateTrackingStream - states []stateTransition - ) - - BeforeEach(func() { - states = nil - qstr = mockquic.NewMockStream(mockCtrl) + It("recognizes when the receive side is closed", func() { + qstr := mockquic.NewMockStream(mockCtrl) qstr.EXPECT().StreamID().AnyTimes() - str = newStateTrackingStream(qstr, func(state streamState, err error) { + qstr.EXPECT().Context().Return(context.Background()).AnyTimes() + var states []stateTransition + str := newStateTrackingStream(qstr, func(state streamState, err error) { states = append(states, stateTransition{state, err}) }) - }) - It("recognizes when the receive side is closed", func() { buf := bytes.NewBuffer([]byte("foobar")) qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() for i := 0; i < 3; i++ { @@ -50,6 +44,14 @@ var _ = Describe("State Tracking Stream", func() { }) It("recognizes read cancellations", func() { + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().Context().Return(context.Background()).AnyTimes() + var states []stateTransition + str := newStateTrackingStream(qstr, func(state streamState, err error) { + states = append(states, stateTransition{state, err}) + }) + buf := bytes.NewBuffer([]byte("foobar")) qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() qstr.EXPECT().CancelRead(quic.StreamErrorCode(1337)) @@ -62,7 +64,15 @@ var _ = Describe("State Tracking Stream", func() { Expect(states[0].err).To(Equal(&quic.StreamError{ErrorCode: 1337})) }) - It("recognizes when the send side is closed", func() { + It("recognizes when the send side is closed, when write errors", func() { + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().Context().Return(context.Background()).AnyTimes() + var states []stateTransition + str := newStateTrackingStream(qstr, func(state streamState, err error) { + states = append(states, stateTransition{state, err}) + }) + testErr := errors.New("test error") qstr.EXPECT().Write([]byte("foo")).Return(3, nil) qstr.EXPECT().Write([]byte("bar")).Return(0, testErr) @@ -76,7 +86,15 @@ var _ = Describe("State Tracking Stream", func() { Expect(states[0].err).To(Equal(testErr)) }) - It("recognizes write cancellations", func() { + It("recognizes when the send side is closed, when CancelWrite is called", func() { + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().Context().Return(context.Background()).AnyTimes() + var states []stateTransition + str := newStateTrackingStream(qstr, func(state streamState, err error) { + states = append(states, stateTransition{state, err}) + }) + qstr.EXPECT().Write(gomock.Any()) qstr.EXPECT().CancelWrite(quic.StreamErrorCode(1337)) _, err := str.Write([]byte("foobar")) @@ -87,4 +105,26 @@ var _ = Describe("State Tracking Stream", func() { Expect(states[0].state).To(Equal(streamStateSendClosed)) Expect(states[0].err).To(Equal(&quic.StreamError{ErrorCode: 1337})) }) + + It("recognizes when the send side is closed, when the stream context is canceled", func() { + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().AnyTimes() + ctx, cancel := context.WithCancelCause(context.Background()) + qstr.EXPECT().Context().Return(ctx).AnyTimes() + var states []stateTransition + + done := make(chan struct{}) + newStateTrackingStream(qstr, func(state streamState, err error) { + states = append(states, stateTransition{state, err}) + close(done) + }) + + Expect(states).To(BeEmpty()) + testErr := errors.New("test error") + cancel(testErr) + Eventually(done).Should(BeClosed()) + Expect(states).To(HaveLen(1)) + Expect(states[0].state).To(Equal(streamStateSendClosed)) + Expect(states[0].err).To(Equal(testErr)) + }) }) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 176b37931..0264d1a4e 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -846,6 +846,46 @@ var _ = Describe("HTTP tests", func() { // make sure we can't send anymore Expect(str.SendDatagram([]byte("foo"))).ToNot(Succeed()) }) + + It("detecting a stream reset from the server", func() { + errChan := make(chan error, 1) + datagramChan := make(chan []byte, 1) + mux.HandleFunc("/datagrams", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + conn := w.(http3.Hijacker).Connection() + Eventually(conn.ReceivedSettings()).Should(BeClosed()) + Expect(conn.Settings().EnableDatagrams).To(BeTrue()) + w.WriteHeader(http.StatusOK) + + str := w.(http3.HTTPStreamer).HTTPStream() + go str.Read([]byte{0}) // need to continue reading from stream to observe state transitions + + for { + data, err := str.ReceiveDatagram(context.Background()) + if err != nil { + errChan <- err + return + } + str.CancelRead(42) + datagramChan <- data + } + }) + + str, closeFn := openDatagramStream(fmt.Sprintf("https://localhost:%d/datagrams", port)) + defer closeFn() + go str.Read([]byte{0}) + + Expect(str.SendDatagram([]byte("foo"))).To(Succeed()) + Eventually(datagramChan).Should(Receive(Equal([]byte("foo")))) + // signal that we're done sending + + var resetErr error + Eventually(errChan).Should(Receive(&resetErr)) + Expect(resetErr).To(Equal(&quic.StreamError{ErrorCode: 42, Remote: false})) + + // make sure we can't send anymore + Expect(str.SendDatagram([]byte("foo"))).To(Equal(&quic.StreamError{ErrorCode: 42, Remote: true})) + }) }) Context("0-RTT", func() {