forked from quic-go/quic-go
http3: implement client-side GOAWAY handling (#5143)
When receiving a GOAWAY frame, the client: * immediately closes the connection if there are no active requests * refuses to open streams with stream IDs larger than the stream ID in the GOAWAY frame * closes the connection once the stream count drops to zero
This commit is contained in:
@@ -1099,3 +1099,115 @@ func testHTTPRequestRetryAfterIdleTimeout(t *testing.T, onlyCachedConn bool) {
|
||||
require.Equal(t, 2, headersCount)
|
||||
require.Empty(t, conns) // make sure we dialed 2 connections
|
||||
}
|
||||
|
||||
func TestHTTPRequestAfterGracefulShutdown(t *testing.T) {
|
||||
t.Run("Request.GetBody set", func(t *testing.T) {
|
||||
testHTTPRequestAfterGracefulShutdown(t, true)
|
||||
})
|
||||
t.Run("Request.GetBody not set", func(t *testing.T) {
|
||||
testHTTPRequestAfterGracefulShutdown(t, false)
|
||||
})
|
||||
}
|
||||
|
||||
func testHTTPRequestAfterGracefulShutdown(t *testing.T, setGetBody bool) {
|
||||
t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true")
|
||||
|
||||
ln, err := quic.ListenEarly(
|
||||
newUDPConnLocalhost(t),
|
||||
http3.ConfigureTLSConfig(getTLSConfig()),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
var inShutdown atomic.Bool
|
||||
proxy := quicproxy.Proxy{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ServerAddr: ln.Addr().(*net.UDPAddr),
|
||||
DelayPacket: func(_ quicproxy.Direction, _, _ net.Addr, data []byte) time.Duration {
|
||||
if inShutdown.Load() {
|
||||
return scaleDuration(10 * time.Millisecond)
|
||||
}
|
||||
return scaleDuration(2 * time.Millisecond)
|
||||
},
|
||||
}
|
||||
require.NoError(t, proxy.Start())
|
||||
defer proxy.Close()
|
||||
|
||||
mux2 := http.NewServeMux()
|
||||
mux2.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
|
||||
data, _ := io.ReadAll(r.Body)
|
||||
w.Write(data)
|
||||
})
|
||||
server2 := &http3.Server{Handler: mux2}
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
server1 := &http3.Server{Handler: http.NewServeMux()}
|
||||
|
||||
go server1.ServeListener(ln)
|
||||
|
||||
tlsConf := getTLSClientConfigWithoutServerName()
|
||||
tlsConf.NextProtos = []string{http3.NextProtoH3}
|
||||
var dialCount int
|
||||
tr := &http3.Transport{
|
||||
TLSClientConfig: tlsConf,
|
||||
Dial: func(ctx context.Context, a string, tlsConf *tls.Config, conf *quic.Config) (quic.EarlyConnection, error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dialCount++
|
||||
return quic.DialEarly(ctx, newUDPConnLocalhost(t), addr, tlsConf, conf)
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() { tr.Close() })
|
||||
cl := &http.Client{Transport: tr}
|
||||
|
||||
// first request to establish the connection
|
||||
resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/", ln.Addr().(*net.UDPAddr).Port))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
|
||||
// If the body is a strings.Reader, http.NewRequest automatically sets the GetBody callback.
|
||||
// This can be prevented by using a different kind of reader, e.g. the io.LimitReader.
|
||||
var headersCount int
|
||||
req, err := http.NewRequestWithContext(
|
||||
httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
|
||||
WroteHeaders: func() { headersCount++ },
|
||||
}),
|
||||
http.MethodGet,
|
||||
fmt.Sprintf("https://localhost:%d/echo", ln.Addr().(*net.UDPAddr).Port),
|
||||
io.LimitReader(strings.NewReader("foobar"), 1000),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
if setGetBody {
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(strings.NewReader("foobaz")), nil
|
||||
}
|
||||
} else {
|
||||
require.Nil(t, req.GetBody)
|
||||
}
|
||||
|
||||
inShutdown.Store(true)
|
||||
go server1.Shutdown(context.Background())
|
||||
go server2.ServeListener(ln)
|
||||
defer server2.Close()
|
||||
|
||||
// so that graceful shutdown can actually start
|
||||
time.Sleep(scaleDuration(10 * time.Millisecond))
|
||||
|
||||
resp, err = cl.Do(req)
|
||||
if !setGetBody {
|
||||
require.ErrorContains(t, err, "after Request.Body was written; define Request.GetBody to avoid this error")
|
||||
require.Equal(t, 1, dialCount)
|
||||
require.Equal(t, 1, headersCount)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foobaz", string(body))
|
||||
require.Equal(t, 2, dialCount)
|
||||
require.Equal(t, 2, headersCount)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user