From 89d0ae9810a0edcd2a5195df70681e0035da3280 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Wed, 5 Apr 2017 21:34:08 +0200 Subject: [PATCH] Simplify ackhandler and congestion by splitting up OnCongestionEvent No functional change. --- ackhandler/sent_packet_handler.go | 31 +++++----------- ackhandler/sent_packet_handler_test.go | 51 ++++++++++++++------------ congestion/congestion_vector.go | 12 ------ congestion/cubic_sender.go | 19 ++-------- congestion/cubic_sender_test.go | 17 ++------- congestion/interface.go | 4 +- 6 files changed, 49 insertions(+), 85 deletions(-) delete mode 100644 congestion/congestion_vector.go diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index 5d0048e3..e7984723 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -152,24 +152,23 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime) + if rttUpdated { + h.congestion.MaybeExitSlowStart() + } + ackedPackets, err := h.determineNewlyAckedPackets(ackFrame) if err != nil { return err } if len(ackedPackets) > 0 { - var ackedPacketsCongestion congestion.PacketVector for _, p := range ackedPackets { h.onPacketAcked(p) - ackedPacketsCongestion = append(ackedPacketsCongestion, congestion.PacketInfo{ - Number: p.Value.PacketNumber, - Length: p.Value.Length, - }) + h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight) } - h.congestion.OnCongestionEvent(rttUpdated, h.bytesInFlight, ackedPacketsCongestion, nil) } - h.detectLostPackets(rttUpdated) + h.detectLostPackets() h.updateLossDetectionAlarm() h.garbageCollectSkippedPackets() @@ -249,8 +248,7 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() { } } -// TODO(lucas-clemente): Introducing congestion.MaybeExitSlowStart() would allow us to call through for each packet and eliminate both the rttUpdated param and the packet slices passed to the congestion -func (h *sentPacketHandler) detectLostPackets(rttUpdated bool) { +func (h *sentPacketHandler) detectLostPackets() { h.lossTime = time.Time{} now := time.Now() @@ -275,15 +273,10 @@ func (h *sentPacketHandler) detectLostPackets(rttUpdated bool) { } if len(lostPackets) > 0 { - var lostPacketsCongestion congestion.PacketVector for _, p := range lostPackets { h.queuePacketForRetransmission(p) - lostPacketsCongestion = append(lostPacketsCongestion, congestion.PacketInfo{ - Number: p.Value.PacketNumber, - Length: p.Value.Length, - }) + h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight) } - h.congestion.OnCongestionEvent(rttUpdated, h.bytesInFlight, nil, lostPacketsCongestion) } } @@ -292,7 +285,7 @@ func (h *sentPacketHandler) OnAlarm() { // TODO(#497): TLP if !h.lossTime.IsZero() { // Early retransmit or time loss detection - h.detectLostPackets(false /* rttUpdated */) + h.detectLostPackets() } else { // RTO h.retransmitOldestTwoPackets() @@ -349,13 +342,9 @@ func (h *sentPacketHandler) retransmitOldestTwoPackets() { func (h *sentPacketHandler) queueRTO(el *PacketElement) { packet := &el.Value - packetsLost := congestion.PacketVector{congestion.PacketInfo{ - Number: packet.PacketNumber, - Length: packet.Length, - }} utils.Debugf("\tQueueing packet 0x%x for retransmission (RTO)", packet.PacketNumber) h.queuePacketForRetransmission(el) - h.congestion.OnCongestionEvent(false, h.bytesInFlight, nil, packetsLost) + h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight) h.congestion.OnRetransmissionTimeout(true) } diff --git a/ackhandler/sent_packet_handler_test.go b/ackhandler/sent_packet_handler_test.go index 7ad3d305..cb23b81a 100644 --- a/ackhandler/sent_packet_handler_test.go +++ b/ackhandler/sent_packet_handler_test.go @@ -11,10 +11,12 @@ import ( ) type mockCongestion struct { - nCalls int argsOnPacketSent []interface{} - argsOnCongestionEvent []interface{} + maybeExitSlowStart bool onRetransmissionTimeout bool + getCongestionWindow bool + packetsAcked [][]interface{} + packetsLost [][]interface{} } func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration { @@ -22,23 +24,20 @@ func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.Byt } func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool { - m.nCalls++ m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable} return false } func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount { - m.nCalls++ + m.getCongestionWindow = true return protocol.DefaultTCPMSS } -func (m *mockCongestion) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets congestion.PacketVector, lostPackets congestion.PacketVector) { - m.nCalls++ - m.argsOnCongestionEvent = []interface{}{rttUpdated, bytesInFlight, ackedPackets, lostPackets} +func (m *mockCongestion) MaybeExitSlowStart() { + m.maybeExitSlowStart = true } func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) { - m.nCalls++ m.onRetransmissionTimeout = true } @@ -50,6 +49,14 @@ func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not i func (m *mockCongestion) OnConnectionMigration() { panic("not implemented") } func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") } +func (m *mockCongestion) OnPacketAcked(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) { + m.packetsAcked = append(m.packetsAcked, []interface{}{n, l, bif}) +} + +func (m *mockCongestion) OnPacketLost(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) { + m.packetsLost = append(m.packetsLost, []interface{}{n, l, bif}) +} + var _ = Describe("SentPacketHandler", func() { var ( handler *sentPacketHandler @@ -627,38 +634,36 @@ var _ = Describe("SentPacketHandler", func() { } err := handler.SentPacket(p) Expect(err).NotTo(HaveOccurred()) - Expect(cong.nCalls).To(Equal(1)) Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(42))) Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1))) Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(42))) Expect(cong.argsOnPacketSent[4]).To(BeTrue()) }) - It("should call OnCongestionEvent for ACKs", func() { + It("should call MaybeExitSlowStart and OnPacketAcked", func() { handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) handler.SentPacket(&Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 1}) - Expect(cong.nCalls).To(Equal(2)) err := handler.ReceivedAck(&frames.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, time.Now()) Expect(err).NotTo(HaveOccurred()) - Expect(cong.nCalls).To(Equal(3)) - Expect(cong.argsOnCongestionEvent[0]).To(BeTrue()) - Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(1))) - Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{Number: 1, Length: 1}})) - Expect(cong.argsOnCongestionEvent[3]).To(BeEmpty()) + Expect(cong.maybeExitSlowStart).To(BeTrue()) + Expect(cong.packetsAcked).To(BeEquivalentTo([][]interface{}{ + {protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(1)}, + })) + Expect(cong.packetsLost).To(BeEmpty()) }) - It("should call OnCongestionEvent for losses", func() { + It("should call MaybeExitSlowStart and OnPacketLost", func() { handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) handler.SentPacket(&Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 1}) handler.SentPacket(&Packet{PacketNumber: 3, Frames: []frames.Frame{}, Length: 1}) - Expect(cong.nCalls).To(Equal(3)) handler.OnAlarm() // RTO, meaning 2 lost packets - Expect(cong.nCalls).To(Equal(3 + 4 /* 2* (OnCongestionEvent+OnRTO)*/)) + Expect(cong.maybeExitSlowStart).To(BeFalse()) Expect(cong.onRetransmissionTimeout).To(BeTrue()) - Expect(cong.argsOnCongestionEvent[0]).To(BeFalse()) - Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(1))) - Expect(cong.argsOnCongestionEvent[2]).To(BeEmpty()) - Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{{Number: 2, Length: 1}})) + Expect(cong.packetsAcked).To(BeEmpty()) + Expect(cong.packetsLost).To(BeEquivalentTo([][]interface{}{ + {protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)}, + {protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(1)}, + })) }) It("allows or denies sending based on congestion", func() { diff --git a/congestion/congestion_vector.go b/congestion/congestion_vector.go deleted file mode 100644 index 9c6ebb4e..00000000 --- a/congestion/congestion_vector.go +++ /dev/null @@ -1,12 +0,0 @@ -package congestion - -import "github.com/lucas-clemente/quic-go/protocol" - -// PacketInfo combines packet number and length of a packet for congestion calculation -type PacketInfo struct { - Number protocol.PacketNumber - Length protocol.ByteCount -} - -// PacketVector is passed to the congestion algorithm -type PacketVector []PacketInfo diff --git a/congestion/cubic_sender.go b/congestion/cubic_sender.go index a34a69df..34947d34 100644 --- a/congestion/cubic_sender.go +++ b/congestion/cubic_sender.go @@ -125,24 +125,13 @@ func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber { return c.slowstartThreshold } -// OnCongestionEvent indicates an update to the congestion state, caused either by an incoming -// ack or loss event timeout. |rttUpdated| indicates whether a new -// latest_rtt sample has been taken, |byte_in_flight| the bytes in flight -// prior to the congestion event. |ackedPackets| and |lostPackets| are -// any packets considered acked or lost as a result of the congestion event. -func (c *cubicSender) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets PacketVector, lostPackets PacketVector) { - if rttUpdated && c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) { +func (c *cubicSender) MaybeExitSlowStart() { + if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) { c.ExitSlowstart() } - for _, i := range lostPackets { - c.onPacketLost(i.Number, i.Length, bytesInFlight) - } - for _, i := range ackedPackets { - c.onPacketAcked(i.Number, i.Length, bytesInFlight) - } } -func (c *cubicSender) onPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { +func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber) if c.InRecovery() { // PRR is used when in recovery. @@ -155,7 +144,7 @@ func (c *cubicSender) onPacketAcked(ackedPacketNumber protocol.PacketNumber, ack } } -func (c *cubicSender) onPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { +func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets // already sent should be treated as a single loss event, since it's expected. if packetNumber <= c.largestSentAtLastCutback { diff --git a/congestion/cubic_sender_test.go b/congestion/cubic_sender_test.go index a7256c87..8afd426e 100644 --- a/congestion/cubic_sender_test.go +++ b/congestion/cubic_sender_test.go @@ -59,35 +59,26 @@ var _ = Describe("Cubic Sender", func() { // Normal is that TCP acks every other segment. AckNPacketsLen := func(n int, packetLength protocol.ByteCount) { rttStats.UpdateRTT(60*time.Millisecond, 0, clock.Now()) - var ackedPackets PacketVector - var lostPackets PacketVector + sender.MaybeExitSlowStart() for i := 0; i < n; i++ { ackedPacketNumber++ - ackedPackets = append(ackedPackets, PacketInfo{Number: ackedPacketNumber, Length: packetLength}) + sender.OnPacketAcked(ackedPacketNumber, packetLength, bytesInFlight) } - sender.OnCongestionEvent(true, bytesInFlight, ackedPackets, lostPackets) bytesInFlight -= protocol.ByteCount(n) * packetLength clock.Advance(time.Millisecond) } LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) { - var ackedPackets PacketVector - var lostPackets PacketVector for i := 0; i < n; i++ { ackedPacketNumber++ - lostPackets = append(lostPackets, PacketInfo{Number: ackedPacketNumber, Length: packetLength}) + sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight) } - sender.OnCongestionEvent(false, bytesInFlight, ackedPackets, lostPackets) bytesInFlight -= protocol.ByteCount(n) * packetLength } // Does not increment acked_packet_number_. LosePacket := func(number protocol.PacketNumber) { - var ackedPackets PacketVector - var lostPackets PacketVector = PacketVector([]PacketInfo{ - {Number: number, Length: protocol.DefaultTCPMSS}, - }) - sender.OnCongestionEvent(false, bytesInFlight, ackedPackets, lostPackets) + sender.OnPacketLost(number, protocol.DefaultTCPMSS, bytesInFlight) bytesInFlight -= protocol.DefaultTCPMSS } diff --git a/congestion/interface.go b/congestion/interface.go index 593c09e7..bbce0a63 100644 --- a/congestion/interface.go +++ b/congestion/interface.go @@ -11,7 +11,9 @@ type SendAlgorithm interface { TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool GetCongestionWindow() protocol.ByteCount - OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets PacketVector, lostPackets PacketVector) + MaybeExitSlowStart() + OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) + OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) SetNumEmulatedConnections(n int) OnRetransmissionTimeout(packetsRetransmitted bool) OnConnectionMigration()