diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index d79a35c4..65e0042f 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -100,6 +100,14 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, } func (a *updatableAEAD) rollKeys() { + if a.prevRcvAEAD != nil { + a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) + if a.tracer != nil { + a.tracer.DroppedKey(a.keyPhase - 1) + } + a.prevRcvAEADExpiry = time.Time{} + } + a.keyPhase++ a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber a.firstSentWithCurrentKey = protocol.InvalidPacketNumber diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index a16c9c01..96c65bf5 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -7,6 +7,8 @@ import ( "os" "time" + "github.com/golang/mock/gomock" + mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" @@ -40,13 +42,12 @@ var _ = Describe("Updatable AEAD", func() { Context(fmt.Sprintf("using %s", qtls.CipherSuiteName(cs.ID)), func() { var ( - client, server *updatableAEAD - clientTracer, serverTracer *mocklogging.MockConnectionTracer - rttStats *utils.RTTStats + client, server *updatableAEAD + serverTracer *mocklogging.MockConnectionTracer + rttStats *utils.RTTStats ) BeforeEach(func() { - clientTracer = mocklogging.NewMockConnectionTracer(mockCtrl) serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) trafficSecret1 := make([]byte, 16) trafficSecret2 := make([]byte, 16) @@ -54,7 +55,7 @@ var _ = Describe("Updatable AEAD", func() { rand.Read(trafficSecret2) rttStats = utils.NewRTTStats() - client = newUpdatableAEAD(rttStats, clientTracer, utils.DefaultLogger) + client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger) server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger) client.SetReadKey(cs, trafficSecret2) client.SetWriteKey(cs, trafficSecret1) @@ -378,6 +379,78 @@ var _ = Describe("Updatable AEAD", func() { _, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad) Expect(err).ToNot(HaveOccurred()) }) + + It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + const nextPN = keyUpdateInterval + 1 + // Send and receive an acknowledgement for a packet in key phase 1. + // We are now running a timer to drop the keys with 3 PTO. + server.Seal(nil, msg, nextPN, ad) + client.rollKeys() + dataKeyPhaseOne := client.Seal(nil, msg, 2, ad) + now := time.Now() + _, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.SetLargestAcked(nextPN)) + // Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over. + // This mean that we need to drop the keys for key phase 0 immediately. + client.rollKeys() + dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad) + gomock.InOrder( + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true), + ) + _, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + + It("drops keys early when we initiate another key update within the 3 PTO period", func() { + // send so many packets that we initiate the first key update + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // send so many packets that we initiate the next key update + for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, pn, ad) + } + client.rollKeys() + b = client.Seal(nil, []byte("foobar"), 2, []byte("ad")) + now := time.Now() + _, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed()) + gomock.InOrder( + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false), + ) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + // We haven't received an ACK for a packet sent in key phase 2 yet. + // Make sure we canceled the timer to drop the previous key phase. + b = client.Seal(nil, []byte("foobar"), 3, []byte("ad")) + _, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + }) }) Context("reading the key update env", func() {