diff --git a/internal/ackhandler/ecn.go b/internal/ackhandler/ecn.go index 3e851c8e..68415ac6 100644 --- a/internal/ackhandler/ecn.go +++ b/internal/ackhandler/ecn.go @@ -135,7 +135,10 @@ func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets) } e.state = ecnStateFailed + return } + // Path validation also fails if some testing packets are lost, and all other testing packets where CE-marked + e.failIfMangled() } // HandleNewlyAcked handles the ECN counts on an ACK frame. @@ -229,12 +232,11 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // Detect mangling (a path remarking all ECN-marked testing packets as CE), // once all 10 testing packets have been sent out. - if e.state == ecnStateUnknown && e.numSentECT0+e.numSentECT1 == e.numAckedECNCE { - if e.tracer != nil && e.tracer.ECNStateUpdated != nil { - e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + if e.state == ecnStateUnknown { + e.failIfMangled() + if e.state == ecnStateFailed { + return false } - e.state = ecnStateFailed - return false } if e.state == ecnStateTesting || e.state == ecnStateUnknown { var ackedTestingPacket bool @@ -259,6 +261,18 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 return e.state == ecnStateCapable && newECNCE > 0 } +// failIfMangled fails ECN validation if all testing packets are lost or CE-marked. +func (e *ecnTracker) failIfMangled() { + numAckedECNCE := e.numAckedECNCE + int64(e.numLostTesting) + if e.numSentECT0+e.numSentECT1 > numAckedECNCE { + return + } + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + } + e.state = ecnStateFailed +} + func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN { if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber { return protocol.ECNNon diff --git a/internal/ackhandler/ecn_test.go b/internal/ackhandler/ecn_test.go index 78eef8c9..ab648d81 100644 --- a/internal/ackhandler/ecn_test.go +++ b/internal/ackhandler/ecn_test.go @@ -192,7 +192,7 @@ var _ = Describe("ECN tracker", func() { Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 15), 3, 0, 2)).To(BeFalse()) }) - It("detects ECN mangling", func() { + It("detects ECN mangling if all testing packets are marked CE", func() { sendAllTestingPackets() for i := 10; i < 20; i++ { Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) @@ -222,6 +222,39 @@ var _ = Describe("ECN tracker", func() { Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(9), 0, 0, 10)).To(BeFalse()) }) + It("detects ECN mangling, if some testing packets are marked CE, and then others are lost", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // ECN capability not confirmed yet, therefore CE marks are not regarded as congestion events + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3), 0, 0, 4)).To(BeFalse()) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(6, 7, 8, 9), 0, 0, 8)).To(BeFalse()) + // Lose one of the two unacknowledged packets. + ecnTracker.LostPacket(4) + // By losing the last unacknowledged testing packets, we should detect the mangling. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + ecnTracker.LostPacket(5) + }) + + It("detects ECN mangling, if some testing packets are lost, and then others are marked CE", func() { + sendAllTestingPackets() + for i := 10; i < 20; i++ { + Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon)) + ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) + } + // Lose a few packets. + ecnTracker.LostPacket(0) + ecnTracker.LostPacket(1) + ecnTracker.LostPacket(2) + // ECN capability not confirmed yet, therefore CE marks are not regarded as congestion events + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(3, 4, 5, 6, 7, 8), 0, 0, 6)).To(BeFalse()) + // By CE-marking the last unacknowledged testing packets, we should detect the mangling. + tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(9), 0, 0, 7)).To(BeFalse()) + }) + It("declares congestion", func() { sendAllTestingPackets() for i := 10; i < 20; i++ {