remove the port from the hostname used for tls.Config.ServerName (#4046)

This commit is contained in:
Marten Seemann
2023-08-25 07:53:02 +07:00
committed by GitHub
parent f633dca488
commit d22854641a
2 changed files with 33 additions and 10 deletions

View File

@@ -159,7 +159,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
return t.dial(ctx, addr, "", tlsConf, conf, true) return t.dial(ctx, addr, "", tlsConf, conf, true)
} }
func (t *Transport) dial(ctx context.Context, addr net.Addr, hostname string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) { func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) {
if err := validateConfig(conf); err != nil { if err := validateConfig(conf); err != nil {
return nil, err return nil, err
} }
@@ -173,15 +173,7 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, hostname string, tl
} }
tlsConf = tlsConf.Clone() tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13 tlsConf.MinVersion = tls.VersionTLS13
// If no ServerName is set, infer the ServerName from the hostname we're connecting to. setTLSConfigServerName(tlsConf, addr, host)
if tlsConf.ServerName == "" {
if hostname == "" {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
hostname = udpAddr.IP.String()
}
}
tlsConf.ServerName = hostname
}
return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT)
} }
@@ -478,3 +470,22 @@ func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.A
return 0, nil, errors.New("closed") return 0, nil, errors.New("closed")
} }
} }
func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) {
// If no ServerName is set, infer the ServerName from the host we're connecting to.
if tlsConf.ServerName != "" {
return
}
if host == "" {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
tlsConf.ServerName = udpAddr.IP.String()
return
}
}
h, _, err := net.SplitHostPort(host)
if err != nil { // This happens if the host doesn't contain a port number.
tlsConf.ServerName = host
return
}
tlsConf.ServerName = h
}

View File

@@ -396,6 +396,18 @@ var _ = Describe("Transport", func() {
close(packetChan) close(packetChan)
tr.Close() tr.Close()
}) })
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234}
DescribeTable("setting the tls.Config.ServerName",
func(expected string, conf *tls.Config, addr net.Addr, host string) {
setTLSConfigServerName(conf, addr, host)
Expect(conf.ServerName).To(Equal(expected))
},
Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"),
Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"),
Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"),
Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""),
)
}) })
type mockSyscallConn struct { type mockSyscallConn struct {