From 328dd2c8485f2397d5bdfa50fbaf493b647fab5c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 24 Aug 2019 11:08:06 +0700 Subject: [PATCH] set the H3 ALPN on tls.Configs returned by GetConfigForClient --- http3/server.go | 11 +++++++++++ http3/server_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/http3/server.go b/http3/server.go index 3f93d1dfc..c65e46316 100644 --- a/http3/server.go +++ b/http3/server.go @@ -99,6 +99,17 @@ func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { } // Replace existing ALPNs by H3 tlsConf.NextProtos = []string{nextProtoH3} + if tlsConf.GetConfigForClient != nil { + getConfigForClient := tlsConf.GetConfigForClient + tlsConf.GetConfigForClient = func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + conf, err := getConfigForClient(ch) + if err != nil || conf == nil { + return conf, err + } + conf.NextProtos = []string{nextProtoH3} + return conf, nil + } + } var ln quic.Listener var err error diff --git a/http3/server_test.go b/http3/server_test.go index 92b92aa3a..8ef958f59 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -406,6 +406,30 @@ var _ = Describe("Server", func() { Expect(s.ListenAndServe()).To(HaveOccurred()) Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3})) }) + + It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { + tlsConf := &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + return &tls.Config{NextProtos: []string{"foo", "bar"}}, nil + }, + } + + var receivedConf *tls.Config + quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.Listener, error) { + receivedConf = conf + return nil, errors.New("listen err") + } + s.TLSConfig = tlsConf + Expect(s.ListenAndServe()).To(HaveOccurred()) + // check that the config used by QUIC uses the h3 ALPN + conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf.NextProtos).To(Equal([]string{nextProtoH3})) + // check that the original config was not modified + conf, err = tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) + }) }) Context("ListenAndServeTLS", func() {