From 8fd2674ce41cb225d211d04faed450d134d95934 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 8 Apr 2020 16:19:45 +0700 Subject: [PATCH] fix conversion of qtls.ClientHelloInfo in GetCertificate --- internal/handshake/qtls.go | 21 ++++++++++++----- internal/handshake/qtls_test.go | 41 +++++++++++++++++++++++++++++++++ internal/handshake/unsafe.go | 3 +++ 3 files changed, 59 insertions(+), 6 deletions(-) diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go index 33c98904a..98b3f4102 100644 --- a/internal/handshake/qtls.go +++ b/internal/handshake/qtls.go @@ -62,11 +62,7 @@ func tlsConfigToQtlsConfig( var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) if c.GetConfigForClient != nil { getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) { - var chi *tls.ClientHelloInfo - if ch != nil { - chi = toTLSClientHelloInfo(ch) - } - tlsConf, err := c.GetConfigForClient(chi) + tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch)) if err != nil { return nil, err } @@ -76,6 +72,19 @@ func tlsConfigToQtlsConfig( return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, rttStats, getDataForSessionState, setDataFromSessionState, accept0RTT, rejected0RTT, enable0RTT), nil } } + var getCertificate func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) + if c.GetCertificate != nil { + getCertificate = func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) { + cert, err := c.GetCertificate(toTLSClientHelloInfo(ch)) + if err != nil { + return nil, err + } + if cert == nil { + return nil, nil + } + return (*qtls.Certificate)(unsafe.Pointer(cert)), nil + } + } var csc qtls.ClientSessionCache if c.ClientSessionCache != nil { csc = newClientSessionCache(c.ClientSessionCache, rttStats, getDataForSessionState, setDataFromSessionState) @@ -87,7 +96,7 @@ func tlsConfigToQtlsConfig( // NameToCertificate is deprecated, but we still need to copy it if the user sets it. //nolint:staticcheck NameToCertificate: *(*map[string]*qtls.Certificate)(unsafe.Pointer(&c.NameToCertificate)), - GetCertificate: *(*func(*qtls.ClientHelloInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetCertificate)), + GetCertificate: getCertificate, GetClientCertificate: *(*func(*qtls.CertificateRequestInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)), GetConfigForClient: getConfigForClient, VerifyPeerCertificate: c.VerifyPeerCertificate, diff --git a/internal/handshake/qtls_test.go b/internal/handshake/qtls_test.go index ea472e206..576607ba6 100644 --- a/internal/handshake/qtls_test.go +++ b/internal/handshake/qtls_test.go @@ -132,6 +132,47 @@ var _ = Describe("qtls.Config generation", func() { }) }) + Context("GetCertificate callback", func() { + It("returns a certificate", func() { + tlsConf := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil + }, + } + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + qtlsCert, err := qtlsConf.GetCertificate(nil) + Expect(err).ToNot(HaveOccurred()) + Expect(qtlsCert).ToNot(BeNil()) + Expect(qtlsCert.Certificate).To(Equal([][]byte{[]byte("foo"), []byte("bar")})) + }) + + It("doesn't set it if absent", func() { + qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + Expect(qtlsConf.GetCertificate).To(BeNil()) + }) + + It("returns errors", func() { + tlsConf := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return nil, errors.New("test") + }, + } + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + _, err := qtlsConf.GetCertificate(nil) + Expect(err).To(MatchError("test")) + }) + + It("returns nil when the callback returns nil", func() { + tlsConf := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return nil, nil + }, + } + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + Expect(qtlsConf.GetCertificate(nil)).To(BeNil()) + }) + }) + Context("ClientSessionCache", func() { It("doesn't set if absent", func() { qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) diff --git a/internal/handshake/unsafe.go b/internal/handshake/unsafe.go index 6aa5f35e5..c78031176 100644 --- a/internal/handshake/unsafe.go +++ b/internal/handshake/unsafe.go @@ -81,6 +81,9 @@ type qtlsClientHelloInfo struct { } func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo { + if chi == nil { + return nil + } qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi)) var config *tls.Config if qtlsCHI.config != nil {