diff --git a/http3/client.go b/http3/client.go index 8aca48078..9608f8824 100644 --- a/http3/client.go +++ b/http3/client.go @@ -254,6 +254,15 @@ func (c *client) maxHeaderBytes() uint64 { // RoundTripOpt executes a request and returns a response func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + rsp, err := c.roundTripOpt(req, opt) + if err != nil && req.Context().Err() != nil { + // if the context was canceled, return the context cancellation error + err = req.Context().Err() + } + return rsp, err +} + +func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) } diff --git a/http3/client_test.go b/http3/client_test.go index 014c29a5b..00d01bcaa 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -922,7 +922,7 @@ var _ = Describe("Client", func() { return 0, errors.New("test done") }) _, err := cl.RoundTripOpt(req, roundTripOpt) - Expect(err).To(MatchError("test done")) + Expect(err).To(MatchError(context.Canceled)) Eventually(done).Should(BeClosed()) }) }) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 5ec17e2d9..96e72dc7c 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -319,6 +319,21 @@ var _ = Describe("HTTP tests", func() { Expect(string(body)).To(Equal("Hello, World!\n")) }) + It("handles context cancellations", func() { + mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + }) + + ctx, cancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/cancel", port), nil) + Expect(err).ToNot(HaveOccurred()) + time.AfterFunc(50*time.Millisecond, cancel) + + _, err = client.Do(req) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(context.Canceled)) + }) + It("cancels requests", func() { handlerCalled := make(chan struct{}) mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) {