handshake: initiate the first key update after 100 packets

This commit is contained in:
Marten Seemann
2023-03-28 23:31:07 +09:00
parent 1ebd359b20
commit b03585160f
4 changed files with 31 additions and 22 deletions

View File

@@ -19,6 +19,10 @@ import (
// It's a package-level variable to allow modifying it for testing purposes.
var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
// It's a package-level variable to allow modifying it for testing purposes.
var FirstKeyUpdateInterval uint64 = 100
type updatableAEAD struct {
suite *qtls.CipherSuiteTLS13
@@ -284,6 +288,12 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
if !a.updateAllowed() {
return false
}
// Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism.
if a.keyPhase == 0 {
if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval {
return true
}
}
if a.numRcvdWithCurrentKey >= KeyUpdateInterval {
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
return true

View File

@@ -282,21 +282,25 @@ var _ = Describe("Updatable AEAD", func() {
})
Context("initiating key updates", func() {
const firstKeyUpdateInterval = 5
const keyUpdateInterval = 20
var origKeyUpdateInterval uint64
var origKeyUpdateInterval, origFirstKeyUpdateInterval uint64
BeforeEach(func() {
origKeyUpdateInterval = KeyUpdateInterval
origFirstKeyUpdateInterval = FirstKeyUpdateInterval
KeyUpdateInterval = keyUpdateInterval
FirstKeyUpdateInterval = firstKeyUpdateInterval
server.SetHandshakeConfirmed()
})
AfterEach(func() {
KeyUpdateInterval = origKeyUpdateInterval
FirstKeyUpdateInterval = origFirstKeyUpdateInterval
})
It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
for i := 0; i < keyUpdateInterval; i++ {
for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)
@@ -328,7 +332,7 @@ var _ = Describe("Updatable AEAD", func() {
It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
// First make sure that we update our keys.
for i := 0; i < keyUpdateInterval; i++ {
for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)
@@ -336,7 +340,7 @@ var _ = Describe("Updatable AEAD", func() {
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// Now that our keys are updated, send a packet using the new keys.
const nextPN = keyUpdateInterval + 1
const nextPN = firstKeyUpdateInterval + 1
server.Seal(nil, msg, nextPN, ad)
// We haven't decrypted any packet in the new key phase yet.
// This means that the ACK must have been sent in the old key phase.
@@ -348,7 +352,7 @@ var _ = Describe("Updatable AEAD", func() {
It("doesn't error before actually sending a packet in the new key phase", func() {
// First make sure that we update our keys.
for i := 0; i < keyUpdateInterval; i++ {
for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)
@@ -366,7 +370,7 @@ var _ = Describe("Updatable AEAD", func() {
})
It("initiates a key update after opening the maximum number of packets, for the first update", func() {
for i := 0; i < keyUpdateInterval; i++ {
for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
encrypted := client.Seal(nil, msg, pn, ad)
@@ -399,7 +403,7 @@ var _ = Describe("Updatable AEAD", func() {
It("drops keys 3 PTOs after a key update", func() {
now := time.Now()
for i := 0; i < keyUpdateInterval; i++ {
for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)
@@ -435,7 +439,7 @@ var _ = Describe("Updatable AEAD", func() {
data1 := client.Seal(nil, msg, 1, ad)
_, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
for i := 0; i < keyUpdateInterval; i++ {
for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)
@@ -451,7 +455,7 @@ var _ = Describe("Updatable AEAD", func() {
})
It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() {
for i := 0; i < keyUpdateInterval; i++ {
for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)
@@ -488,7 +492,7 @@ var _ = Describe("Updatable AEAD", func() {
It("drops keys early when we initiate another key update within the 3 PTO period", func() {
server.SetHandshakeConfirmed()
// send so many packets that we initiate the first key update
for i := 0; i < keyUpdateInterval; i++ {
for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)