From b7e93b54c9ccade8daa6346a83788f1f80651342 Mon Sep 17 00:00:00 2001 From: Artem Mikheev <30644072+renbou@users.noreply.github.com> Date: Mon, 21 Mar 2022 12:20:29 +0300 Subject: [PATCH] Implement http3.Server.ServeListener (#3349) * feat(http3): implement serving from quic.Listener ServeListener method added to http3.Server allowing serving from an existing listener ConfigureTLSConfig function added to http3 which should be used to create listeners meant for serving http3. * docs(http3): add note about using ConfigureTLSConfig to ServeListener * fix(http3): stop serving non-created listeners after Server.Close * refactor(http3): return ErrServerClosed once server closes instead of context.Canceled * feat(http3): close listeners from ServeListener as well * fix(http3): fix logger not being setup during ServeListener * test(http3): add unit tests for serving listeners * test(http3): add tests for ConfigureTLSConfig * test(http3): added server hotswapping integration test * fix: race condition in listener tests --- http3/server.go | 130 +++++++++++------- http3/server_test.go | 117 ++++++++++++++++ integrationtests/self/hotswap_test.go | 190 ++++++++++++++++++++++++++ 3 files changed, 385 insertions(+), 52 deletions(-) create mode 100644 integrationtests/self/hotswap_test.go diff --git a/http3/server.go b/http3/server.go index 2ae3fef5a..c568decbd 100644 --- a/http3/server.go +++ b/http3/server.go @@ -51,6 +51,44 @@ func versionToALPN(v protocol.VersionNumber) string { return "" } +// ConfigureTLSConfig creates a new tls.Config which can be used +// to create a quic.Listener meant for serving http3. The created +// tls.Config adds the functionality of detecting the used QUIC version +// in order to set the correct ALPN value for the http3 connection. +func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config { + // The tls.Config used to setup the quic.Listener needs to have the GetConfigForClient callback set. + // That way, we can get the QUIC version and set the correct ALPN value. + return &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + // determine the ALPN from the QUIC version used + proto := nextProtoH3Draft29 + if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok { + if qconn.GetQUICVersion() == protocol.Version1 { + proto = nextProtoH3 + } + } + config := tlsConf + if tlsConf.GetConfigForClient != nil { + getConfigForClient := tlsConf.GetConfigForClient + var err error + conf, err := getConfigForClient(ch) + if err != nil { + return nil, err + } + if conf != nil { + config = conf + } + } + if config == nil { + return nil, nil + } + config = config.Clone() + config.NextProtos = []string{proto} + return config, nil + }, + } +} + // contextKey is a value for use with context.WithValue. It's used as // a pointer so it fits in an interface{} without allocation. type contextKey struct { @@ -111,7 +149,7 @@ func (s *Server) ListenAndServe() error { if s.Server == nil { return errors.New("use of http3.Server without http.Server") } - return s.serveImpl(s.TLSConfig, nil) + return s.serveConn(s.TLSConfig, nil) } // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. @@ -127,17 +165,52 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { config := &tls.Config{ Certificates: certs, } - return s.serveImpl(config, nil) + return s.serveConn(config, nil) } // Serve an existing UDP connection. // It is possible to reuse the same connection for outgoing connections. // Closing the server does not close the packet conn. func (s *Server) Serve(conn net.PacketConn) error { - return s.serveImpl(s.TLSConfig, conn) + return s.serveConn(s.TLSConfig, conn) } -func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { +// Serve an existing QUIC listener. +// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config +// and use it to construct a http3-friendly QUIC listener. +// Closing the server does close the listener. +func (s *Server) ServeListener(listener quic.EarlyListener) error { + return s.serveImpl(func() (quic.EarlyListener, error) { return listener, nil }) +} + +func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { + return s.serveImpl(func() (quic.EarlyListener, error) { + baseConf := ConfigureTLSConfig(tlsConf) + quicConf := s.QuicConfig + if quicConf == nil { + quicConf = &quic.Config{} + } else { + quicConf = s.QuicConfig.Clone() + } + if s.EnableDatagrams { + quicConf.EnableDatagrams = true + } + + var ln quic.EarlyListener + var err error + if conn == nil { + ln, err = quicListenAddr(s.Addr, baseConf, quicConf) + } else { + ln, err = quicListen(conn, baseConf, quicConf) + } + if err != nil { + return nil, err + } + return ln, nil + }) +} + +func (s *Server) serveImpl(startListener func() (quic.EarlyListener, error)) error { if s.closed.Get() { return http.ErrServerClosed } @@ -148,54 +221,7 @@ func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { s.logger = utils.DefaultLogger.WithPrefix("server") }) - // The tls.Config we pass to Listen needs to have the GetConfigForClient callback set. - // That way, we can get the QUIC version and set the correct ALPN value. - baseConf := &tls.Config{ - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - // determine the ALPN from the QUIC version used - proto := nextProtoH3Draft29 - if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok { - if qconn.GetQUICVersion() == protocol.Version1 { - proto = nextProtoH3 - } - } - config := tlsConf - if tlsConf.GetConfigForClient != nil { - getConfigForClient := tlsConf.GetConfigForClient - var err error - conf, err := getConfigForClient(ch) - if err != nil { - return nil, err - } - if conf != nil { - config = conf - } - } - if config == nil { - return nil, nil - } - config = config.Clone() - config.NextProtos = []string{proto} - return config, nil - }, - } - - var ln quic.EarlyListener - var err error - quicConf := s.QuicConfig - if quicConf == nil { - quicConf = &quic.Config{} - } else { - quicConf = s.QuicConfig.Clone() - } - if s.EnableDatagrams { - quicConf.EnableDatagrams = true - } - if conn == nil { - ln, err = quicListenAddr(s.Addr, baseConf, quicConf) - } else { - ln, err = quicListen(conn, baseConf, quicConf) - } + ln, err := startListener() if err != nil { return err } diff --git a/http3/server_test.go b/http3/server_test.go index 02e9c4166..0ff6bdf8e 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/http" + "sync/atomic" "time" "github.com/lucas-clemente/quic-go" @@ -619,6 +620,35 @@ var _ = Describe("Server", func() { Expect(serv.ListenAndServe()).To(MatchError(http.ErrServerClosed)) }) + Context("ConfigureTLSConfig", func() { + var tlsConf *tls.Config + var ch *tls.ClientHelloInfo + + BeforeEach(func() { + tlsConf = &tls.Config{} + ch = &tls.ClientHelloInfo{} + }) + + It("advertises draft by default", func() { + tlsConf = ConfigureTLSConfig(tlsConf) + Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) + + config, err := tlsConf.GetConfigForClient(ch) + Expect(err).NotTo(HaveOccurred()) + Expect(config.NextProtos).To(Equal([]string{nextProtoH3Draft29})) + }) + + It("advertises h3 for quic version 1", func() { + tlsConf = ConfigureTLSConfig(tlsConf) + Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) + + ch.Conn = newMockConn(protocol.Version1) + config, err := tlsConf.GetConfigForClient(ch) + Expect(err).NotTo(HaveOccurred()) + Expect(config.NextProtos).To(Equal([]string{nextProtoH3})) + }) + }) + Context("Serve", func() { origQuicListen := quicListen @@ -704,6 +734,93 @@ var _ = Describe("Server", func() { }) }) + Context("ServeListener", func() { + origQuicListen := quicListen + + AfterEach(func() { + quicListen = origQuicListen + }) + + It("serves a listener", func() { + var called int32 + ln := mockquic.NewMockEarlyListener(mockCtrl) + quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + atomic.StoreInt32(&called, 1) + return ln, nil + } + + s := &Server{Server: &http.Server{}} + s.TLSConfig = &tls.Config{} + + stopAccept := make(chan struct{}) + ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + <-stopAccept + return nil, errors.New("closed") + }) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + s.ServeListener(ln) + }() + + Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) + Consistently(done).ShouldNot(BeClosed()) + ln.EXPECT().Close().Do(func() { close(stopAccept) }) + Expect(s.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("serves two listeners", func() { + var called int32 + ln1 := mockquic.NewMockEarlyListener(mockCtrl) + ln2 := mockquic.NewMockEarlyListener(mockCtrl) + lns := make(chan quic.EarlyListener, 2) + lns <- ln1 + lns <- ln2 + quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + atomic.StoreInt32(&called, 1) + return <-lns, nil + } + + s := &Server{Server: &http.Server{}} + s.TLSConfig = &tls.Config{} + + stopAccept1 := make(chan struct{}) + ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + <-stopAccept1 + return nil, errors.New("closed") + }) + stopAccept2 := make(chan struct{}) + ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + <-stopAccept2 + return nil, errors.New("closed") + }) + + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done1) + s.ServeListener(ln1) + }() + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done2) + s.ServeListener(ln2) + }() + + Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) + Consistently(done1).ShouldNot(BeClosed()) + Expect(done2).ToNot(BeClosed()) + ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) + ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) + Expect(s.Close()).To(Succeed()) + Eventually(done1).Should(BeClosed()) + Eventually(done2).Should(BeClosed()) + }) + }) + Context("ListenAndServe", func() { BeforeEach(func() { s.Server.Addr = "localhost:0" diff --git a/integrationtests/self/hotswap_test.go b/integrationtests/self/hotswap_test.go new file mode 100644 index 000000000..112e63166 --- /dev/null +++ b/integrationtests/self/hotswap_test.go @@ -0,0 +1,190 @@ +package self_test + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "strconv" + "sync/atomic" + "time" + + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/http3" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/testdata" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +type listenerWrapper struct { + quic.EarlyListener + listenerClosed bool + count int32 +} + +func (ln *listenerWrapper) Close() error { + ln.listenerClosed = true + return ln.EarlyListener.Close() +} + +func (ln *listenerWrapper) Faker() *fakeClosingListener { + atomic.AddInt32(&ln.count, 1) + ctx, cancel := context.WithCancel(context.Background()) + return &fakeClosingListener{ln, 0, ctx, cancel} +} + +type fakeClosingListener struct { + *listenerWrapper + closed int32 + ctx context.Context + cancel context.CancelFunc +} + +func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlySession, error) { + Expect(ctx).To(Equal(context.Background())) + return ln.listenerWrapper.Accept(ln.ctx) +} + +func (ln *fakeClosingListener) Close() error { + if atomic.CompareAndSwapInt32(&ln.closed, 0, 1) { + ln.cancel() + if atomic.AddInt32(&ln.listenerWrapper.count, -1) == 0 { + ln.listenerWrapper.Close() + } + } + return nil +} + +var _ = Describe("HTTP3 Server hotswap test", func() { + var ( + mux1 *http.ServeMux + mux2 *http.ServeMux + client *http.Client + server1 *http3.Server + server2 *http3.Server + ln *listenerWrapper + port string + ) + + versions := protocol.SupportedVersions + + BeforeEach(func() { + mux1 = http.NewServeMux() + mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset. + }) + + mux2 = http.NewServeMux() + mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset. + }) + + server1 = &http3.Server{ + Server: &http.Server{ + Handler: mux1, + TLSConfig: testdata.GetTLSConfig(), + }, + QuicConfig: getQuicConfig(&quic.Config{Versions: versions}), + } + + server2 = &http3.Server{ + Server: &http.Server{ + Handler: mux2, + TLSConfig: testdata.GetTLSConfig(), + }, + QuicConfig: getQuicConfig(&quic.Config{Versions: versions}), + } + + tlsConf := http3.ConfigureTLSConfig(testdata.GetTLSConfig()) + quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(&quic.Config{Versions: versions})) + ln = &listenerWrapper{EarlyListener: quicln} + Expect(err).NotTo(HaveOccurred()) + port = strconv.Itoa(ln.Addr().(*net.UDPAddr).Port) + }) + + AfterEach(func() { + Expect(ln.Close()).NotTo(HaveOccurred()) + }) + + for _, v := range versions { + version := v + + Context(fmt.Sprintf("with QUIC version %s", version), func() { + BeforeEach(func() { + client = &http.Client{ + Transport: &http3.RoundTripper{ + TLSClientConfig: &tls.Config{ + RootCAs: testdata.GetRootCA(), + }, + DisableCompression: true, + QuicConfig: getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + MaxIdleTimeout: 10 * time.Second, + }), + }, + } + }) + + It("hotswap works", func() { + // open first server and make single request to it + fake1 := ln.Faker() + stoppedServing1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + server1.ServeListener(fake1) + close(stoppedServing1) + }() + + resp, err := client.Get("https://localhost:" + port + "/hello1") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World 1!\n")) + + // open second server with same underlying listener, + // make sure it opened and both servers are currently running + fake2 := ln.Faker() + stoppedServing2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + server2.ServeListener(fake2) + close(stoppedServing2) + }() + + Consistently(stoppedServing1).ShouldNot(BeClosed()) + Consistently(stoppedServing2).ShouldNot(BeClosed()) + + // now close first server, no errors should occur here + // and only the fake listener should be closed + Expect(server1.Close()).NotTo(HaveOccurred()) + Eventually(stoppedServing1).Should(BeClosed()) + Expect(fake1.closed).To(Equal(int32(1))) + Expect(fake2.closed).To(Equal(int32(0))) + Expect(ln.listenerClosed).ToNot(BeTrue()) + Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred()) + + // verify that new sessions are being initiated from the second server now + resp, err = client.Get("https://localhost:" + port + "/hello2") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err = io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World 2!\n")) + + // close the other server - both the fake and the actual listeners must close now + Expect(server2.Close()).NotTo(HaveOccurred()) + Eventually(stoppedServing2).Should(BeClosed()) + Expect(fake2.closed).To(Equal(int32(1))) + Expect(ln.listenerClosed).To(BeTrue()) + }) + }) + } +})