diff --git a/internal/protocol/server_parameters.go b/internal/protocol/server_parameters.go index 96ebbdcd7..b2045b426 100644 --- a/internal/protocol/server_parameters.go +++ b/internal/protocol/server_parameters.go @@ -133,6 +133,12 @@ const NumCachedCertificates = 128 // 2. it reduces the head-of-line blocking, when a packet is lost const MinStreamFrameSize ByteCount = 128 +// MaxAckFrameSize is the maximum size for an (IETF QUIC) ACK frame that we write +// Due to the varint encoding, ACK frames can grow (almost) indefinitely large. +// The MaxAckFrameSize should be large enough to encode many ACK range, +// but must ensure that a maximum size ACK frame fits into one packet. +const MaxAckFrameSize ByteCount = 1000 + // MinPacingDelay is the minimum duration that is used for packet pacing // If the packet packing frequency is higher, multiple packets might be sent at once. // Example: For a packet pacing delay of 20 microseconds, we would send 5 packets at once, wait for 100 microseconds, and so forth. diff --git a/internal/wire/ack_frame.go b/internal/wire/ack_frame.go index 9d3d7bb86..7b99ef40d 100644 --- a/internal/wire/ack_frame.go +++ b/internal/wire/ack_frame.go @@ -99,31 +99,22 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error return f.writeLegacy(b, version) } - largestAcked := f.AckRanges[0].Largest - lowestInFirstRange := f.AckRanges[0].Smallest - b.WriteByte(0x0d) - utils.WriteVarInt(b, uint64(largestAcked)) + utils.WriteVarInt(b, uint64(f.LargestAcked())) utils.WriteVarInt(b, encodeAckDelay(f.DelayTime)) - // TODO: limit the number of ACK ranges, such that the frame doesn't grow larger than an upper bound - utils.WriteVarInt(b, uint64(len(f.AckRanges)-1)) + numRanges := f.numEncodableAckRanges() + utils.WriteVarInt(b, uint64(numRanges-1)) // write the first range - utils.WriteVarInt(b, uint64(largestAcked-lowestInFirstRange)) + _, firstRange := f.encodeAckRange(0) + utils.WriteVarInt(b, firstRange) // write all the other range - if f.HasMissingRanges() { - var lowest protocol.PacketNumber - for i, ackRange := range f.AckRanges { - if i == 0 { - lowest = lowestInFirstRange - continue - } - utils.WriteVarInt(b, uint64(lowest-ackRange.Largest-2)) - utils.WriteVarInt(b, uint64(ackRange.Largest-ackRange.Smallest)) - lowest = ackRange.Smallest - } + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + utils.WriteVarInt(b, gap) + utils.WriteVarInt(b, len) } return nil } @@ -135,28 +126,48 @@ func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount { } largestAcked := f.AckRanges[0].Largest + numRanges := f.numEncodableAckRanges() + length := 1 + utils.VarIntLen(uint64(largestAcked)) + utils.VarIntLen(encodeAckDelay(f.DelayTime)) - length += utils.VarIntLen(uint64(len(f.AckRanges) - 1)) + length += utils.VarIntLen(uint64(numRanges - 1)) lowestInFirstRange := f.AckRanges[0].Smallest length += utils.VarIntLen(uint64(largestAcked - lowestInFirstRange)) - if !f.HasMissingRanges() { - return length - } - var lowest protocol.PacketNumber - for i, ackRange := range f.AckRanges { - if i == 0 { - lowest = ackRange.Smallest - continue - } - length += utils.VarIntLen(uint64(lowest - ackRange.Largest - 2)) - length += utils.VarIntLen(uint64(ackRange.Largest - ackRange.Smallest)) - lowest = ackRange.Smallest + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + length += utils.VarIntLen(gap) + length += utils.VarIntLen(len) } return length } +// gets the number of ACK ranges that can be encoded +// such that the resulting frame is smaller than the maximum ACK frame size +func (f *AckFrame) numEncodableAckRanges() int { + length := 1 + utils.VarIntLen(uint64(f.LargestAcked())) + utils.VarIntLen(encodeAckDelay(f.DelayTime)) + length += 2 // assume that the number of ranges will consume 2 bytes + for i := 1; i < len(f.AckRanges); i++ { + gap, len := f.encodeAckRange(i) + rangeLen := utils.VarIntLen(gap) + utils.VarIntLen(len) + if length+rangeLen > protocol.MaxAckFrameSize { + // Writing range i would exceed the MaxAckFrameSize. + // So encode one range less than that. + return i - 1 + } + length += rangeLen + } + return len(f.AckRanges) +} + +func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) { + if i == 0 { + return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest) + } + return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2), + uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest) +} + // HasMissingRanges returns if this frame reports any missing packets func (f *AckFrame) HasMissingRanges() bool { return len(f.AckRanges) > 1 diff --git a/internal/wire/ack_frame_test.go b/internal/wire/ack_frame_test.go index 6e8204b2c..5c3821173 100644 --- a/internal/wire/ack_frame_test.go +++ b/internal/wire/ack_frame_test.go @@ -220,6 +220,29 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { Expect(frame.HasMissingRanges()).To(BeTrue()) Expect(b.Len()).To(BeZero()) }) + + It("limits the maximum size of the ACK frame", func() { + buf := &bytes.Buffer{} + const numRanges = 1000 + ackRanges := make([]AckRange, numRanges) + for i := protocol.PacketNumber(1); i <= numRanges; i++ { + ackRanges[numRanges-i] = AckRange{Smallest: 2 * i, Largest: 2 * i} + } + f := &AckFrame{AckRanges: ackRanges} + Expect(f.validateAckRanges()).To(BeTrue()) + err := f.Write(buf, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Length(versionIETFFrames)).To(BeEquivalentTo(buf.Len())) + // make sure the ACK frame is *a little bit* smaller than the MaxAckFrameSize + Expect(buf.Len()).To(BeNumerically(">", protocol.MaxAckFrameSize-5)) + Expect(buf.Len()).To(BeNumerically("<=", protocol.MaxAckFrameSize)) + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.HasMissingRanges()).To(BeTrue()) + Expect(b.Len()).To(BeZero()) + Expect(len(frame.AckRanges)).To(BeNumerically("<", numRanges)) // make sure we dropped some ranges + }) }) Context("ACK range validator", func() {