From 66d50b4289e53aba465b2953fb248d4852ecbb0d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 8 Apr 2020 15:56:30 +0700 Subject: [PATCH 1/2] fix conversion of qtls.ClientHelloInfo in GetConfigForClient --- internal/handshake/qtls.go | 6 +++++- internal/handshake/unsafe.go | 6 +++++- internal/handshake/unsafe_test.go | 6 ++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go index 85fea0cdb..33c98904a 100644 --- a/internal/handshake/qtls.go +++ b/internal/handshake/qtls.go @@ -62,7 +62,11 @@ func tlsConfigToQtlsConfig( var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) if c.GetConfigForClient != nil { getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) { - tlsConf, err := c.GetConfigForClient((*tls.ClientHelloInfo)(unsafe.Pointer(ch))) + var chi *tls.ClientHelloInfo + if ch != nil { + chi = toTLSClientHelloInfo(ch) + } + tlsConf, err := c.GetConfigForClient(chi) if err != nil { return nil, err } diff --git a/internal/handshake/unsafe.go b/internal/handshake/unsafe.go index e44afb696..6aa5f35e5 100644 --- a/internal/handshake/unsafe.go +++ b/internal/handshake/unsafe.go @@ -82,6 +82,10 @@ type qtlsClientHelloInfo struct { func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo { qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi)) + var config *tls.Config + if qtlsCHI.config != nil { + config = qtlsConfigToTLSConfig((*qtls.Config)(unsafe.Pointer(qtlsCHI.config))) + } return (*tls.ClientHelloInfo)(unsafe.Pointer(&clientHelloInfo{ CipherSuites: chi.CipherSuites, ServerName: chi.ServerName, @@ -91,6 +95,6 @@ func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo { SupportedProtos: chi.SupportedProtos, SupportedVersions: chi.SupportedVersions, Conn: chi.Conn, - config: qtlsConfigToTLSConfig((*qtls.Config)(unsafe.Pointer(qtlsCHI.config))), + config: config, })) } diff --git a/internal/handshake/unsafe_test.go b/internal/handshake/unsafe_test.go index 2b1a2df4f..772e99a6b 100644 --- a/internal/handshake/unsafe_test.go +++ b/internal/handshake/unsafe_test.go @@ -95,4 +95,10 @@ var _ = Describe("Unsafe checks", func() { Expect(c.config.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12)) Expect(c.config.CurvePreferences).To(Equal([]tls.CurveID{19, 20, 21})) }) + + It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() { + chi := &qtlsClientHelloInfo{CipherSuites: []uint16{13, 37}} + tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi))) + Expect(tlsCHI.CipherSuites).To(Equal([]uint16{13, 37})) + }) }) From 8fd2674ce41cb225d211d04faed450d134d95934 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 8 Apr 2020 16:19:45 +0700 Subject: [PATCH 2/2] fix conversion of qtls.ClientHelloInfo in GetCertificate --- internal/handshake/qtls.go | 21 ++++++++++++----- internal/handshake/qtls_test.go | 41 +++++++++++++++++++++++++++++++++ internal/handshake/unsafe.go | 3 +++ 3 files changed, 59 insertions(+), 6 deletions(-) diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go index 33c98904a..98b3f4102 100644 --- a/internal/handshake/qtls.go +++ b/internal/handshake/qtls.go @@ -62,11 +62,7 @@ func tlsConfigToQtlsConfig( var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) if c.GetConfigForClient != nil { getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) { - var chi *tls.ClientHelloInfo - if ch != nil { - chi = toTLSClientHelloInfo(ch) - } - tlsConf, err := c.GetConfigForClient(chi) + tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch)) if err != nil { return nil, err } @@ -76,6 +72,19 @@ func tlsConfigToQtlsConfig( return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, rttStats, getDataForSessionState, setDataFromSessionState, accept0RTT, rejected0RTT, enable0RTT), nil } } + var getCertificate func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) + if c.GetCertificate != nil { + getCertificate = func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) { + cert, err := c.GetCertificate(toTLSClientHelloInfo(ch)) + if err != nil { + return nil, err + } + if cert == nil { + return nil, nil + } + return (*qtls.Certificate)(unsafe.Pointer(cert)), nil + } + } var csc qtls.ClientSessionCache if c.ClientSessionCache != nil { csc = newClientSessionCache(c.ClientSessionCache, rttStats, getDataForSessionState, setDataFromSessionState) @@ -87,7 +96,7 @@ func tlsConfigToQtlsConfig( // NameToCertificate is deprecated, but we still need to copy it if the user sets it. //nolint:staticcheck NameToCertificate: *(*map[string]*qtls.Certificate)(unsafe.Pointer(&c.NameToCertificate)), - GetCertificate: *(*func(*qtls.ClientHelloInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetCertificate)), + GetCertificate: getCertificate, GetClientCertificate: *(*func(*qtls.CertificateRequestInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)), GetConfigForClient: getConfigForClient, VerifyPeerCertificate: c.VerifyPeerCertificate, diff --git a/internal/handshake/qtls_test.go b/internal/handshake/qtls_test.go index ea472e206..576607ba6 100644 --- a/internal/handshake/qtls_test.go +++ b/internal/handshake/qtls_test.go @@ -132,6 +132,47 @@ var _ = Describe("qtls.Config generation", func() { }) }) + Context("GetCertificate callback", func() { + It("returns a certificate", func() { + tlsConf := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil + }, + } + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + qtlsCert, err := qtlsConf.GetCertificate(nil) + Expect(err).ToNot(HaveOccurred()) + Expect(qtlsCert).ToNot(BeNil()) + Expect(qtlsCert.Certificate).To(Equal([][]byte{[]byte("foo"), []byte("bar")})) + }) + + It("doesn't set it if absent", func() { + qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + Expect(qtlsConf.GetCertificate).To(BeNil()) + }) + + It("returns errors", func() { + tlsConf := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return nil, errors.New("test") + }, + } + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + _, err := qtlsConf.GetCertificate(nil) + Expect(err).To(MatchError("test")) + }) + + It("returns nil when the callback returns nil", func() { + tlsConf := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return nil, nil + }, + } + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + Expect(qtlsConf.GetCertificate(nil)).To(BeNil()) + }) + }) + Context("ClientSessionCache", func() { It("doesn't set if absent", func() { qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) diff --git a/internal/handshake/unsafe.go b/internal/handshake/unsafe.go index 6aa5f35e5..c78031176 100644 --- a/internal/handshake/unsafe.go +++ b/internal/handshake/unsafe.go @@ -81,6 +81,9 @@ type qtlsClientHelloInfo struct { } func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo { + if chi == nil { + return nil + } qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi)) var config *tls.Config if qtlsCHI.config != nil {