From 1ebd359b206151d8b54a108ffc8fda517a3f4a89 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 28 Mar 2023 22:50:21 +0900 Subject: [PATCH] handshake: remove unnecessary member variable from updatableAEAD --- internal/handshake/updatable_aead.go | 8 ++++---- internal/handshake/updatable_aead_test.go | 9 +++++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 89a9dcd62..6fa4f76f9 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -27,7 +27,6 @@ type updatableAEAD struct { firstPacketNumber protocol.PacketNumber handshakeConfirmed bool - keyUpdateInterval uint64 invalidPacketLimit uint64 invalidPacketCount uint64 @@ -74,7 +73,6 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, largestAcked: protocol.InvalidPacketNumber, firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, firstSentWithCurrentKey: protocol.InvalidPacketNumber, - keyUpdateInterval: KeyUpdateInterval, rttStats: rttStats, tracer: tracer, logger: logger, @@ -116,6 +114,7 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) } +// SetReadKey sets the read key. // For the client, this function is called before SetWriteKey. // For the server, this function is called after SetWriteKey. func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { @@ -129,6 +128,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [ a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) } +// SetWriteKey sets the write key. // For the client, this function is called after SetReadKey. // For the server, this function is called before SetWriteKey. func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { @@ -284,11 +284,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { if !a.updateAllowed() { return false } - if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { + if a.numRcvdWithCurrentKey >= KeyUpdateInterval { a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) return true } - if a.numSentWithCurrentKey >= a.keyUpdateInterval { + if a.numSentWithCurrentKey >= KeyUpdateInterval { a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) return true } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index b42f3dce6..f9ac53a91 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -283,13 +283,18 @@ var _ = Describe("Updatable AEAD", func() { Context("initiating key updates", func() { const keyUpdateInterval = 20 + var origKeyUpdateInterval uint64 BeforeEach(func() { - Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) - server.keyUpdateInterval = keyUpdateInterval + origKeyUpdateInterval = KeyUpdateInterval + KeyUpdateInterval = keyUpdateInterval server.SetHandshakeConfirmed() }) + AfterEach(func() { + KeyUpdateInterval = origKeyUpdateInterval + }) + It("initiates a key update after sealing the maximum number of packets, for the first update", func() { for i := 0; i < keyUpdateInterval; i++ { pn := protocol.PacketNumber(i)