Implement http3.Server.ServeListener (#3349)

* feat(http3): implement serving from quic.Listener

ServeListener method added to http3.Server allowing serving from an existing listener
ConfigureTLSConfig function added to http3 which should be used to create listeners meant for serving http3.

* docs(http3): add note about using ConfigureTLSConfig to ServeListener

* fix(http3): stop serving non-created listeners after Server.Close

* refactor(http3): return ErrServerClosed once server closes instead of context.Canceled

* feat(http3): close listeners from ServeListener as well

* fix(http3): fix logger not being setup during ServeListener

* test(http3): add unit tests for serving listeners

* test(http3): add tests for ConfigureTLSConfig

* test(http3): added server hotswapping integration test

* fix: race condition in listener tests
This commit is contained in:
Artem Mikheev
2022-03-21 12:20:29 +03:00
committed by GitHub
parent 9c8cadba9e
commit b7e93b54c9
3 changed files with 385 additions and 52 deletions

View File

@@ -51,6 +51,44 @@ func versionToALPN(v protocol.VersionNumber) string {
return ""
}
// ConfigureTLSConfig creates a new tls.Config which can be used
// to create a quic.Listener meant for serving http3. The created
// tls.Config adds the functionality of detecting the used QUIC version
// in order to set the correct ALPN value for the http3 connection.
func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
// The tls.Config used to setup the quic.Listener needs to have the GetConfigForClient callback set.
// That way, we can get the QUIC version and set the correct ALPN value.
return &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
// determine the ALPN from the QUIC version used
proto := nextProtoH3Draft29
if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok {
if qconn.GetQUICVersion() == protocol.Version1 {
proto = nextProtoH3
}
}
config := tlsConf
if tlsConf.GetConfigForClient != nil {
getConfigForClient := tlsConf.GetConfigForClient
var err error
conf, err := getConfigForClient(ch)
if err != nil {
return nil, err
}
if conf != nil {
config = conf
}
}
if config == nil {
return nil, nil
}
config = config.Clone()
config.NextProtos = []string{proto}
return config, nil
},
}
}
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
@@ -111,7 +149,7 @@ func (s *Server) ListenAndServe() error {
if s.Server == nil {
return errors.New("use of http3.Server without http.Server")
}
return s.serveImpl(s.TLSConfig, nil)
return s.serveConn(s.TLSConfig, nil)
}
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
@@ -127,17 +165,52 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
config := &tls.Config{
Certificates: certs,
}
return s.serveImpl(config, nil)
return s.serveConn(config, nil)
}
// Serve an existing UDP connection.
// It is possible to reuse the same connection for outgoing connections.
// Closing the server does not close the packet conn.
func (s *Server) Serve(conn net.PacketConn) error {
return s.serveImpl(s.TLSConfig, conn)
return s.serveConn(s.TLSConfig, conn)
}
func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error {
// Serve an existing QUIC listener.
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
// and use it to construct a http3-friendly QUIC listener.
// Closing the server does close the listener.
func (s *Server) ServeListener(listener quic.EarlyListener) error {
return s.serveImpl(func() (quic.EarlyListener, error) { return listener, nil })
}
func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
return s.serveImpl(func() (quic.EarlyListener, error) {
baseConf := ConfigureTLSConfig(tlsConf)
quicConf := s.QuicConfig
if quicConf == nil {
quicConf = &quic.Config{}
} else {
quicConf = s.QuicConfig.Clone()
}
if s.EnableDatagrams {
quicConf.EnableDatagrams = true
}
var ln quic.EarlyListener
var err error
if conn == nil {
ln, err = quicListenAddr(s.Addr, baseConf, quicConf)
} else {
ln, err = quicListen(conn, baseConf, quicConf)
}
if err != nil {
return nil, err
}
return ln, nil
})
}
func (s *Server) serveImpl(startListener func() (quic.EarlyListener, error)) error {
if s.closed.Get() {
return http.ErrServerClosed
}
@@ -148,54 +221,7 @@ func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error {
s.logger = utils.DefaultLogger.WithPrefix("server")
})
// The tls.Config we pass to Listen needs to have the GetConfigForClient callback set.
// That way, we can get the QUIC version and set the correct ALPN value.
baseConf := &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
// determine the ALPN from the QUIC version used
proto := nextProtoH3Draft29
if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok {
if qconn.GetQUICVersion() == protocol.Version1 {
proto = nextProtoH3
}
}
config := tlsConf
if tlsConf.GetConfigForClient != nil {
getConfigForClient := tlsConf.GetConfigForClient
var err error
conf, err := getConfigForClient(ch)
if err != nil {
return nil, err
}
if conf != nil {
config = conf
}
}
if config == nil {
return nil, nil
}
config = config.Clone()
config.NextProtos = []string{proto}
return config, nil
},
}
var ln quic.EarlyListener
var err error
quicConf := s.QuicConfig
if quicConf == nil {
quicConf = &quic.Config{}
} else {
quicConf = s.QuicConfig.Clone()
}
if s.EnableDatagrams {
quicConf.EnableDatagrams = true
}
if conn == nil {
ln, err = quicListenAddr(s.Addr, baseConf, quicConf)
} else {
ln, err = quicListen(conn, baseConf, quicConf)
}
ln, err := startListener()
if err != nil {
return err
}

