automatically set the tls.Config.ServerName if unset (#4032)

This commit is contained in:
Marten Seemann
2023-08-16 20:54:42 +07:00
committed by GitHub
parent 4f696569a2
commit ca3842d6c8
2 changed files with 20 additions and 26 deletions

View File

@@ -55,11 +55,11 @@ func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Confi
if err != nil {
return nil, err
}
dl, err := setupTransport(udpConn, tlsConf, true)
tr, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
return dl.Dial(ctx, udpAddr, tlsConf, conf)
return tr.dial(ctx, udpAddr, addr, tlsConf, conf, false)
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
@@ -73,13 +73,13 @@ func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *
if err != nil {
return nil, err
}
dl, err := setupTransport(udpConn, tlsConf, true)
tr, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf)
conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true)
if err != nil {
dl.Close()
tr.Close()
return nil, err
}
return conn, nil
@@ -163,12 +163,6 @@ func dial(
}
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
if tlsConf == nil {
tlsConf = &tls.Config{}
} else {
tlsConf = tlsConf.Clone()
}
srcConnID, err := connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err

View File

@@ -148,24 +148,15 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen
// Dial dials a new connection to a remote host (not using 0-RTT).
func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateConfig(conf)
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
return t.dial(ctx, addr, "", tlsConf, conf, false)
}
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
return t.dial(ctx, addr, "", tlsConf, conf, true)
}
func (t *Transport) dial(ctx context.Context, addr net.Addr, hostname string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) {
if err := validateConfig(conf); err != nil {
return nil, err
}
@@ -179,7 +170,16 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
}
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
// If no ServerName is set, infer the ServerName from the hostname we're connecting to.
if tlsConf.ServerName == "" {
if hostname == "" {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
hostname = udpAddr.IP.String()
}
}
tlsConf.ServerName = hostname
}
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT)
}
func (t *Transport) init(allowZeroLengthConnIDs bool) error {