From d6ac6300a443a636804399a1cd690ebafddc603d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 2 Sep 2023 09:26:49 +0700 Subject: [PATCH] feed ECN feedback into the congestion controller --- internal/ackhandler/ecn.go | 9 ++ internal/ackhandler/mock_ecn_handler_test.go | 87 +++++++++++ internal/ackhandler/mockgen.go | 3 + internal/ackhandler/sent_packet_handler.go | 11 +- .../ackhandler/sent_packet_handler_test.go | 147 ++++++++++++++++++ 5 files changed, 251 insertions(+), 6 deletions(-) create mode 100644 internal/ackhandler/mock_ecn_handler_test.go diff --git a/internal/ackhandler/ecn.go b/internal/ackhandler/ecn.go index 0fd39098c..ff0dfe652 100644 --- a/internal/ackhandler/ecn.go +++ b/internal/ackhandler/ecn.go @@ -21,6 +21,13 @@ const ( // must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type const numECNTestingPackets = 10 +type ecnHandler interface { + SentPacket(protocol.PacketNumber, protocol.ECN) + Mode() protocol.ECN + HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) + LostPacket(protocol.PacketNumber) +} + // The ecnTracker performs ECN validation of a path. // Once failed, it doesn't do any re-validation of the path. // It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces. @@ -42,6 +49,8 @@ type ecnTracker struct { logger utils.Logger } +var _ ecnHandler = &ecnTracker{} + func newECNTracker(logger utils.Logger, tracer logging.ConnectionTracer) *ecnTracker { return &ecnTracker{ firstTestingPacket: protocol.InvalidPacketNumber, diff --git a/internal/ackhandler/mock_ecn_handler_test.go b/internal/ackhandler/mock_ecn_handler_test.go new file mode 100644 index 000000000..28f350c17 --- /dev/null +++ b/internal/ackhandler/mock_ecn_handler_test.go @@ -0,0 +1,87 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go/internal/ackhandler (interfaces: ECNHandler) + +// Package ackhandler is a generated GoMock package. +package ackhandler + +import ( + reflect "reflect" + + protocol "github.com/quic-go/quic-go/internal/protocol" + gomock "go.uber.org/mock/gomock" +) + +// MockECNHandler is a mock of ECNHandler interface. +type MockECNHandler struct { + ctrl *gomock.Controller + recorder *MockECNHandlerMockRecorder +} + +// MockECNHandlerMockRecorder is the mock recorder for MockECNHandler. +type MockECNHandlerMockRecorder struct { + mock *MockECNHandler +} + +// NewMockECNHandler creates a new mock instance. +func NewMockECNHandler(ctrl *gomock.Controller) *MockECNHandler { + mock := &MockECNHandler{ctrl: ctrl} + mock.recorder = &MockECNHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockECNHandler) EXPECT() *MockECNHandlerMockRecorder { + return m.recorder +} + +// HandleNewlyAcked mocks base method. +func (m *MockECNHandler) HandleNewlyAcked(arg0 []*packet, arg1, arg2, arg3 int64) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleNewlyAcked", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(bool) + return ret0 +} + +// HandleNewlyAcked indicates an expected call of HandleNewlyAcked. +func (mr *MockECNHandlerMockRecorder) HandleNewlyAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleNewlyAcked", reflect.TypeOf((*MockECNHandler)(nil).HandleNewlyAcked), arg0, arg1, arg2, arg3) +} + +// LostPacket mocks base method. +func (m *MockECNHandler) LostPacket(arg0 protocol.PacketNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LostPacket", arg0) +} + +// LostPacket indicates an expected call of LostPacket. +func (mr *MockECNHandlerMockRecorder) LostPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockECNHandler)(nil).LostPacket), arg0) +} + +// Mode mocks base method. +func (m *MockECNHandler) Mode() protocol.ECN { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Mode") + ret0, _ := ret[0].(protocol.ECN) + return ret0 +} + +// Mode indicates an expected call of Mode. +func (mr *MockECNHandlerMockRecorder) Mode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mode", reflect.TypeOf((*MockECNHandler)(nil).Mode)) +} + +// SentPacket mocks base method. +func (m *MockECNHandler) SentPacket(arg0 protocol.PacketNumber, arg1 protocol.ECN) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockECNHandlerMockRecorder) SentPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockECNHandler)(nil).SentPacket), arg0, arg1) +} diff --git a/internal/ackhandler/mockgen.go b/internal/ackhandler/mockgen.go index b9eb8a88e..dbf6ee2d1 100644 --- a/internal/ackhandler/mockgen.go +++ b/internal/ackhandler/mockgen.go @@ -4,3 +4,6 @@ package ackhandler //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker" type SentPacketTracker = sentPacketTracker + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_ecn_handler_test.go github.com/quic-go/quic-go/internal/ackhandler ECNHandler" +type ECNHandler = ecnHandler diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 90f804223..b498ce4b7 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -93,7 +93,7 @@ type sentPacketHandler struct { alarm time.Time enableECN bool - ecnTracker *ecnTracker + ecnTracker ecnHandler perspective protocol.Perspective @@ -346,17 +346,16 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } - var ecnCongestionDetected bool // Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked. if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked { - ecnCongestionDetected = h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE)) + congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE)) + if congested { + h.congestion.OnCongestionEvent(largestAcked, 0, priorInFlight) + } } pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked) - // TODO: inform the congestion controller - _ = ecnCongestionDetected - if err := h.detectLostPackets(rcvTime, encLevel); err != nil { return false, err } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index acadcf373..63f1b8341 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -1429,4 +1429,151 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.rttStats.SmoothedRTT()).To(BeZero()) }) }) + + Context("ECN handling", func() { + var ecnHandler *MockECNHandler + var cong *mocks.MockSendAlgorithmWithDebugInfos + + JustBeforeEach(func() { + cong = mocks.NewMockSendAlgorithmWithDebugInfos(mockCtrl) + cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + cong.EXPECT().OnPacketAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + cong.EXPECT().MaybeExitSlowStart().AnyTimes() + ecnHandler = NewMockECNHandler(mockCtrl) + lostPackets = nil + rttStats := utils.NewRTTStats() + rttStats.UpdateRTT(time.Hour, 0, time.Now()) + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, false, perspective, nil, utils.DefaultLogger) + handler.ecnTracker = ecnHandler + handler.congestion = cong + }) + + It("informs about sent packets", func() { + // Check that only 1-RTT packets are reported + handler.SentPacket(time.Now(), 100, -1, nil, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false) + handler.SentPacket(time.Now(), 101, -1, nil, nil, protocol.EncryptionHandshake, protocol.ECT0, 1200, false) + handler.SentPacket(time.Now(), 102, -1, nil, nil, protocol.Encryption0RTT, protocol.ECNCE, 1200, false) + + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(103), protocol.ECT1) + handler.SentPacket(time.Now(), 103, -1, nil, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + }) + + It("informs about sent packets", func() { + // Check that only 1-RTT packets are reported + handler.SentPacket(time.Now(), 100, -1, nil, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false) + handler.SentPacket(time.Now(), 101, -1, nil, nil, protocol.EncryptionHandshake, protocol.ECT0, 1200, false) + handler.SentPacket(time.Now(), 102, -1, nil, nil, protocol.Encryption0RTT, protocol.ECNCE, 1200, false) + + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(103), protocol.ECT1) + handler.SentPacket(time.Now(), 103, -1, nil, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + }) + + It("informs about lost packets", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + cong.EXPECT().OnCongestionEvent(gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(10)) + ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(11)) + ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(12)) + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 16, Smallest: 13}}}, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("processes ACKs", func() { + // Check that we only care about 1-RTT packets. + handler.SentPacket(time.Now(), 100, -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 100}}}, protocol.EncryptionInitial, time.Now()) + Expect(err).ToNot(HaveOccurred()) + + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool { + Expect(packets).To(HaveLen(5)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(10))) + Expect(packets[1].PacketNumber).To(Equal(protocol.PacketNumber(11))) + Expect(packets[2].PacketNumber).To(Equal(protocol.PacketNumber(12))) + Expect(packets[3].PacketNumber).To(Equal(protocol.PacketNumber(14))) + Expect(packets[4].PacketNumber).To(Equal(protocol.PacketNumber(15))) + return false + }) + _, err = handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{ + {Largest: 15, Smallest: 14}, + {Largest: 12, Smallest: 10}, + }, + ECT0: 1, + ECT1: 2, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("ignores reordered ACKs", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool { + Expect(packets).To(HaveLen(2)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(11))) + Expect(packets[1].PacketNumber).To(Equal(protocol.PacketNumber(12))) + return false + }) + _, err := handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 12, Smallest: 11}}, + ECT0: 1, + ECT1: 2, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + // acknowledge packet 10 now, but don't increase the largest acked + _, err = handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 12, Smallest: 10}}, + ECT0: 1, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("ignores ACKs that don't increase the largest acked", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool { + Expect(packets).To(HaveLen(1)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(11))) + return false + }) + _, err := handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 11, Smallest: 11}}, + ECT0: 1, + ECT1: 2, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + _, err = handler.ReceivedAck(&wire.AckFrame{ + AckRanges: []wire.AckRange{{Largest: 11, Smallest: 10}}, + ECT0: 1, + ECNCE: 3, + }, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("informs the congestion controller about CE events", func() { + for i := 10; i < 20; i++ { + ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT0) + handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT0, 1200, false) + } + ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(0), int64(0), int64(0)).Return(true) + cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(15), gomock.Any(), gomock.Any()) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 15, Smallest: 10}}}, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + }) })