diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 02014c47f..f8b0e542c 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -165,7 +165,13 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. defer cl.useCount.Add(-1) rsp, err := cl.rt.RoundTripOpt(req, opt) if err != nil { - r.removeClient(hostname) + // non-nil errors on roundtrip are likely due to a problem with the connection + // so we remove the client from the cache so that subsequent trips reconnect + // context cancelation is excluded as is does not signify a connection error + if !errors.Is(err, context.Canceled) { + r.removeClient(hostname) + } + if isReused { if nerr, ok := err.(net.Error); ok && nerr.Timeout() { return r.RoundTripOpt(req, opt) diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index fd8583b99..c314f2733 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -296,6 +296,37 @@ var _ = Describe("RoundTripper", func() { Expect(count).To(Equal(2)) }) + It("does not remove a client when a request returns context canceled error", func() { + cl1 := NewMockSingleRoundTripper(mockCtrl) + clientChan <- cl1 + cl2 := NewMockSingleRoundTripper(mockCtrl) + clientChan <- cl2 + + req1, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + req2, err := http.NewRequest("GET", "https://quic-go.net/bar.html", nil) + Expect(err).ToNot(HaveOccurred()) + + conn := mockquic.NewMockEarlyConnection(mockCtrl) + var count int + rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + count++ + return conn, nil + } + testErr := context.Canceled + handshakeChan := make(chan struct{}) + close(handshakeChan) + conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2) + cl1.EXPECT().RoundTripOpt(req1, gomock.Any()).Return(nil, testErr) + cl1.EXPECT().RoundTripOpt(req2, gomock.Any()).Return(&http.Response{Request: req2}, nil) + _, err = rt.RoundTrip(req1) + Expect(err).To(MatchError(testErr)) + rsp, err := rt.RoundTrip(req2) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.Request).To(Equal(req2)) + Expect(count).To(Equal(1)) + }) + It("recreates a client when a request times out", func() { var reqCount int cl1 := NewMockSingleRoundTripper(mockCtrl)