diff --git a/internal/crypto/key_derivation.go b/internal/crypto/key_derivation.go index f135db271..316bd1b3b 100644 --- a/internal/crypto/key_derivation.go +++ b/internal/crypto/key_derivation.go @@ -10,15 +10,14 @@ const ( serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret" ) -// MintController is an interface that bundles all methods needed to interact with mint -type MintController interface { - Handshake() mint.Alert +// A TLSExporter gets the negotiated ciphersuite and computes exporter +type TLSExporter interface { GetCipherSuite() mint.CipherSuiteParams ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) } // DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance -func DeriveAESKeys(mc MintController, pers protocol.Perspective) (AEAD, error) { +func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) { var myLabel, otherLabel string if pers == protocol.PerspectiveClient { myLabel = clientExporterLabel @@ -27,20 +26,20 @@ func DeriveAESKeys(mc MintController, pers protocol.Perspective) (AEAD, error) { myLabel = serverExporterLabel otherLabel = clientExporterLabel } - myKey, myIV, err := computeKeyAndIV(mc, myLabel) + myKey, myIV, err := computeKeyAndIV(tls, myLabel) if err != nil { return nil, err } - otherKey, otherIV, err := computeKeyAndIV(mc, otherLabel) + otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel) if err != nil { return nil, err } return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) } -func computeKeyAndIV(mc MintController, label string) (key, iv []byte, err error) { - cs := mc.GetCipherSuite() - secret, err := mc.ComputeExporter(label, nil, cs.Hash.Size()) +func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) { + cs := tls.GetCipherSuite() + secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size()) if err != nil { return nil, nil, err } diff --git a/internal/crypto/key_derivation_test.go b/internal/crypto/key_derivation_test.go index 5f3667569..6b2499d75 100644 --- a/internal/crypto/key_derivation_test.go +++ b/internal/crypto/key_derivation_test.go @@ -10,16 +10,16 @@ import ( . "github.com/onsi/gomega" ) -type mockMintController struct { +type mockTLSExporter struct { hash crypto.Hash computerError error } -var _ MintController = &mockMintController{} +var _ TLSExporter = &mockTLSExporter{} -func (c *mockMintController) Handshake() mint.Alert { panic("not implemented") } +func (c *mockTLSExporter) Handshake() mint.Alert { panic("not implemented") } -func (c *mockMintController) GetCipherSuite() mint.CipherSuiteParams { +func (c *mockTLSExporter) GetCipherSuite() mint.CipherSuiteParams { return mint.CipherSuiteParams{ Hash: c.hash, KeyLen: 32, @@ -27,7 +27,7 @@ func (c *mockMintController) GetCipherSuite() mint.CipherSuiteParams { } } -func (c *mockMintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { +func (c *mockTLSExporter) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { if c.computerError != nil { return nil, c.computerError } @@ -36,9 +36,9 @@ func (c *mockMintController) ComputeExporter(label string, context []byte, keyLe var _ = Describe("Key Derivation", func() { It("derives keys", func() { - clientAEAD, err := DeriveAESKeys(&mockMintController{hash: crypto.SHA256}, protocol.PerspectiveClient) + clientAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) - serverAEAD, err := DeriveAESKeys(&mockMintController{hash: crypto.SHA256}, protocol.PerspectiveServer) + serverAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) ciphertext := clientAEAD.Seal(nil, []byte("foobar"), 0, []byte("aad")) data, err := serverAEAD.Open(nil, ciphertext, 0, []byte("aad")) @@ -47,9 +47,9 @@ var _ = Describe("Key Derivation", func() { }) It("fails when different hash functions are used", func() { - clientAEAD, err := DeriveAESKeys(&mockMintController{hash: crypto.SHA256}, protocol.PerspectiveClient) + clientAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) - serverAEAD, err := DeriveAESKeys(&mockMintController{hash: crypto.SHA512}, protocol.PerspectiveServer) + serverAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA512}, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) ciphertext := clientAEAD.Seal(nil, []byte("foobar"), 0, []byte("aad")) _, err = serverAEAD.Open(nil, ciphertext, 0, []byte("aad")) @@ -58,7 +58,7 @@ var _ = Describe("Key Derivation", func() { It("fails when computing the exporter fails", func() { testErr := errors.New("test error") - _, err := DeriveAESKeys(&mockMintController{hash: crypto.SHA256, computerError: testErr}, protocol.PerspectiveClient) + _, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256, computerError: testErr}, protocol.PerspectiveClient) Expect(err).To(MatchError(testErr)) }) }) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 00276d64f..e360fbf6c 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -12,7 +12,7 @@ import ( ) // KeyDerivationFunction is used for key derivation -type KeyDerivationFunction func(crypto.MintController, protocol.Perspective) (crypto.AEAD, error) +type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) type cryptoSetupTLS struct { mutex sync.RWMutex @@ -21,8 +21,7 @@ type cryptoSetupTLS struct { keyDerivation KeyDerivationFunction - conn *mint.Conn - extensionHandler mint.AppExtensionHandler + tls mintTLS nullAEAD crypto.AEAD aead crypto.AEAD @@ -30,10 +29,6 @@ type cryptoSetupTLS struct { aeadChanged chan<- protocol.EncryptionLevel } -var newMintController = func(conn *mint.Conn) crypto.MintController { - return &mintController{conn} -} - // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server func NewCryptoSetupTLSServer( cryptoStream io.ReadWriter, @@ -48,14 +43,18 @@ func NewCryptoSetupTLSServer( if err != nil { return nil, err } + mintConn := mint.Server(&fakeConn{cryptoStream}, mintConf) + eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version) + if err := mintConn.SetExtensionHandler(eh); err != nil { + return nil, err + } return &cryptoSetupTLS{ - perspective: protocol.PerspectiveServer, - conn: mint.Server(&fakeConn{cryptoStream}, mintConf), - nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), - keyDerivation: crypto.DeriveAESKeys, - aeadChanged: aeadChanged, - extensionHandler: newExtensionHandlerServer(params, paramsChan, supportedVersions, version), + perspective: protocol.PerspectiveServer, + tls: &mintController{mintConn}, + nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), + keyDerivation: crypto.DeriveAESKeys, + aeadChanged: aeadChanged, }, nil } @@ -76,28 +75,28 @@ func NewCryptoSetupTLSClient( return nil, err } mintConf.ServerName = hostname + mintConn := mint.Client(&fakeConn{cryptoStream}, mintConf) + eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version) + if err := mintConn.SetExtensionHandler(eh); err != nil { + return nil, err + } return &cryptoSetupTLS{ - perspective: protocol.PerspectiveClient, - conn: mint.Client(&fakeConn{cryptoStream}, mintConf), - nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), - keyDerivation: crypto.DeriveAESKeys, - aeadChanged: aeadChanged, - extensionHandler: newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version), + perspective: protocol.PerspectiveClient, + tls: &mintController{mintConn}, + nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), + keyDerivation: crypto.DeriveAESKeys, + aeadChanged: aeadChanged, }, nil } func (h *cryptoSetupTLS) HandleCryptoStream() error { - if err := h.conn.SetExtensionHandler(h.extensionHandler); err != nil { - return err - } - mc := newMintController(h.conn) - if alert := mc.Handshake(); alert != mint.AlertNoAlert { + if alert := h.tls.Handshake(); alert != mint.AlertNoAlert { return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) } - aead, err := h.keyDerivation(mc, h.perspective) + aead, err := h.keyDerivation(h.tls, h.perspective) if err != nil { return err } diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 0ce65f206..6a33a2613 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -12,21 +12,24 @@ import ( . "github.com/onsi/gomega" ) -type fakeMintController struct { +type fakeMintTLS struct { result mint.Alert } -var _ crypto.MintController = &fakeMintController{} +var _ mintTLS = &fakeMintTLS{} -func (h *fakeMintController) Handshake() mint.Alert { +func (h *fakeMintTLS) Handshake() mint.Alert { return h.result } -func (h *fakeMintController) GetCipherSuite() mint.CipherSuiteParams { panic("not implemented") } -func (h *fakeMintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { +func (h *fakeMintTLS) GetCipherSuite() mint.CipherSuiteParams { panic("not implemented") } +func (h *fakeMintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { + panic("not implemented") +} +func (h *fakeMintTLS) SetExtensionHandler(mint.AppExtensionHandler) error { panic("not implemented") } -func mockKeyDerivation(crypto.MintController, protocol.Perspective) (crypto.AEAD, error) { +func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) { return &mockAEAD{encLevel: protocol.EncryptionForwardSecure}, nil } @@ -35,8 +38,6 @@ var _ = Describe("TLS Crypto Setup", func() { cs *cryptoSetupTLS paramsChan chan TransportParameters aeadChanged chan protocol.EncryptionLevel - - mintControllerConstructor = newMintController ) BeforeEach(func() { @@ -55,23 +56,15 @@ var _ = Describe("TLS Crypto Setup", func() { cs = csInt.(*cryptoSetupTLS) }) - AfterEach(func() { - newMintController = mintControllerConstructor - }) - It("errors when the handshake fails", func() { alert := mint.AlertBadRecordMAC - newMintController = func(*mint.Conn) crypto.MintController { - return &fakeMintController{result: alert} - } + cs.tls = &fakeMintTLS{result: alert} err := cs.HandleCryptoStream() Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert))) }) It("derives keys", func() { - newMintController = func(*mint.Conn) crypto.MintController { - return &fakeMintController{result: mint.AlertNoAlert} - } + cs.tls = &fakeMintTLS{result: mint.AlertNoAlert} cs.keyDerivation = mockKeyDerivation err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) @@ -83,9 +76,7 @@ var _ = Describe("TLS Crypto Setup", func() { var foobarFNVSigned []byte // a "foobar", FNV signed doHandshake := func() { - newMintController = func(*mint.Conn) crypto.MintController { - return &fakeMintController{result: mint.AlertNoAlert} - } + cs.tls = &fakeMintTLS{result: mint.AlertNoAlert} cs.keyDerivation = mockKeyDerivation err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) diff --git a/internal/handshake/mint_utils.go b/internal/handshake/mint_utils.go index f35b643b0..63e8d8c4d 100644 --- a/internal/handshake/mint_utils.go +++ b/internal/handshake/mint_utils.go @@ -43,15 +43,16 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf return mconf, nil } +type mintTLS interface { + crypto.TLSExporter + Handshake() mint.Alert +} + type mintController struct { conn *mint.Conn } -var _ crypto.MintController = &mintController{} - -func (mc *mintController) Handshake() mint.Alert { - return mc.conn.Handshake() -} +var _ mintTLS = &mintController{} func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams { return mc.conn.State().CipherSuite @@ -61,6 +62,10 @@ func (mc *mintController) ComputeExporter(label string, context []byte, keyLengt return mc.conn.ComputeExporter(label, context, keyLength) } +func (mc *mintController) Handshake() mint.Alert { + return mc.conn.Handshake() +} + // mint expects a net.Conn, but we're doing the handshake on a stream // so we wrap a stream such that implements a net.Conn type fakeConn struct {