fix race conditions in h2quic.Server

This commit is contained in:
Lucas Clemente
2016-06-02 21:19:43 +02:00
parent cdff2e5531
commit bd228ffe27
2 changed files with 40 additions and 14 deletions

View File

@@ -8,6 +8,7 @@ import (
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go"
@@ -29,8 +30,10 @@ type Server struct {
// Private flag for demo, do not use
CloseAfterFirstRequest bool
port int
port uint32 // used atomically
server *quic.Server
serverMutex sync.Mutex
}
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
@@ -38,19 +41,27 @@ func (s *Server) ListenAndServe() error {
if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server")
}
s.serverMutex.Lock()
if s.server != nil {
s.serverMutex.Unlock()
return errors.New("ListenAndServe may only be called once")
}
var err error
s.server, err = quic.NewServer(s.Addr, s.TLSConfig, s.handleStreamCb)
server, err := quic.NewServer(s.Addr, s.TLSConfig, s.handleStreamCb)
if err != nil {
s.serverMutex.Unlock()
return err
}
return s.server.ListenAndServe()
s.server = server
s.serverMutex.Unlock()
return server.ListenAndServe()
}
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server")
}
var err error
certs := make([]tls.Certificate, 1)
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
@@ -62,14 +73,19 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
config := &tls.Config{
Certificates: certs,
}
s.serverMutex.Lock()
if s.server != nil {
s.serverMutex.Unlock()
return errors.New("ListenAndServe may only be called once")
}
s.server, err = quic.NewServer(s.Addr, config, s.handleStreamCb)
server, err := quic.NewServer(s.Addr, config, s.handleStreamCb)
if err != nil {
s.serverMutex.Unlock()
return err
}
return s.server.ListenAndServe()
s.server = server
s.serverMutex.Unlock()
return server.ListenAndServe()
}
// Serve should not be called, since it only works properly for TCP listeners.
@@ -155,8 +171,12 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients
func (s *Server) Close() error {
s.serverMutex.Lock()
defer s.serverMutex.Unlock()
if s.server != nil {
return s.server.Close()
err := s.server.Close()
s.server = nil
return err
}
return nil
}
@@ -172,21 +192,24 @@ func (s *Server) CloseGracefully(timeout time.Duration) error {
// Alternate-Protocol: 443:quic
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
func (s *Server) SetQuicHeaders(hdr http.Header) error {
if s.port == 0 {
port := atomic.LoadUint32(&s.port)
if port == 0 {
// Extract port from s.Server.Addr
_, portStr, err := net.SplitHostPort(s.Server.Addr)
if err != nil {
return err
}
port, err := net.LookupPort("tcp", portStr)
portInt, err := net.LookupPort("tcp", portStr)
if err != nil {
return err
}
s.port = port
port = uint32(portInt)
atomic.StoreUint32(&s.port, port)
}
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", s.port))
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, s.port, protocol.SupportedVersionsAsString))
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", port))
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, protocol.SupportedVersionsAsString))
return nil
}

View File

@@ -212,6 +212,11 @@ var _ = Describe("H2 server", func() {
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
})
It("should error when ListenAndServeTLS is called with s.Server nil", func() {
err := (&Server{}).ListenAndServeTLS("", "")
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
})
It("should nop-Close() when s.server is nil", func() {
err := (&Server{}).Close()
Expect(err).NotTo(HaveOccurred())
@@ -223,7 +228,6 @@ var _ = Describe("H2 server", func() {
})
AfterEach(func() {
time.Sleep(10 * time.Millisecond)
err := s.Close()
Expect(err).NotTo(HaveOccurred())
})
@@ -264,7 +268,6 @@ var _ = Describe("H2 server", func() {
})
AfterEach(func() {
time.Sleep(10 * time.Millisecond)
err := s.Close()
Expect(err).NotTo(HaveOccurred())
})