http3: fix race condition when closing the RoundTripper (#4458)

This commit is contained in:
Marten Seemann
2024-04-23 22:23:48 +02:00
committed by GitHub
parent eb1c16bd0e
commit 6a4512a6f0
2 changed files with 46 additions and 8 deletions

View File

@@ -41,6 +41,7 @@ type singleRoundTripper interface {
}
type roundTripperWithCount struct {
cancel context.CancelFunc
dialing chan struct{} // closed as soon as quic.Dial(Early) returned
dialErr error
conn quic.EarlyConnection
@@ -49,6 +50,15 @@ type roundTripperWithCount struct {
useCount atomic.Int64
}
func (r *roundTripperWithCount) Close() error {
r.cancel()
<-r.dialing
if r.conn != nil {
return r.conn.CloseWithError(0, "")
}
return nil
}
// RoundTripper implements the http.RoundTripper interface
type RoundTripper struct {
mutex sync.Mutex
@@ -227,11 +237,14 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
if onlyCached {
return nil, false, ErrNoCachedConn
}
ctx, cancel := context.WithCancel(ctx)
cl = &roundTripperWithCount{
dialing: make(chan struct{}),
cancel: cancel,
}
go func() {
defer close(cl.dialing)
defer cancel()
conn, rt, err := r.dial(ctx, hostname)
if err != nil {
cl.dialErr = err
@@ -315,8 +328,8 @@ func (r *RoundTripper) removeClient(hostname string) {
func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.conn.CloseWithError(0, ""); err != nil {
for _, cl := range r.clients {
if err := cl.Close(); err != nil {
return err
}
}
@@ -364,9 +377,9 @@ func isNotToken(r rune) bool {
func (r *RoundTripper) CloseIdleConnections() {
r.mutex.Lock()
defer r.mutex.Unlock()
for hostname, client := range r.clients {
if client.useCount.Load() == 0 {
client.conn.CloseWithError(0, "")
for hostname, cl := range r.clients {
if cl.useCount.Load() == 0 {
cl.Close()
delete(r.clients, hostname)
}
}

View File

@@ -146,13 +146,13 @@ var _ = Describe("RoundTripper", func() {
testErr := errors.New("test done")
tlsConf := &tls.Config{ServerName: "foo.bar"}
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel()
// nolint:staticcheck // This is a test.
ctx := context.WithValue(context.Background(), "foo", "bar")
var dialerCalled bool
rt := &RoundTripper{
Dial: func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
defer GinkgoRecover()
Expect(ctxP).To(Equal(ctx))
Expect(ctx.Value("foo").(string)).To(Equal("bar"))
Expect(address).To(Equal("www.example.org:443"))
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
@@ -515,6 +515,31 @@ var _ = Describe("RoundTripper", func() {
Expect(rt.Close()).To(Succeed())
})
It("closes while dialing", func() {
rt := &RoundTripper{
Dial: func(ctx context.Context, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
defer GinkgoRecover()
Eventually(ctx.Done()).Should(BeClosed())
return nil, errors.New("cancelled")
},
}
req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
errChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
_, err := rt.RoundTrip(req)
errChan <- err
}()
Consistently(errChan, scaleDuration(30*time.Millisecond)).ShouldNot(Receive())
Expect(rt.Close()).To(Succeed())
var rtErr error
Eventually(errChan).Should(Receive(&rtErr))
Expect(rtErr).To(MatchError("cancelled"))
})
It("closes idle connections", func() {
conn1 := mockquic.NewMockEarlyConnection(mockCtrl)
conn2 := mockquic.NewMockEarlyConnection(mockCtrl)