forked from quic-go/quic-go
fix conversion of qtls.ClientHelloInfo in GetCertificate
This commit is contained in:
@@ -62,11 +62,7 @@ func tlsConfigToQtlsConfig(
|
||||
var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error)
|
||||
if c.GetConfigForClient != nil {
|
||||
getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) {
|
||||
var chi *tls.ClientHelloInfo
|
||||
if ch != nil {
|
||||
chi = toTLSClientHelloInfo(ch)
|
||||
}
|
||||
tlsConf, err := c.GetConfigForClient(chi)
|
||||
tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -76,6 +72,19 @@ func tlsConfigToQtlsConfig(
|
||||
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, rttStats, getDataForSessionState, setDataFromSessionState, accept0RTT, rejected0RTT, enable0RTT), nil
|
||||
}
|
||||
}
|
||||
var getCertificate func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error)
|
||||
if c.GetCertificate != nil {
|
||||
getCertificate = func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) {
|
||||
cert, err := c.GetCertificate(toTLSClientHelloInfo(ch))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cert == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return (*qtls.Certificate)(unsafe.Pointer(cert)), nil
|
||||
}
|
||||
}
|
||||
var csc qtls.ClientSessionCache
|
||||
if c.ClientSessionCache != nil {
|
||||
csc = newClientSessionCache(c.ClientSessionCache, rttStats, getDataForSessionState, setDataFromSessionState)
|
||||
@@ -87,7 +96,7 @@ func tlsConfigToQtlsConfig(
|
||||
// NameToCertificate is deprecated, but we still need to copy it if the user sets it.
|
||||
//nolint:staticcheck
|
||||
NameToCertificate: *(*map[string]*qtls.Certificate)(unsafe.Pointer(&c.NameToCertificate)),
|
||||
GetCertificate: *(*func(*qtls.ClientHelloInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetCertificate)),
|
||||
GetCertificate: getCertificate,
|
||||
GetClientCertificate: *(*func(*qtls.CertificateRequestInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)),
|
||||
GetConfigForClient: getConfigForClient,
|
||||
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||
|
||||
@@ -132,6 +132,47 @@ var _ = Describe("qtls.Config generation", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetCertificate callback", func() {
|
||||
It("returns a certificate", func() {
|
||||
tlsConf := &tls.Config{
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsCert, err := qtlsConf.GetCertificate(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(qtlsCert).ToNot(BeNil())
|
||||
Expect(qtlsCert.Certificate).To(Equal([][]byte{[]byte("foo"), []byte("bar")}))
|
||||
})
|
||||
|
||||
It("doesn't set it if absent", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
Expect(qtlsConf.GetCertificate).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns errors", func() {
|
||||
tlsConf := &tls.Config{
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, errors.New("test")
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
_, err := qtlsConf.GetCertificate(nil)
|
||||
Expect(err).To(MatchError("test"))
|
||||
})
|
||||
|
||||
It("returns nil when the callback returns nil", func() {
|
||||
tlsConf := &tls.Config{
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
Expect(qtlsConf.GetCertificate(nil)).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ClientSessionCache", func() {
|
||||
It("doesn't set if absent", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
|
||||
@@ -81,6 +81,9 @@ type qtlsClientHelloInfo struct {
|
||||
}
|
||||
|
||||
func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo {
|
||||
if chi == nil {
|
||||
return nil
|
||||
}
|
||||
qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi))
|
||||
var config *tls.Config
|
||||
if qtlsCHI.config != nil {
|
||||
|
||||
Reference in New Issue
Block a user