diff --git a/client.go b/client.go index 71e72dce..0a26fba9 100644 --- a/client.go +++ b/client.go @@ -25,8 +25,6 @@ type client struct { // If it is started with Dial, we take a packet conn as a parameter. createdPacketConn bool - hostname string - packetHandlers packetHandlerManager token []byte @@ -159,13 +157,12 @@ func newClient( closeCallback func(protocol.ConnectionID), createdPacketConn bool, ) (*client, error) { - var hostname string - if tlsConf != nil { - hostname = tlsConf.ServerName + if tlsConf == nil { + tlsConf = &tls.Config{} } - if hostname == "" { + if tlsConf.ServerName == "" { var err error - hostname, _, err = net.SplitHostPort(host) + tlsConf.ServerName, _, err = net.SplitHostPort(host) if err != nil { return nil, err } @@ -186,7 +183,6 @@ func newClient( c := &client{ conn: &conn{pconn: pconn, currentAddr: remoteAddr}, createdPacketConn: createdPacketConn, - hostname: hostname, tlsConf: tlsConf, config: config, version: config.Versions[0], @@ -286,7 +282,7 @@ func (c *client) generateConnectionIDs() error { } func (c *client) dial(ctx context.Context) error { - c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) var err error if c.version.UsesTLS() { @@ -324,7 +320,6 @@ func (c *client) dialTLS(ctx context.Context) error { return err } mintConf.ExtensionHandler = extHandler - mintConf.ServerName = c.hostname c.mintConf = mintConf if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { @@ -513,7 +508,6 @@ func (c *client) createNewGQUICSession() error { sess, err := newClientSession( c.conn, runner, - c.hostname, c.version, c.destConnID, c.srcConnID, diff --git a/client_test.go b/client_test.go index ef39eec2..287138c1 100644 --- a/client_test.go +++ b/client_test.go @@ -33,7 +33,7 @@ var _ = Describe("Client", func() { supportedVersionsWithoutGQUIC44 []protocol.VersionNumber - originalClientSessConstructor func(connection, sessionRunner, string, protocol.VersionNumber, protocol.ConnectionID, protocol.ConnectionID, *tls.Config, *Config, protocol.VersionNumber, []protocol.VersionNumber, utils.Logger) (quicSession, error) + originalClientSessConstructor func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, protocol.ConnectionID, *tls.Config, *Config, protocol.VersionNumber, []protocol.VersionNumber, utils.Logger) (quicSession, error) ) // generate a packet sent by the server that accepts the QUIC version suggested by the client @@ -132,7 +132,6 @@ var _ = Describe("Client", func() { newClientSession = func( conn connection, _ sessionRunner, - _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -161,17 +160,16 @@ var _ = Describe("Client", func() { newClientSession = func( _ connection, _ sessionRunner, - h string, _ protocol.VersionNumber, _ protocol.ConnectionID, _ protocol.ConnectionID, - _ *tls.Config, + tlsConf *tls.Config, _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, _ utils.Logger, ) (quicSession, error) { - hostnameChan <- h + hostnameChan <- tlsConf.ServerName sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run() return sess, nil @@ -190,7 +188,6 @@ var _ = Describe("Client", func() { newClientSession = func( _ connection, runner sessionRunner, - _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -238,7 +235,6 @@ var _ = Describe("Client", func() { newClientSession = func( conn connection, _ sessionRunner, - _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -277,7 +273,6 @@ var _ = Describe("Client", func() { newClientSession = func( conn connection, _ sessionRunner, - _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -321,7 +316,6 @@ var _ = Describe("Client", func() { newClientSession = func( conn connection, runnerP sessionRunner, - _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -361,7 +355,6 @@ var _ = Describe("Client", func() { newClientSession = func( connP connection, _ sessionRunner, - _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -491,7 +484,6 @@ var _ = Describe("Client", func() { newClientSession = func( _ connection, _ sessionRunner, - _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -683,7 +675,6 @@ var _ = Describe("Client", func() { newClientSession = func( conn connection, _ sessionRunner, - _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -744,7 +735,6 @@ var _ = Describe("Client", func() { newClientSession = func( _ connection, _ sessionRunner, - _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -757,6 +747,7 @@ var _ = Describe("Client", func() { return <-sessionChan, nil } + cl.tlsConf = &tls.Config{} cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2}} dialed := make(chan struct{}) go func() { @@ -791,7 +782,6 @@ var _ = Describe("Client", func() { newClientSession = func( _ connection, _ sessionRunner, - _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -804,6 +794,7 @@ var _ = Describe("Client", func() { return <-sessionChan, nil } + cl.tlsConf = &tls.Config{} cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2, version3}} dialed := make(chan struct{}) go func() { @@ -934,18 +925,17 @@ var _ = Describe("Client", func() { newClientSession = func( connP connection, _ sessionRunner, - hostnameP string, versionP protocol.VersionNumber, connIDP protocol.ConnectionID, _ protocol.ConnectionID, - _ *tls.Config, + tlsConf *tls.Config, configP *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, _ utils.Logger, ) (quicSession, error) { cconn = connP - hostname = hostnameP + hostname = tlsConf.ServerName version = versionP conf = configP connID = connIDP diff --git a/internal/handshake/crypto_setup_client.go b/internal/handshake/crypto_setup_client.go index 6687b834..655449c7 100644 --- a/internal/handshake/crypto_setup_client.go +++ b/internal/handshake/crypto_setup_client.go @@ -69,10 +69,9 @@ var ( // NewCryptoSetupClient creates a new CryptoSetup instance for a client func NewCryptoSetupClient( cryptoStream io.ReadWriter, - hostname string, connID protocol.ConnectionID, version protocol.VersionNumber, - tlsConfig *tls.Config, + tlsConf *tls.Config, params *TransportParameters, paramsChan chan<- TransportParameters, handshakeEvent chan<- struct{}, @@ -87,10 +86,10 @@ func NewCryptoSetupClient( divNonceChan := make(chan struct{}) cs := &cryptoSetupClient{ cryptoStream: cryptoStream, - hostname: hostname, + hostname: tlsConf.ServerName, connID: connID, version: version, - certManager: crypto.NewCertManager(tlsConfig), + certManager: crypto.NewCertManager(tlsConf), params: params, keyDerivation: crypto.DeriveQuicCryptoAESKeys, nullAEAD: nullAEAD, diff --git a/internal/handshake/crypto_setup_client_test.go b/internal/handshake/crypto_setup_client_test.go index 92633fe5..0da07224 100644 --- a/internal/handshake/crypto_setup_client_test.go +++ b/internal/handshake/crypto_setup_client_test.go @@ -2,6 +2,7 @@ package handshake import ( "bytes" + "crypto/tls" "crypto/x509" "encoding/binary" "errors" @@ -121,10 +122,9 @@ var _ = Describe("Client Crypto Setup", func() { handshakeEvent = make(chan struct{}, 2) csInt, err := NewCryptoSetupClient( stream, - "hostname", protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, version, - nil, + &tls.Config{ServerName: "hostname"}, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}, paramsChan, handshakeEvent, diff --git a/session.go b/session.go index d69821fd..cded206b 100644 --- a/session.go +++ b/session.go @@ -228,7 +228,6 @@ func newSession( var newClientSession = func( conn connection, sessionRunner sessionRunner, - hostname string, v protocol.VersionNumber, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, @@ -263,7 +262,6 @@ var newClientSession = func( } cs, err := newCryptoSetupClient( s.cryptoStream, - hostname, destConnID, s.version, tlsConf, diff --git a/session_test.go b/session_test.go index d0ecf53e..c44bb312 100644 --- a/session_test.go +++ b/session_test.go @@ -1567,7 +1567,6 @@ var _ = Describe("Client Session", func() { cryptoSetup = &mockCryptoSetup{} newCryptoSetupClient = func( _ io.ReadWriter, - _ string, _ protocol.ConnectionID, _ protocol.VersionNumber, _ *tls.Config, @@ -1587,7 +1586,6 @@ var _ = Describe("Client Session", func() { sessP, err := newClientSession( mconn, sessionRunner, - "hostname", protocol.Version39, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},