From 6f3d3919df562299fe499b3eef30b80fca553c66 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 2 Oct 2017 11:11:42 +0700 Subject: [PATCH] use separate constructors for client and server for the TLS crypto setup --- internal/handshake/crypto_setup_tls.go | 58 +++++++++++++-------- internal/handshake/crypto_setup_tls_test.go | 6 +-- session.go | 11 ++-- 3 files changed, 43 insertions(+), 32 deletions(-) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 682cd162..64478865 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -35,36 +35,52 @@ var newMintController = func(conn *mint.Conn) crypto.MintController { return &mintController{conn} } -// NewCryptoSetupTLS creates a new CryptoSetup instance for a server -func NewCryptoSetupTLS( - hostname string, // only needed for the client - perspective protocol.Perspective, - version protocol.VersionNumber, +// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server +func NewCryptoSetupTLSServer( tlsConfig *tls.Config, transportParams *TransportParameters, aeadChanged chan<- protocol.EncryptionLevel, + version protocol.VersionNumber, ) (CryptoSetup, ParamsNegotiator, error) { - mintConf, err := tlsToMintConfig(tlsConfig, perspective) + mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer) + if err != nil { + return nil, nil, err + } + + params := newParamsNegotiator(protocol.PerspectiveServer, version, transportParams) + return &cryptoSetupTLS{ + perspective: protocol.PerspectiveServer, + mintConf: mintConf, + nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), + keyDerivation: crypto.DeriveAESKeys, + aeadChanged: aeadChanged, + extensionHandler: newExtensionHandlerServer(params), + }, params, nil +} + +// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client +func NewCryptoSetupTLSClient( + hostname string, // only needed for the client + tlsConfig *tls.Config, + transportParams *TransportParameters, + aeadChanged chan<- protocol.EncryptionLevel, + version protocol.VersionNumber, +) (CryptoSetup, ParamsNegotiator, error) { + mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient) if err != nil { return nil, nil, err } mintConf.ServerName = hostname - params := newParamsNegotiator(perspective, version, transportParams) - cs := &cryptoSetupTLS{ - perspective: perspective, - mintConf: mintConf, - nullAEAD: crypto.NewNullAEAD(perspective, version), - keyDerivation: crypto.DeriveAESKeys, - aeadChanged: aeadChanged, - } - if perspective == protocol.PerspectiveClient { - cs.extensionHandler = newExtensionHandlerClient(params) - } else { - cs.extensionHandler = newExtensionHandlerServer(params) - } - - return cs, params, nil + params := newParamsNegotiator(protocol.PerspectiveClient, version, transportParams) + return &cryptoSetupTLS{ + perspective: protocol.PerspectiveClient, + mintConf: mintConf, + nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), + keyDerivation: crypto.DeriveAESKeys, + aeadChanged: aeadChanged, + extensionHandler: newExtensionHandlerClient(params), + }, params, nil } func (h *cryptoSetupTLS) HandleCryptoStream(cryptoStream io.ReadWriter) error { diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 5c340f36..2524bb5a 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -40,13 +40,11 @@ var _ = Describe("TLS Crypto Setup", func() { BeforeEach(func() { aeadChanged = make(chan protocol.EncryptionLevel, 2) - csInt, _, err := NewCryptoSetupTLS( - "", - protocol.PerspectiveServer, - protocol.VersionTLS, + csInt, _, err := NewCryptoSetupTLSServer( testdata.GetTLSConfig(), &TransportParameters{}, aeadChanged, + protocol.VersionTLS, ) Expect(err).ToNot(HaveOccurred()) cs = csInt.(*cryptoSetupTLS) diff --git a/session.go b/session.go index b1875d77..2c46d318 100644 --- a/session.go +++ b/session.go @@ -194,13 +194,11 @@ func (s *session) setup( return s.config.AcceptCookie(clientAddr, cookie) } if s.version.UsesTLS() { - s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLS( - "", - s.perspective, - s.version, + s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLSServer( tlsConf, transportParams, aeadChanged, + s.version, ) } else { s.cryptoSetup, s.connParams, err = newCryptoSetup( @@ -216,13 +214,12 @@ func (s *session) setup( } } else { if s.version.UsesTLS() { - s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLS( + s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLSClient( hostname, - s.perspective, - s.version, tlsConf, transportParams, aeadChanged, + s.version, ) } else { transportParams.RequestConnectionIDOmission = s.config.RequestConnectionIDOmission