diff --git a/handshake/crypto_setup.go b/handshake/crypto_setup.go index 3b3fe6c70..4d04119f9 100644 --- a/handshake/crypto_setup.go +++ b/handshake/crypto_setup.go @@ -15,6 +15,9 @@ import ( // KeyDerivationFunction is used for key derivation type KeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte) (crypto.AEAD, error) +// KeyExchangeFunction is used to make a new KEX +type KeyExchangeFunction func() crypto.KeyExchange + // The CryptoSetup handles all things crypto for the Session type CryptoSetup struct { connID protocol.ConnectionID @@ -28,6 +31,7 @@ type CryptoSetup struct { receivedSecurePacket bool keyDerivation KeyDerivationFunction + keyExchange KeyExchangeFunction cryptoStream utils.Stream @@ -50,6 +54,7 @@ func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber scfg: scfg, nonce: nonce, keyDerivation: crypto.DeriveKeysChacha20, + keyExchange: crypto.NewCurve25519KEX, cryptoStream: cryptoStream, connectionParametersManager: connectionParametersManager, } @@ -181,8 +186,14 @@ func (h *CryptoSetup) handleCHLO(data []byte, cryptoData map[Tag][]byte) ([]byte if err != nil { return nil, err } - // TODO: Use new curve - h.forwardSecureAEAD, err = h.keyDerivation(true, sharedSecret, nonce.Bytes(), h.connID, data, h.scfg.Get(), h.scfg.signer.GetCertUncompressed()) + + // Generate a new curve instance to derive the forward secure key + ephermalKex := h.keyExchange() + ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS]) + if err != nil { + return nil, err + } + h.forwardSecureAEAD, err = h.keyDerivation(true, ephermalSharedSecret, nonce.Bytes(), h.connID, data, h.scfg.Get(), h.scfg.signer.GetCertUncompressed()) if err != nil { return nil, err } @@ -194,7 +205,7 @@ func (h *CryptoSetup) handleCHLO(data []byte, cryptoData map[Tag][]byte) ([]byte replyMap := h.connectionParametersManager.GetSHLOMap() // add crypto parameters - replyMap[TagPUBS] = h.scfg.kex.PublicKey() + replyMap[TagPUBS] = ephermalKex.PublicKey() replyMap[TagSNO] = h.nonce replyMap[TagVER] = protocol.SupportedVersionsAsTags diff --git a/handshake/crypto_setup_test.go b/handshake/crypto_setup_test.go index 35d3fc572..94c5ffdb0 100644 --- a/handshake/crypto_setup_test.go +++ b/handshake/crypto_setup_test.go @@ -11,12 +11,21 @@ import ( . "github.com/onsi/gomega" ) -type mockKEX struct{} - -func (*mockKEX) PublicKey() []byte { - return []byte("pubs-s") +type mockKEX struct { + ephermal bool } -func (*mockKEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) { + +func (m *mockKEX) PublicKey() []byte { + if m.ephermal { + return []byte("ephermal pub") + } + return []byte("initial public") +} + +func (m *mockKEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) { + if m.ephermal { + return []byte("shared ephermal"), nil + } return []byte("shared key"), nil } @@ -39,6 +48,7 @@ func (*mockSigner) GetCertUncompressed() []byte { type mockAEAD struct { forwardSecure bool + sharedSecret []byte } func (m *mockAEAD) Seal(packetNumber protocol.PacketNumber, associatedData []byte, plaintext []byte) []byte { @@ -58,7 +68,7 @@ func (m *mockAEAD) Open(packetNumber protocol.PacketNumber, associatedData []byt } func mockKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte) (crypto.AEAD, error) { - return &mockAEAD{forwardSecure: forwardSecure}, nil + return &mockAEAD{forwardSecure: forwardSecure, sharedSecret: sharedSecret}, nil } type mockStream struct { @@ -101,6 +111,7 @@ var _ = Describe("Crypto setup", func() { cpm = NewConnectionParamatersManager() cs = NewCryptoSetup(protocol.ConnectionID(42), v, scfg, stream, cpm) cs.keyDerivation = mockKeyDerivation + cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } }) It("has a nonce", func() { @@ -118,7 +129,7 @@ var _ = Describe("Crypto setup", func() { Expect(err).ToNot(HaveOccurred()) Expect(response).To(HavePrefix("REJ")) Expect(response).To(ContainSubstring("certcompressed")) - Expect(response).To(ContainSubstring("pubs-s")) + Expect(response).To(ContainSubstring("initial public")) Expect(signer.gotCHLO).To(BeTrue()) }) @@ -128,11 +139,15 @@ var _ = Describe("Crypto setup", func() { }) Expect(err).ToNot(HaveOccurred()) Expect(response).To(HavePrefix("SHLO")) - Expect(response).To(ContainSubstring("pubs-s")) // TODO: Should be new pubs + Expect(response).To(ContainSubstring("ephermal pub")) Expect(response).To(ContainSubstring(string(cs.nonce))) Expect(response).To(ContainSubstring(string(protocol.SupportedVersionsAsTags))) Expect(cs.secureAEAD).ToNot(BeNil()) + Expect(cs.secureAEAD.(*mockAEAD).forwardSecure).To(BeFalse()) + Expect(cs.secureAEAD.(*mockAEAD).sharedSecret).To(Equal([]byte("shared key"))) Expect(cs.forwardSecureAEAD).ToNot(BeNil()) + Expect(cs.forwardSecureAEAD.(*mockAEAD).sharedSecret).To(Equal([]byte("shared ephermal"))) + Expect(cs.forwardSecureAEAD.(*mockAEAD).forwardSecure).To(BeTrue()) }) It("handles long handshake", func() {