move generation of qtls.Config to a separate file, add tests

This commit is contained in:
Marten Seemann
2019-03-25 16:23:16 +01:00
parent 358fce241a
commit 9ddf9129cf
3 changed files with 172 additions and 63 deletions

View File

@@ -62,8 +62,7 @@ type cryptoSetup struct {
readEncLevel protocol.EncryptionLevel
writeEncLevel protocol.EncryptionLevel
extHandler tlsExtensionHandler
paramsChan <-chan []byte
handleParamsCallback func([]byte)
alertChan chan uint8
@@ -190,7 +189,7 @@ func newCryptoSetup(
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
handleParamsCallback: handleParams,
extHandler: extHandler,
paramsChan: extHandler.TransportParameters(),
logger: logger,
perspective: perspective,
handshakeDone: make(chan struct{}),
@@ -203,7 +202,7 @@ func newCryptoSetup(
writeRecord: make(chan struct{}),
closeChan: make(chan struct{}),
}
qtlsConf := cs.tlsConfigToQtlsConfig(tlsConf)
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler)
cs.tlsConf = qtlsConf
return cs, cs.clientHelloWrittenChan, nil
}
@@ -310,7 +309,7 @@ func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
// If it accepts the ClientHello, it will first read the transport parameters.
h.logger.Debugf("Sending HelloRetryRequest")
return false
case data := <-h.extHandler.TransportParameters():
case data := <-h.paramsChan:
h.handleParamsCallback(data)
case <-h.handshakeDone:
return false
@@ -374,7 +373,7 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
return true
case typeEncryptedExtensions:
select {
case data := <-h.extHandler.TransportParameters():
case data := <-h.paramsChan:
h.handleParamsCallback(data)
case <-h.handshakeDone:
return false
@@ -551,60 +550,3 @@ func (h *cryptoSetup) ConnectionState() tls.ConnectionState {
// In unsafe.go we check that the two objects are actually identical.
return *(*tls.ConnectionState)(unsafe.Pointer(&cs))
}
func (h *cryptoSetup) tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
if c == nil {
c = &tls.Config{}
}
// QUIC requires TLS 1.3 or newer
minVersion := c.MinVersion
if minVersion < qtls.VersionTLS13 {
minVersion = qtls.VersionTLS13
}
maxVersion := c.MaxVersion
if maxVersion < qtls.VersionTLS13 {
maxVersion = qtls.VersionTLS13
}
var getConfigForClient func(ch *tls.ClientHelloInfo) (*qtls.Config, error)
if c.GetConfigForClient != nil {
getConfigForClient = func(ch *tls.ClientHelloInfo) (*qtls.Config, error) {
tlsConf, err := c.GetConfigForClient(ch)
if err != nil {
return nil, err
}
if tlsConf == nil {
return nil, nil
}
return h.tlsConfigToQtlsConfig(tlsConf), nil
}
}
return &qtls.Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate,
GetClientCertificate: c.GetClientCertificate,
GetConfigForClient: getConfigForClient,
VerifyPeerCertificate: c.VerifyPeerCertificate,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
MinVersion: minVersion,
MaxVersion: maxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
// no need to copy Renegotiation, it's not supported by TLS 1.3
KeyLogWriter: c.KeyLogWriter,
AlternativeRecordLayer: h,
GetExtensions: h.extHandler.GetExtensions,
ReceivedExtensions: h.extHandler.ReceivedExtensions,
}
}

View File

@@ -0,0 +1,68 @@
package handshake
import (
"crypto/tls"
"github.com/marten-seemann/qtls"
)
func tlsConfigToQtlsConfig(
c *tls.Config,
recordLayer qtls.RecordLayer,
extHandler tlsExtensionHandler,
) *qtls.Config {
if c == nil {
c = &tls.Config{}
}
// QUIC requires TLS 1.3 or newer
minVersion := c.MinVersion
if minVersion < qtls.VersionTLS13 {
minVersion = qtls.VersionTLS13
}
maxVersion := c.MaxVersion
if maxVersion < qtls.VersionTLS13 {
maxVersion = qtls.VersionTLS13
}
var getConfigForClient func(ch *tls.ClientHelloInfo) (*qtls.Config, error)
if c.GetConfigForClient != nil {
getConfigForClient = func(ch *tls.ClientHelloInfo) (*qtls.Config, error) {
tlsConf, err := c.GetConfigForClient(ch)
if err != nil {
return nil, err
}
if tlsConf == nil {
return nil, nil
}
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler), nil
}
}
return &qtls.Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate,
GetClientCertificate: c.GetClientCertificate,
GetConfigForClient: getConfigForClient,
VerifyPeerCertificate: c.VerifyPeerCertificate,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
MinVersion: minVersion,
MaxVersion: maxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
// no need to copy Renegotiation, it's not supported by TLS 1.3
KeyLogWriter: c.KeyLogWriter,
AlternativeRecordLayer: recordLayer,
GetExtensions: extHandler.GetExtensions,
ReceivedExtensions: extHandler.ReceivedExtensions,
}
}

View File

@@ -0,0 +1,99 @@
package handshake
import (
"crypto/tls"
"errors"
"github.com/marten-seemann/qtls"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockExtensionHandler struct {
get, received bool
}
var _ tlsExtensionHandler = &mockExtensionHandler{}
func (h *mockExtensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
h.get = true
return nil
}
func (h *mockExtensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
h.received = true
}
func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") }
var _ = Describe("qtls.Config generation", func() {
It("sets MinVersion and MaxVersion", func() {
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
})
It("works when called with a nil config", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{})
Expect(qtlsConf).ToNot(BeNil())
})
It("sets the setter and getter function for TLS extensions", func() {
extHandler := &mockExtensionHandler{}
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler)
Expect(extHandler.get).To(BeFalse())
qtlsConf.GetExtensions(10)
Expect(extHandler.get).To(BeTrue())
Expect(extHandler.received).To(BeFalse())
qtlsConf.ReceivedExtensions(10, nil)
Expect(extHandler.received).To(BeTrue())
})
Context("GetConfigForClient callback", func() {
It("doesn't set it if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{})
Expect(qtlsConf.GetConfigForClient).To(BeNil())
})
It("returns a qtls.Config", func() {
tlsConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return &tls.Config{ServerName: "foo.bar"}, nil
},
}
extHandler := &mockExtensionHandler{}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler)
Expect(qtlsConf.GetConfigForClient).ToNot(BeNil())
confForClient, err := qtlsConf.GetConfigForClient(nil)
Expect(err).ToNot(HaveOccurred())
Expect(confForClient.ServerName).To(Equal("foo.bar"))
Expect(confForClient).ToNot(BeNil())
Expect(confForClient.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(confForClient.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(extHandler.get).To(BeFalse())
confForClient.GetExtensions(10)
Expect(extHandler.get).To(BeTrue())
})
It("returns errors", func() {
testErr := errors.New("test")
tlsConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return nil, testErr
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
_, err := qtlsConf.GetConfigForClient(nil)
Expect(err).To(MatchError(testErr))
})
It("returns nil when the callback returns nil", func() {
tlsConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return nil, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
})
})
})