From b03585160fb27abf0d03e51484430cd036e99177 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 28 Mar 2023 23:31:07 +0900 Subject: [PATCH] handshake: initiate the first key update after 100 packets --- integrationtests/self/key_update_test.go | 17 ++++++---------- internal/handshake/updatable_aead.go | 10 ++++++++++ internal/handshake/updatable_aead_test.go | 24 +++++++++++++---------- interop/client/main.go | 2 +- 4 files changed, 31 insertions(+), 22 deletions(-) diff --git a/integrationtests/self/key_update_test.go b/integrationtests/self/key_update_test.go index 8ff44e06b..21b8aeefb 100644 --- a/integrationtests/self/key_update_test.go +++ b/integrationtests/self/key_update_test.go @@ -51,12 +51,14 @@ func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader } var _ = Describe("Key Update tests", func() { - var server quic.Listener + It("downloads a large file", func() { + origKeyUpdateInterval := handshake.KeyUpdateInterval + defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }() + handshake.KeyUpdateInterval = 1 // update keys as frequently as possible - runServer := func() { - var err error - server, err = quic.ListenAddr("localhost:0", getTLSConfig(), nil) + server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) Expect(err).ToNot(HaveOccurred()) + defer server.Close() go func() { defer GinkgoRecover() @@ -68,14 +70,7 @@ var _ = Describe("Key Update tests", func() { _, err = str.Write(PRDataLong) Expect(err).ToNot(HaveOccurred()) }() - } - It("downloads a large file", func() { - origKeyUpdateInterval := handshake.KeyUpdateInterval - defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }() - handshake.KeyUpdateInterval = 1 // update keys as frequently as possible - - runServer() conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 6fa4f76f9..ac01acdb1 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -19,6 +19,10 @@ import ( // It's a package-level variable to allow modifying it for testing purposes. var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval +// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update. +// It's a package-level variable to allow modifying it for testing purposes. +var FirstKeyUpdateInterval uint64 = 100 + type updatableAEAD struct { suite *qtls.CipherSuiteTLS13 @@ -284,6 +288,12 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { if !a.updateAllowed() { return false } + // Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism. + if a.keyPhase == 0 { + if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval { + return true + } + } if a.numRcvdWithCurrentKey >= KeyUpdateInterval { a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) return true diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index f9ac53a91..3ab7ddb40 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -282,21 +282,25 @@ var _ = Describe("Updatable AEAD", func() { }) Context("initiating key updates", func() { + const firstKeyUpdateInterval = 5 const keyUpdateInterval = 20 - var origKeyUpdateInterval uint64 + var origKeyUpdateInterval, origFirstKeyUpdateInterval uint64 BeforeEach(func() { origKeyUpdateInterval = KeyUpdateInterval + origFirstKeyUpdateInterval = FirstKeyUpdateInterval KeyUpdateInterval = keyUpdateInterval + FirstKeyUpdateInterval = firstKeyUpdateInterval server.SetHandshakeConfirmed() }) AfterEach(func() { KeyUpdateInterval = origKeyUpdateInterval + FirstKeyUpdateInterval = origFirstKeyUpdateInterval }) It("initiates a key update after sealing the maximum number of packets, for the first update", func() { - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -328,7 +332,7 @@ var _ = Describe("Updatable AEAD", func() { It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { // First make sure that we update our keys. - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -336,7 +340,7 @@ var _ = Describe("Updatable AEAD", func() { serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) // Now that our keys are updated, send a packet using the new keys. - const nextPN = keyUpdateInterval + 1 + const nextPN = firstKeyUpdateInterval + 1 server.Seal(nil, msg, nextPN, ad) // We haven't decrypted any packet in the new key phase yet. // This means that the ACK must have been sent in the old key phase. @@ -348,7 +352,7 @@ var _ = Describe("Updatable AEAD", func() { It("doesn't error before actually sending a packet in the new key phase", func() { // First make sure that we update our keys. - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -366,7 +370,7 @@ var _ = Describe("Updatable AEAD", func() { }) It("initiates a key update after opening the maximum number of packets, for the first update", func() { - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) encrypted := client.Seal(nil, msg, pn, ad) @@ -399,7 +403,7 @@ var _ = Describe("Updatable AEAD", func() { It("drops keys 3 PTOs after a key update", func() { now := time.Now() - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -435,7 +439,7 @@ var _ = Describe("Updatable AEAD", func() { data1 := client.Seal(nil, msg, 1, ad) _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) Expect(err).ToNot(HaveOccurred()) - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -451,7 +455,7 @@ var _ = Describe("Updatable AEAD", func() { }) It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -488,7 +492,7 @@ var _ = Describe("Updatable AEAD", func() { It("drops keys early when we initiate another key update within the 3 PTO period", func() { server.SetHandshakeConfirmed() // send so many packets that we initiate the first key update - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) diff --git a/interop/client/main.go b/interop/client/main.go index 39c7edb05..02cac5d03 100644 --- a/interop/client/main.go +++ b/interop/client/main.go @@ -88,7 +88,7 @@ func runTestcase(testcase string) error { switch testcase { case "handshake", "transfer", "retry": case "keyupdate": - handshake.KeyUpdateInterval = 100 + handshake.FirstKeyUpdateInterval = 100 case "chacha20": tlsConf.CipherSuites = []uint16{tls.TLS_CHACHA20_POLY1305_SHA256} case "multiconnect":