From 6a4512a6f09e2fec4097fcef4404ae427b201e62 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 23 Apr 2024 22:23:48 +0200 Subject: [PATCH] http3: fix race condition when closing the RoundTripper (#4458) --- http3/roundtrip.go | 23 ++++++++++++++++++----- http3/roundtrip_test.go | 31 ++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/http3/roundtrip.go b/http3/roundtrip.go index e2b9171c..e25c2dac 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -41,6 +41,7 @@ type singleRoundTripper interface { } type roundTripperWithCount struct { + cancel context.CancelFunc dialing chan struct{} // closed as soon as quic.Dial(Early) returned dialErr error conn quic.EarlyConnection @@ -49,6 +50,15 @@ type roundTripperWithCount struct { useCount atomic.Int64 } +func (r *roundTripperWithCount) Close() error { + r.cancel() + <-r.dialing + if r.conn != nil { + return r.conn.CloseWithError(0, "") + } + return nil +} + // RoundTripper implements the http.RoundTripper interface type RoundTripper struct { mutex sync.Mutex @@ -227,11 +237,14 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache if onlyCached { return nil, false, ErrNoCachedConn } + ctx, cancel := context.WithCancel(ctx) cl = &roundTripperWithCount{ dialing: make(chan struct{}), + cancel: cancel, } go func() { defer close(cl.dialing) + defer cancel() conn, rt, err := r.dial(ctx, hostname) if err != nil { cl.dialErr = err @@ -315,8 +328,8 @@ func (r *RoundTripper) removeClient(hostname string) { func (r *RoundTripper) Close() error { r.mutex.Lock() defer r.mutex.Unlock() - for _, client := range r.clients { - if err := client.conn.CloseWithError(0, ""); err != nil { + for _, cl := range r.clients { + if err := cl.Close(); err != nil { return err } } @@ -364,9 +377,9 @@ func isNotToken(r rune) bool { func (r *RoundTripper) CloseIdleConnections() { r.mutex.Lock() defer r.mutex.Unlock() - for hostname, client := range r.clients { - if client.useCount.Load() == 0 { - client.conn.CloseWithError(0, "") + for hostname, cl := range r.clients { + if cl.useCount.Load() == 0 { + cl.Close() delete(r.clients, hostname) } } diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index b21f9313..35b22af1 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -146,13 +146,13 @@ var _ = Describe("RoundTripper", func() { testErr := errors.New("test done") tlsConf := &tls.Config{ServerName: "foo.bar"} quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} - ctx, cancel := context.WithTimeout(context.Background(), time.Hour) - defer cancel() + // nolint:staticcheck // This is a test. + ctx := context.WithValue(context.Background(), "foo", "bar") var dialerCalled bool rt := &RoundTripper{ Dial: func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { defer GinkgoRecover() - Expect(ctxP).To(Equal(ctx)) + Expect(ctx.Value("foo").(string)).To(Equal("bar")) Expect(address).To(Equal("www.example.org:443")) Expect(tlsConfP.ServerName).To(Equal("foo.bar")) Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) @@ -515,6 +515,31 @@ var _ = Describe("RoundTripper", func() { Expect(rt.Close()).To(Succeed()) }) + It("closes while dialing", func() { + rt := &RoundTripper{ + Dial: func(ctx context.Context, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { + defer GinkgoRecover() + Eventually(ctx.Done()).Should(BeClosed()) + return nil, errors.New("cancelled") + }, + } + req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + + errChan := make(chan error, 1) + go func() { + defer GinkgoRecover() + _, err := rt.RoundTrip(req) + errChan <- err + }() + + Consistently(errChan, scaleDuration(30*time.Millisecond)).ShouldNot(Receive()) + Expect(rt.Close()).To(Succeed()) + var rtErr error + Eventually(errChan).Should(Receive(&rtErr)) + Expect(rtErr).To(MatchError("cancelled")) + }) + It("closes idle connections", func() { conn1 := mockquic.NewMockEarlyConnection(mockCtrl) conn2 := mockquic.NewMockEarlyConnection(mockCtrl)