View File

@@ -9,6 +9,7 @@ import (
"io"
"net"
"net/http"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go"
@@ -619,6 +620,35 @@ var _ = Describe("Server", func() {
Expect(serv.ListenAndServe()).To(MatchError(http.ErrServerClosed))
})
Context("ConfigureTLSConfig", func() {
var tlsConf *tls.Config
var ch *tls.ClientHelloInfo
BeforeEach(func() {
tlsConf = &tls.Config{}
ch = &tls.ClientHelloInfo{}
})
It("advertises draft by default", func() {
tlsConf = ConfigureTLSConfig(tlsConf)
Expect(tlsConf.GetConfigForClient).NotTo(BeNil())
config, err := tlsConf.GetConfigForClient(ch)
Expect(err).NotTo(HaveOccurred())
Expect(config.NextProtos).To(Equal([]string{nextProtoH3Draft29}))
})
It("advertises h3 for quic version 1", func() {
tlsConf = ConfigureTLSConfig(tlsConf)
Expect(tlsConf.GetConfigForClient).NotTo(BeNil())
ch.Conn = newMockConn(protocol.Version1)
config, err := tlsConf.GetConfigForClient(ch)
Expect(err).NotTo(HaveOccurred())
Expect(config.NextProtos).To(Equal([]string{nextProtoH3}))
})
})
Context("Serve", func() {
origQuicListen := quicListen
@@ -704,6 +734,93 @@ var _ = Describe("Server", func() {
})
})
Context("ServeListener", func() {
origQuicListen := quicListen
AfterEach(func() {
quicListen = origQuicListen
})
It("serves a listener", func() {
var called int32
ln := mockquic.NewMockEarlyListener(mockCtrl)
quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) {
atomic.StoreInt32(&called, 1)
return ln, nil
}
s := &Server{Server: &http.Server{}}
s.TLSConfig = &tls.Config{}
stopAccept := make(chan struct{})
ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) {
<-stopAccept
return nil, errors.New("closed")
})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
s.ServeListener(ln)
}()
Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
Consistently(done).ShouldNot(BeClosed())
ln.EXPECT().Close().Do(func() { close(stopAccept) })
Expect(s.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("serves two listeners", func() {
var called int32
ln1 := mockquic.NewMockEarlyListener(mockCtrl)
ln2 := mockquic.NewMockEarlyListener(mockCtrl)
lns := make(chan quic.EarlyListener, 2)
lns <- ln1
lns <- ln2
quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) {
atomic.StoreInt32(&called, 1)
return <-lns, nil
}
s := &Server{Server: &http.Server{}}
s.TLSConfig = &tls.Config{}
stopAccept1 := make(chan struct{})
ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) {
<-stopAccept1
return nil, errors.New("closed")
})
stopAccept2 := make(chan struct{})
ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) {
<-stopAccept2
return nil, errors.New("closed")
})
done1 := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done1)
s.ServeListener(ln1)
}()
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done2)
s.ServeListener(ln2)
}()
Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
Consistently(done1).ShouldNot(BeClosed())
Expect(done2).ToNot(BeClosed())
ln1.EXPECT().Close().Do(func() { close(stopAccept1) })
ln2.EXPECT().Close().Do(func() { close(stopAccept2) })
Expect(s.Close()).To(Succeed())
Eventually(done1).Should(BeClosed())
Eventually(done2).Should(BeClosed())
})
})
Context("ListenAndServe", func() {
BeforeEach(func() {
s.Server.Addr = "localhost:0"

View File

@@ -0,0 +1,190 @@
package self_test
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"strconv"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/http3"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
)
type listenerWrapper struct {
quic.EarlyListener
listenerClosed bool
count int32
}
func (ln *listenerWrapper) Close() error {
ln.listenerClosed = true
return ln.EarlyListener.Close()
}
func (ln *listenerWrapper) Faker() *fakeClosingListener {
atomic.AddInt32(&ln.count, 1)
ctx, cancel := context.WithCancel(context.Background())
return &fakeClosingListener{ln, 0, ctx, cancel}
}
type fakeClosingListener struct {
*listenerWrapper
closed int32
ctx context.Context
cancel context.CancelFunc
}
func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlySession, error) {
Expect(ctx).To(Equal(context.Background()))
return ln.listenerWrapper.Accept(ln.ctx)
}
func (ln *fakeClosingListener) Close() error {
if atomic.CompareAndSwapInt32(&ln.closed, 0, 1) {
ln.cancel()
if atomic.AddInt32(&ln.listenerWrapper.count, -1) == 0 {
ln.listenerWrapper.Close()
}
}
return nil
}
var _ = Describe("HTTP3 Server hotswap test", func() {
var (
mux1 *http.ServeMux
mux2 *http.ServeMux
client *http.Client
server1 *http3.Server
server2 *http3.Server
ln *listenerWrapper
port string
)
versions := protocol.SupportedVersions
BeforeEach(func() {
mux1 = http.NewServeMux()
mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset.
})
mux2 = http.NewServeMux()
mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset.
})
server1 = &http3.Server{
Server: &http.Server{
Handler: mux1,
TLSConfig: testdata.GetTLSConfig(),
},
QuicConfig: getQuicConfig(&quic.Config{Versions: versions}),
}
server2 = &http3.Server{
Server: &http.Server{
Handler: mux2,
TLSConfig: testdata.GetTLSConfig(),
},
QuicConfig: getQuicConfig(&quic.Config{Versions: versions}),
}
tlsConf := http3.ConfigureTLSConfig(testdata.GetTLSConfig())
quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(&quic.Config{Versions: versions}))
ln = &listenerWrapper{EarlyListener: quicln}
Expect(err).NotTo(HaveOccurred())
port = strconv.Itoa(ln.Addr().(*net.UDPAddr).Port)
})
AfterEach(func() {
Expect(ln.Close()).NotTo(HaveOccurred())
})
for _, v := range versions {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {
BeforeEach(func() {
client = &http.Client{
Transport: &http3.RoundTripper{
TLSClientConfig: &tls.Config{
RootCAs: testdata.GetRootCA(),
},
DisableCompression: true,
QuicConfig: getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
MaxIdleTimeout: 10 * time.Second,
}),
},
}
})
It("hotswap works", func() {
// open first server and make single request to it
fake1 := ln.Faker()
stoppedServing1 := make(chan struct{})
go func() {
defer GinkgoRecover()
server1.ServeListener(fake1)
close(stoppedServing1)
}()
resp, err := client.Get("https://localhost:" + port + "/hello1")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal("Hello, World 1!\n"))
// open second server with same underlying listener,
// make sure it opened and both servers are currently running
fake2 := ln.Faker()
stoppedServing2 := make(chan struct{})
go func() {
defer GinkgoRecover()
server2.ServeListener(fake2)
close(stoppedServing2)
}()
Consistently(stoppedServing1).ShouldNot(BeClosed())
Consistently(stoppedServing2).ShouldNot(BeClosed())
// now close first server, no errors should occur here
// and only the fake listener should be closed
Expect(server1.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing1).Should(BeClosed())
Expect(fake1.closed).To(Equal(int32(1)))
Expect(fake2.closed).To(Equal(int32(0)))
Expect(ln.listenerClosed).ToNot(BeTrue())
Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred())
// verify that new sessions are being initiated from the second server now
resp, err = client.Get("https://localhost:" + port + "/hello2")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err = io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal("Hello, World 2!\n"))
// close the other server - both the fake and the actual listeners must close now
Expect(server2.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing2).Should(BeClosed())
Expect(fake2.closed).To(Equal(int32(1)))
Expect(ln.listenerClosed).To(BeTrue())
})
})
}
})