diff --git a/client.go b/client.go index f0b4c0608..e5e6abe1c 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "crypto/tls" "errors" "net" "strings" @@ -24,6 +25,7 @@ type Client struct { versionNegotiated bool closed uint32 // atomic bool + tlsConfig *tls.Config cryptoChangeCallback CryptoChangeCallback versionNegotiateCallback VersionNegotiateCallback @@ -40,7 +42,7 @@ var ( ) // NewClient makes a new client -func NewClient(host string, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) { +func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) { udpAddr, err := net.ResolveUDPAddr("udp", host) if err != nil { return nil, err @@ -67,6 +69,7 @@ func NewClient(host string, cryptoChangeCallback CryptoChangeCallback, versionNe hostname: hostname, version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default connectionID: connectionID, + tlsConfig: tlsConfig, cryptoChangeCallback: cryptoChangeCallback, versionNegotiateCallback: versionNegotiateCallback, } @@ -200,7 +203,7 @@ func (c *Client) handlePacket(packet []byte) error { func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { var err error - c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.streamCallback, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions) + c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.tlsConfig, c.streamCallback, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions) if err != nil { return err } diff --git a/client_test.go b/client_test.go index e088ff8b7..4a38f187c 100644 --- a/client_test.go +++ b/client_test.go @@ -48,7 +48,7 @@ var _ = Describe("Client", func() { It("creates a new client", func() { var err error - client, err = NewClient("quic.clemente.io:1337", nil, nil) + client, err = NewClient("quic.clemente.io:1337", nil, nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(client.hostname).To(Equal("quic.clemente.io")) Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*Session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil()) diff --git a/h2quic/client.go b/h2quic/client.go index dd7525f0e..ca1802143 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -1,6 +1,7 @@ package h2quic import ( + "crypto/tls" "errors" "fmt" "io" @@ -47,7 +48,7 @@ type Client struct { var _ h2quicClient = &Client{} // NewClient creates a new client -func NewClient(t *QuicRoundTripper, hostname string) (*Client, error) { +func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) { c := &Client{ t: t, hostname: authorityAddr("https", hostname), @@ -57,7 +58,7 @@ func NewClient(t *QuicRoundTripper, hostname string) (*Client, error) { c.cryptoChangedCond = sync.Cond{L: &c.mutex} var err error - c.client, err = quic.NewClient(c.hostname, c.cryptoChangeCallback, c.versionNegotiateCallback) + c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback) if err != nil { return nil, err } diff --git a/h2quic/client_test.go b/h2quic/client_test.go index 52ecdfe60..430341add 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -54,7 +54,7 @@ var _ = Describe("Client", func() { var err error quicTransport = &QuicRoundTripper{} hostname := "quic.clemente.io:1337" - client, err = NewClient(quicTransport, hostname) + client, err = NewClient(quicTransport, nil, hostname) Expect(err).ToNot(HaveOccurred()) Expect(client.hostname).To(Equal(hostname)) qClient = newMockQuicClient() @@ -68,7 +68,7 @@ var _ = Describe("Client", func() { It("adds the port to the hostname, if none is given", func() { var err error - client, err = NewClient(quicTransport, "quic.clemente.io") + client, err = NewClient(quicTransport, nil, "quic.clemente.io") Expect(err).ToNot(HaveOccurred()) Expect(client.hostname).To(Equal("quic.clemente.io:443")) }) @@ -192,7 +192,7 @@ var _ = Describe("Client", func() { It("adds the port for request URLs without one", func(done Done) { var err error - client, err = NewClient(quicTransport, "quic.clemente.io") + client, err = NewClient(quicTransport, nil, "quic.clemente.io") Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index 99bd6e0f6..85faf8e2a 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -1,6 +1,7 @@ package h2quic import ( + "crypto/tls" "errors" "fmt" "net/http" @@ -28,6 +29,10 @@ type QuicRoundTripper struct { // uncompressed. DisableCompression bool + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + clients map[string]h2quicClient } @@ -88,7 +93,7 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { client, ok := r.clients[hostname] if !ok { var err error - client, err = NewClient(r, hostname) + client, err = NewClient(r, r.TLSClientConfig, hostname) if err != nil { return nil, err } diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index da98c4781..e0350f50d 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -57,7 +57,8 @@ var _ = Describe("RoundTripper", func() { It("reuses existing clients", func() { rt.clients = make(map[string]h2quicClient) rt.clients["www.example.org:443"] = &mockQuicRoundTripper{} - rsp, _ := rt.RoundTrip(req1) + rsp, err := rt.RoundTrip(req1) + Expect(err).ToNot(HaveOccurred()) Expect(rsp.Request).To(Equal(req1)) Expect(rt.clients).To(HaveLen(1)) }) diff --git a/handshake/crypto_setup_client.go b/handshake/crypto_setup_client.go index 240ba0a0f..0fc46e506 100644 --- a/handshake/crypto_setup_client.go +++ b/handshake/crypto_setup_client.go @@ -3,6 +3,7 @@ package handshake import ( "bytes" "crypto/rand" + "crypto/tls" "encoding/binary" "errors" "fmt" @@ -64,6 +65,7 @@ func NewCryptoSetupClient( connID protocol.ConnectionID, version protocol.VersionNumber, cryptoStream utils.Stream, + tlsConfig *tls.Config, connectionParameters ConnectionParametersManager, aeadChanged chan struct{}, negotiatedVersions []protocol.VersionNumber, @@ -73,7 +75,7 @@ func NewCryptoSetupClient( connID: connID, version: version, cryptoStream: cryptoStream, - certManager: crypto.NewCertManager(nil), + certManager: crypto.NewCertManager(tlsConfig), connectionParameters: connectionParameters, keyDerivation: crypto.DeriveKeysAESGCM, aeadChanged: aeadChanged, diff --git a/handshake/crypto_setup_client_test.go b/handshake/crypto_setup_client_test.go index 168b85f82..a1bcf62f1 100644 --- a/handshake/crypto_setup_client_test.go +++ b/handshake/crypto_setup_client_test.go @@ -122,7 +122,7 @@ var _ = Describe("Crypto setup", func() { stream = &mockStream{} certManager = &mockCertManager{} version := protocol.Version36 - csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, NewConnectionParamatersManager(protocol.PerspectiveClient, version), make(chan struct{}, 1), nil) + csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, nil, NewConnectionParamatersManager(protocol.PerspectiveClient, version), make(chan struct{}, 1), nil) Expect(err).ToNot(HaveOccurred()) cs = csInt.(*cryptoSetupClient) cs.certManager = certManager diff --git a/session.go b/session.go index 346b2d3d0..ee67fe859 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package quic import ( + "crypto/tls" "errors" "fmt" "net" @@ -127,7 +128,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol return session, err } -func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) { +func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) { session := &Session{ conn: &udpConn{conn: conn, currentAddr: addr}, connectionID: connectionID, @@ -145,7 +146,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p cryptoStream, _ := session.OpenStream(1) var err error - session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, session.connectionParameters, session.aeadChanged, negotiatedVersions) + session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, session.connectionParameters, session.aeadChanged, negotiatedVersions) if err != nil { return nil, err } diff --git a/session_test.go b/session_test.go index 2abe94adf..c5375c8ba 100644 --- a/session_test.go +++ b/session_test.go @@ -155,6 +155,7 @@ var _ = Describe("Session", func() { "hostname", protocol.Version35, 0, + nil, func(*Session, utils.Stream) { streamCallbackCalled = true }, func(protocol.ConnectionID) { closeCallbackCalled = true }, func(isForwardSecure bool) {},