From fd30146de549b8758548872fd55fb485b6228689 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 24 Aug 2019 10:51:01 +0700 Subject: [PATCH 1/2] replace the list of ALPN instead of appending to it in http3 --- http3/client.go | 7 ++++--- http3/client_test.go | 6 ++++-- http3/server.go | 27 +++++++++------------------ http3/server_test.go | 9 ++++++--- 4 files changed, 23 insertions(+), 26 deletions(-) diff --git a/http3/client.go b/http3/client.go index 47bf8b3fd..b4ae76e9d 100644 --- a/http3/client.go +++ b/http3/client.go @@ -57,10 +57,11 @@ func newClient( ) *client { if tlsConf == nil { tlsConf = &tls.Config{} + } else { + tlsConf = tlsConf.Clone() } - if !strSliceContains(tlsConf.NextProtos, nextProtoH3) { - tlsConf.NextProtos = append(tlsConf.NextProtos, nextProtoH3) - } + // Replace existing ALPNs by H3 + tlsConf.NextProtos = []string{nextProtoH3} if quicConfig == nil { quicConfig = defaultQuicConfig } diff --git a/http3/client_test.go b/http3/client_test.go index 1b65d2530..e6709c107 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -85,13 +85,15 @@ var _ = Describe("Client", func() { ) (quic.Session, error) { Expect(hostname).To(Equal("localhost:1337")) Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) - Expect(tlsConfP.NextProtos).To(Equal([]string{"proto foo", "proto bar", nextProtoH3})) + Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3})) Expect(quicConfP.IdleTimeout).To(Equal(quicConf.IdleTimeout)) dialAddrCalled = true return nil, errors.New("test done") } client.RoundTrip(req) Expect(dialAddrCalled).To(BeTrue()) + // make sure the original tls.Config was not modified + Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) }) It("uses the custom dialer, if provided", func() { @@ -102,7 +104,7 @@ var _ = Describe("Client", func() { dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.Session, error) { Expect(network).To(Equal("udp")) Expect(address).To(Equal("localhost:1337")) - Expect(tlsConfP).To(Equal(tlsConf)) + Expect(tlsConfP.ServerName).To(Equal("foo.bar")) Expect(quicConfP.IdleTimeout).To(Equal(quicConf.IdleTimeout)) dialerCalled = true return nil, testErr diff --git a/http3/server.go b/http3/server.go index e505a443c..3f93d1dfc 100644 --- a/http3/server.go +++ b/http3/server.go @@ -77,7 +77,7 @@ func (s *Server) Serve(conn net.PacketConn) error { return s.serveImpl(s.TLSConfig, conn) } -func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { +func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { if s.Server == nil { return errors.New("use of http3.Server without http.Server") } @@ -92,20 +92,20 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { return errors.New("ListenAndServe may only be called once") } - if tlsConfig == nil { - tlsConfig = &tls.Config{} - } - - if !strSliceContains(tlsConfig.NextProtos, nextProtoH3) { - tlsConfig.NextProtos = append(tlsConfig.NextProtos, nextProtoH3) + if tlsConf == nil { + tlsConf = &tls.Config{} + } else { + tlsConf = tlsConf.Clone() } + // Replace existing ALPNs by H3 + tlsConf.NextProtos = []string{nextProtoH3} var ln quic.Listener var err error if conn == nil { - ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig) + ln, err = quicListenAddr(s.Addr, tlsConf, s.QuicConfig) } else { - ln, err = quicListen(conn, tlsConfig, s.QuicConfig) + ln, err = quicListen(conn, tlsConf, s.QuicConfig) } if err != nil { s.listenerMutex.Unlock() @@ -385,12 +385,3 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error return err } } - -func strSliceContains(ss []string, s string) bool { - for _, v := range ss { - if v == s { - return true - } - } - return false -} diff --git a/http3/server_test.go b/http3/server_test.go index 254095531..92b92aa3a 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -383,15 +383,18 @@ var _ = Describe("Server", func() { Expect(receivedConf).To(Equal(conf)) }) - It("adds the ALPN token to the tls.Config", func() { + It("replaces the ALPN token to the tls.Config", func() { + tlsConf := &tls.Config{NextProtos: []string{"foo", "bar"}} var receivedConf *tls.Config quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.Listener, error) { receivedConf = tlsConf return nil, errors.New("listen err") } - s.TLSConfig = &tls.Config{NextProtos: []string{"foo", "bar"}} + s.TLSConfig = tlsConf Expect(s.ListenAndServe()).To(HaveOccurred()) - Expect(receivedConf.NextProtos).To(Equal([]string{"foo", "bar", nextProtoH3})) + Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3})) + // make sure the original tls.Config was not modified + Expect(tlsConf.NextProtos).To(Equal([]string{"foo", "bar"})) }) It("uses the ALPN token if no tls.Config is given", func() { From 328dd2c8485f2397d5bdfa50fbaf493b647fab5c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 24 Aug 2019 11:08:06 +0700 Subject: [PATCH 2/2] set the H3 ALPN on tls.Configs returned by GetConfigForClient --- http3/server.go | 11 +++++++++++ http3/server_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/http3/server.go b/http3/server.go index 3f93d1dfc..c65e46316 100644 --- a/http3/server.go +++ b/http3/server.go @@ -99,6 +99,17 @@ func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { } // Replace existing ALPNs by H3 tlsConf.NextProtos = []string{nextProtoH3} + if tlsConf.GetConfigForClient != nil { + getConfigForClient := tlsConf.GetConfigForClient + tlsConf.GetConfigForClient = func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + conf, err := getConfigForClient(ch) + if err != nil || conf == nil { + return conf, err + } + conf.NextProtos = []string{nextProtoH3} + return conf, nil + } + } var ln quic.Listener var err error diff --git a/http3/server_test.go b/http3/server_test.go index 92b92aa3a..8ef958f59 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -406,6 +406,30 @@ var _ = Describe("Server", func() { Expect(s.ListenAndServe()).To(HaveOccurred()) Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3})) }) + + It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { + tlsConf := &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + return &tls.Config{NextProtos: []string{"foo", "bar"}}, nil + }, + } + + var receivedConf *tls.Config + quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.Listener, error) { + receivedConf = conf + return nil, errors.New("listen err") + } + s.TLSConfig = tlsConf + Expect(s.ListenAndServe()).To(HaveOccurred()) + // check that the config used by QUIC uses the h3 ALPN + conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf.NextProtos).To(Equal([]string{nextProtoH3})) + // check that the original config was not modified + conf, err = tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) + }) }) Context("ListenAndServeTLS", func() {