From 10acc677aaf0e68e183b1b4bc302483391116b15 Mon Sep 17 00:00:00 2001 From: Lorenzo Saino Date: Wed, 7 Feb 2018 14:37:49 +0000 Subject: [PATCH] Make RoundTripper.RoundTrip(...) return if client timeout expires Currently, the implementation of h2quic.RoundTripper.RoundTrip(req *http.Request) ignores the context of req. As a result, if the RoundTripper is used as transport of an http.Client with a timeout value set, that is ignored. For example, in the following snippet, client.Do(req) does not promptly return if the task takes more than client.Timeout to complete. client := http.Client{ Timeout: 50 * time.Millisecond, Transport = &h2quic.RoundTripper{} } req, err := http.NewRequest("GET", "https://www.example.com", nil) response, err := client.Do(req) This commit updates h2quic.client.RoundTrip(req *http.Request) to return an error if the request is cancelled. --- h2quic/client.go | 9 +++++ h2quic/client_test.go | 61 ++++++++++++++++++++++++++++++++++ h2quic/response_writer_test.go | 15 +++++---- 3 files changed, 78 insertions(+), 7 deletions(-) diff --git a/h2quic/client.go b/h2quic/client.go index c81224264..8fadc58ff 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -202,6 +202,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { bodySent = true } + ctx := req.Context() for !(bodySent && receivedResponse) { select { case res = <-responseChan: @@ -214,6 +215,14 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil { return nil, err } + case <-ctx.Done(): + // error code 6 signals that stream was canceled + dataStream.CancelRead(6) + dataStream.CancelWrite(6) + c.mutex.Lock() + delete(c.responses, dataStream.StreamID()) + c.mutex.Unlock() + return nil, ctx.Err() case <-c.headerErrored: // an error occurred on the header stream _ = c.CloseWithError(c.headerErr) diff --git a/h2quic/client_test.go b/h2quic/client_test.go index 83151ff21..dc23d13a6 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -231,6 +231,67 @@ var _ = Describe("Client", func() { Eventually(done).Should(BeClosed()) }) + It("errors if a request without a body is canceled", func() { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer GinkgoRecover() + request = request.WithContext(ctx) + rsp, err := client.RoundTrip(request) + Expect(err).To(MatchError(context.Canceled)) + Expect(rsp).To(BeNil()) + close(done) + }() + + cancel() + Eventually(done).Should(BeClosed()) + Expect(dataStream.reset).To(BeTrue()) + Expect(dataStream.canceledWrite).To(BeTrue()) + Expect(client.headerErrored).ToNot(BeClosed()) + }) + + It("errors if a request with a body is canceled after the body is sent", func() { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer GinkgoRecover() + request = request.WithContext(ctx) + request.Body = &mockBody{} + rsp, err := client.RoundTrip(request) + Expect(err).To(MatchError(context.Canceled)) + Expect(rsp).To(BeNil()) + close(done) + }() + + time.Sleep(10 * time.Millisecond) + cancel() + Eventually(done).Should(BeClosed()) + Expect(dataStream.reset).To(BeTrue()) + Expect(dataStream.canceledWrite).To(BeTrue()) + Expect(client.headerErrored).ToNot(BeClosed()) + }) + + It("errors if a request with a body is canceled before the body is sent", func() { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer GinkgoRecover() + request = request.WithContext(ctx) + request.Body = &mockBody{} + cancel() + time.Sleep(10 * time.Millisecond) + rsp, err := client.RoundTrip(request) + Expect(err).To(MatchError(context.Canceled)) + Expect(rsp).To(BeNil()) + close(done) + }() + + Eventually(done).Should(BeClosed()) + Expect(dataStream.reset).To(BeTrue()) + Expect(dataStream.canceledWrite).To(BeTrue()) + Expect(client.headerErrored).ToNot(BeClosed()) + }) + It("closes the quic client when encountering an error on the header stream", func() { headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) done := make(chan struct{}) diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index 4ea0701c5..847cf645f 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -18,12 +18,13 @@ import ( ) type mockStream struct { - id protocol.StreamID - dataToRead bytes.Buffer - dataWritten bytes.Buffer - reset bool - closed bool - remoteClosed bool + id protocol.StreamID + dataToRead bytes.Buffer + dataWritten bytes.Buffer + reset bool + canceledWrite bool + closed bool + remoteClosed bool unblockRead chan struct{} ctx context.Context @@ -43,7 +44,7 @@ func newMockStream(id protocol.StreamID) *mockStream { func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil } func (s *mockStream) CancelRead(quic.ErrorCode) error { s.reset = true; return nil } -func (s *mockStream) CancelWrite(quic.ErrorCode) error { panic("not implemented") } +func (s *mockStream) CancelWrite(quic.ErrorCode) error { s.canceledWrite = true; return nil } func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() } func (s mockStream) StreamID() protocol.StreamID { return s.id } func (s *mockStream) Context() context.Context { return s.ctx }