forked from quic-go/quic-go
fix race conditions in h2quic.Server
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
@@ -29,8 +30,10 @@ type Server struct {
|
||||
// Private flag for demo, do not use
|
||||
CloseAfterFirstRequest bool
|
||||
|
||||
port int
|
||||
port uint32 // used atomically
|
||||
|
||||
server *quic.Server
|
||||
serverMutex sync.Mutex
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
||||
@@ -38,19 +41,27 @@ func (s *Server) ListenAndServe() error {
|
||||
if s.Server == nil {
|
||||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
s.serverMutex.Lock()
|
||||
if s.server != nil {
|
||||
s.serverMutex.Unlock()
|
||||
return errors.New("ListenAndServe may only be called once")
|
||||
}
|
||||
var err error
|
||||
s.server, err = quic.NewServer(s.Addr, s.TLSConfig, s.handleStreamCb)
|
||||
server, err := quic.NewServer(s.Addr, s.TLSConfig, s.handleStreamCb)
|
||||
if err != nil {
|
||||
s.serverMutex.Unlock()
|
||||
return err
|
||||
}
|
||||
return s.server.ListenAndServe()
|
||||
s.server = server
|
||||
s.serverMutex.Unlock()
|
||||
return server.ListenAndServe()
|
||||
}
|
||||
|
||||
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
||||
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
||||
if s.Server == nil {
|
||||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
@@ -62,14 +73,19 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
||||
config := &tls.Config{
|
||||
Certificates: certs,
|
||||
}
|
||||
s.serverMutex.Lock()
|
||||
if s.server != nil {
|
||||
s.serverMutex.Unlock()
|
||||
return errors.New("ListenAndServe may only be called once")
|
||||
}
|
||||
s.server, err = quic.NewServer(s.Addr, config, s.handleStreamCb)
|
||||
server, err := quic.NewServer(s.Addr, config, s.handleStreamCb)
|
||||
if err != nil {
|
||||
s.serverMutex.Unlock()
|
||||
return err
|
||||
}
|
||||
return s.server.ListenAndServe()
|
||||
s.server = server
|
||||
s.serverMutex.Unlock()
|
||||
return server.ListenAndServe()
|
||||
}
|
||||
|
||||
// Serve should not be called, since it only works properly for TCP listeners.
|
||||
@@ -155,8 +171,12 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
|
||||
|
||||
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients
|
||||
func (s *Server) Close() error {
|
||||
s.serverMutex.Lock()
|
||||
defer s.serverMutex.Unlock()
|
||||
if s.server != nil {
|
||||
return s.server.Close()
|
||||
err := s.server.Close()
|
||||
s.server = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -172,21 +192,24 @@ func (s *Server) CloseGracefully(timeout time.Duration) error {
|
||||
// Alternate-Protocol: 443:quic
|
||||
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
|
||||
func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
||||
if s.port == 0 {
|
||||
port := atomic.LoadUint32(&s.port)
|
||||
|
||||
if port == 0 {
|
||||
// Extract port from s.Server.Addr
|
||||
_, portStr, err := net.SplitHostPort(s.Server.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
port, err := net.LookupPort("tcp", portStr)
|
||||
portInt, err := net.LookupPort("tcp", portStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.port = port
|
||||
port = uint32(portInt)
|
||||
atomic.StoreUint32(&s.port, port)
|
||||
}
|
||||
|
||||
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", s.port))
|
||||
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, s.port, protocol.SupportedVersionsAsString))
|
||||
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", port))
|
||||
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, protocol.SupportedVersionsAsString))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -212,6 +212,11 @@ var _ = Describe("H2 server", func() {
|
||||
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
|
||||
})
|
||||
|
||||
It("should error when ListenAndServeTLS is called with s.Server nil", func() {
|
||||
err := (&Server{}).ListenAndServeTLS("", "")
|
||||
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
|
||||
})
|
||||
|
||||
It("should nop-Close() when s.server is nil", func() {
|
||||
err := (&Server{}).Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
@@ -223,7 +228,6 @@ var _ = Describe("H2 server", func() {
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
err := s.Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
@@ -264,7 +268,6 @@ var _ = Describe("H2 server", func() {
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
err := s.Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user