From 1fb970cbac27c01352879e8d026b2ead29aaf3b5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 11 Jun 2019 20:25:51 +0800 Subject: [PATCH] perform a key update when receiving a packet with a different key phase --- internal/handshake/updatable_aead.go | 116 +++++++++++++----- internal/handshake/updatable_aead_test.go | 136 ++++++++++++++++------ 2 files changed, 182 insertions(+), 70 deletions(-) diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 3cf02997..f62d4a84 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -1,15 +1,30 @@ package handshake import ( + "crypto" "crypto/cipher" "encoding/binary" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/marten-seemann/qtls" ) type updatableAEAD struct { - openAEAD cipher.AEAD - sealAEAD cipher.AEAD + suite cipherSuite + + keyPhase protocol.KeyPhase + + prevRcvAEAD cipher.AEAD + + firstRcvdWithCurrentKey protocol.PacketNumber + firstSentWithCurrentKey protocol.PacketNumber + rcvAEAD cipher.AEAD + sendAEAD cipher.AEAD + + nextRcvAEAD cipher.AEAD + nextSendAEAD cipher.AEAD + nextRcvTrafficSecret []byte + nextSendTrafficSecret []byte hpDecrypter cipher.Block hpEncrypter cipher.Block @@ -23,63 +38,100 @@ var _ ShortHeaderOpener = &updatableAEAD{} var _ ShortHeaderSealer = &updatableAEAD{} func newUpdatableAEAD() *updatableAEAD { - return &updatableAEAD{} + return &updatableAEAD{ + firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, + firstSentWithCurrentKey: protocol.InvalidPacketNumber, + } } +func (a *updatableAEAD) rollKeys() { + a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber + a.firstSentWithCurrentKey = protocol.InvalidPacketNumber + a.keyPhase = a.keyPhase.Next() + a.prevRcvAEAD = a.rcvAEAD + a.rcvAEAD = a.nextRcvAEAD + a.sendAEAD = a.nextSendAEAD + + a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash(), a.nextRcvTrafficSecret) + a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash(), a.nextSendTrafficSecret) + a.nextRcvAEAD, _ = createAEAD(a.suite, a.nextRcvTrafficSecret) + a.nextSendAEAD, _ = createAEAD(a.suite, a.nextSendTrafficSecret) +} + +func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { + return qtls.HkdfExpandLabel(hash, ts, []byte{}, "traffic upd", hash.Size()) +} + +// For the client, this function is called before SetWriteKey. +// For the server, this function is called after SetWriteKey. func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { - aead, hpDecrypter := createAEAD(suite, trafficSecret) - if len(a.nonceBuf) == 0 { - a.nonceBuf = make([]byte, aead.NonceSize()) - } else if len(a.nonceBuf) != aead.NonceSize() { - panic("invalid nonce size") + a.rcvAEAD, a.hpDecrypter = createAEAD(suite, trafficSecret) + if a.suite == nil { + a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize()) + a.hpMask = make([]byte, a.hpDecrypter.BlockSize()) + a.suite = suite } - a.openAEAD = aead - if len(a.hpMask) == 0 { - a.hpMask = make([]byte, hpDecrypter.BlockSize()) - } else if len(a.hpMask) != hpDecrypter.BlockSize() { - panic("invalid header protection block size") - } - a.hpDecrypter = hpDecrypter + a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash(), trafficSecret) + a.nextRcvAEAD, _ = createAEAD(suite, a.nextRcvTrafficSecret) } +// For the client, this function is called after SetReadKey. +// For the server, this function is called before SetWriteKey. func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { - aead, hpEncrypter := createAEAD(suite, trafficSecret) - if len(a.nonceBuf) == 0 { - a.nonceBuf = make([]byte, aead.NonceSize()) - } else if len(a.nonceBuf) != aead.NonceSize() { - panic("invalid nonce size") + a.sendAEAD, a.hpEncrypter = createAEAD(suite, trafficSecret) + if a.suite == nil { + a.nonceBuf = make([]byte, a.sendAEAD.NonceSize()) + a.hpMask = make([]byte, a.hpEncrypter.BlockSize()) + a.suite = suite } - a.sealAEAD = aead - if len(a.hpMask) == 0 { - a.hpMask = make([]byte, hpEncrypter.BlockSize()) - } else if len(a.hpMask) != hpEncrypter.BlockSize() { - panic("invalid header protection block size") - } - a.hpEncrypter = hpEncrypter + a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash(), trafficSecret) + a.nextSendAEAD, _ = createAEAD(suite, a.nextSendTrafficSecret) } -func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, _ protocol.KeyPhase, ad []byte) ([]byte, error) { +func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhase, ad []byte) ([]byte, error) { binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) + if kp != a.keyPhase { + if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { + // TODO: check that prevRcv actually exists + // we updated the key, but the peer hasn't updated yet + return a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) + } + // try opening the packet with the next key phase + dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err == nil { + // if opening succeeds, roll over to the next key phase + a.rollKeys() + a.firstRcvdWithCurrentKey = pn + } + return dec, err + } // The AEAD we're using here will be the qtls.aeadAESGCM13. // It uses the nonce provided here and XOR it with the IV. - return a.openAEAD.Open(dst, a.nonceBuf, src, ad) + dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err == nil && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { + a.firstRcvdWithCurrentKey = pn + } + return dec, err } func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { + a.firstSentWithCurrentKey = pn + } binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) // The AEAD we're using here will be the qtls.aeadAESGCM13. // It uses the nonce provided here and XOR it with the IV. - return a.sealAEAD.Seal(dst, a.nonceBuf, src, ad) + return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) } func (a *updatableAEAD) KeyPhase() protocol.KeyPhase { - return protocol.KeyPhaseOne + return a.keyPhase } func (a *updatableAEAD) Overhead() int { - return a.sealAEAD.Overhead() + return a.sendAEAD.Overhead() } func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 05626de1..b72393b7 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -27,62 +27,122 @@ func (c *mockCipherSuite) AEAD(key, _ []byte) cipher.AEAD { } var _ = Describe("Updatable AEAD", func() { - getAEAD := func() *updatableAEAD { - trafficSecret := make([]byte, 16) - rand.Read(trafficSecret) + getPeers := func() (client, server *updatableAEAD) { + trafficSecret1 := make([]byte, 16) + trafficSecret2 := make([]byte, 16) + rand.Read(trafficSecret1) + rand.Read(trafficSecret2) - u := newUpdatableAEAD() - u.SetReadKey(&mockCipherSuite{}, trafficSecret) - u.SetWriteKey(&mockCipherSuite{}, trafficSecret) - return u + client = newUpdatableAEAD() + server = newUpdatableAEAD() + client.SetReadKey(&mockCipherSuite{}, trafficSecret2) + client.SetWriteKey(&mockCipherSuite{}, trafficSecret1) + server.SetReadKey(&mockCipherSuite{}, trafficSecret1) + server.SetWriteKey(&mockCipherSuite{}, trafficSecret2) + return } - Context("message encryption", func() { - msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad := []byte("Donec in velit neque.") - - It("encrypts and decrypts a message", func() { - aead := getAEAD() - encrypted := aead.Seal(nil, msg, 0x1337, ad) - opened, err := aead.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - - It("fails to open a message if the associated data is not the same", func() { - aead := getAEAD() - encrypted := aead.Seal(nil, msg, 0x1337, ad) - _, err := aead.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) - Expect(err).To(MatchError("cipher: message authentication failed")) - }) - - It("fails to open a message if the packet number is not the same", func() { - aead := getAEAD() - encrypted := aead.Seal(nil, msg, 0x1337, ad) - _, err := aead.Open(nil, encrypted, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError("cipher: message authentication failed")) - }) - }) - - Context("header encryption", func() { + Context("header protection", func() { It("encrypts and decrypts the header", func() { - aead := getAEAD() + server, client := getPeers() var lastFiveBitsDifferent int for i := 0; i < 100; i++ { sample := make([]byte, 16) rand.Read(sample) header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - aead.EncryptHeader(sample, &header[0], header[9:13]) + client.EncryptHeader(sample, &header[0], header[9:13]) if header[0]&0x1f != 0xb5&0x1f { lastFiveBitsDifferent++ } Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - aead.DecryptHeader(sample, &header[0], header[9:13]) + server.DecryptHeader(sample, &header[0], header[9:13]) Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) } Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) }) }) + + Context("message encryption", func() { + var msg, ad []byte + var server, client *updatableAEAD + + BeforeEach(func() { + server, client = getPeers() + msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad = []byte("Donec in velit neque.") + }) + + It("encrypts and decrypts a message", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + opened, err := client.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + + It("fails to open a message if the associated data is not the same", func() { + encrypted := client.Seal(nil, msg, 0x1337, ad) + _, err := server.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) + Expect(err).To(MatchError("cipher: message authentication failed")) + }) + + It("fails to open a message if the packet number is not the same", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError("cipher: message authentication failed")) + }) + + Context("key updates", func() { + It("updates keys", func() { + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted0 := server.Seal(nil, msg, 0x1337, ad) + server.rollKeys() + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted1 := server.Seal(nil, msg, 0x1337, ad) + Expect(encrypted0).ToNot(Equal(encrypted1)) + // expect opening to fail. The client didn't roll keys yet + _, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError("cipher: message authentication failed")) + client.rollKeys() + decrypted, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + }) + + It("updates the keys when receiving a packet with the next key phase", func() { + encrypted0 := client.Seal(nil, msg, 0x42, ad) + decrypted, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x43, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + decrypted, err = server.Open(nil, encrypted1, 0x43, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("opens a reordered packet with the old keys after an update", func() { + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // now receive a packet with key phase 1 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + decrypted, err := server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + }) + }) })