From 3eed9d11048e3f7fc14a982045a9895cd34ac228 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 30 May 2022 20:03:21 +0200 Subject: [PATCH] don't close the stream if it was hijacked --- http3/body.go | 8 ++++++++ http3/server.go | 7 ++++++- http3/server_test.go | 28 ++++++++++++++++++++++++---- integrationtests/self/http_test.go | 11 ++++++++--- 4 files changed, 46 insertions(+), 8 deletions(-) diff --git a/http3/body.go b/http3/body.go index 3a5f1150..b3d1afd7 100644 --- a/http3/body.go +++ b/http3/body.go @@ -12,6 +12,7 @@ import ( // * for the server: the http.Request.Body // * for the client: the http.Response.Body // On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set. +// When a stream is taken over, it's the caller's responsibility to close the stream. type HTTPStreamer interface { HTTPStream() Stream } @@ -36,6 +37,8 @@ type Hijacker interface { // The body of a http.Request or http.Response. type body struct { str quic.Stream + + wasHijacked bool // set when HTTPStream is called } var ( @@ -48,9 +51,14 @@ func newRequestBody(str Stream) *body { } func (r *body) HTTPStream() Stream { + r.wasHijacked = true return r.str } +func (r *body) wasStreamHijacked() bool { + return r.wasHijacked +} + func (r *body) Read(b []byte) (int, error) { return r.str.Read(b) } diff --git a/http3/server.go b/http3/server.go index 040770e0..cc904433 100644 --- a/http3/server.go +++ b/http3/server.go @@ -549,7 +549,8 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q } req.RemoteAddr = conn.RemoteAddr().String() - req.Body = newRequestBody(newStream(str, onFrameError)) + body := newRequestBody(newStream(str, onFrameError)) + req.Body = body if s.logger.Debug() { s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID()) @@ -583,6 +584,10 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q handler.ServeHTTP(r, req) }() + if body.wasStreamHijacked() { + return requestError{err: errHijacked} + } + if panicked { r.WriteHeader(500) } else { diff --git a/http3/server_test.go b/http3/server_test.go index e7fc0f3e..6b4cae02 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -159,7 +159,6 @@ var _ = Describe("Server", func() { qpackDecoder = qpack.NewDecoder(nil) str = mockquic.NewMockStream(mockCtrl) - conn = mockquic.NewMockEarlyConnection(mockCtrl) addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() @@ -480,7 +479,7 @@ var _ = Describe("Server", func() { }) } - It("reset streams Other than the control stream and the QPACK streams", func() { + It("reset streams other than the control stream and the QPACK streams", func() { buf := &bytes.Buffer{} quicvarint.Write(buf, 1337) str := mockquic.NewMockStream(mockCtrl) @@ -623,9 +622,9 @@ var _ = Describe("Server", func() { AfterEach(func() { testDone <- struct{}{} }) It("cancels reading when client sends a body in GET request", func() { - handlerCalled := make(chan struct{}) + var handlerCalled bool s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(handlerCalled) + handlerCalled = true }) requestData := encodeRequest(exampleGetRequest) @@ -644,6 +643,27 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + Expect(handlerCalled).To(BeTrue()) + }) + + It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() { + handlerCalled := make(chan struct{}) + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(handlerCalled) + r.Body.(HTTPStreamer).HTTPStream() + str.Write([]byte("foobar")) + }) + + requestData := encodeRequest(exampleGetRequest) + buf := &bytes.Buffer{} + (&dataFrame{Length: 6}).Write(buf) // add a body + buf.Write([]byte("foobar")) + setRequest(append(requestData, buf.Bytes()...)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write([]byte("foobar")).Return(6, nil) + + s.handleConn(conn) + Eventually(handlerCalled).Should(BeClosed()) }) It("errors when the client sends a too large header frame", func() { diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 15e1a9ed..728511e1 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -321,9 +321,14 @@ var _ = Describe("HTTP tests", func() { str := r.Body.(http3.HTTPStreamer).HTTPStream() str.Write([]byte("foobar")) - _, err := io.Copy(str, str) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) + // Do this in a Go routine, so that the handler returns early. + // This way, we can also check that the HTTP/3 doesn't close the stream. + go func() { + defer GinkgoRecover() + _, err := io.Copy(str, str) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() }) req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/httpstreamer", nil)