http3: implement client-side GOAWAY handling (#5143)

When receiving a GOAWAY frame, the client:
* immediately closes the connection if there are no active requests
* refuses to open streams with stream IDs larger than the stream ID in
the GOAWAY frame
* closes the connection once the stream count drops to zero
This commit is contained in:
Marten Seemann
2025-05-18 13:33:43 +08:00
committed by GitHub
parent 06e8ee1bcf
commit 363e0ccafb
5 changed files with 524 additions and 53 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/quic-go/quic-go/http3"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -39,33 +40,93 @@ func TestHTTPShutdown(t *testing.T) {
}
func TestGracefulShutdownShortRequest(t *testing.T) {
delay := scaleDuration(25 * time.Millisecond)
var server *http3.Server
mux := http.NewServeMux()
port := startHTTPServer(t, mux, func(s *http3.Server) { server = s })
errChan := make(chan error, 1)
proceed := make(chan struct{})
mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) {
go func() {
defer close(errChan)
errChan <- server.Shutdown(context.Background())
}()
time.Sleep(delay)
w.WriteHeader(http.StatusOK)
w.(http.Flusher).Flush()
<-proceed
w.Write([]byte("shutdown"))
})
client := newHTTP3Client(t)
ctx, cancel := context.WithTimeout(context.Background(), 3*delay)
connChan := make(chan quic.EarlyConnection, 1)
tr := &http3.Transport{
TLSClientConfig: getTLSClientConfigWithoutServerName(),
Dial: func(ctx context.Context, a string, tlsConf *tls.Config, conf *quic.Config) (quic.EarlyConnection, error) {
addr, err := net.ResolveUDPAddr("udp", a)
if err != nil {
return nil, err
}
conn, err := quic.DialEarly(ctx, newUDPConnLocalhost(t), addr, tlsConf, conf)
connChan <- conn
return conn, err
},
}
t.Cleanup(func() { tr.Close() })
client := &http.Client{Transport: tr}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/shutdown", port), nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, []byte("shutdown"), body)
client.Transport.(*http3.Transport).Close() // manually close the client
var conn quic.EarlyConnection
select {
case conn = <-connChan:
default:
t.Fatal("expected a connection")
}
type result struct {
body []byte
err error
}
resultChan := make(chan result, 1)
go func() {
body, err := io.ReadAll(resp.Body)
resultChan <- result{body: body, err: err}
}()
select {
case <-resultChan:
t.Fatal("request body shouldn't have been read yet")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
select {
case <-conn.Context().Done():
t.Fatal("connection shouldn't have been closed")
default:
}
// allow the request to proceed
close(proceed)
select {
case res := <-resultChan:
require.NoError(t, res.err)
require.Equal(t, []byte("shutdown"), res.body)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// now that the stream count dropped to 0, the client should close the connection
select {
case <-conn.Context().Done():
var appErr *quic.ApplicationError
require.ErrorAs(t, context.Cause(conn.Context()), &appErr)
assert.False(t, appErr.Remote)
assert.Equal(t, quic.ApplicationErrorCode(http3.ErrCodeNoError), appErr.ErrorCode)
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case err := <-errChan:
@@ -75,6 +136,63 @@ func TestGracefulShutdownShortRequest(t *testing.T) {
}
}
func TestGracefulShutdownIdleConnection(t *testing.T) {
var server *http3.Server
port := startHTTPServer(t, http.NewServeMux(), func(s *http3.Server) { server = s })
connChan := make(chan quic.EarlyConnection, 1)
tr := &http3.Transport{
TLSClientConfig: getTLSClientConfigWithoutServerName(),
Dial: func(ctx context.Context, a string, tlsConf *tls.Config, conf *quic.Config) (quic.EarlyConnection, error) {
addr, err := net.ResolveUDPAddr("udp", a)
if err != nil {
return nil, err
}
conn, err := quic.DialEarly(ctx, newUDPConnLocalhost(t), addr, tlsConf, conf)
connChan <- conn
return conn, err
},
}
t.Cleanup(func() { tr.Close() })
client := &http.Client{Transport: tr}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/", port), nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusNotFound, resp.StatusCode)
require.NoError(t, resp.Body.Close())
var conn quic.EarlyConnection
select {
case conn = <-connChan:
default:
t.Fatal("expected a connection")
}
// the connection should still be alive (and idle)
select {
case <-conn.Context().Done():
t.Fatal("connection shouldn't have been closed")
default:
}
shutdownChan := make(chan error, 1)
go func() { shutdownChan <- server.Shutdown(context.Background()) }()
// since the connection is idle, the client should close it immediately
select {
case <-conn.Context().Done():
var appErr *quic.ApplicationError
require.ErrorAs(t, context.Cause(conn.Context()), &appErr)
assert.False(t, appErr.Remote)
assert.Equal(t, quic.ApplicationErrorCode(http3.ErrCodeNoError), appErr.ErrorCode)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestGracefulShutdownLongLivedRequest(t *testing.T) {
delay := scaleDuration(25 * time.Millisecond)
errChan := make(chan error, 1)
@@ -88,6 +206,8 @@ func TestGracefulShutdownLongLivedRequest(t *testing.T) {
w.WriteHeader(http.StatusOK)
w.(http.Flusher).Flush()
// The request simulated here takes longer than the server's graceful shutdown period.
// We expect it to be terminated once the server shuts down.
go func() {
ctx, cancel := context.WithTimeout(context.Background(), delay)
defer cancel()

View File

@@ -1099,3 +1099,115 @@ func testHTTPRequestRetryAfterIdleTimeout(t *testing.T, onlyCachedConn bool) {
require.Equal(t, 2, headersCount)
require.Empty(t, conns) // make sure we dialed 2 connections
}
func TestHTTPRequestAfterGracefulShutdown(t *testing.T) {
t.Run("Request.GetBody set", func(t *testing.T) {
testHTTPRequestAfterGracefulShutdown(t, true)
})
t.Run("Request.GetBody not set", func(t *testing.T) {
testHTTPRequestAfterGracefulShutdown(t, false)
})
}
func testHTTPRequestAfterGracefulShutdown(t *testing.T, setGetBody bool) {
t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true")
ln, err := quic.ListenEarly(
newUDPConnLocalhost(t),
http3.ConfigureTLSConfig(getTLSConfig()),
getQuicConfig(nil),
)
require.NoError(t, err)
var inShutdown atomic.Bool
proxy := quicproxy.Proxy{
Conn: newUDPConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(_ quicproxy.Direction, _, _ net.Addr, data []byte) time.Duration {
if inShutdown.Load() {
return scaleDuration(10 * time.Millisecond)
}
return scaleDuration(2 * time.Millisecond)
},
}
require.NoError(t, proxy.Start())
defer proxy.Close()
mux2 := http.NewServeMux()
mux2.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
data, _ := io.ReadAll(r.Body)
w.Write(data)
})
server2 := &http3.Server{Handler: mux2}
done := make(chan struct{})
defer close(done)
server1 := &http3.Server{Handler: http.NewServeMux()}
go server1.ServeListener(ln)
tlsConf := getTLSClientConfigWithoutServerName()
tlsConf.NextProtos = []string{http3.NextProtoH3}
var dialCount int
tr := &http3.Transport{
TLSClientConfig: tlsConf,
Dial: func(ctx context.Context, a string, tlsConf *tls.Config, conf *quic.Config) (quic.EarlyConnection, error) {
addr, err := net.ResolveUDPAddr("udp", a)
if err != nil {
return nil, err
}
dialCount++
return quic.DialEarly(ctx, newUDPConnLocalhost(t), addr, tlsConf, conf)
},
}
t.Cleanup(func() { tr.Close() })
cl := &http.Client{Transport: tr}
// first request to establish the connection
resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/", ln.Addr().(*net.UDPAddr).Port))
require.NoError(t, err)
require.Equal(t, http.StatusNotFound, resp.StatusCode)
// If the body is a strings.Reader, http.NewRequest automatically sets the GetBody callback.
// This can be prevented by using a different kind of reader, e.g. the io.LimitReader.
var headersCount int
req, err := http.NewRequestWithContext(
httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
WroteHeaders: func() { headersCount++ },
}),
http.MethodGet,
fmt.Sprintf("https://localhost:%d/echo", ln.Addr().(*net.UDPAddr).Port),
io.LimitReader(strings.NewReader("foobar"), 1000),
)
require.NoError(t, err)
if setGetBody {
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader("foobaz")), nil
}
} else {
require.Nil(t, req.GetBody)
}
inShutdown.Store(true)
go server1.Shutdown(context.Background())
go server2.ServeListener(ln)
defer server2.Close()
// so that graceful shutdown can actually start
time.Sleep(scaleDuration(10 * time.Millisecond))
resp, err = cl.Do(req)
if !setGetBody {
require.ErrorContains(t, err, "after Request.Body was written; define Request.GetBody to avoid this error")
require.Equal(t, 1, dialCount)
require.Equal(t, 1, headersCount)
return
}
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "foobaz", string(body))
require.Equal(t, 2, dialCount)
require.Equal(t, 2, headersCount)
}