fix race conditions in h2quic.Server

This commit is contained in:
Lucas Clemente
2016-06-02 21:19:43 +02:00
parent cdff2e5531
commit bd228ffe27
2 changed files with 40 additions and 14 deletions

View File

@@ -8,6 +8,7 @@ import (
"net" "net"
"net/http" "net/http"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
@@ -29,8 +30,10 @@ type Server struct {
// Private flag for demo, do not use // Private flag for demo, do not use
CloseAfterFirstRequest bool CloseAfterFirstRequest bool
port int port uint32 // used atomically
server *quic.Server
server *quic.Server
serverMutex sync.Mutex
} }
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
@@ -38,19 +41,27 @@ func (s *Server) ListenAndServe() error {
if s.Server == nil { if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server") return errors.New("use of h2quic.Server without http.Server")
} }
s.serverMutex.Lock()
if s.server != nil { if s.server != nil {
s.serverMutex.Unlock()
return errors.New("ListenAndServe may only be called once") return errors.New("ListenAndServe may only be called once")
} }
var err error var err error
s.server, err = quic.NewServer(s.Addr, s.TLSConfig, s.handleStreamCb) server, err := quic.NewServer(s.Addr, s.TLSConfig, s.handleStreamCb)
if err != nil { if err != nil {
s.serverMutex.Unlock()
return err return err
} }
return s.server.ListenAndServe() s.server = server
s.serverMutex.Unlock()
return server.ListenAndServe()
} }
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server")
}
var err error var err error
certs := make([]tls.Certificate, 1) certs := make([]tls.Certificate, 1)
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
@@ -62,14 +73,19 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
config := &tls.Config{ config := &tls.Config{
Certificates: certs, Certificates: certs,
} }
s.serverMutex.Lock()
if s.server != nil { if s.server != nil {
s.serverMutex.Unlock()
return errors.New("ListenAndServe may only be called once") return errors.New("ListenAndServe may only be called once")
} }
s.server, err = quic.NewServer(s.Addr, config, s.handleStreamCb) server, err := quic.NewServer(s.Addr, config, s.handleStreamCb)
if err != nil { if err != nil {
s.serverMutex.Unlock()
return err return err
} }
return s.server.ListenAndServe() s.server = server
s.serverMutex.Unlock()
return server.ListenAndServe()
} }
// Serve should not be called, since it only works properly for TCP listeners. // Serve should not be called, since it only works properly for TCP listeners.
@@ -155,8 +171,12 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients
func (s *Server) Close() error { func (s *Server) Close() error {
s.serverMutex.Lock()
defer s.serverMutex.Unlock()
if s.server != nil { if s.server != nil {
return s.server.Close() err := s.server.Close()
s.server = nil
return err
} }
return nil return nil
} }
@@ -172,21 +192,24 @@ func (s *Server) CloseGracefully(timeout time.Duration) error {
// Alternate-Protocol: 443:quic // Alternate-Protocol: 443:quic
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30" // Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
func (s *Server) SetQuicHeaders(hdr http.Header) error { func (s *Server) SetQuicHeaders(hdr http.Header) error {
if s.port == 0 { port := atomic.LoadUint32(&s.port)
if port == 0 {
// Extract port from s.Server.Addr // Extract port from s.Server.Addr
_, portStr, err := net.SplitHostPort(s.Server.Addr) _, portStr, err := net.SplitHostPort(s.Server.Addr)
if err != nil { if err != nil {
return err return err
} }
port, err := net.LookupPort("tcp", portStr) portInt, err := net.LookupPort("tcp", portStr)
if err != nil { if err != nil {
return err return err
} }
s.port = port port = uint32(portInt)
atomic.StoreUint32(&s.port, port)
} }
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", s.port)) hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", port))
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, s.port, protocol.SupportedVersionsAsString)) hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, protocol.SupportedVersionsAsString))
return nil return nil
} }

View File

@@ -212,6 +212,11 @@ var _ = Describe("H2 server", func() {
Expect(err).To(MatchError("use of h2quic.Server without http.Server")) Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
}) })
It("should error when ListenAndServeTLS is called with s.Server nil", func() {
err := (&Server{}).ListenAndServeTLS("", "")
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
})
It("should nop-Close() when s.server is nil", func() { It("should nop-Close() when s.server is nil", func() {
err := (&Server{}).Close() err := (&Server{}).Close()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@@ -223,7 +228,6 @@ var _ = Describe("H2 server", func() {
}) })
AfterEach(func() { AfterEach(func() {
time.Sleep(10 * time.Millisecond)
err := s.Close() err := s.Close()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
@@ -264,7 +268,6 @@ var _ = Describe("H2 server", func() {
}) })
AfterEach(func() { AfterEach(func() {
time.Sleep(10 * time.Millisecond)
err := s.Close() err := s.Close()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })