forked from quic-go/quic-go
move generation of qtls.Config to a separate file, add tests
This commit is contained in:
@@ -62,8 +62,7 @@ type cryptoSetup struct {
|
|||||||
readEncLevel protocol.EncryptionLevel
|
readEncLevel protocol.EncryptionLevel
|
||||||
writeEncLevel protocol.EncryptionLevel
|
writeEncLevel protocol.EncryptionLevel
|
||||||
|
|
||||||
extHandler tlsExtensionHandler
|
paramsChan <-chan []byte
|
||||||
|
|
||||||
handleParamsCallback func([]byte)
|
handleParamsCallback func([]byte)
|
||||||
|
|
||||||
alertChan chan uint8
|
alertChan chan uint8
|
||||||
@@ -190,7 +189,7 @@ func newCryptoSetup(
|
|||||||
readEncLevel: protocol.EncryptionInitial,
|
readEncLevel: protocol.EncryptionInitial,
|
||||||
writeEncLevel: protocol.EncryptionInitial,
|
writeEncLevel: protocol.EncryptionInitial,
|
||||||
handleParamsCallback: handleParams,
|
handleParamsCallback: handleParams,
|
||||||
extHandler: extHandler,
|
paramsChan: extHandler.TransportParameters(),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
perspective: perspective,
|
perspective: perspective,
|
||||||
handshakeDone: make(chan struct{}),
|
handshakeDone: make(chan struct{}),
|
||||||
@@ -203,7 +202,7 @@ func newCryptoSetup(
|
|||||||
writeRecord: make(chan struct{}),
|
writeRecord: make(chan struct{}),
|
||||||
closeChan: make(chan struct{}),
|
closeChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
qtlsConf := cs.tlsConfigToQtlsConfig(tlsConf)
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler)
|
||||||
cs.tlsConf = qtlsConf
|
cs.tlsConf = qtlsConf
|
||||||
return cs, cs.clientHelloWrittenChan, nil
|
return cs, cs.clientHelloWrittenChan, nil
|
||||||
}
|
}
|
||||||
@@ -310,7 +309,7 @@ func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
|
|||||||
// If it accepts the ClientHello, it will first read the transport parameters.
|
// If it accepts the ClientHello, it will first read the transport parameters.
|
||||||
h.logger.Debugf("Sending HelloRetryRequest")
|
h.logger.Debugf("Sending HelloRetryRequest")
|
||||||
return false
|
return false
|
||||||
case data := <-h.extHandler.TransportParameters():
|
case data := <-h.paramsChan:
|
||||||
h.handleParamsCallback(data)
|
h.handleParamsCallback(data)
|
||||||
case <-h.handshakeDone:
|
case <-h.handshakeDone:
|
||||||
return false
|
return false
|
||||||
@@ -374,7 +373,7 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
|
|||||||
return true
|
return true
|
||||||
case typeEncryptedExtensions:
|
case typeEncryptedExtensions:
|
||||||
select {
|
select {
|
||||||
case data := <-h.extHandler.TransportParameters():
|
case data := <-h.paramsChan:
|
||||||
h.handleParamsCallback(data)
|
h.handleParamsCallback(data)
|
||||||
case <-h.handshakeDone:
|
case <-h.handshakeDone:
|
||||||
return false
|
return false
|
||||||
@@ -551,60 +550,3 @@ func (h *cryptoSetup) ConnectionState() tls.ConnectionState {
|
|||||||
// In unsafe.go we check that the two objects are actually identical.
|
// In unsafe.go we check that the two objects are actually identical.
|
||||||
return *(*tls.ConnectionState)(unsafe.Pointer(&cs))
|
return *(*tls.ConnectionState)(unsafe.Pointer(&cs))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
|
|
||||||
if c == nil {
|
|
||||||
c = &tls.Config{}
|
|
||||||
}
|
|
||||||
// QUIC requires TLS 1.3 or newer
|
|
||||||
minVersion := c.MinVersion
|
|
||||||
if minVersion < qtls.VersionTLS13 {
|
|
||||||
minVersion = qtls.VersionTLS13
|
|
||||||
}
|
|
||||||
maxVersion := c.MaxVersion
|
|
||||||
if maxVersion < qtls.VersionTLS13 {
|
|
||||||
maxVersion = qtls.VersionTLS13
|
|
||||||
}
|
|
||||||
var getConfigForClient func(ch *tls.ClientHelloInfo) (*qtls.Config, error)
|
|
||||||
if c.GetConfigForClient != nil {
|
|
||||||
getConfigForClient = func(ch *tls.ClientHelloInfo) (*qtls.Config, error) {
|
|
||||||
tlsConf, err := c.GetConfigForClient(ch)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if tlsConf == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return h.tlsConfigToQtlsConfig(tlsConf), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &qtls.Config{
|
|
||||||
Rand: c.Rand,
|
|
||||||
Time: c.Time,
|
|
||||||
Certificates: c.Certificates,
|
|
||||||
NameToCertificate: c.NameToCertificate,
|
|
||||||
GetCertificate: c.GetCertificate,
|
|
||||||
GetClientCertificate: c.GetClientCertificate,
|
|
||||||
GetConfigForClient: getConfigForClient,
|
|
||||||
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
|
||||||
RootCAs: c.RootCAs,
|
|
||||||
NextProtos: c.NextProtos,
|
|
||||||
ServerName: c.ServerName,
|
|
||||||
ClientAuth: c.ClientAuth,
|
|
||||||
ClientCAs: c.ClientCAs,
|
|
||||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
|
||||||
CipherSuites: c.CipherSuites,
|
|
||||||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
|
||||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
|
||||||
SessionTicketKey: c.SessionTicketKey,
|
|
||||||
MinVersion: minVersion,
|
|
||||||
MaxVersion: maxVersion,
|
|
||||||
CurvePreferences: c.CurvePreferences,
|
|
||||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
|
||||||
// no need to copy Renegotiation, it's not supported by TLS 1.3
|
|
||||||
KeyLogWriter: c.KeyLogWriter,
|
|
||||||
AlternativeRecordLayer: h,
|
|
||||||
GetExtensions: h.extHandler.GetExtensions,
|
|
||||||
ReceivedExtensions: h.extHandler.ReceivedExtensions,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
68
internal/handshake/qtls.go
Normal file
68
internal/handshake/qtls.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
|
||||||
|
"github.com/marten-seemann/qtls"
|
||||||
|
)
|
||||||
|
|
||||||
|
func tlsConfigToQtlsConfig(
|
||||||
|
c *tls.Config,
|
||||||
|
recordLayer qtls.RecordLayer,
|
||||||
|
extHandler tlsExtensionHandler,
|
||||||
|
) *qtls.Config {
|
||||||
|
if c == nil {
|
||||||
|
c = &tls.Config{}
|
||||||
|
}
|
||||||
|
// QUIC requires TLS 1.3 or newer
|
||||||
|
minVersion := c.MinVersion
|
||||||
|
if minVersion < qtls.VersionTLS13 {
|
||||||
|
minVersion = qtls.VersionTLS13
|
||||||
|
}
|
||||||
|
maxVersion := c.MaxVersion
|
||||||
|
if maxVersion < qtls.VersionTLS13 {
|
||||||
|
maxVersion = qtls.VersionTLS13
|
||||||
|
}
|
||||||
|
var getConfigForClient func(ch *tls.ClientHelloInfo) (*qtls.Config, error)
|
||||||
|
if c.GetConfigForClient != nil {
|
||||||
|
getConfigForClient = func(ch *tls.ClientHelloInfo) (*qtls.Config, error) {
|
||||||
|
tlsConf, err := c.GetConfigForClient(ch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tlsConf == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &qtls.Config{
|
||||||
|
Rand: c.Rand,
|
||||||
|
Time: c.Time,
|
||||||
|
Certificates: c.Certificates,
|
||||||
|
NameToCertificate: c.NameToCertificate,
|
||||||
|
GetCertificate: c.GetCertificate,
|
||||||
|
GetClientCertificate: c.GetClientCertificate,
|
||||||
|
GetConfigForClient: getConfigForClient,
|
||||||
|
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||||
|
RootCAs: c.RootCAs,
|
||||||
|
NextProtos: c.NextProtos,
|
||||||
|
ServerName: c.ServerName,
|
||||||
|
ClientAuth: c.ClientAuth,
|
||||||
|
ClientCAs: c.ClientCAs,
|
||||||
|
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||||
|
CipherSuites: c.CipherSuites,
|
||||||
|
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||||
|
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||||
|
SessionTicketKey: c.SessionTicketKey,
|
||||||
|
MinVersion: minVersion,
|
||||||
|
MaxVersion: maxVersion,
|
||||||
|
CurvePreferences: c.CurvePreferences,
|
||||||
|
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||||
|
// no need to copy Renegotiation, it's not supported by TLS 1.3
|
||||||
|
KeyLogWriter: c.KeyLogWriter,
|
||||||
|
AlternativeRecordLayer: recordLayer,
|
||||||
|
GetExtensions: extHandler.GetExtensions,
|
||||||
|
ReceivedExtensions: extHandler.ReceivedExtensions,
|
||||||
|
}
|
||||||
|
}
|
||||||
99
internal/handshake/qtls_test.go
Normal file
99
internal/handshake/qtls_test.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/marten-seemann/qtls"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockExtensionHandler struct {
|
||||||
|
get, received bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ tlsExtensionHandler = &mockExtensionHandler{}
|
||||||
|
|
||||||
|
func (h *mockExtensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
|
||||||
|
h.get = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (h *mockExtensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
|
||||||
|
h.received = true
|
||||||
|
}
|
||||||
|
func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") }
|
||||||
|
|
||||||
|
var _ = Describe("qtls.Config generation", func() {
|
||||||
|
It("sets MinVersion and MaxVersion", func() {
|
||||||
|
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
|
||||||
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
||||||
|
Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||||
|
Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("works when called with a nil config", func() {
|
||||||
|
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{})
|
||||||
|
Expect(qtlsConf).ToNot(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sets the setter and getter function for TLS extensions", func() {
|
||||||
|
extHandler := &mockExtensionHandler{}
|
||||||
|
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler)
|
||||||
|
Expect(extHandler.get).To(BeFalse())
|
||||||
|
qtlsConf.GetExtensions(10)
|
||||||
|
Expect(extHandler.get).To(BeTrue())
|
||||||
|
Expect(extHandler.received).To(BeFalse())
|
||||||
|
qtlsConf.ReceivedExtensions(10, nil)
|
||||||
|
Expect(extHandler.received).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("GetConfigForClient callback", func() {
|
||||||
|
It("doesn't set it if absent", func() {
|
||||||
|
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{})
|
||||||
|
Expect(qtlsConf.GetConfigForClient).To(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns a qtls.Config", func() {
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return &tls.Config{ServerName: "foo.bar"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
extHandler := &mockExtensionHandler{}
|
||||||
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler)
|
||||||
|
Expect(qtlsConf.GetConfigForClient).ToNot(BeNil())
|
||||||
|
confForClient, err := qtlsConf.GetConfigForClient(nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(confForClient.ServerName).To(Equal("foo.bar"))
|
||||||
|
Expect(confForClient).ToNot(BeNil())
|
||||||
|
Expect(confForClient.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||||
|
Expect(confForClient.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||||
|
Expect(extHandler.get).To(BeFalse())
|
||||||
|
confForClient.GetExtensions(10)
|
||||||
|
Expect(extHandler.get).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns errors", func() {
|
||||||
|
testErr := errors.New("test")
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return nil, testErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
||||||
|
_, err := qtlsConf.GetConfigForClient(nil)
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns nil when the callback returns nil", func() {
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
||||||
|
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
Reference in New Issue
Block a user