forked from quic-go/quic-go
fix race conditions in h2quic.Server
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
@@ -29,8 +30,10 @@ type Server struct {
|
|||||||
// Private flag for demo, do not use
|
// Private flag for demo, do not use
|
||||||
CloseAfterFirstRequest bool
|
CloseAfterFirstRequest bool
|
||||||
|
|
||||||
port int
|
port uint32 // used atomically
|
||||||
server *quic.Server
|
|
||||||
|
server *quic.Server
|
||||||
|
serverMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
||||||
@@ -38,19 +41,27 @@ func (s *Server) ListenAndServe() error {
|
|||||||
if s.Server == nil {
|
if s.Server == nil {
|
||||||
return errors.New("use of h2quic.Server without http.Server")
|
return errors.New("use of h2quic.Server without http.Server")
|
||||||
}
|
}
|
||||||
|
s.serverMutex.Lock()
|
||||||
if s.server != nil {
|
if s.server != nil {
|
||||||
|
s.serverMutex.Unlock()
|
||||||
return errors.New("ListenAndServe may only be called once")
|
return errors.New("ListenAndServe may only be called once")
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
s.server, err = quic.NewServer(s.Addr, s.TLSConfig, s.handleStreamCb)
|
server, err := quic.NewServer(s.Addr, s.TLSConfig, s.handleStreamCb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.serverMutex.Unlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.server.ListenAndServe()
|
s.server = server
|
||||||
|
s.serverMutex.Unlock()
|
||||||
|
return server.ListenAndServe()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
||||||
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
||||||
|
if s.Server == nil {
|
||||||
|
return errors.New("use of h2quic.Server without http.Server")
|
||||||
|
}
|
||||||
var err error
|
var err error
|
||||||
certs := make([]tls.Certificate, 1)
|
certs := make([]tls.Certificate, 1)
|
||||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
@@ -62,14 +73,19 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
|||||||
config := &tls.Config{
|
config := &tls.Config{
|
||||||
Certificates: certs,
|
Certificates: certs,
|
||||||
}
|
}
|
||||||
|
s.serverMutex.Lock()
|
||||||
if s.server != nil {
|
if s.server != nil {
|
||||||
|
s.serverMutex.Unlock()
|
||||||
return errors.New("ListenAndServe may only be called once")
|
return errors.New("ListenAndServe may only be called once")
|
||||||
}
|
}
|
||||||
s.server, err = quic.NewServer(s.Addr, config, s.handleStreamCb)
|
server, err := quic.NewServer(s.Addr, config, s.handleStreamCb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.serverMutex.Unlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.server.ListenAndServe()
|
s.server = server
|
||||||
|
s.serverMutex.Unlock()
|
||||||
|
return server.ListenAndServe()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve should not be called, since it only works properly for TCP listeners.
|
// Serve should not be called, since it only works properly for TCP listeners.
|
||||||
@@ -155,8 +171,12 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
|
|||||||
|
|
||||||
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients
|
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients
|
||||||
func (s *Server) Close() error {
|
func (s *Server) Close() error {
|
||||||
|
s.serverMutex.Lock()
|
||||||
|
defer s.serverMutex.Unlock()
|
||||||
if s.server != nil {
|
if s.server != nil {
|
||||||
return s.server.Close()
|
err := s.server.Close()
|
||||||
|
s.server = nil
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -172,21 +192,24 @@ func (s *Server) CloseGracefully(timeout time.Duration) error {
|
|||||||
// Alternate-Protocol: 443:quic
|
// Alternate-Protocol: 443:quic
|
||||||
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
|
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
|
||||||
func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
||||||
if s.port == 0 {
|
port := atomic.LoadUint32(&s.port)
|
||||||
|
|
||||||
|
if port == 0 {
|
||||||
// Extract port from s.Server.Addr
|
// Extract port from s.Server.Addr
|
||||||
_, portStr, err := net.SplitHostPort(s.Server.Addr)
|
_, portStr, err := net.SplitHostPort(s.Server.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
port, err := net.LookupPort("tcp", portStr)
|
portInt, err := net.LookupPort("tcp", portStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.port = port
|
port = uint32(portInt)
|
||||||
|
atomic.StoreUint32(&s.port, port)
|
||||||
}
|
}
|
||||||
|
|
||||||
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", s.port))
|
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", port))
|
||||||
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, s.port, protocol.SupportedVersionsAsString))
|
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, protocol.SupportedVersionsAsString))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -212,6 +212,11 @@ var _ = Describe("H2 server", func() {
|
|||||||
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
|
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("should error when ListenAndServeTLS is called with s.Server nil", func() {
|
||||||
|
err := (&Server{}).ListenAndServeTLS("", "")
|
||||||
|
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
|
||||||
|
})
|
||||||
|
|
||||||
It("should nop-Close() when s.server is nil", func() {
|
It("should nop-Close() when s.server is nil", func() {
|
||||||
err := (&Server{}).Close()
|
err := (&Server{}).Close()
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
@@ -223,7 +228,6 @@ var _ = Describe("H2 server", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
err := s.Close()
|
err := s.Close()
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
})
|
})
|
||||||
@@ -264,7 +268,6 @@ var _ = Describe("H2 server", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
err := s.Close()
|
err := s.Close()
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user