From a09c0453244227768bbb9bba96028eeabf6dcc10 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 13 Jun 2019 16:24:06 +0800 Subject: [PATCH] initiate a key update after sending / receiving 100000 packets --- Changelog.md | 2 +- internal/handshake/crypto_setup.go | 3 +- internal/handshake/updatable_aead.go | 48 +++++- internal/handshake/updatable_aead_test.go | 174 +++++++++++++--------- internal/protocol/params.go | 3 + 5 files changed, 156 insertions(+), 74 deletions(-) diff --git a/Changelog.md b/Changelog.md index fe92d7b3..cbf56a81 100644 --- a/Changelog.md +++ b/Changelog.md @@ -9,7 +9,7 @@ - Use a varint for error codes. - Add support for [quic-trace](https://github.com/google/quic-trace). - Add a context to `Listener.Accept`, `Session.Accept{Uni}Stream` and `Session.Open{Uni}StreamSync`. -- Implement receiving of TLS key updates. +- Implement TLS key updates. ## v0.11.0 (2019-04-05) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index b0859d8b..509c13c3 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -223,7 +223,8 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) error { return nil } -func (h *cryptoSetup) SetLargest1RTTAcked(_ protocol.PacketNumber) { +func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) { + h.aead.SetLargestAcked(pn) // drop initial keys // TODO: do this earlier if h.initialOpener != nil { diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 0967b373..c05e5dae 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -15,12 +15,16 @@ import ( type updatableAEAD struct { suite cipherSuite - keyPhase protocol.KeyPhase + keyPhase protocol.KeyPhase + largestAcked protocol.PacketNumber + keyUpdateInterval uint64 prevRcvAEAD cipher.AEAD firstRcvdWithCurrentKey protocol.PacketNumber firstSentWithCurrentKey protocol.PacketNumber + numRcvdWithCurrentKey uint64 + numSentWithCurrentKey uint64 rcvAEAD cipher.AEAD sendAEAD cipher.AEAD @@ -44,17 +48,20 @@ var _ ShortHeaderSealer = &updatableAEAD{} func newUpdatableAEAD(logger utils.Logger) *updatableAEAD { return &updatableAEAD{ + largestAcked: protocol.InvalidPacketNumber, firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, firstSentWithCurrentKey: protocol.InvalidPacketNumber, + keyUpdateInterval: protocol.KeyUpdateInterval, logger: logger, } } func (a *updatableAEAD) rollKeys() { a.keyPhase = a.keyPhase.Next() - a.logger.Debugf("Updating keys to the next key phase: %s", a.keyPhase) a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber a.firstSentWithCurrentKey = protocol.InvalidPacketNumber + a.numRcvdWithCurrentKey = 0 + a.numSentWithCurrentKey = 0 a.prevRcvAEAD = a.rcvAEAD a.rcvAEAD = a.nextRcvAEAD a.sendAEAD = a.nextSendAEAD @@ -126,6 +133,7 @@ func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp proto return nil, qerr.Error(qerr.ProtocolViolation, "keys updated too quickly") } a.rollKeys() + a.logger.Debugf("Peer updated keys to %s", a.keyPhase) a.firstRcvdWithCurrentKey = pn return dec, err } @@ -134,8 +142,11 @@ func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp proto dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) if err != nil { err = ErrDecryptionFailed - } else if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { - a.firstRcvdWithCurrentKey = pn + } else { + a.numRcvdWithCurrentKey++ + if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { + a.firstRcvdWithCurrentKey = pn + } } return dec, err } @@ -144,13 +155,42 @@ func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byt if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { a.firstSentWithCurrentKey = pn } + a.numSentWithCurrentKey++ binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) // The AEAD we're using here will be the qtls.aeadAESGCM13. // It uses the nonce provided here and XOR it with the IV. return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) } +func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) { + a.largestAcked = pn +} + +func (a *updatableAEAD) updateAllowed() bool { + return a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && + a.largestAcked != protocol.InvalidPacketNumber && + a.largestAcked >= a.firstSentWithCurrentKey +} + +func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { + if !a.updateAllowed() { + return false + } + if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { + a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %s", a.numRcvdWithCurrentKey, a.keyPhase.Next()) + return true + } + if a.numSentWithCurrentKey >= a.keyUpdateInterval { + a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %s", a.numSentWithCurrentKey, a.keyPhase.Next()) + return true + } + return false +} + func (a *updatableAEAD) KeyPhase() protocol.KeyPhase { + if a.shouldInitiateKeyUpdate() { + a.rollKeys() + } return a.keyPhase } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index b29b3d83..64de2234 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -95,79 +95,117 @@ var _ = Describe("Updatable AEAD", func() { }) Context("key updates", func() { - It("updates keys", func() { - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - encrypted0 := server.Seal(nil, msg, 0x1337, ad) - server.rollKeys() - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - encrypted1 := server.Seal(nil, msg, 0x1337, ad) - Expect(encrypted0).ToNot(Equal(encrypted1)) - // expect opening to fail. The client didn't roll keys yet - _, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - client.rollKeys() - decrypted, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) + Context("receiving key updates", func() { + It("updates keys", func() { + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted0 := server.Seal(nil, msg, 0x1337, ad) + server.rollKeys() + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted1 := server.Seal(nil, msg, 0x1337, ad) + Expect(encrypted0).ToNot(Equal(encrypted1)) + // expect opening to fail. The client didn't roll keys yet + _, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + client.rollKeys() + decrypted, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + }) + + It("updates the keys when receiving a packet with the next key phase", func() { + // receive the first packet at key phase zero + encrypted0 := client.Seal(nil, msg, 0x42, ad) + decrypted, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + // send one packet at key phase zero + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _ = server.Seal(nil, msg, 0x1, ad) + // now received a message at key phase one + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x43, ad) + decrypted, err = server.Open(nil, encrypted1, 0x43, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("opens a reordered packet with the old keys after an update", func() { + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + decrypted, err := server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("errors when the peer starts with key phase 1", func() { + client.rollKeys() + encrypted := client.Seal(nil, msg, 0x1337, ad) + _, err := server.Open(nil, encrypted, 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase")) + }) + + It("errors when the peer updates keys too frequently", func() { + // receive the first packet at key phase zero + encrypted0 := client.Seal(nil, msg, 0x42, ad) + _, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // now receive a packet at key phase one, before having sent any packets + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x42, ad) + _, err = server.Open(nil, encrypted1, 0x42, protocol.KeyPhaseOne, ad) + Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly")) + }) }) - It("updates the keys when receiving a packet with the next key phase", func() { - // receive the first packet at key phase zero - encrypted0 := client.Seal(nil, msg, 0x42, ad) - decrypted, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - // send one packet at key phase zero - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - _ = server.Seal(nil, msg, 0x1, ad) - // now received a message at key phase one - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x43, ad) - decrypted, err = server.Open(nil, encrypted1, 0x43, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) + Context("initiating key updates", func() { + const keyUpdateInterval = 20 - It("opens a reordered packet with the old keys after an update", func() { - encrypted01 := client.Seal(nil, msg, 0x42, ad) - encrypted02 := client.Seal(nil, msg, 0x43, ad) - // receive the first packet with key phase 0 - _, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // send one packet at key phase zero - _ = server.Seal(nil, msg, 0x1, ad) - // now receive a packet with key phase 1 - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x44, ad) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - _, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // now receive a reordered packet with key phase 0 - decrypted, err := server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) + BeforeEach(func() { + Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) + server.keyUpdateInterval = keyUpdateInterval + }) - It("errors when the peer starts with key phase 1", func() { - client.rollKeys() - encrypted := client.Seal(nil, msg, 0x1337, ad) - _, err := server.Open(nil, encrypted, 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase")) - }) + It("initiates a key update after sealing the maximum number of packets", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.SetLargestAcked(0) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) - It("errors when the peer updates keys too frequently", func() { - // receive the first packet at key phase zero - encrypted0 := client.Seal(nil, msg, 0x42, ad) - _, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // now receive a packet at key phase one, before having sent any packets - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x42, ad) - _, err = server.Open(nil, encrypted1, 0x42, protocol.KeyPhaseOne, ad) - Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly")) + It("initiates a key update after opening the maximum number of packets", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted := client.Seal(nil, msg, pn, ad) + _, err := server.Open(nil, encrypted, pn, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, 1, ad) + server.SetLargestAcked(1) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) }) }) }) diff --git a/internal/protocol/params.go b/internal/protocol/params.go index 73c13d07..8a6d3444 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -139,3 +139,6 @@ const MaxAckDelay = 25 * time.Millisecond // MaxAckDelayInclGranularity is the max_ack_delay including the timer granularity. // This is the value that should be advertised to the peer. const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity + +// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key udpate. +const KeyUpdateInterval = 100 * 1000