fix conversion of qtls.ClientHelloInfo in GetCertificate

This commit is contained in:
Marten Seemann
2020-04-08 16:19:45 +07:00
parent 66d50b4289
commit 8fd2674ce4
3 changed files with 59 additions and 6 deletions

View File

@@ -62,11 +62,7 @@ func tlsConfigToQtlsConfig(
var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error)
if c.GetConfigForClient != nil {
getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) {
var chi *tls.ClientHelloInfo
if ch != nil {
chi = toTLSClientHelloInfo(ch)
}
tlsConf, err := c.GetConfigForClient(chi)
tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch))
if err != nil {
return nil, err
}
@@ -76,6 +72,19 @@ func tlsConfigToQtlsConfig(
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, rttStats, getDataForSessionState, setDataFromSessionState, accept0RTT, rejected0RTT, enable0RTT), nil
}
}
var getCertificate func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error)
if c.GetCertificate != nil {
getCertificate = func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) {
cert, err := c.GetCertificate(toTLSClientHelloInfo(ch))
if err != nil {
return nil, err
}
if cert == nil {
return nil, nil
}
return (*qtls.Certificate)(unsafe.Pointer(cert)), nil
}
}
var csc qtls.ClientSessionCache
if c.ClientSessionCache != nil {
csc = newClientSessionCache(c.ClientSessionCache, rttStats, getDataForSessionState, setDataFromSessionState)
@@ -87,7 +96,7 @@ func tlsConfigToQtlsConfig(
// NameToCertificate is deprecated, but we still need to copy it if the user sets it.
//nolint:staticcheck
NameToCertificate: *(*map[string]*qtls.Certificate)(unsafe.Pointer(&c.NameToCertificate)),
GetCertificate: *(*func(*qtls.ClientHelloInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetCertificate)),
GetCertificate: getCertificate,
GetClientCertificate: *(*func(*qtls.CertificateRequestInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)),
GetConfigForClient: getConfigForClient,
VerifyPeerCertificate: c.VerifyPeerCertificate,

View File

@@ -132,6 +132,47 @@ var _ = Describe("qtls.Config generation", func() {
})
})
Context("GetCertificate callback", func() {
It("returns a certificate", func() {
tlsConf := &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
qtlsCert, err := qtlsConf.GetCertificate(nil)
Expect(err).ToNot(HaveOccurred())
Expect(qtlsCert).ToNot(BeNil())
Expect(qtlsCert.Certificate).To(Equal([][]byte{[]byte("foo"), []byte("bar")}))
})
It("doesn't set it if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetCertificate).To(BeNil())
})
It("returns errors", func() {
tlsConf := &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, errors.New("test")
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
_, err := qtlsConf.GetCertificate(nil)
Expect(err).To(MatchError("test"))
})
It("returns nil when the callback returns nil", func() {
tlsConf := &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetCertificate(nil)).To(BeNil())
})
})
Context("ClientSessionCache", func() {
It("doesn't set if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)

View File

@@ -81,6 +81,9 @@ type qtlsClientHelloInfo struct {
}
func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo {
if chi == nil {
return nil
}
qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi))
var config *tls.Config
if qtlsCHI.config != nil {