diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 63979acf..94877953 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -47,6 +47,7 @@ type updatableAEAD struct { keyPhase protocol.KeyPhase largestAcked protocol.PacketNumber + firstPacketNumber protocol.PacketNumber keyUpdateInterval uint64 // Time when the keys should be dropped. Keys are dropped on the next call to Open(). @@ -83,6 +84,7 @@ var _ ShortHeaderSealer = &updatableAEAD{} func newUpdatableAEAD(rttStats *congestion.RTTStats, logger utils.Logger) *updatableAEAD { return &updatableAEAD{ + firstPacketNumber: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, firstSentWithCurrentKey: protocol.InvalidPacketNumber, @@ -199,6 +201,9 @@ func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byt if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { a.firstSentWithCurrentKey = pn } + if a.firstPacketNumber == protocol.InvalidPacketNumber { + a.firstPacketNumber = pn + } a.numSentWithCurrentKey++ binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) // The AEAD we're using here will be the qtls.aeadAESGCM13. @@ -249,3 +254,7 @@ func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes [ func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) } + +func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber { + return a.firstPacketNumber +} diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 54d071de..65a9b789 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -75,6 +75,13 @@ var _ = Describe("Updatable AEAD", func() { Expect(opened).To(Equal(msg)) }) + It("saves the first packet number", func() { + client.Seal(nil, msg, 0x1337, ad) + Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) + client.Seal(nil, msg, 0x1338, ad) + Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) + }) + It("fails to open a message if the associated data is not the same", func() { encrypted := client.Seal(nil, msg, 0x1337, ad) _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))