forked from quic-go/quic-go
http3: don't allow usage of closed Transport (#5324)
* do not allow use a closed transport * update tests * add test * Update http3/transport.go Co-authored-by: Marten Seemann <martenseemann@gmail.com> * unit tests * update test --------- Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
@@ -111,6 +111,7 @@ type Transport struct {
|
||||
|
||||
clients map[string]*roundTripperWithCount
|
||||
transport *quic.Transport
|
||||
closed bool
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -118,8 +119,12 @@ var (
|
||||
_ io.Closer = &Transport{}
|
||||
)
|
||||
|
||||
// ErrNoCachedConn is returned when Transport.OnlyCachedConn is set
|
||||
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
|
||||
var (
|
||||
// ErrNoCachedConn is returned when Transport.OnlyCachedConn is set
|
||||
ErrNoCachedConn = errors.New("http3: no cached connection was available")
|
||||
// ErrTransportClosed is returned when attempting to use a closed Transport
|
||||
ErrTransportClosed = errors.New("http3: transport is closed")
|
||||
)
|
||||
|
||||
func (t *Transport) init() error {
|
||||
if t.newClientConn == nil {
|
||||
@@ -292,6 +297,9 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (t *Transport) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
if t.closed {
|
||||
return nil, false, ErrTransportClosed
|
||||
}
|
||||
|
||||
if t.clients == nil {
|
||||
t.clients = make(map[string]*roundTripperWithCount)
|
||||
@@ -431,6 +439,7 @@ func (t *Transport) NewClientConn(conn *quic.Conn) *ClientConn {
|
||||
}
|
||||
|
||||
// Close closes the QUIC connections that this Transport has used.
|
||||
// A Transport cannot be used after it has been closed.
|
||||
func (t *Transport) Close() error {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
@@ -448,8 +457,8 @@ func (t *Transport) Close() error {
|
||||
return err
|
||||
}
|
||||
t.transport = nil
|
||||
t.initOnce = sync.Once{}
|
||||
}
|
||||
t.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -548,3 +548,24 @@ func TestTransportCloseIdleConnections(t *testing.T) {
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportClose(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
tr := &Transport{
|
||||
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||
return nil, nil
|
||||
},
|
||||
newClientConn: func(*quic.Conn) clientConn {
|
||||
cl := NewMockClientConn(mockCtrl)
|
||||
cl.EXPECT().RoundTrip(gomock.Any()).Return(nil, nil)
|
||||
return cl
|
||||
},
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
require.NoError(t, err)
|
||||
_, err = tr.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tr.Close())
|
||||
_, err = tr.RoundTrip(req)
|
||||
require.ErrorIs(t, err, ErrTransportClosed)
|
||||
}
|
||||
|
||||
@@ -38,15 +38,19 @@ func TestHTTP3ServerHotswap(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
port := strconv.Itoa(ln.Addr().(*net.UDPAddr).Port)
|
||||
|
||||
rt := &http3.Transport{
|
||||
TLSClientConfig: getTLSClientConfig(),
|
||||
DisableCompression: true,
|
||||
QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
|
||||
newClient := func() *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &http3.Transport{
|
||||
TLSClientConfig: getTLSClientConfig(),
|
||||
DisableCompression: true,
|
||||
QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
|
||||
},
|
||||
}
|
||||
}
|
||||
client := &http.Client{Transport: rt}
|
||||
|
||||
client := newClient()
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, rt.Close())
|
||||
require.NoError(t, ln.Close())
|
||||
}()
|
||||
|
||||
@@ -83,6 +87,10 @@ func TestHTTP3ServerHotswap(t *testing.T) {
|
||||
t.Fatal("timed out waiting for server1 to stop")
|
||||
}
|
||||
require.NoError(t, client.Transport.(*http3.Transport).Close())
|
||||
client = newClient()
|
||||
defer func() {
|
||||
require.NoError(t, client.Transport.(*http3.Transport).Close())
|
||||
}()
|
||||
|
||||
// verify that new connections are handled by the second server now
|
||||
resp, err = client.Get("https://localhost:" + port + "/hello2")
|
||||
|
||||
Reference in New Issue
Block a user