From 854940cecc16431a67279b7bbc99c5705fcbbf27 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 9 Sep 2020 11:38:44 +0700 Subject: [PATCH] don't drop keys for key phase N before receiving a N+1-protected packet --- internal/handshake/updatable_aead.go | 30 +++++++++------ internal/handshake/updatable_aead_test.go | 46 +++++++++++++++++++---- 2 files changed, 57 insertions(+), 19 deletions(-) diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index e21430b3..36caf398 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -95,14 +95,13 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, } } -func (a *updatableAEAD) rollKeys(now time.Time) { +func (a *updatableAEAD) rollKeys() { a.keyPhase++ a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber a.firstSentWithCurrentKey = protocol.InvalidPacketNumber a.numRcvdWithCurrentKey = 0 a.numSentWithCurrentKey = 0 a.prevRcvAEAD = a.rcvAEAD - a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO(true)) a.rcvAEAD = a.nextRcvAEAD a.sendAEAD = a.nextSendAEAD @@ -112,6 +111,10 @@ func (a *updatableAEAD) rollKeys(now time.Time) { a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret) } +func (a *updatableAEAD) startKeyDropTimer(now time.Time) { + a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO(true)) +} + func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) } @@ -147,7 +150,7 @@ func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret } func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { - if a.prevRcvAEAD != nil && rcvTime.After(a.prevRcvAEADExpiry) { + if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { a.prevRcvAEAD = nil a.prevRcvAEADExpiry = time.Time{} if a.tracer != nil { @@ -187,7 +190,10 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { return nil, qerr.NewError(qerr.ProtocolViolation, "keys updated too quickly") } - a.rollKeys(rcvTime) + a.rollKeys() + // The peer initiated this key update. It's safe to drop the keys for the previous generation now. + // Start a timer to drop the previous key generation. + a.startKeyDropTimer(rcvTime) a.logger.Debugf("Peer updated keys to %s", a.keyPhase) if a.tracer != nil { a.tracer.UpdatedKey(a.keyPhase, true) @@ -199,12 +205,14 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac // It uses the nonce provided here and XOR it with the IV. dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) if err != nil { - err = ErrDecryptionFailed - } else { - a.numRcvdWithCurrentKey++ - if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { - a.firstRcvdWithCurrentKey = pn - } + return dec, ErrDecryptionFailed + } + a.numRcvdWithCurrentKey++ + if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { + // We initiated the key updated, and now we received the first packet protected with the new key phase. + // Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys. + a.startKeyDropTimer(rcvTime) + a.firstRcvdWithCurrentKey = pn } return dec, err } @@ -250,7 +258,7 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { if a.shouldInitiateKeyUpdate() { - a.rollKeys(time.Now()) + a.rollKeys() a.logger.Debugf("Initiating key update to key phase %s", a.keyPhase) if a.tracer != nil { a.tracer.UpdatedKey(a.keyPhase, false) diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 1e8c92e2..9cfbd505 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -122,14 +122,14 @@ var _ = Describe("Updatable AEAD", func() { now := time.Now() Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) encrypted0 := server.Seal(nil, msg, 0x1337, ad) - server.rollKeys(now) + server.rollKeys() Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) encrypted1 := server.Seal(nil, msg, 0x1337, ad) Expect(encrypted0).ToNot(Equal(encrypted1)) // expect opening to fail. The client didn't roll keys yet _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) Expect(err).To(MatchError(ErrDecryptionFailed)) - client.rollKeys(now) + client.rollKeys() decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) Expect(err).ToNot(HaveOccurred()) Expect(decrypted).To(Equal(msg)) @@ -146,7 +146,7 @@ var _ = Describe("Updatable AEAD", func() { Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) _ = server.Seal(nil, msg, 0x1, ad) // now received a message at key phase one - client.rollKeys(now) + client.rollKeys() encrypted1 := client.Seal(nil, msg, 0x43, ad) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) @@ -165,7 +165,7 @@ var _ = Describe("Updatable AEAD", func() { // send one packet at key phase zero _ = server.Seal(nil, msg, 0x1, ad) // now receive a packet with key phase 1 - client.rollKeys(now) + client.rollKeys() encrypted1 := client.Seal(nil, msg, 0x44, ad) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) @@ -191,7 +191,7 @@ var _ = Describe("Updatable AEAD", func() { // send one packet at key phase zero _ = server.Seal(nil, msg, 0x1, ad) // now receive a packet with key phase 1 - client.rollKeys(now) + client.rollKeys() encrypted1 := client.Seal(nil, msg, 0x44, ad) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) @@ -205,14 +205,14 @@ var _ = Describe("Updatable AEAD", func() { }) It("errors when the peer starts with key phase 1", func() { - client.rollKeys(time.Now()) + client.rollKeys() encrypted := client.Seal(nil, msg, 0x1337, ad) _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial key phase")) }) It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { - client.rollKeys(time.Now()) + client.rollKeys() encrypted := client.Seal(nil, msg, 0x1337, ad) encrypted = encrypted[:len(encrypted)-1] _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) @@ -225,7 +225,7 @@ var _ = Describe("Updatable AEAD", func() { _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad) Expect(err).ToNot(HaveOccurred()) // now receive a packet at key phase one, before having sent any packets - client.rollKeys(time.Now()) + client.rollKeys() encrypted1 := client.Seal(nil, msg, 0x42, ad) _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad) Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly")) @@ -268,6 +268,36 @@ var _ = Describe("Updatable AEAD", func() { serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) + + It("drops keys 3 PTOs after a key update", func() { + now := time.Now() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + server.SetLargestAcked(pn) + } + // Now we've initiated the first key update. + // Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there + threePTO := 3 * rttStats.PTO(false) + dataKeyPhaseZero := client.Seal(nil, msg, 1, ad) + _, err := server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // Now receive a packet with key phase 1. + // This should start the timer to drop the keys after 3 PTOs. + client.rollKeys() + dataKeyPhaseOne := client.Seal(nil, msg, 10, ad) + t := now.Add(threePTO).Add(time.Second) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + // Make sure the keys are still here. + _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrKeysDropped)) + }) }) Context("reading the key update env", func() {