forked from quic-go/quic-go
Merge pull request #2475 from lucas-clemente/fix-client-hello-conversion
fix conversion of qtls.ClientHelloInfo to tls.ClientHelloInfo
This commit is contained in:
@@ -62,7 +62,7 @@ func tlsConfigToQtlsConfig(
|
||||
var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error)
|
||||
if c.GetConfigForClient != nil {
|
||||
getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) {
|
||||
tlsConf, err := c.GetConfigForClient((*tls.ClientHelloInfo)(unsafe.Pointer(ch)))
|
||||
tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -72,6 +72,19 @@ func tlsConfigToQtlsConfig(
|
||||
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, rttStats, getDataForSessionState, setDataFromSessionState, accept0RTT, rejected0RTT, enable0RTT), nil
|
||||
}
|
||||
}
|
||||
var getCertificate func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error)
|
||||
if c.GetCertificate != nil {
|
||||
getCertificate = func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) {
|
||||
cert, err := c.GetCertificate(toTLSClientHelloInfo(ch))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cert == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return (*qtls.Certificate)(unsafe.Pointer(cert)), nil
|
||||
}
|
||||
}
|
||||
var csc qtls.ClientSessionCache
|
||||
if c.ClientSessionCache != nil {
|
||||
csc = newClientSessionCache(c.ClientSessionCache, rttStats, getDataForSessionState, setDataFromSessionState)
|
||||
@@ -83,7 +96,7 @@ func tlsConfigToQtlsConfig(
|
||||
// NameToCertificate is deprecated, but we still need to copy it if the user sets it.
|
||||
//nolint:staticcheck
|
||||
NameToCertificate: *(*map[string]*qtls.Certificate)(unsafe.Pointer(&c.NameToCertificate)),
|
||||
GetCertificate: *(*func(*qtls.ClientHelloInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetCertificate)),
|
||||
GetCertificate: getCertificate,
|
||||
GetClientCertificate: *(*func(*qtls.CertificateRequestInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)),
|
||||
GetConfigForClient: getConfigForClient,
|
||||
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||
|
||||
@@ -132,6 +132,47 @@ var _ = Describe("qtls.Config generation", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetCertificate callback", func() {
|
||||
It("returns a certificate", func() {
|
||||
tlsConf := &tls.Config{
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsCert, err := qtlsConf.GetCertificate(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(qtlsCert).ToNot(BeNil())
|
||||
Expect(qtlsCert.Certificate).To(Equal([][]byte{[]byte("foo"), []byte("bar")}))
|
||||
})
|
||||
|
||||
It("doesn't set it if absent", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
Expect(qtlsConf.GetCertificate).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns errors", func() {
|
||||
tlsConf := &tls.Config{
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, errors.New("test")
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
_, err := qtlsConf.GetCertificate(nil)
|
||||
Expect(err).To(MatchError("test"))
|
||||
})
|
||||
|
||||
It("returns nil when the callback returns nil", func() {
|
||||
tlsConf := &tls.Config{
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
Expect(qtlsConf.GetCertificate(nil)).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ClientSessionCache", func() {
|
||||
It("doesn't set if absent", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
|
||||
@@ -81,7 +81,14 @@ type qtlsClientHelloInfo struct {
|
||||
}
|
||||
|
||||
func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo {
|
||||
if chi == nil {
|
||||
return nil
|
||||
}
|
||||
qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi))
|
||||
var config *tls.Config
|
||||
if qtlsCHI.config != nil {
|
||||
config = qtlsConfigToTLSConfig((*qtls.Config)(unsafe.Pointer(qtlsCHI.config)))
|
||||
}
|
||||
return (*tls.ClientHelloInfo)(unsafe.Pointer(&clientHelloInfo{
|
||||
CipherSuites: chi.CipherSuites,
|
||||
ServerName: chi.ServerName,
|
||||
@@ -91,6 +98,6 @@ func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo {
|
||||
SupportedProtos: chi.SupportedProtos,
|
||||
SupportedVersions: chi.SupportedVersions,
|
||||
Conn: chi.Conn,
|
||||
config: qtlsConfigToTLSConfig((*qtls.Config)(unsafe.Pointer(qtlsCHI.config))),
|
||||
config: config,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -95,4 +95,10 @@ var _ = Describe("Unsafe checks", func() {
|
||||
Expect(c.config.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12))
|
||||
Expect(c.config.CurvePreferences).To(Equal([]tls.CurveID{19, 20, 21}))
|
||||
})
|
||||
|
||||
It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() {
|
||||
chi := &qtlsClientHelloInfo{CipherSuites: []uint16{13, 37}}
|
||||
tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi)))
|
||||
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{13, 37}))
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user