http3: don't allow usage of closed Transport (#5324)

* do not allow use a closed transport

* update tests

* add test

* Update http3/transport.go

Co-authored-by: Marten Seemann <martenseemann@gmail.com>

* unit tests

* update test

---------

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
Glonee
2025-09-07 18:52:25 +08:00
committed by GitHub
parent 6dcac15a12
commit 0ae8c03816
3 changed files with 47 additions and 9 deletions

View File

@@ -111,6 +111,7 @@ type Transport struct {
clients map[string]*roundTripperWithCount clients map[string]*roundTripperWithCount
transport *quic.Transport transport *quic.Transport
closed bool
} }
var ( var (
@@ -118,8 +119,12 @@ var (
_ io.Closer = &Transport{} _ io.Closer = &Transport{}
) )
var (
// ErrNoCachedConn is returned when Transport.OnlyCachedConn is set // ErrNoCachedConn is returned when Transport.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("http3: no cached connection was available") ErrNoCachedConn = errors.New("http3: no cached connection was available")
// ErrTransportClosed is returned when attempting to use a closed Transport
ErrTransportClosed = errors.New("http3: transport is closed")
)
func (t *Transport) init() error { func (t *Transport) init() error {
if t.newClientConn == nil { if t.newClientConn == nil {
@@ -292,6 +297,9 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
func (t *Transport) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) { func (t *Transport) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
t.mutex.Lock() t.mutex.Lock()
defer t.mutex.Unlock() defer t.mutex.Unlock()
if t.closed {
return nil, false, ErrTransportClosed
}
if t.clients == nil { if t.clients == nil {
t.clients = make(map[string]*roundTripperWithCount) t.clients = make(map[string]*roundTripperWithCount)
@@ -431,6 +439,7 @@ func (t *Transport) NewClientConn(conn *quic.Conn) *ClientConn {
} }
// Close closes the QUIC connections that this Transport has used. // Close closes the QUIC connections that this Transport has used.
// A Transport cannot be used after it has been closed.
func (t *Transport) Close() error { func (t *Transport) Close() error {
t.mutex.Lock() t.mutex.Lock()
defer t.mutex.Unlock() defer t.mutex.Unlock()
@@ -448,8 +457,8 @@ func (t *Transport) Close() error {
return err return err
} }
t.transport = nil t.transport = nil
t.initOnce = sync.Once{}
} }
t.closed = true
return nil return nil
} }

View File

@@ -548,3 +548,24 @@ func TestTransportCloseIdleConnections(t *testing.T) {
t.Fatal("timeout") t.Fatal("timeout")
} }
} }
func TestTransportClose(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr := &Transport{
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
return nil, nil
},
newClientConn: func(*quic.Conn) clientConn {
cl := NewMockClientConn(mockCtrl)
cl.EXPECT().RoundTrip(gomock.Any()).Return(nil, nil)
return cl
},
}
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
require.NoError(t, err)
_, err = tr.RoundTrip(req)
require.NoError(t, err)
require.NoError(t, tr.Close())
_, err = tr.RoundTrip(req)
require.ErrorIs(t, err, ErrTransportClosed)
}

View File

@@ -38,15 +38,19 @@ func TestHTTP3ServerHotswap(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
port := strconv.Itoa(ln.Addr().(*net.UDPAddr).Port) port := strconv.Itoa(ln.Addr().(*net.UDPAddr).Port)
rt := &http3.Transport{ newClient := func() *http.Client {
return &http.Client{
Transport: &http3.Transport{
TLSClientConfig: getTLSClientConfig(), TLSClientConfig: getTLSClientConfig(),
DisableCompression: true, DisableCompression: true,
QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
},
} }
client := &http.Client{Transport: rt} }
client := newClient()
defer func() { defer func() {
require.NoError(t, rt.Close())
require.NoError(t, ln.Close()) require.NoError(t, ln.Close())
}() }()
@@ -83,6 +87,10 @@ func TestHTTP3ServerHotswap(t *testing.T) {
t.Fatal("timed out waiting for server1 to stop") t.Fatal("timed out waiting for server1 to stop")
} }
require.NoError(t, client.Transport.(*http3.Transport).Close()) require.NoError(t, client.Transport.(*http3.Transport).Close())
client = newClient()
defer func() {
require.NoError(t, client.Transport.(*http3.Transport).Close())
}()
// verify that new connections are handled by the second server now // verify that new connections are handled by the second server now
resp, err = client.Get("https://localhost:" + port + "/hello2") resp, err = client.Get("https://localhost:" + port + "/hello2")