diff --git a/http3/server.go b/http3/server.go index 91a4e5359..3353f0af2 100644 --- a/http3/server.go +++ b/http3/server.go @@ -206,9 +206,9 @@ func (s *Server) ListenAndServe() error { if err != nil { return err } - defer s.removeListener(&ln) + defer s.removeListener(ln) - return s.serveListener(ln) + return s.serveListener(*ln) } // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. @@ -227,9 +227,9 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { if err != nil { return err } - defer s.removeListener(&ln) + defer s.removeListener(ln) - return s.serveListener(ln) + return s.serveListener(*ln) } // Serve an existing UDP connection. @@ -240,9 +240,9 @@ func (s *Server) Serve(conn net.PacketConn) error { if err != nil { return err } - defer s.removeListener(&ln) + defer s.removeListener(ln) - return s.serveListener(ln) + return s.serveListener(*ln) } // init initializes the contexts used for shutting down the server. @@ -314,7 +314,7 @@ func (s *Server) serveListener(ln QUICEarlyListener) error { var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig") -func (s *Server) setupListenerForConn(tlsConf *tls.Config, conn net.PacketConn) (QUICEarlyListener, error) { +func (s *Server) setupListenerForConn(tlsConf *tls.Config, conn net.PacketConn) (*QUICEarlyListener, error) { if tlsConf == nil { return nil, errServerWithoutTLSConfig } @@ -354,7 +354,7 @@ func (s *Server) setupListenerForConn(tlsConf *tls.Config, conn net.PacketConn) if err := s.addListener(&ln); err != nil { return nil, err } - return ln, nil + return &ln, nil } func extractPort(addr string) (int, error) { diff --git a/http3/server_test.go b/http3/server_test.go index 4af00e203..3f6f29a3d 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -5,12 +5,14 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "log/slog" "net" "net/http" "runtime" "sync/atomic" + "testing" "time" "github.com/quic-go/quic-go" @@ -18,6 +20,7 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/testdata" "github.com/quic-go/quic-go/quicvarint" + "github.com/stretchr/testify/require" "github.com/quic-go/qpack" "go.uber.org/mock/gomock" @@ -1223,3 +1226,86 @@ var _ = Describe("Server", func() { Expect(receivedConf.EnableDatagrams).To(BeTrue()) }) }) + +func getAltSvc(s *Server) (string, bool) { + hdr := http.Header{} + s.SetQUICHeaders(hdr) + if altSvc, ok := hdr["Alt-Svc"]; ok { + return altSvc[0], true + } + return "", false +} + +func TestServerAltSvcFromListenersAndConns(t *testing.T) { + t.Run("default", func(t *testing.T) { + testServerAltSvcFromListenersAndConns(t, []quic.Version{}) + }) + t.Run("v1", func(t *testing.T) { + testServerAltSvcFromListenersAndConns(t, []quic.Version{quic.Version1}) + }) + t.Run("v1 and v2", func(t *testing.T) { + testServerAltSvcFromListenersAndConns(t, []quic.Version{quic.Version1, quic.Version2}) + }) +} + +func testServerAltSvcFromListenersAndConns(t *testing.T, versions []quic.Version) { + conn1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + require.NoError(t, err) + t.Cleanup(func() { conn1.Close() }) + ln1, err := quic.ListenEarly(conn1, testdata.GetTLSConfig(), nil) + require.NoError(t, err) + port1 := ln1.Addr().(*net.UDPAddr).Port + + s := &Server{ + TLSConfig: testdata.GetTLSConfig(), + QUICConfig: &quic.Config{Versions: versions}, + } + done1 := make(chan struct{}) + go func() { + defer close(done1) + s.ServeListener(ln1) + }() + time.Sleep(scaleDuration(10 * time.Millisecond)) + altSvc, ok := getAltSvc(s) + require.True(t, ok) + require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000`, port1), altSvc) + + conn2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + require.NoError(t, err) + t.Cleanup(func() { conn2.Close() }) + port2 := conn2.LocalAddr().(*net.UDPAddr).Port + done2 := make(chan struct{}) + go func() { + defer close(done2) + s.Serve(conn2) + }() + time.Sleep(scaleDuration(10 * time.Millisecond)) + altSvc, ok = getAltSvc(s) + require.True(t, ok) + require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000,h3=":%d"; ma=2592000`, port1, port2), altSvc) + + // Close the first listener. + // This should remove the associated Alt-Svc entry. + require.NoError(t, ln1.Close()) + select { + case <-done1: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + altSvc, ok = getAltSvc(s) + require.True(t, ok) + require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000`, port2), altSvc) + + // Close the second listener. + // This should remove the Alt-Svc entry altogether. + require.NoError(t, conn2.Close()) + select { + case <-done2: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + _, ok = getAltSvc(s) + require.False(t, ok) +}