diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index bf45394a..7e777eb9 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -242,10 +242,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { if a.shouldInitiateKeyUpdate() { + a.rollKeys(time.Now()) + a.logger.Debugf("Initiating key update to key phase %s", a.keyPhase) if a.tracer != nil { a.tracer.UpdatedKey(a.keyPhase, false) } - a.rollKeys(time.Now()) } return a.keyPhase.Bit() } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 7ec8cb61..3ed89caf 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -7,6 +7,7 @@ import ( "os" "time" + mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/utils" @@ -37,24 +38,31 @@ var _ = Describe("Updatable AEAD", func() { cs := cipherSuites[i] Context(fmt.Sprintf("using %s", qtls.CipherSuiteName(cs.ID)), func() { - getPeers := func(rttStats *utils.RTTStats) (client, server *updatableAEAD) { + var ( + client, server *updatableAEAD + clientTracer, serverTracer *mocklogging.MockConnectionTracer + rttStats *utils.RTTStats + ) + + BeforeEach(func() { + clientTracer = mocklogging.NewMockConnectionTracer(mockCtrl) + serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) trafficSecret1 := make([]byte, 16) trafficSecret2 := make([]byte, 16) rand.Read(trafficSecret1) rand.Read(trafficSecret2) - client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger) - server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger) + rttStats = utils.NewRTTStats() + client = newUpdatableAEAD(rttStats, clientTracer, utils.DefaultLogger) + server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger) client.SetReadKey(cs, trafficSecret2) client.SetWriteKey(cs, trafficSecret1) server.SetReadKey(cs, trafficSecret1) server.SetWriteKey(cs, trafficSecret2) - return - } + }) Context("header protection", func() { It("encrypts and decrypts the header", func() { - server, client := getPeers(&utils.RTTStats{}) var lastFiveBitsDifferent int for i := 0; i < 100; i++ { sample := make([]byte, 16) @@ -76,12 +84,8 @@ var _ = Describe("Updatable AEAD", func() { Context("message encryption", func() { var msg, ad []byte - var server, client *updatableAEAD - var rttStats *utils.RTTStats BeforeEach(func() { - rttStats = &utils.RTTStats{} - server, client = getPeers(rttStats) msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") ad = []byte("Donec in velit neque.") }) @@ -144,6 +148,7 @@ var _ = Describe("Updatable AEAD", func() { // now received a message at key phase one client.rollKeys(now) encrypted1 := client.Seal(nil, msg, 0x43, ad) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) Expect(err).ToNot(HaveOccurred()) Expect(decrypted).To(Equal(msg)) @@ -163,6 +168,7 @@ var _ = Describe("Updatable AEAD", func() { client.rollKeys(now) encrypted1 := client.Seal(nil, msg, 0x44, ad) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) Expect(err).ToNot(HaveOccurred()) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) @@ -188,6 +194,7 @@ var _ = Describe("Updatable AEAD", func() { client.rollKeys(now) encrypted1 := client.Seal(nil, msg, 0x44, ad) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) Expect(err).ToNot(HaveOccurred()) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) @@ -233,6 +240,7 @@ var _ = Describe("Updatable AEAD", func() { // no update allowed before receiving an acknowledgement for the current key phase Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.SetLargestAcked(0) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) @@ -248,6 +256,7 @@ var _ = Describe("Updatable AEAD", func() { Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, 1, ad) server.SetLargestAcked(1) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) })