From 2a082f973a06c3b34fedc27cfeb68149b65b0813 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 23 Jun 2024 12:38:49 +0800 Subject: [PATCH] http3: allow re-dialing of connection after a dial error (#4573) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * http3: do not cache dial error * add an integration test * http3: add a unit test for dial failures --------- Co-authored-by: 世界 --- http3/roundtrip.go | 2 ++ http3/roundtrip_test.go | 31 ++++++++++++++++++++++++++++++ integrationtests/self/http_test.go | 23 ++++++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 148e33735..a9b169ee1 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -166,6 +166,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. } if cl.dialErr != nil { + r.removeClient(hostname) return nil, cl.dialErr } defer cl.useCount.Add(-1) @@ -258,6 +259,7 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache select { case <-cl.dialing: if cl.dialErr != nil { + delete(r.clients, hostname) return nil, false, cl.dialErr } select { diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index fb081abe3..c8472e04b 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -265,6 +265,37 @@ var _ = Describe("RoundTripper", func() { Expect(count).To(Equal(1)) }) + It("redials a connection if dialing failed", func() { + cl1 := NewMockSingleRoundTripper(mockCtrl) + clientChan <- cl1 + + req1, err := http.NewRequest("GET", "https://quic-go.net/foo.html", nil) + Expect(err).ToNot(HaveOccurred()) + req2, err := http.NewRequest("GET", "https://quic-go.net/bar.html", nil) + Expect(err).ToNot(HaveOccurred()) + + testErr := errors.New("handshake error") + conn := mockquic.NewMockEarlyConnection(mockCtrl) + var count int + rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + count++ + if count == 1 { + return nil, testErr + } + return conn, nil + } + handshakeChan := make(chan struct{}) + close(handshakeChan) + conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2) + cl1.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil) + _, err = rt.RoundTrip(req1) + Expect(err).To(MatchError(testErr)) + rsp, err := rt.RoundTrip(req2) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.Request).To(Equal(req2)) + Expect(count).To(Equal(2)) + }) + It("immediately removes a clients when a request errored", func() { cl1 := NewMockSingleRoundTripper(mockCtrl) clientChan <- cl1 diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 0688a4374..e2cf5fb25 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -150,6 +150,29 @@ var _ = Describe("HTTP tests", func() { Expect(resp.Header.Get("Content-Length")).To(Equal("6")) }) + It("re-establishes a QUIC connection after a dial error", func() { + var dialCounter int + testErr := errors.New("test error") + cl := http.Client{ + Transport: &http3.RoundTripper{ + TLSClientConfig: getTLSClientConfig(), + Dial: func(ctx context.Context, addr string, tlsConf *tls.Config, conf *quic.Config) (quic.EarlyConnection, error) { + dialCounter++ + if dialCounter == 1 { // make the first dial fail + return nil, testErr + } + return quic.DialAddrEarly(ctx, addr, tlsConf, conf) + }, + }, + } + defer cl.Transport.(io.Closer).Close() + _, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", port)) + Expect(err).To(MatchError(testErr)) + resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }) + It("detects stream errors when server panics when writing response", func() { respChan := make(chan struct{}) mux.HandleFunc("/writing_and_panicking", func(w http.ResponseWriter, r *http.Request) {