fix conversion of qtls.ClientHelloInfo in GetConfigForClient

This commit is contained in:
Marten Seemann
2020-04-08 15:56:30 +07:00
parent 799d80197f
commit 66d50b4289
3 changed files with 16 additions and 2 deletions

View File

@@ -62,7 +62,11 @@ func tlsConfigToQtlsConfig(
var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error)
if c.GetConfigForClient != nil { if c.GetConfigForClient != nil {
getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) { getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) {
tlsConf, err := c.GetConfigForClient((*tls.ClientHelloInfo)(unsafe.Pointer(ch))) var chi *tls.ClientHelloInfo
if ch != nil {
chi = toTLSClientHelloInfo(ch)
}
tlsConf, err := c.GetConfigForClient(chi)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -82,6 +82,10 @@ type qtlsClientHelloInfo struct {
func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo { func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo {
qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi)) qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi))
var config *tls.Config
if qtlsCHI.config != nil {
config = qtlsConfigToTLSConfig((*qtls.Config)(unsafe.Pointer(qtlsCHI.config)))
}
return (*tls.ClientHelloInfo)(unsafe.Pointer(&clientHelloInfo{ return (*tls.ClientHelloInfo)(unsafe.Pointer(&clientHelloInfo{
CipherSuites: chi.CipherSuites, CipherSuites: chi.CipherSuites,
ServerName: chi.ServerName, ServerName: chi.ServerName,
@@ -91,6 +95,6 @@ func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo {
SupportedProtos: chi.SupportedProtos, SupportedProtos: chi.SupportedProtos,
SupportedVersions: chi.SupportedVersions, SupportedVersions: chi.SupportedVersions,
Conn: chi.Conn, Conn: chi.Conn,
config: qtlsConfigToTLSConfig((*qtls.Config)(unsafe.Pointer(qtlsCHI.config))), config: config,
})) }))
} }

View File

@@ -95,4 +95,10 @@ var _ = Describe("Unsafe checks", func() {
Expect(c.config.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12)) Expect(c.config.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12))
Expect(c.config.CurvePreferences).To(Equal([]tls.CurveID{19, 20, 21})) Expect(c.config.CurvePreferences).To(Equal([]tls.CurveID{19, 20, 21}))
}) })
It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() {
chi := &qtlsClientHelloInfo{CipherSuites: []uint16{13, 37}}
tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi)))
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{13, 37}))
})
}) })