From afa4615c4d01d47e3372d88c489fdd451156c738 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 20 May 2019 17:55:59 +0100 Subject: [PATCH] make sure the TLS config contains the H3 ALPN token in server and client --- http3/client.go | 4 +++- http3/client_test.go | 13 +++++++++---- http3/server.go | 19 +++++++++++++++++++ http3/server_test.go | 25 +++++++++++++++++++++++-- 4 files changed, 54 insertions(+), 7 deletions(-) diff --git a/http3/client.go b/http3/client.go index 9f0e88311..f9b66adfd 100644 --- a/http3/client.go +++ b/http3/client.go @@ -55,7 +55,9 @@ func newClient( if tlsConf == nil { tlsConf = &tls.Config{} } - tlsConf.NextProtos = []string{"h3-19"} + if !strSliceContains(tlsConf.NextProtos, nextProtoH3) { + tlsConf.NextProtos = append(tlsConf.NextProtos, nextProtoH3) + } if quicConfig == nil { quicConfig = defaultQuicConfig } diff --git a/http3/client_test.go b/http3/client_test.go index 62f3c79ae..c67a3d2b9 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -42,11 +42,12 @@ var _ = Describe("Client", func() { dialAddr = origDialAddr }) - It("uses the default QUIC config if none is give", func() { + It("uses the default QUIC and TLS config if none is give", func() { client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) var dialAddrCalled bool - dialAddr = func(_ string, _ *tls.Config, quicConf *quic.Config) (quic.Session, error) { + dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.Session, error) { Expect(quicConf).To(Equal(defaultQuicConfig)) + Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3})) dialAddrCalled = true return nil, errors.New("test done") } @@ -69,7 +70,10 @@ var _ = Describe("Client", func() { }) It("uses the TLS config and QUIC config", func() { - tlsConf := &tls.Config{ServerName: "foo.bar"} + tlsConf := &tls.Config{ + ServerName: "foo.bar", + NextProtos: []string{"proto foo", "proto bar"}, + } quicConf := &quic.Config{IdleTimeout: time.Nanosecond} client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) var dialAddrCalled bool @@ -79,7 +83,8 @@ var _ = Describe("Client", func() { quicConfP *quic.Config, ) (quic.Session, error) { Expect(hostname).To(Equal("localhost:1337")) - Expect(tlsConfP).To(Equal(tlsConf)) + Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) + Expect(tlsConfP.NextProtos).To(Equal([]string{"proto foo", "proto bar", nextProtoH3})) Expect(quicConfP.IdleTimeout).To(Equal(quicConf.IdleTimeout)) dialAddrCalled = true return nil, errors.New("test done") diff --git a/http3/server.go b/http3/server.go index 9e9711f6b..558d9d9ae 100644 --- a/http3/server.go +++ b/http3/server.go @@ -25,6 +25,8 @@ var ( quicListenAddr = quic.ListenAddr ) +const nextProtoH3 = "h3-20" + // Server is a HTTP2 server listening for QUIC connections. type Server struct { *http.Server @@ -88,6 +90,14 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { return errors.New("ListenAndServe may only be called once") } + if tlsConfig == nil { + tlsConfig = &tls.Config{} + } + + if !strSliceContains(tlsConfig.NextProtos, nextProtoH3) { + tlsConfig.NextProtos = append(tlsConfig.NextProtos, nextProtoH3) + } + var ln quic.Listener var err error if conn == nil { @@ -353,3 +363,12 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error return err } } + +func strSliceContains(ss []string, s string) bool { + for _, v := range ss { + if v == s { + return true + } + } + return false +} diff --git a/http3/server_test.go b/http3/server_test.go index 012817b61..314201e21 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -349,10 +349,10 @@ var _ = Describe("Server", func() { Expect(s.Close()).To(Succeed()) }) - It("uses the quic.Config to start the quic server", func() { + It("uses the quic.Config to start the QUIC server", func() { conf := &quic.Config{HandshakeTimeout: time.Nanosecond} var receivedConf *quic.Config - quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) { + quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (quic.Listener, error) { receivedConf = config return nil, errors.New("listen err") } @@ -360,6 +360,27 @@ var _ = Describe("Server", func() { Expect(s.ListenAndServe()).To(HaveOccurred()) Expect(receivedConf).To(Equal(conf)) }) + + It("adds the ALPN token to the tls.Config", func() { + var receivedConf *tls.Config + quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.Listener, error) { + receivedConf = tlsConf + return nil, errors.New("listen err") + } + s.TLSConfig = &tls.Config{NextProtos: []string{"foo", "bar"}} + Expect(s.ListenAndServe()).To(HaveOccurred()) + Expect(receivedConf.NextProtos).To(Equal([]string{"foo", "bar", nextProtoH3})) + }) + + It("uses the ALPN token if no tls.Config is given", func() { + var receivedConf *tls.Config + quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.Listener, error) { + receivedConf = tlsConf + return nil, errors.New("listen err") + } + Expect(s.ListenAndServe()).To(HaveOccurred()) + Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3})) + }) }) Context("ListenAndServeTLS", func() {