diff --git a/h2quic/server.go b/h2quic/server.go index 1ef5c924..9c4178e5 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "sync" + "sync/atomic" "time" "github.com/lucas-clemente/quic-go" @@ -29,8 +30,10 @@ type Server struct { // Private flag for demo, do not use CloseAfterFirstRequest bool - port int - server *quic.Server + port uint32 // used atomically + + server *quic.Server + serverMutex sync.Mutex } // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. @@ -38,19 +41,27 @@ func (s *Server) ListenAndServe() error { if s.Server == nil { return errors.New("use of h2quic.Server without http.Server") } + s.serverMutex.Lock() if s.server != nil { + s.serverMutex.Unlock() return errors.New("ListenAndServe may only be called once") } var err error - s.server, err = quic.NewServer(s.Addr, s.TLSConfig, s.handleStreamCb) + server, err := quic.NewServer(s.Addr, s.TLSConfig, s.handleStreamCb) if err != nil { + s.serverMutex.Unlock() return err } - return s.server.ListenAndServe() + s.server = server + s.serverMutex.Unlock() + return server.ListenAndServe() } // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { + if s.Server == nil { + return errors.New("use of h2quic.Server without http.Server") + } var err error certs := make([]tls.Certificate, 1) certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) @@ -62,14 +73,19 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { config := &tls.Config{ Certificates: certs, } + s.serverMutex.Lock() if s.server != nil { + s.serverMutex.Unlock() return errors.New("ListenAndServe may only be called once") } - s.server, err = quic.NewServer(s.Addr, config, s.handleStreamCb) + server, err := quic.NewServer(s.Addr, config, s.handleStreamCb) if err != nil { + s.serverMutex.Unlock() return err } - return s.server.ListenAndServe() + s.server = server + s.serverMutex.Unlock() + return server.ListenAndServe() } // Serve should not be called, since it only works properly for TCP listeners. @@ -155,8 +171,12 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients func (s *Server) Close() error { + s.serverMutex.Lock() + defer s.serverMutex.Unlock() if s.server != nil { - return s.server.Close() + err := s.server.Close() + s.server = nil + return err } return nil } @@ -172,21 +192,24 @@ func (s *Server) CloseGracefully(timeout time.Duration) error { // Alternate-Protocol: 443:quic // Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30" func (s *Server) SetQuicHeaders(hdr http.Header) error { - if s.port == 0 { + port := atomic.LoadUint32(&s.port) + + if port == 0 { // Extract port from s.Server.Addr _, portStr, err := net.SplitHostPort(s.Server.Addr) if err != nil { return err } - port, err := net.LookupPort("tcp", portStr) + portInt, err := net.LookupPort("tcp", portStr) if err != nil { return err } - s.port = port + port = uint32(portInt) + atomic.StoreUint32(&s.port, port) } - hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", s.port)) - hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, s.port, protocol.SupportedVersionsAsString)) + hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", port)) + hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, protocol.SupportedVersionsAsString)) return nil } diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 097540fb..aa6f1b21 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -212,6 +212,11 @@ var _ = Describe("H2 server", func() { Expect(err).To(MatchError("use of h2quic.Server without http.Server")) }) + It("should error when ListenAndServeTLS is called with s.Server nil", func() { + err := (&Server{}).ListenAndServeTLS("", "") + Expect(err).To(MatchError("use of h2quic.Server without http.Server")) + }) + It("should nop-Close() when s.server is nil", func() { err := (&Server{}).Close() Expect(err).NotTo(HaveOccurred()) @@ -223,7 +228,6 @@ var _ = Describe("H2 server", func() { }) AfterEach(func() { - time.Sleep(10 * time.Millisecond) err := s.Close() Expect(err).NotTo(HaveOccurred()) }) @@ -264,7 +268,6 @@ var _ = Describe("H2 server", func() { }) AfterEach(func() { - time.Sleep(10 * time.Millisecond) err := s.Close() Expect(err).NotTo(HaveOccurred()) })