forked from quic-go/quic-go
copy the GetConfigForClient callback when creating the qtls.Config
This commit is contained in:
@@ -198,10 +198,7 @@ func newCryptoSetup(
|
|||||||
receivedWriteKey: make(chan struct{}),
|
receivedWriteKey: make(chan struct{}),
|
||||||
closeChan: make(chan struct{}),
|
closeChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf)
|
qtlsConf := cs.tlsConfigToQtlsConfig(tlsConf)
|
||||||
qtlsConf.AlternativeRecordLayer = cs
|
|
||||||
qtlsConf.GetExtensions = extHandler.GetExtensions
|
|
||||||
qtlsConf.ReceivedExtensions = extHandler.ReceivedExtensions
|
|
||||||
cs.tlsConf = qtlsConf
|
cs.tlsConf = qtlsConf
|
||||||
return cs, cs.clientHelloWrittenChan, nil
|
return cs, cs.clientHelloWrittenChan, nil
|
||||||
}
|
}
|
||||||
@@ -528,3 +525,60 @@ func (h *cryptoSetup) ConnectionState() ConnectionState {
|
|||||||
PeerCertificates: connState.PeerCertificates,
|
PeerCertificates: connState.PeerCertificates,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
|
||||||
|
if c == nil {
|
||||||
|
c = &tls.Config{}
|
||||||
|
}
|
||||||
|
// QUIC requires TLS 1.3 or newer
|
||||||
|
minVersion := c.MinVersion
|
||||||
|
if minVersion < qtls.VersionTLS13 {
|
||||||
|
minVersion = qtls.VersionTLS13
|
||||||
|
}
|
||||||
|
maxVersion := c.MaxVersion
|
||||||
|
if maxVersion < qtls.VersionTLS13 {
|
||||||
|
maxVersion = qtls.VersionTLS13
|
||||||
|
}
|
||||||
|
var getConfigForClient func(ch *tls.ClientHelloInfo) (*qtls.Config, error)
|
||||||
|
if c.GetConfigForClient != nil {
|
||||||
|
getConfigForClient = func(ch *tls.ClientHelloInfo) (*qtls.Config, error) {
|
||||||
|
tlsConf, err := c.GetConfigForClient(ch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tlsConf == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return h.tlsConfigToQtlsConfig(tlsConf), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &qtls.Config{
|
||||||
|
Rand: c.Rand,
|
||||||
|
Time: c.Time,
|
||||||
|
Certificates: c.Certificates,
|
||||||
|
NameToCertificate: c.NameToCertificate,
|
||||||
|
GetCertificate: c.GetCertificate,
|
||||||
|
GetClientCertificate: c.GetClientCertificate,
|
||||||
|
GetConfigForClient: getConfigForClient,
|
||||||
|
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||||
|
RootCAs: c.RootCAs,
|
||||||
|
NextProtos: c.NextProtos,
|
||||||
|
ServerName: c.ServerName,
|
||||||
|
ClientAuth: c.ClientAuth,
|
||||||
|
ClientCAs: c.ClientCAs,
|
||||||
|
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||||
|
CipherSuites: c.CipherSuites,
|
||||||
|
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||||
|
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||||
|
SessionTicketKey: c.SessionTicketKey,
|
||||||
|
MinVersion: minVersion,
|
||||||
|
MaxVersion: maxVersion,
|
||||||
|
CurvePreferences: c.CurvePreferences,
|
||||||
|
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||||
|
// no need to copy Renegotiation, it's not supported by TLS 1.3
|
||||||
|
KeyLogWriter: c.KeyLogWriter,
|
||||||
|
AlternativeRecordLayer: h,
|
||||||
|
GetExtensions: h.extHandler.GetExtensions,
|
||||||
|
ReceivedExtensions: h.extHandler.ReceivedExtensions,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
|
"errors"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/big"
|
"math/big"
|
||||||
"time"
|
"time"
|
||||||
@@ -65,6 +66,47 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("creates a qtls.Config", func() {
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
ServerName: "quic.clemente.io",
|
||||||
|
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
return nil, errors.New("GetCertificate")
|
||||||
|
},
|
||||||
|
GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
|
return nil, errors.New("GetClientCertificate")
|
||||||
|
},
|
||||||
|
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return &tls.Config{ServerName: ch.ServerName}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server, err := NewCryptoSetupServer(
|
||||||
|
&bytes.Buffer{},
|
||||||
|
&bytes.Buffer{},
|
||||||
|
ioutil.Discard,
|
||||||
|
protocol.ConnectionID{},
|
||||||
|
&EncryptedExtensionsTransportParameters{
|
||||||
|
NegotiatedVersion: protocol.VersionTLS,
|
||||||
|
SupportedVersions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||||
|
},
|
||||||
|
func([]byte) {},
|
||||||
|
tlsConf,
|
||||||
|
utils.DefaultLogger.WithPrefix("server"),
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
qtlsConf := server.(*cryptoSetup).tlsConf
|
||||||
|
Expect(qtlsConf.ServerName).To(Equal(tlsConf.ServerName))
|
||||||
|
_, getCertificateErr := qtlsConf.GetCertificate(nil)
|
||||||
|
Expect(getCertificateErr).To(MatchError("GetCertificate"))
|
||||||
|
_, getClientCertificateErr := qtlsConf.GetClientCertificate(nil)
|
||||||
|
Expect(getClientCertificateErr).To(MatchError("GetClientCertificate"))
|
||||||
|
cconf, err := qtlsConf.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "foo.bar"})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(cconf.ServerName).To(Equal("foo.bar"))
|
||||||
|
Expect(cconf.AlternativeRecordLayer).ToNot(BeNil())
|
||||||
|
Expect(cconf.GetExtensions).ToNot(BeNil())
|
||||||
|
Expect(cconf.ReceivedExtensions).ToNot(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
It("returns Handshake() when an error occurs", func() {
|
It("returns Handshake() when an error occurs", func() {
|
||||||
_, sInitialStream, sHandshakeStream := initStreams()
|
_, sInitialStream, sHandshakeStream := initStreams()
|
||||||
server, err := NewCryptoSetupServer(
|
server, err := NewCryptoSetupServer(
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
|
|
||||||
"github.com/marten-seemann/qtls"
|
|
||||||
)
|
|
||||||
|
|
||||||
func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
|
|
||||||
if c == nil {
|
|
||||||
c = &tls.Config{}
|
|
||||||
}
|
|
||||||
// QUIC requires TLS 1.3 or newer
|
|
||||||
minVersion := c.MinVersion
|
|
||||||
if minVersion < qtls.VersionTLS13 {
|
|
||||||
minVersion = qtls.VersionTLS13
|
|
||||||
}
|
|
||||||
maxVersion := c.MaxVersion
|
|
||||||
if maxVersion < qtls.VersionTLS13 {
|
|
||||||
maxVersion = qtls.VersionTLS13
|
|
||||||
}
|
|
||||||
return &qtls.Config{
|
|
||||||
Rand: c.Rand,
|
|
||||||
Time: c.Time,
|
|
||||||
Certificates: c.Certificates,
|
|
||||||
NameToCertificate: c.NameToCertificate,
|
|
||||||
GetCertificate: c.GetCertificate,
|
|
||||||
GetClientCertificate: c.GetClientCertificate,
|
|
||||||
// TODO: make GetConfigForClient work
|
|
||||||
// GetConfigForClient: c.GetConfigForClient,
|
|
||||||
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
|
||||||
RootCAs: c.RootCAs,
|
|
||||||
NextProtos: c.NextProtos,
|
|
||||||
ServerName: c.ServerName,
|
|
||||||
ClientAuth: c.ClientAuth,
|
|
||||||
ClientCAs: c.ClientCAs,
|
|
||||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
|
||||||
CipherSuites: c.CipherSuites,
|
|
||||||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
|
||||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
|
||||||
SessionTicketKey: c.SessionTicketKey,
|
|
||||||
MinVersion: minVersion,
|
|
||||||
MaxVersion: maxVersion,
|
|
||||||
CurvePreferences: c.CurvePreferences,
|
|
||||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
|
||||||
// Renegotiation is not supported by TLS 1.3
|
|
||||||
KeyLogWriter: c.KeyLogWriter,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user