diff --git a/http3/server.go b/http3/server.go index 10f3f80a..be004f13 100644 --- a/http3/server.go +++ b/http3/server.go @@ -94,6 +94,9 @@ var RemoteAddrContextKey = &contextKey{"remote-addr"} type listener struct { ln *QUICEarlyListener port int // 0 means that no info about port is available + + // if this listener was constructed by the application, it won't be closed when the server is closed + createdLocally bool } // Server is a HTTP/3 server. @@ -273,7 +276,7 @@ func (s *Server) ServeQUICConn(conn quic.Connection) error { // ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed. func (s *Server) ServeListener(ln QUICEarlyListener) error { s.mutex.Lock() - if err := s.addListener(&ln); err != nil { + if err := s.addListener(&ln, false); err != nil { s.mutex.Unlock() return err } @@ -344,7 +347,7 @@ func (s *Server) setupListenerForConn(tlsConf *tls.Config, conn net.PacketConn) if err != nil { return nil, err } - if err := s.addListener(&ln); err != nil { + if err := s.addListener(&ln, true); err != nil { return nil, err } return &ln, nil @@ -401,7 +404,7 @@ func (s *Server) generateAltSvcHeader() { s.altSvcHeader = strings.Join(altSvc, ",") } -func (s *Server) addListener(l *QUICEarlyListener) error { +func (s *Server) addListener(l *QUICEarlyListener, createdLocally bool) error { if s.closed { return http.ErrServerClosed } @@ -409,14 +412,14 @@ func (s *Server) addListener(l *QUICEarlyListener) error { laddr := (*l).Addr() if port, err := extractPort(laddr.String()); err == nil { - s.listeners = append(s.listeners, listener{ln: l, port: port}) + s.listeners = append(s.listeners, listener{ln: l, port: port, createdLocally: createdLocally}) } else { logger := s.Logger if logger == nil { logger = slog.Default() } logger.Error("Unable to extract port from listener, will not be announced using SetQUICHeaders", "local addr", laddr, "error", err) - s.listeners = append(s.listeners, listener{ln: l, port: 0}) + s.listeners = append(s.listeners, listener{ln: l, port: 0, createdLocally: createdLocally}) } s.generateAltSvcHeader() return nil @@ -688,9 +691,11 @@ func (s *Server) Close() error { s.closeCancel() var err error - for _, info := range s.listeners { - if cerr := (*info.ln).Close(); cerr != nil && err == nil { - err = cerr + for _, l := range s.listeners { + if l.createdLocally { + if cerr := (*l.ln).Close(); cerr != nil && err == nil { + err = cerr + } } } if s.connCount.Load() == 0 { @@ -708,7 +713,7 @@ func (s *Server) Close() error { func (s *Server) Shutdown(ctx context.Context) error { s.mutex.Lock() s.closed = true - // server is never used + // server was never used if s.closeCtx == nil { s.mutex.Unlock() return nil diff --git a/http3/server_test.go b/http3/server_test.go index b27d85f5..02df9a0a 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -880,6 +880,18 @@ func TestServerConcurrentServeAndClose(t *testing.T) { } } +func TestServerImmediateGracefulShutdown(t *testing.T) { + s := &Server{TLSConfig: testdata.GetTLSConfig()} + errChan := make(chan error, 1) + go func() { errChan <- s.Shutdown(context.Background()) }() + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + func TestServerGracefulShutdown(t *testing.T) { s := &Server{TLSConfig: testdata.GetTLSConfig()} s.init() diff --git a/integrationtests/self/http_hotswap_test.go b/integrationtests/self/http_hotswap_test.go index 638d8e24..246322c9 100644 --- a/integrationtests/self/http_hotswap_test.go +++ b/integrationtests/self/http_hotswap_test.go @@ -1,12 +1,10 @@ package self_test import ( - "context" "io" "net" "net/http" "strconv" - "sync/atomic" "testing" "time" @@ -15,48 +13,6 @@ import ( "github.com/stretchr/testify/require" ) -type listenerWrapper struct { - http3.QUICEarlyListener - listenerClosed bool - count atomic.Int32 -} - -func (ln *listenerWrapper) Close() error { - ln.listenerClosed = true - return ln.QUICEarlyListener.Close() -} - -func (ln *listenerWrapper) Faker() *fakeClosingListener { - ln.count.Add(1) - ctx, cancel := context.WithCancel(context.Background()) - return &fakeClosingListener{ - listenerWrapper: ln, - ctx: ctx, - cancel: cancel, - } -} - -type fakeClosingListener struct { - *listenerWrapper - closed atomic.Bool - ctx context.Context - cancel context.CancelFunc -} - -func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection, error) { - return ln.listenerWrapper.Accept(ln.ctx) -} - -func (ln *fakeClosingListener) Close() error { - if ln.closed.CompareAndSwap(false, true) { - ln.cancel() - if ln.count.Add(-1) == 0 { - ln.listenerWrapper.Close() - } - } - return nil -} - func TestHTTP3ServerHotswap(t *testing.T) { mux1 := http.NewServeMux() mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) { @@ -78,9 +34,8 @@ func TestHTTP3ServerHotswap(t *testing.T) { } tlsConf := http3.ConfigureTLSConfig(getTLSConfig()) - quicLn, err := quic.ListenEarly(newUDPConnLocalhost(t), tlsConf, getQuicConfig(nil)) + ln, err := quic.ListenEarly(newUDPConnLocalhost(t), tlsConf, getQuicConfig(nil)) require.NoError(t, err) - ln := &listenerWrapper{QUICEarlyListener: quicLn} port := strconv.Itoa(ln.Addr().(*net.UDPAddr).Port) rt := &http3.Transport{ @@ -96,12 +51,8 @@ func TestHTTP3ServerHotswap(t *testing.T) { }() // open first server and make single request to it - fake1 := ln.Faker() - stoppedServing1 := make(chan struct{}) - go func() { - server1.ServeListener(fake1) - close(stoppedServing1) - }() + errChan1 := make(chan error, 1) + go func() { errChan1 <- server1.ServeListener(ln) }() resp, err := client.Get("https://localhost:" + port + "/hello1") require.NoError(t, err) @@ -111,36 +62,29 @@ func TestHTTP3ServerHotswap(t *testing.T) { require.Equal(t, "Hello, World 1!\n", string(body)) // open second server with same underlying listener - fake2 := ln.Faker() - stoppedServing2 := make(chan struct{}) - go func() { - server2.ServeListener(fake2) - close(stoppedServing2) - }() + errChan2 := make(chan error, 1) + go func() { errChan2 <- server2.ServeListener(ln) }() - // Verify both servers are running by waiting a bit and checking channels aren't closed - time.Sleep(50 * time.Millisecond) + time.Sleep(scaleDuration(20 * time.Millisecond)) select { - case <-stoppedServing1: - t.Fatal("server1 stopped unexpectedly") - case <-stoppedServing2: - t.Fatal("server2 stopped unexpectedly") + case err := <-errChan1: + t.Fatalf("server1 stopped unexpectedly: %v", err) + case err := <-errChan2: + t.Fatalf("server2 stopped unexpectedly: %v", err) default: } // now close first server require.NoError(t, server1.Close()) select { - case <-stoppedServing1: - case <-time.After(time.Second): + case err := <-errChan1: + require.ErrorIs(t, err, http.ErrServerClosed) + case <-time.After(5 * time.Second): t.Fatal("timed out waiting for server1 to stop") } - require.True(t, fake1.closed.Load()) - require.False(t, fake2.closed.Load()) - require.False(t, ln.listenerClosed) require.NoError(t, client.Transport.(*http3.Transport).Close()) - // verify that new connections are being initiated from the second server now + // verify that new connections are handled by the second server now resp, err = client.Get("https://localhost:" + port + "/hello2") require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -148,13 +92,12 @@ func TestHTTP3ServerHotswap(t *testing.T) { require.NoError(t, err) require.Equal(t, "Hello, World 2!\n", string(body)) - // close the other server - both the fake and the actual listeners must close now + // close the other server require.NoError(t, server2.Close()) select { - case <-stoppedServing2: + case err := <-errChan2: + require.ErrorIs(t, err, http.ErrServerClosed) case <-time.After(time.Second): t.Fatal("timed out waiting for server2 to stop") } - require.True(t, fake2.closed.Load()) - require.True(t, ln.listenerClosed) } diff --git a/integrationtests/self/http_shutdown_test.go b/integrationtests/self/http_shutdown_test.go index fd1e00e2..0ff40e3e 100644 --- a/integrationtests/self/http_shutdown_test.go +++ b/integrationtests/self/http_shutdown_test.go @@ -6,9 +6,11 @@ import ( "io" "net" "net/http" + "net/url" "testing" "time" + "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" @@ -196,3 +198,118 @@ func TestGracefulShutdownPendingStreams(t *testing.T) { t.Fatal("shutdown did not complete") } } + +func TestHTTP3ListenerClosing(t *testing.T) { + t.Run("application listener", func(t *testing.T) { + testHTTP3ListenerClosing(t, true) + }) + t.Run("listener created by the http3.Server", func(t *testing.T) { + testHTTP3ListenerClosing(t, false) + }) +} + +func testHTTP3ListenerClosing(t *testing.T, useApplicationListener bool) { + dial := func(t *testing.T, ctx context.Context, u *url.URL) error { + t.Helper() + tlsConf := getTLSClientConfig() + tlsConf.NextProtos = []string{http3.NextProtoH3} + tr := &http3.Transport{TLSClientConfig: tlsConf} + defer tr.Close() + cl := &http.Client{Transport: tr} + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + require.NoError(t, err) + resp, err := cl.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + return nil + } + + mux := http.NewServeMux() + mux.HandleFunc("/ok", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + tlsConf := http3.ConfigureTLSConfig(getTLSConfig()) + server := &http3.Server{ + Handler: mux, + // the following values will be ignored when using ServeListener + TLSConfig: tlsConf, + QUICConfig: getQuicConfig(nil), + Addr: "127.0.0.1:47283", + } + + serveChan := make(chan error, 1) + var host string + var ln *quic.EarlyListener // only set when using application listener + if useApplicationListener { + var err error + ln, err = quic.ListenEarly(newUDPConnLocalhost(t), tlsConf, getQuicConfig(nil)) + require.NoError(t, err) + defer ln.Close() + host = ln.Addr().String() + go func() { serveChan <- server.ServeListener(ln) }() + } else { + go func() { serveChan <- server.ListenAndServe() }() + host = server.Addr + } + + u := &url.URL{Scheme: "https", Host: host, Path: "/ok"} + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, dial(t, ctx, u)) + + // close the server + require.NoError(t, server.Close()) + + select { + case err := <-serveChan: + require.ErrorIs(t, err, http.ErrServerClosed) + case <-time.After(time.Second): + t.Fatal("server did not stop") + } + + // If the listener was created by the http3.Server, it will now be closed. + if !useApplicationListener { + ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(10*time.Millisecond)) + defer cancel() + require.ErrorIs(t, dial(t, ctx, u), context.DeadlineExceeded) + return + } + + // If the listener was created by the application, it will not be closed, + // and it can be used to accept new connections. + errChan := make(chan error, 1) + go func() { + for { + conn, err := ln.Accept(context.Background()) + if err != nil { + errChan <- err + return + } + select { + case <-conn.HandshakeComplete(): + conn.CloseWithError(1337, "") + case <-time.After(time.Second): + errChan <- fmt.Errorf("connection did not complete handshake") + } + errChan <- nil + } + }() + + for range 3 { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := dial(t, ctx, u) + var h3Err *http3.Error + require.ErrorAs(t, err, &h3Err) + require.Equal(t, http3.ErrCode(1337), h3Err.ErrorCode) + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("server did not accept connection") + } + } +}