diff --git a/internal/ackhandler/received_packet_tracker.go b/internal/ackhandler/received_packet_tracker.go index 7fd071e68..8d15d7c18 100644 --- a/internal/ackhandler/received_packet_tracker.go +++ b/internal/ackhandler/received_packet_tracker.go @@ -58,11 +58,9 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro if ackEliciting { h.hasNewAck = true + h.maybeQueueACK(pn, rcvTime, ecn, isMissing) } - if ackEliciting { - h.maybeQueueACK(pn, rcvTime, isMissing) - } - //nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECNCE. + //nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE. switch ecn { case protocol.ECT0: h.ect0++ @@ -104,7 +102,7 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool { } // maybeQueueACK queues an ACK, if necessary. -func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) { +func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, ecn protocol.ECN, wasMissing bool) { // always acknowledge the first packet if h.lastAck == nil { if !h.ackQueued { @@ -143,12 +141,18 @@ func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime h.ackAlarm = rcvTime.Add(h.maxAckDelay) } - // Queue an ACK if there are new missing packets to report. + // queue an ACK if there are new missing packets to report if h.hasNewMissingPackets() { h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.") h.ackQueued = true } + // queue an ACK if the packet was ECN-CE marked + if ecn == protocol.ECNCE { + h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.") + h.ackQueued = true + } + if h.ackQueued { // cancel the ack alarm h.ackAlarm = time.Time{} diff --git a/internal/ackhandler/received_packet_tracker_test.go b/internal/ackhandler/received_packet_tracker_test.go index 8c76f207b..b0d7db3c3 100644 --- a/internal/ackhandler/received_packet_tracker_test.go +++ b/internal/ackhandler/received_packet_tracker_test.go @@ -50,6 +50,7 @@ var _ = Describe("Received Packet Tracker", func() { Context("ACKs", func() { Context("queueing ACKs", func() { + // receives and gets ACKs for packet numbers 1 to 10 (including) receiveAndAck10Packets := func() { for i := 1; i <= 10; i++ { Expect(tracker.ReceivedPacket(protocol.PacketNumber(i), protocol.ECNNon, time.Time{}, true)).To(Succeed()) @@ -126,6 +127,16 @@ var _ = Describe("Received Packet Tracker", func() { Expect(tracker.GetAlarmTimeout()).To(Equal(rcvTime.Add(protocol.MaxAckDelay))) }) + It("queues an ACK if the packet was ECN-CE marked", func() { + receiveAndAck10Packets() + Expect(tracker.ReceivedPacket(11, protocol.ECNCE, time.Now(), true)).To(Succeed()) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(HaveLen(1)) + Expect(ack.AckRanges[0].Largest).To(Equal(protocol.PacketNumber(11))) + Expect(ack.ECNCE).To(BeEquivalentTo(1)) + }) + It("queues an ACK if it was reported missing before", func() { receiveAndAck10Packets() Expect(tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true)).To(Succeed())