forked from quic-go/quic-go
http3: fix race condition when closing the RoundTripper (#4458)
This commit is contained in:
@@ -41,6 +41,7 @@ type singleRoundTripper interface {
|
||||
}
|
||||
|
||||
type roundTripperWithCount struct {
|
||||
cancel context.CancelFunc
|
||||
dialing chan struct{} // closed as soon as quic.Dial(Early) returned
|
||||
dialErr error
|
||||
conn quic.EarlyConnection
|
||||
@@ -49,6 +50,15 @@ type roundTripperWithCount struct {
|
||||
useCount atomic.Int64
|
||||
}
|
||||
|
||||
func (r *roundTripperWithCount) Close() error {
|
||||
r.cancel()
|
||||
<-r.dialing
|
||||
if r.conn != nil {
|
||||
return r.conn.CloseWithError(0, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoundTripper implements the http.RoundTripper interface
|
||||
type RoundTripper struct {
|
||||
mutex sync.Mutex
|
||||
@@ -227,11 +237,14 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
|
||||
if onlyCached {
|
||||
return nil, false, ErrNoCachedConn
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
cl = &roundTripperWithCount{
|
||||
dialing: make(chan struct{}),
|
||||
cancel: cancel,
|
||||
}
|
||||
go func() {
|
||||
defer close(cl.dialing)
|
||||
defer cancel()
|
||||
conn, rt, err := r.dial(ctx, hostname)
|
||||
if err != nil {
|
||||
cl.dialErr = err
|
||||
@@ -315,8 +328,8 @@ func (r *RoundTripper) removeClient(hostname string) {
|
||||
func (r *RoundTripper) Close() error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for _, client := range r.clients {
|
||||
if err := client.conn.CloseWithError(0, ""); err != nil {
|
||||
for _, cl := range r.clients {
|
||||
if err := cl.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -364,9 +377,9 @@ func isNotToken(r rune) bool {
|
||||
func (r *RoundTripper) CloseIdleConnections() {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for hostname, client := range r.clients {
|
||||
if client.useCount.Load() == 0 {
|
||||
client.conn.CloseWithError(0, "")
|
||||
for hostname, cl := range r.clients {
|
||||
if cl.useCount.Load() == 0 {
|
||||
cl.Close()
|
||||
delete(r.clients, hostname)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,13 +146,13 @@ var _ = Describe("RoundTripper", func() {
|
||||
testErr := errors.New("test done")
|
||||
tlsConf := &tls.Config{ServerName: "foo.bar"}
|
||||
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
|
||||
defer cancel()
|
||||
// nolint:staticcheck // This is a test.
|
||||
ctx := context.WithValue(context.Background(), "foo", "bar")
|
||||
var dialerCalled bool
|
||||
rt := &RoundTripper{
|
||||
Dial: func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(ctxP).To(Equal(ctx))
|
||||
Expect(ctx.Value("foo").(string)).To(Equal("bar"))
|
||||
Expect(address).To(Equal("www.example.org:443"))
|
||||
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
|
||||
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
|
||||
@@ -515,6 +515,31 @@ var _ = Describe("RoundTripper", func() {
|
||||
Expect(rt.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("closes while dialing", func() {
|
||||
rt := &RoundTripper{
|
||||
Dial: func(ctx context.Context, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
|
||||
defer GinkgoRecover()
|
||||
Eventually(ctx.Done()).Should(BeClosed())
|
||||
return nil, errors.New("cancelled")
|
||||
},
|
||||
}
|
||||
req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := rt.RoundTrip(req)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
Consistently(errChan, scaleDuration(30*time.Millisecond)).ShouldNot(Receive())
|
||||
Expect(rt.Close()).To(Succeed())
|
||||
var rtErr error
|
||||
Eventually(errChan).Should(Receive(&rtErr))
|
||||
Expect(rtErr).To(MatchError("cancelled"))
|
||||
})
|
||||
|
||||
It("closes idle connections", func() {
|
||||
conn1 := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
conn2 := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
|
||||
Reference in New Issue
Block a user