From 0a298f2aef433ac63c887bce6285cf8f5b9e64c2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 24 Aug 2019 09:47:49 +0700 Subject: [PATCH] implement client-side request cancelations --- http3/body.go | 30 ++++++++++++++++++++++++-- http3/body_test.go | 33 +++++++++++++++++++++++++---- http3/client.go | 15 ++++++++++++- http3/client_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 7 deletions(-) diff --git a/http3/body.go b/http3/body.go index cbc8e9e99..e450a56a5 100644 --- a/http3/body.go +++ b/http3/body.go @@ -13,6 +13,12 @@ type body struct { isRequest bool + // only set for the http.Response + // The channel is closed when the user is done with this response: + // either when Read() errors, or when Close() is called. + reqDone chan<- struct{} + reqDoneClosed bool + bytesRemainingInFrame uint64 } @@ -25,11 +31,22 @@ func newRequestBody(str quic.Stream) *body { } } -func newResponseBody(str quic.Stream) *body { - return &body{str: str} +func newResponseBody(str quic.Stream, done chan<- struct{}) *body { + return &body{ + str: str, + reqDone: done, + } } func (r *body) Read(b []byte) (int, error) { + n, err := r.readImpl(b) + if err != nil && !r.isRequest { + r.requestDone() + } + return n, err +} + +func (r *body) readImpl(b []byte) (int, error) { if r.bytesRemainingInFrame == 0 { parseLoop: for { @@ -61,11 +78,20 @@ func (r *body) Read(b []byte) (int, error) { return n, err } +func (r *body) requestDone() { + if r.reqDoneClosed { + return + } + close(r.reqDone) + r.reqDoneClosed = true +} + func (r *body) Close() error { // quic.Stream.Close() closes the write side, not the read side if r.isRequest { return r.str.Close() } + r.requestDone() r.str.CancelRead(quic.ErrorCode(errorRequestCanceled)) return nil } diff --git a/http3/body_test.go b/http3/body_test.go index dac6e2d52..9843237ec 100644 --- a/http3/body_test.go +++ b/http3/body_test.go @@ -29,9 +29,10 @@ func (t bodyType) String() string { var _ = Describe("Body", func() { var ( - rb *body - str *mockquic.MockStream - buf *bytes.Buffer + rb *body + str *mockquic.MockStream + buf *bytes.Buffer + reqDone chan struct{} ) getDataFrame := func(data []byte) []byte { @@ -62,7 +63,8 @@ var _ = Describe("Body", func() { case bodyTypeRequest: rb = newRequestBody(str) case bodyTypeResponse: - rb = newResponseBody(str) + reqDone = make(chan struct{}) + rb = newResponseBody(str, reqDone) } }) @@ -156,10 +158,33 @@ var _ = Describe("Body", func() { } if bodyType == bodyTypeResponse { + It("closes the reqDone channel when Read errors", func() { + buf.Write([]byte("invalid")) + _, err := rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + Expect(reqDone).To(BeClosed()) + }) + + It("allows multiple calls to Read, when Read errors", func() { + buf.Write([]byte("invalid")) + _, err := rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + Expect(reqDone).To(BeClosed()) + _, err = rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + }) + It("closes responses", func() { str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)) Expect(rb.Close()).To(Succeed()) }) + + It("allows multiple calls to Close", func() { + str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)).MaxTimes(2) + Expect(rb.Close()).To(Succeed()) + Expect(reqDone).To(BeClosed()) + Expect(rb.Close()).To(Succeed()) + }) } }) } diff --git a/http3/client.go b/http3/client.go index 95fb671dd..be352f700 100644 --- a/http3/client.go +++ b/http3/client.go @@ -153,6 +153,19 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } + // Request Cancelation: + // This go routine keeps running even after RoundTrip() returns. + // It is shut down when the application is done processing the body. + reqDone := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + str.CancelWrite(quic.ErrorCode(errorRequestCanceled)) + str.CancelRead(quic.ErrorCode(errorRequestCanceled)) + case <-reqDone: + } + }() + var requestGzip bool if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { requestGzip = true @@ -198,7 +211,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { res.Header.Add(hf.Name, hf.Value) } } - respBody := newResponseBody(str) + respBody := newResponseBody(str, reqDone) if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") diff --git a/http3/client_test.go b/http3/client_test.go index 1b65d2530..cbe0a9413 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -299,6 +299,56 @@ var _ = Describe("Client", func() { }) }) + Context("request cancelations", func() { + It("cancels a request while the request is still in flight", func() { + ctx, cancel := context.WithCancel(context.Background()) + req := request.WithContext(ctx) + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + buf := &bytes.Buffer{} + str.EXPECT().Close().MaxTimes(1) + + done := make(chan struct{}) + str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return buf.Write(p) + }) + str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled)) + str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) { close(done) }) + str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { + cancel() + return 0, errors.New("test done") + }) + _, err := client.RoundTrip(req) + Expect(err).To(MatchError("test done")) + Eventually(done).Should(BeClosed()) + }) + + It("cancels a request after the response arrived", func() { + rspBuf := &bytes.Buffer{} + rw := newResponseWriter(rspBuf, utils.DefaultLogger) + rw.WriteHeader(418) + + ctx, cancel := context.WithCancel(context.Background()) + req := request.WithContext(ctx) + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + buf := &bytes.Buffer{} + str.EXPECT().Close().MaxTimes(1) + + done := make(chan struct{}) + str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return buf.Write(p) + }) + str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) { + return rspBuf.Read(b) + }).AnyTimes() + str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled)) + str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) { close(done) }) + _, err := client.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + cancel() + Eventually(done).Should(BeClosed()) + }) + }) + Context("gzip compression", func() { var gzippedData []byte // a gzipped foobar var response *http.Response