From af6c19617cd34a24b8b890bdb01b40d6e5b8fa39 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 28 Mar 2023 22:46:57 +0900 Subject: [PATCH 1/3] handshake: add benchmark tests for packet sealing, opening and rolling keys --- internal/handshake/updatable_aead_test.go | 61 ++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 35a82ecd..b42f3dce 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -4,15 +4,15 @@ import ( "crypto/rand" "crypto/tls" "fmt" + "testing" "time" - "github.com/golang/mock/gomock" - mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -525,3 +525,60 @@ var _ = Describe("Updatable AEAD", func() { }) } }) + +func getClientAndServer() (client, server *updatableAEAD) { + trafficSecret1 := make([]byte, 16) + trafficSecret2 := make([]byte, 16) + rand.Read(trafficSecret1) + rand.Read(trafficSecret2) + + cs := cipherSuites[0] + rttStats := utils.NewRTTStats() + client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1) + server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1) + client.SetReadKey(cs, trafficSecret2) + client.SetWriteKey(cs, trafficSecret1) + server.SetReadKey(cs, trafficSecret1) + server.SetWriteKey(cs, trafficSecret2) + return +} + +func BenchmarkPacketEncryption(b *testing.B) { + client, _ := getClientAndServer() + const l = 1200 + src := make([]byte, l) + rand.Read(src) + ad := make([]byte, 32) + rand.Read(ad) + + for i := 0; i < b.N; i++ { + src = client.Seal(src[:0], src[:l], protocol.PacketNumber(i), ad) + } +} + +func BenchmarkPacketDecryption(b *testing.B) { + client, server := getClientAndServer() + const l = 1200 + src := make([]byte, l) + dst := make([]byte, l) + rand.Read(src) + ad := make([]byte, 32) + rand.Read(ad) + src = client.Seal(src[:0], src[:l], 1337, ad) + + for i := 0; i < b.N; i++ { + if _, err := server.Open(dst[:0], src, time.Time{}, 1337, protocol.KeyPhaseZero, ad); err != nil { + b.Fatalf("opening failed: %v", err) + } + } +} + +func BenchmarkRollKeys(b *testing.B) { + client, _ := getClientAndServer() + for i := 0; i < b.N; i++ { + client.rollKeys() + } + if int(client.keyPhase) != b.N { + b.Fatal("didn't roll keys often enough") + } +} From 1ebd359b206151d8b54a108ffc8fda517a3f4a89 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 28 Mar 2023 22:50:21 +0900 Subject: [PATCH 2/3] handshake: remove unnecessary member variable from updatableAEAD --- internal/handshake/updatable_aead.go | 8 ++++---- internal/handshake/updatable_aead_test.go | 9 +++++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 89a9dcd6..6fa4f76f 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -27,7 +27,6 @@ type updatableAEAD struct { firstPacketNumber protocol.PacketNumber handshakeConfirmed bool - keyUpdateInterval uint64 invalidPacketLimit uint64 invalidPacketCount uint64 @@ -74,7 +73,6 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, largestAcked: protocol.InvalidPacketNumber, firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, firstSentWithCurrentKey: protocol.InvalidPacketNumber, - keyUpdateInterval: KeyUpdateInterval, rttStats: rttStats, tracer: tracer, logger: logger, @@ -116,6 +114,7 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) } +// SetReadKey sets the read key. // For the client, this function is called before SetWriteKey. // For the server, this function is called after SetWriteKey. func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { @@ -129,6 +128,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [ a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) } +// SetWriteKey sets the write key. // For the client, this function is called after SetReadKey. // For the server, this function is called before SetWriteKey. func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { @@ -284,11 +284,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { if !a.updateAllowed() { return false } - if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { + if a.numRcvdWithCurrentKey >= KeyUpdateInterval { a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) return true } - if a.numSentWithCurrentKey >= a.keyUpdateInterval { + if a.numSentWithCurrentKey >= KeyUpdateInterval { a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) return true } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index b42f3dce..f9ac53a9 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -283,13 +283,18 @@ var _ = Describe("Updatable AEAD", func() { Context("initiating key updates", func() { const keyUpdateInterval = 20 + var origKeyUpdateInterval uint64 BeforeEach(func() { - Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) - server.keyUpdateInterval = keyUpdateInterval + origKeyUpdateInterval = KeyUpdateInterval + KeyUpdateInterval = keyUpdateInterval server.SetHandshakeConfirmed() }) + AfterEach(func() { + KeyUpdateInterval = origKeyUpdateInterval + }) + It("initiates a key update after sealing the maximum number of packets, for the first update", func() { for i := 0; i < keyUpdateInterval; i++ { pn := protocol.PacketNumber(i) From b03585160fb27abf0d03e51484430cd036e99177 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 28 Mar 2023 23:31:07 +0900 Subject: [PATCH 3/3] handshake: initiate the first key update after 100 packets --- integrationtests/self/key_update_test.go | 17 ++++++---------- internal/handshake/updatable_aead.go | 10 ++++++++++ internal/handshake/updatable_aead_test.go | 24 +++++++++++++---------- interop/client/main.go | 2 +- 4 files changed, 31 insertions(+), 22 deletions(-) diff --git a/integrationtests/self/key_update_test.go b/integrationtests/self/key_update_test.go index 8ff44e06..21b8aeef 100644 --- a/integrationtests/self/key_update_test.go +++ b/integrationtests/self/key_update_test.go @@ -51,12 +51,14 @@ func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader } var _ = Describe("Key Update tests", func() { - var server quic.Listener + It("downloads a large file", func() { + origKeyUpdateInterval := handshake.KeyUpdateInterval + defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }() + handshake.KeyUpdateInterval = 1 // update keys as frequently as possible - runServer := func() { - var err error - server, err = quic.ListenAddr("localhost:0", getTLSConfig(), nil) + server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) Expect(err).ToNot(HaveOccurred()) + defer server.Close() go func() { defer GinkgoRecover() @@ -68,14 +70,7 @@ var _ = Describe("Key Update tests", func() { _, err = str.Write(PRDataLong) Expect(err).ToNot(HaveOccurred()) }() - } - It("downloads a large file", func() { - origKeyUpdateInterval := handshake.KeyUpdateInterval - defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }() - handshake.KeyUpdateInterval = 1 // update keys as frequently as possible - - runServer() conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 6fa4f76f..ac01acdb 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -19,6 +19,10 @@ import ( // It's a package-level variable to allow modifying it for testing purposes. var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval +// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update. +// It's a package-level variable to allow modifying it for testing purposes. +var FirstKeyUpdateInterval uint64 = 100 + type updatableAEAD struct { suite *qtls.CipherSuiteTLS13 @@ -284,6 +288,12 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { if !a.updateAllowed() { return false } + // Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism. + if a.keyPhase == 0 { + if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval { + return true + } + } if a.numRcvdWithCurrentKey >= KeyUpdateInterval { a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) return true diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index f9ac53a9..3ab7ddb4 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -282,21 +282,25 @@ var _ = Describe("Updatable AEAD", func() { }) Context("initiating key updates", func() { + const firstKeyUpdateInterval = 5 const keyUpdateInterval = 20 - var origKeyUpdateInterval uint64 + var origKeyUpdateInterval, origFirstKeyUpdateInterval uint64 BeforeEach(func() { origKeyUpdateInterval = KeyUpdateInterval + origFirstKeyUpdateInterval = FirstKeyUpdateInterval KeyUpdateInterval = keyUpdateInterval + FirstKeyUpdateInterval = firstKeyUpdateInterval server.SetHandshakeConfirmed() }) AfterEach(func() { KeyUpdateInterval = origKeyUpdateInterval + FirstKeyUpdateInterval = origFirstKeyUpdateInterval }) It("initiates a key update after sealing the maximum number of packets, for the first update", func() { - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -328,7 +332,7 @@ var _ = Describe("Updatable AEAD", func() { It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { // First make sure that we update our keys. - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -336,7 +340,7 @@ var _ = Describe("Updatable AEAD", func() { serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) // Now that our keys are updated, send a packet using the new keys. - const nextPN = keyUpdateInterval + 1 + const nextPN = firstKeyUpdateInterval + 1 server.Seal(nil, msg, nextPN, ad) // We haven't decrypted any packet in the new key phase yet. // This means that the ACK must have been sent in the old key phase. @@ -348,7 +352,7 @@ var _ = Describe("Updatable AEAD", func() { It("doesn't error before actually sending a packet in the new key phase", func() { // First make sure that we update our keys. - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -366,7 +370,7 @@ var _ = Describe("Updatable AEAD", func() { }) It("initiates a key update after opening the maximum number of packets, for the first update", func() { - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) encrypted := client.Seal(nil, msg, pn, ad) @@ -399,7 +403,7 @@ var _ = Describe("Updatable AEAD", func() { It("drops keys 3 PTOs after a key update", func() { now := time.Now() - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -435,7 +439,7 @@ var _ = Describe("Updatable AEAD", func() { data1 := client.Seal(nil, msg, 1, ad) _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) Expect(err).ToNot(HaveOccurred()) - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -451,7 +455,7 @@ var _ = Describe("Updatable AEAD", func() { }) It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) @@ -488,7 +492,7 @@ var _ = Describe("Updatable AEAD", func() { It("drops keys early when we initiate another key update within the 3 PTO period", func() { server.SetHandshakeConfirmed() // send so many packets that we initiate the first key update - for i := 0; i < keyUpdateInterval; i++ { + for i := 0; i < firstKeyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) diff --git a/interop/client/main.go b/interop/client/main.go index 39c7edb0..02cac5d0 100644 --- a/interop/client/main.go +++ b/interop/client/main.go @@ -88,7 +88,7 @@ func runTestcase(testcase string) error { switch testcase { case "handshake", "transfer", "retry": case "keyupdate": - handshake.KeyUpdateInterval = 100 + handshake.FirstKeyUpdateInterval = 100 case "chacha20": tlsConf.CipherSuites = []uint16{tls.TLS_CHACHA20_POLY1305_SHA256} case "multiconnect":