From d22854641ace1792ca27d7d911a6d62c0075e2c2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 25 Aug 2023 07:53:02 +0700 Subject: [PATCH] remove the port from the hostname used for tls.Config.ServerName (#4046) --- transport.go | 31 +++++++++++++++++++++---------- transport_test.go | 12 ++++++++++++ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/transport.go b/transport.go index fe6dc1fc..d8da9b1a 100644 --- a/transport.go +++ b/transport.go @@ -159,7 +159,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C return t.dial(ctx, addr, "", tlsConf, conf, true) } -func (t *Transport) dial(ctx context.Context, addr net.Addr, hostname string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) { +func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) { if err := validateConfig(conf); err != nil { return nil, err } @@ -173,15 +173,7 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, hostname string, tl } tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - // If no ServerName is set, infer the ServerName from the hostname we're connecting to. - if tlsConf.ServerName == "" { - if hostname == "" { - if udpAddr, ok := addr.(*net.UDPAddr); ok { - hostname = udpAddr.IP.String() - } - } - tlsConf.ServerName = hostname - } + setTLSConfigServerName(tlsConf, addr, host) return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) } @@ -478,3 +470,22 @@ func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.A return 0, nil, errors.New("closed") } } + +func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) { + // If no ServerName is set, infer the ServerName from the host we're connecting to. + if tlsConf.ServerName != "" { + return + } + if host == "" { + if udpAddr, ok := addr.(*net.UDPAddr); ok { + tlsConf.ServerName = udpAddr.IP.String() + return + } + } + h, _, err := net.SplitHostPort(host) + if err != nil { // This happens if the host doesn't contain a port number. + tlsConf.ServerName = host + return + } + tlsConf.ServerName = h +} diff --git a/transport_test.go b/transport_test.go index 93e1d32a..cf38e325 100644 --- a/transport_test.go +++ b/transport_test.go @@ -396,6 +396,18 @@ var _ = Describe("Transport", func() { close(packetChan) tr.Close() }) + + remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234} + DescribeTable("setting the tls.Config.ServerName", + func(expected string, conf *tls.Config, addr net.Addr, host string) { + setTLSConfigServerName(conf, addr, host) + Expect(conf.ServerName).To(Equal(expected)) + }, + Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"), + Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"), + Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"), + Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""), + ) }) type mockSyscallConn struct {