forked from quic-go/quic-go
Simplify ackhandler and congestion by splitting up OnCongestionEvent
No functional change.
This commit is contained in:
@@ -152,24 +152,23 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum
|
||||
|
||||
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
|
||||
|
||||
if rttUpdated {
|
||||
h.congestion.MaybeExitSlowStart()
|
||||
}
|
||||
|
||||
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(ackedPackets) > 0 {
|
||||
var ackedPacketsCongestion congestion.PacketVector
|
||||
for _, p := range ackedPackets {
|
||||
h.onPacketAcked(p)
|
||||
ackedPacketsCongestion = append(ackedPacketsCongestion, congestion.PacketInfo{
|
||||
Number: p.Value.PacketNumber,
|
||||
Length: p.Value.Length,
|
||||
})
|
||||
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
||||
}
|
||||
h.congestion.OnCongestionEvent(rttUpdated, h.bytesInFlight, ackedPacketsCongestion, nil)
|
||||
}
|
||||
|
||||
h.detectLostPackets(rttUpdated)
|
||||
h.detectLostPackets()
|
||||
h.updateLossDetectionAlarm()
|
||||
|
||||
h.garbageCollectSkippedPackets()
|
||||
@@ -249,8 +248,7 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(lucas-clemente): Introducing congestion.MaybeExitSlowStart() would allow us to call through for each packet and eliminate both the rttUpdated param and the packet slices passed to the congestion
|
||||
func (h *sentPacketHandler) detectLostPackets(rttUpdated bool) {
|
||||
func (h *sentPacketHandler) detectLostPackets() {
|
||||
h.lossTime = time.Time{}
|
||||
now := time.Now()
|
||||
|
||||
@@ -275,15 +273,10 @@ func (h *sentPacketHandler) detectLostPackets(rttUpdated bool) {
|
||||
}
|
||||
|
||||
if len(lostPackets) > 0 {
|
||||
var lostPacketsCongestion congestion.PacketVector
|
||||
for _, p := range lostPackets {
|
||||
h.queuePacketForRetransmission(p)
|
||||
lostPacketsCongestion = append(lostPacketsCongestion, congestion.PacketInfo{
|
||||
Number: p.Value.PacketNumber,
|
||||
Length: p.Value.Length,
|
||||
})
|
||||
h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
||||
}
|
||||
h.congestion.OnCongestionEvent(rttUpdated, h.bytesInFlight, nil, lostPacketsCongestion)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,7 +285,7 @@ func (h *sentPacketHandler) OnAlarm() {
|
||||
// TODO(#497): TLP
|
||||
if !h.lossTime.IsZero() {
|
||||
// Early retransmit or time loss detection
|
||||
h.detectLostPackets(false /* rttUpdated */)
|
||||
h.detectLostPackets()
|
||||
} else {
|
||||
// RTO
|
||||
h.retransmitOldestTwoPackets()
|
||||
@@ -349,13 +342,9 @@ func (h *sentPacketHandler) retransmitOldestTwoPackets() {
|
||||
|
||||
func (h *sentPacketHandler) queueRTO(el *PacketElement) {
|
||||
packet := &el.Value
|
||||
packetsLost := congestion.PacketVector{congestion.PacketInfo{
|
||||
Number: packet.PacketNumber,
|
||||
Length: packet.Length,
|
||||
}}
|
||||
utils.Debugf("\tQueueing packet 0x%x for retransmission (RTO)", packet.PacketNumber)
|
||||
h.queuePacketForRetransmission(el)
|
||||
h.congestion.OnCongestionEvent(false, h.bytesInFlight, nil, packetsLost)
|
||||
h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight)
|
||||
h.congestion.OnRetransmissionTimeout(true)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,10 +11,12 @@ import (
|
||||
)
|
||||
|
||||
type mockCongestion struct {
|
||||
nCalls int
|
||||
argsOnPacketSent []interface{}
|
||||
argsOnCongestionEvent []interface{}
|
||||
maybeExitSlowStart bool
|
||||
onRetransmissionTimeout bool
|
||||
getCongestionWindow bool
|
||||
packetsAcked [][]interface{}
|
||||
packetsLost [][]interface{}
|
||||
}
|
||||
|
||||
func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
|
||||
@@ -22,23 +24,20 @@ func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.Byt
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
|
||||
m.nCalls++
|
||||
m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount {
|
||||
m.nCalls++
|
||||
m.getCongestionWindow = true
|
||||
return protocol.DefaultTCPMSS
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets congestion.PacketVector, lostPackets congestion.PacketVector) {
|
||||
m.nCalls++
|
||||
m.argsOnCongestionEvent = []interface{}{rttUpdated, bytesInFlight, ackedPackets, lostPackets}
|
||||
func (m *mockCongestion) MaybeExitSlowStart() {
|
||||
m.maybeExitSlowStart = true
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
||||
m.nCalls++
|
||||
m.onRetransmissionTimeout = true
|
||||
}
|
||||
|
||||
@@ -50,6 +49,14 @@ func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not i
|
||||
func (m *mockCongestion) OnConnectionMigration() { panic("not implemented") }
|
||||
func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") }
|
||||
|
||||
func (m *mockCongestion) OnPacketAcked(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) {
|
||||
m.packetsAcked = append(m.packetsAcked, []interface{}{n, l, bif})
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnPacketLost(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) {
|
||||
m.packetsLost = append(m.packetsLost, []interface{}{n, l, bif})
|
||||
}
|
||||
|
||||
var _ = Describe("SentPacketHandler", func() {
|
||||
var (
|
||||
handler *sentPacketHandler
|
||||
@@ -627,38 +634,36 @@ var _ = Describe("SentPacketHandler", func() {
|
||||
}
|
||||
err := handler.SentPacket(p)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cong.nCalls).To(Equal(1))
|
||||
Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(42)))
|
||||
Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1)))
|
||||
Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(42)))
|
||||
Expect(cong.argsOnPacketSent[4]).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should call OnCongestionEvent for ACKs", func() {
|
||||
It("should call MaybeExitSlowStart and OnPacketAcked", func() {
|
||||
handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1})
|
||||
handler.SentPacket(&Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 1})
|
||||
Expect(cong.nCalls).To(Equal(2))
|
||||
err := handler.ReceivedAck(&frames.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, time.Now())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cong.nCalls).To(Equal(3))
|
||||
Expect(cong.argsOnCongestionEvent[0]).To(BeTrue())
|
||||
Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(1)))
|
||||
Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{Number: 1, Length: 1}}))
|
||||
Expect(cong.argsOnCongestionEvent[3]).To(BeEmpty())
|
||||
Expect(cong.maybeExitSlowStart).To(BeTrue())
|
||||
Expect(cong.packetsAcked).To(BeEquivalentTo([][]interface{}{
|
||||
{protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(1)},
|
||||
}))
|
||||
Expect(cong.packetsLost).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should call OnCongestionEvent for losses", func() {
|
||||
It("should call MaybeExitSlowStart and OnPacketLost", func() {
|
||||
handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1})
|
||||
handler.SentPacket(&Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 1})
|
||||
handler.SentPacket(&Packet{PacketNumber: 3, Frames: []frames.Frame{}, Length: 1})
|
||||
Expect(cong.nCalls).To(Equal(3))
|
||||
handler.OnAlarm() // RTO, meaning 2 lost packets
|
||||
Expect(cong.nCalls).To(Equal(3 + 4 /* 2* (OnCongestionEvent+OnRTO)*/))
|
||||
Expect(cong.maybeExitSlowStart).To(BeFalse())
|
||||
Expect(cong.onRetransmissionTimeout).To(BeTrue())
|
||||
Expect(cong.argsOnCongestionEvent[0]).To(BeFalse())
|
||||
Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(1)))
|
||||
Expect(cong.argsOnCongestionEvent[2]).To(BeEmpty())
|
||||
Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{{Number: 2, Length: 1}}))
|
||||
Expect(cong.packetsAcked).To(BeEmpty())
|
||||
Expect(cong.packetsLost).To(BeEquivalentTo([][]interface{}{
|
||||
{protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)},
|
||||
{protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(1)},
|
||||
}))
|
||||
})
|
||||
|
||||
It("allows or denies sending based on congestion", func() {
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package congestion
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
// PacketInfo combines packet number and length of a packet for congestion calculation
|
||||
type PacketInfo struct {
|
||||
Number protocol.PacketNumber
|
||||
Length protocol.ByteCount
|
||||
}
|
||||
|
||||
// PacketVector is passed to the congestion algorithm
|
||||
type PacketVector []PacketInfo
|
||||
@@ -125,24 +125,13 @@ func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber {
|
||||
return c.slowstartThreshold
|
||||
}
|
||||
|
||||
// OnCongestionEvent indicates an update to the congestion state, caused either by an incoming
|
||||
// ack or loss event timeout. |rttUpdated| indicates whether a new
|
||||
// latest_rtt sample has been taken, |byte_in_flight| the bytes in flight
|
||||
// prior to the congestion event. |ackedPackets| and |lostPackets| are
|
||||
// any packets considered acked or lost as a result of the congestion event.
|
||||
func (c *cubicSender) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets PacketVector, lostPackets PacketVector) {
|
||||
if rttUpdated && c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) {
|
||||
func (c *cubicSender) MaybeExitSlowStart() {
|
||||
if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) {
|
||||
c.ExitSlowstart()
|
||||
}
|
||||
for _, i := range lostPackets {
|
||||
c.onPacketLost(i.Number, i.Length, bytesInFlight)
|
||||
}
|
||||
for _, i := range ackedPackets {
|
||||
c.onPacketAcked(i.Number, i.Length, bytesInFlight)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) onPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
||||
func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
||||
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
|
||||
if c.InRecovery() {
|
||||
// PRR is used when in recovery.
|
||||
@@ -155,7 +144,7 @@ func (c *cubicSender) onPacketAcked(ackedPacketNumber protocol.PacketNumber, ack
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) onPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
||||
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
||||
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
||||
// already sent should be treated as a single loss event, since it's expected.
|
||||
if packetNumber <= c.largestSentAtLastCutback {
|
||||
|
||||
@@ -59,35 +59,26 @@ var _ = Describe("Cubic Sender", func() {
|
||||
// Normal is that TCP acks every other segment.
|
||||
AckNPacketsLen := func(n int, packetLength protocol.ByteCount) {
|
||||
rttStats.UpdateRTT(60*time.Millisecond, 0, clock.Now())
|
||||
var ackedPackets PacketVector
|
||||
var lostPackets PacketVector
|
||||
sender.MaybeExitSlowStart()
|
||||
for i := 0; i < n; i++ {
|
||||
ackedPacketNumber++
|
||||
ackedPackets = append(ackedPackets, PacketInfo{Number: ackedPacketNumber, Length: packetLength})
|
||||
sender.OnPacketAcked(ackedPacketNumber, packetLength, bytesInFlight)
|
||||
}
|
||||
sender.OnCongestionEvent(true, bytesInFlight, ackedPackets, lostPackets)
|
||||
bytesInFlight -= protocol.ByteCount(n) * packetLength
|
||||
clock.Advance(time.Millisecond)
|
||||
}
|
||||
|
||||
LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) {
|
||||
var ackedPackets PacketVector
|
||||
var lostPackets PacketVector
|
||||
for i := 0; i < n; i++ {
|
||||
ackedPacketNumber++
|
||||
lostPackets = append(lostPackets, PacketInfo{Number: ackedPacketNumber, Length: packetLength})
|
||||
sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight)
|
||||
}
|
||||
sender.OnCongestionEvent(false, bytesInFlight, ackedPackets, lostPackets)
|
||||
bytesInFlight -= protocol.ByteCount(n) * packetLength
|
||||
}
|
||||
|
||||
// Does not increment acked_packet_number_.
|
||||
LosePacket := func(number protocol.PacketNumber) {
|
||||
var ackedPackets PacketVector
|
||||
var lostPackets PacketVector = PacketVector([]PacketInfo{
|
||||
{Number: number, Length: protocol.DefaultTCPMSS},
|
||||
})
|
||||
sender.OnCongestionEvent(false, bytesInFlight, ackedPackets, lostPackets)
|
||||
sender.OnPacketLost(number, protocol.DefaultTCPMSS, bytesInFlight)
|
||||
bytesInFlight -= protocol.DefaultTCPMSS
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,9 @@ type SendAlgorithm interface {
|
||||
TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration
|
||||
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
|
||||
GetCongestionWindow() protocol.ByteCount
|
||||
OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets PacketVector, lostPackets PacketVector)
|
||||
MaybeExitSlowStart()
|
||||
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
|
||||
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
|
||||
SetNumEmulatedConnections(n int)
|
||||
OnRetransmissionTimeout(packetsRetransmitted bool)
|
||||
OnConnectionMigration()
|
||||
|
||||
Reference in New Issue
Block a user