diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go index 85fea0cdb..98b3f4102 100644 --- a/internal/handshake/qtls.go +++ b/internal/handshake/qtls.go @@ -62,7 +62,7 @@ func tlsConfigToQtlsConfig( var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) if c.GetConfigForClient != nil { getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) { - tlsConf, err := c.GetConfigForClient((*tls.ClientHelloInfo)(unsafe.Pointer(ch))) + tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch)) if err != nil { return nil, err } @@ -72,6 +72,19 @@ func tlsConfigToQtlsConfig( return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, rttStats, getDataForSessionState, setDataFromSessionState, accept0RTT, rejected0RTT, enable0RTT), nil } } + var getCertificate func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) + if c.GetCertificate != nil { + getCertificate = func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) { + cert, err := c.GetCertificate(toTLSClientHelloInfo(ch)) + if err != nil { + return nil, err + } + if cert == nil { + return nil, nil + } + return (*qtls.Certificate)(unsafe.Pointer(cert)), nil + } + } var csc qtls.ClientSessionCache if c.ClientSessionCache != nil { csc = newClientSessionCache(c.ClientSessionCache, rttStats, getDataForSessionState, setDataFromSessionState) @@ -83,7 +96,7 @@ func tlsConfigToQtlsConfig( // NameToCertificate is deprecated, but we still need to copy it if the user sets it. //nolint:staticcheck NameToCertificate: *(*map[string]*qtls.Certificate)(unsafe.Pointer(&c.NameToCertificate)), - GetCertificate: *(*func(*qtls.ClientHelloInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetCertificate)), + GetCertificate: getCertificate, GetClientCertificate: *(*func(*qtls.CertificateRequestInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)), GetConfigForClient: getConfigForClient, VerifyPeerCertificate: c.VerifyPeerCertificate, diff --git a/internal/handshake/qtls_test.go b/internal/handshake/qtls_test.go index ea472e206..576607ba6 100644 --- a/internal/handshake/qtls_test.go +++ b/internal/handshake/qtls_test.go @@ -132,6 +132,47 @@ var _ = Describe("qtls.Config generation", func() { }) }) + Context("GetCertificate callback", func() { + It("returns a certificate", func() { + tlsConf := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil + }, + } + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + qtlsCert, err := qtlsConf.GetCertificate(nil) + Expect(err).ToNot(HaveOccurred()) + Expect(qtlsCert).ToNot(BeNil()) + Expect(qtlsCert.Certificate).To(Equal([][]byte{[]byte("foo"), []byte("bar")})) + }) + + It("doesn't set it if absent", func() { + qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + Expect(qtlsConf.GetCertificate).To(BeNil()) + }) + + It("returns errors", func() { + tlsConf := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return nil, errors.New("test") + }, + } + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + _, err := qtlsConf.GetCertificate(nil) + Expect(err).To(MatchError("test")) + }) + + It("returns nil when the callback returns nil", func() { + tlsConf := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return nil, nil + }, + } + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) + Expect(qtlsConf.GetCertificate(nil)).To(BeNil()) + }) + }) + Context("ClientSessionCache", func() { It("doesn't set if absent", func() { qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) diff --git a/internal/handshake/unsafe.go b/internal/handshake/unsafe.go index e44afb696..c78031176 100644 --- a/internal/handshake/unsafe.go +++ b/internal/handshake/unsafe.go @@ -81,7 +81,14 @@ type qtlsClientHelloInfo struct { } func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo { + if chi == nil { + return nil + } qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi)) + var config *tls.Config + if qtlsCHI.config != nil { + config = qtlsConfigToTLSConfig((*qtls.Config)(unsafe.Pointer(qtlsCHI.config))) + } return (*tls.ClientHelloInfo)(unsafe.Pointer(&clientHelloInfo{ CipherSuites: chi.CipherSuites, ServerName: chi.ServerName, @@ -91,6 +98,6 @@ func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo { SupportedProtos: chi.SupportedProtos, SupportedVersions: chi.SupportedVersions, Conn: chi.Conn, - config: qtlsConfigToTLSConfig((*qtls.Config)(unsafe.Pointer(qtlsCHI.config))), + config: config, })) } diff --git a/internal/handshake/unsafe_test.go b/internal/handshake/unsafe_test.go index 2b1a2df4f..772e99a6b 100644 --- a/internal/handshake/unsafe_test.go +++ b/internal/handshake/unsafe_test.go @@ -95,4 +95,10 @@ var _ = Describe("Unsafe checks", func() { Expect(c.config.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12)) Expect(c.config.CurvePreferences).To(Equal([]tls.CurveID{19, 20, 21})) }) + + It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() { + chi := &qtlsClientHelloInfo{CipherSuites: []uint16{13, 37}} + tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi))) + Expect(tlsCHI.CipherSuites).To(Equal([]uint16{13, 37})) + }) })