forked from quic-go/quic-go
remove the port from the hostname used for tls.Config.ServerName (#4046)
This commit is contained in:
31
transport.go
31
transport.go
@@ -159,7 +159,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
|
|||||||
return t.dial(ctx, addr, "", tlsConf, conf, true)
|
return t.dial(ctx, addr, "", tlsConf, conf, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) dial(ctx context.Context, addr net.Addr, hostname string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) {
|
func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) {
|
||||||
if err := validateConfig(conf); err != nil {
|
if err := validateConfig(conf); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -173,15 +173,7 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, hostname string, tl
|
|||||||
}
|
}
|
||||||
tlsConf = tlsConf.Clone()
|
tlsConf = tlsConf.Clone()
|
||||||
tlsConf.MinVersion = tls.VersionTLS13
|
tlsConf.MinVersion = tls.VersionTLS13
|
||||||
// If no ServerName is set, infer the ServerName from the hostname we're connecting to.
|
setTLSConfigServerName(tlsConf, addr, host)
|
||||||
if tlsConf.ServerName == "" {
|
|
||||||
if hostname == "" {
|
|
||||||
if udpAddr, ok := addr.(*net.UDPAddr); ok {
|
|
||||||
hostname = udpAddr.IP.String()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tlsConf.ServerName = hostname
|
|
||||||
}
|
|
||||||
return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT)
|
return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -478,3 +470,22 @@ func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.A
|
|||||||
return 0, nil, errors.New("closed")
|
return 0, nil, errors.New("closed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) {
|
||||||
|
// If no ServerName is set, infer the ServerName from the host we're connecting to.
|
||||||
|
if tlsConf.ServerName != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if host == "" {
|
||||||
|
if udpAddr, ok := addr.(*net.UDPAddr); ok {
|
||||||
|
tlsConf.ServerName = udpAddr.IP.String()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h, _, err := net.SplitHostPort(host)
|
||||||
|
if err != nil { // This happens if the host doesn't contain a port number.
|
||||||
|
tlsConf.ServerName = host
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tlsConf.ServerName = h
|
||||||
|
}
|
||||||
|
|||||||
@@ -396,6 +396,18 @@ var _ = Describe("Transport", func() {
|
|||||||
close(packetChan)
|
close(packetChan)
|
||||||
tr.Close()
|
tr.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234}
|
||||||
|
DescribeTable("setting the tls.Config.ServerName",
|
||||||
|
func(expected string, conf *tls.Config, addr net.Addr, host string) {
|
||||||
|
setTLSConfigServerName(conf, addr, host)
|
||||||
|
Expect(conf.ServerName).To(Equal(expected))
|
||||||
|
},
|
||||||
|
Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"),
|
||||||
|
Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"),
|
||||||
|
Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"),
|
||||||
|
Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""),
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
type mockSyscallConn struct {
|
type mockSyscallConn struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user