Simplify ackhandler and congestion by splitting up OnCongestionEvent

No functional change.
This commit is contained in:
Lucas Clemente
2017-04-05 21:34:08 +02:00
parent fb5cccf799
commit 89d0ae9810
6 changed files with 49 additions and 85 deletions

View File

@@ -152,24 +152,23 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
if rttUpdated {
h.congestion.MaybeExitSlowStart()
}
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
if err != nil {
return err
}
if len(ackedPackets) > 0 {
var ackedPacketsCongestion congestion.PacketVector
for _, p := range ackedPackets {
h.onPacketAcked(p)
ackedPacketsCongestion = append(ackedPacketsCongestion, congestion.PacketInfo{
Number: p.Value.PacketNumber,
Length: p.Value.Length,
})
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
}
h.congestion.OnCongestionEvent(rttUpdated, h.bytesInFlight, ackedPacketsCongestion, nil)
}
h.detectLostPackets(rttUpdated)
h.detectLostPackets()
h.updateLossDetectionAlarm()
h.garbageCollectSkippedPackets()
@@ -249,8 +248,7 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() {
}
}
// TODO(lucas-clemente): Introducing congestion.MaybeExitSlowStart() would allow us to call through for each packet and eliminate both the rttUpdated param and the packet slices passed to the congestion
func (h *sentPacketHandler) detectLostPackets(rttUpdated bool) {
func (h *sentPacketHandler) detectLostPackets() {
h.lossTime = time.Time{}
now := time.Now()
@@ -275,15 +273,10 @@ func (h *sentPacketHandler) detectLostPackets(rttUpdated bool) {
}
if len(lostPackets) > 0 {
var lostPacketsCongestion congestion.PacketVector
for _, p := range lostPackets {
h.queuePacketForRetransmission(p)
lostPacketsCongestion = append(lostPacketsCongestion, congestion.PacketInfo{
Number: p.Value.PacketNumber,
Length: p.Value.Length,
})
h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
}
h.congestion.OnCongestionEvent(rttUpdated, h.bytesInFlight, nil, lostPacketsCongestion)
}
}
@@ -292,7 +285,7 @@ func (h *sentPacketHandler) OnAlarm() {
// TODO(#497): TLP
if !h.lossTime.IsZero() {
// Early retransmit or time loss detection
h.detectLostPackets(false /* rttUpdated */)
h.detectLostPackets()
} else {
// RTO
h.retransmitOldestTwoPackets()
@@ -349,13 +342,9 @@ func (h *sentPacketHandler) retransmitOldestTwoPackets() {
func (h *sentPacketHandler) queueRTO(el *PacketElement) {
packet := &el.Value
packetsLost := congestion.PacketVector{congestion.PacketInfo{
Number: packet.PacketNumber,
Length: packet.Length,
}}
utils.Debugf("\tQueueing packet 0x%x for retransmission (RTO)", packet.PacketNumber)
h.queuePacketForRetransmission(el)
h.congestion.OnCongestionEvent(false, h.bytesInFlight, nil, packetsLost)
h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight)
h.congestion.OnRetransmissionTimeout(true)
}

View File

@@ -11,10 +11,12 @@ import (
)
type mockCongestion struct {
nCalls int
argsOnPacketSent []interface{}
argsOnCongestionEvent []interface{}
maybeExitSlowStart bool
onRetransmissionTimeout bool
getCongestionWindow bool
packetsAcked [][]interface{}
packetsLost [][]interface{}
}
func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
@@ -22,23 +24,20 @@ func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.Byt
}
func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
m.nCalls++
m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable}
return false
}
func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount {
m.nCalls++
m.getCongestionWindow = true
return protocol.DefaultTCPMSS
}
func (m *mockCongestion) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets congestion.PacketVector, lostPackets congestion.PacketVector) {
m.nCalls++
m.argsOnCongestionEvent = []interface{}{rttUpdated, bytesInFlight, ackedPackets, lostPackets}
func (m *mockCongestion) MaybeExitSlowStart() {
m.maybeExitSlowStart = true
}
func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) {
m.nCalls++
m.onRetransmissionTimeout = true
}
@@ -50,6 +49,14 @@ func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not i
func (m *mockCongestion) OnConnectionMigration() { panic("not implemented") }
func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") }
func (m *mockCongestion) OnPacketAcked(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) {
m.packetsAcked = append(m.packetsAcked, []interface{}{n, l, bif})
}
func (m *mockCongestion) OnPacketLost(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) {
m.packetsLost = append(m.packetsLost, []interface{}{n, l, bif})
}
var _ = Describe("SentPacketHandler", func() {
var (
handler *sentPacketHandler
@@ -627,38 +634,36 @@ var _ = Describe("SentPacketHandler", func() {
}
err := handler.SentPacket(p)
Expect(err).NotTo(HaveOccurred())
Expect(cong.nCalls).To(Equal(1))
Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(42)))
Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1)))
Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(42)))
Expect(cong.argsOnPacketSent[4]).To(BeTrue())
})
It("should call OnCongestionEvent for ACKs", func() {
It("should call MaybeExitSlowStart and OnPacketAcked", func() {
handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1})
handler.SentPacket(&Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 1})
Expect(cong.nCalls).To(Equal(2))
err := handler.ReceivedAck(&frames.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, time.Now())
Expect(err).NotTo(HaveOccurred())
Expect(cong.nCalls).To(Equal(3))
Expect(cong.argsOnCongestionEvent[0]).To(BeTrue())
Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(1)))
Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{Number: 1, Length: 1}}))
Expect(cong.argsOnCongestionEvent[3]).To(BeEmpty())
Expect(cong.maybeExitSlowStart).To(BeTrue())
Expect(cong.packetsAcked).To(BeEquivalentTo([][]interface{}{
{protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(1)},
}))
Expect(cong.packetsLost).To(BeEmpty())
})
It("should call OnCongestionEvent for losses", func() {
It("should call MaybeExitSlowStart and OnPacketLost", func() {
handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1})
handler.SentPacket(&Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 1})
handler.SentPacket(&Packet{PacketNumber: 3, Frames: []frames.Frame{}, Length: 1})
Expect(cong.nCalls).To(Equal(3))
handler.OnAlarm() // RTO, meaning 2 lost packets
Expect(cong.nCalls).To(Equal(3 + 4 /* 2* (OnCongestionEvent+OnRTO)*/))
Expect(cong.maybeExitSlowStart).To(BeFalse())
Expect(cong.onRetransmissionTimeout).To(BeTrue())
Expect(cong.argsOnCongestionEvent[0]).To(BeFalse())
Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(1)))
Expect(cong.argsOnCongestionEvent[2]).To(BeEmpty())
Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{{Number: 2, Length: 1}}))
Expect(cong.packetsAcked).To(BeEmpty())
Expect(cong.packetsLost).To(BeEquivalentTo([][]interface{}{
{protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)},
{protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(1)},
}))
})
It("allows or denies sending based on congestion", func() {

View File

@@ -1,12 +0,0 @@
package congestion
import "github.com/lucas-clemente/quic-go/protocol"
// PacketInfo combines packet number and length of a packet for congestion calculation
type PacketInfo struct {
Number protocol.PacketNumber
Length protocol.ByteCount
}
// PacketVector is passed to the congestion algorithm
type PacketVector []PacketInfo

View File

@@ -125,24 +125,13 @@ func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber {
return c.slowstartThreshold
}
// OnCongestionEvent indicates an update to the congestion state, caused either by an incoming
// ack or loss event timeout. |rttUpdated| indicates whether a new
// latest_rtt sample has been taken, |byte_in_flight| the bytes in flight
// prior to the congestion event. |ackedPackets| and |lostPackets| are
// any packets considered acked or lost as a result of the congestion event.
func (c *cubicSender) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets PacketVector, lostPackets PacketVector) {
if rttUpdated && c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) {
func (c *cubicSender) MaybeExitSlowStart() {
if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) {
c.ExitSlowstart()
}
for _, i := range lostPackets {
c.onPacketLost(i.Number, i.Length, bytesInFlight)
}
for _, i := range ackedPackets {
c.onPacketAcked(i.Number, i.Length, bytesInFlight)
}
}
func (c *cubicSender) onPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
if c.InRecovery() {
// PRR is used when in recovery.
@@ -155,7 +144,7 @@ func (c *cubicSender) onPacketAcked(ackedPacketNumber protocol.PacketNumber, ack
}
}
func (c *cubicSender) onPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback {

View File

@@ -59,35 +59,26 @@ var _ = Describe("Cubic Sender", func() {
// Normal is that TCP acks every other segment.
AckNPacketsLen := func(n int, packetLength protocol.ByteCount) {
rttStats.UpdateRTT(60*time.Millisecond, 0, clock.Now())
var ackedPackets PacketVector
var lostPackets PacketVector
sender.MaybeExitSlowStart()
for i := 0; i < n; i++ {
ackedPacketNumber++
ackedPackets = append(ackedPackets, PacketInfo{Number: ackedPacketNumber, Length: packetLength})
sender.OnPacketAcked(ackedPacketNumber, packetLength, bytesInFlight)
}
sender.OnCongestionEvent(true, bytesInFlight, ackedPackets, lostPackets)
bytesInFlight -= protocol.ByteCount(n) * packetLength
clock.Advance(time.Millisecond)
}
LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) {
var ackedPackets PacketVector
var lostPackets PacketVector
for i := 0; i < n; i++ {
ackedPacketNumber++
lostPackets = append(lostPackets, PacketInfo{Number: ackedPacketNumber, Length: packetLength})
sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight)
}
sender.OnCongestionEvent(false, bytesInFlight, ackedPackets, lostPackets)
bytesInFlight -= protocol.ByteCount(n) * packetLength
}
// Does not increment acked_packet_number_.
LosePacket := func(number protocol.PacketNumber) {
var ackedPackets PacketVector
var lostPackets PacketVector = PacketVector([]PacketInfo{
{Number: number, Length: protocol.DefaultTCPMSS},
})
sender.OnCongestionEvent(false, bytesInFlight, ackedPackets, lostPackets)
sender.OnPacketLost(number, protocol.DefaultTCPMSS, bytesInFlight)
bytesInFlight -= protocol.DefaultTCPMSS
}

View File

@@ -11,7 +11,9 @@ type SendAlgorithm interface {
TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
GetCongestionWindow() protocol.ByteCount
OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets PacketVector, lostPackets PacketVector)
MaybeExitSlowStart()
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
SetNumEmulatedConnections(n int)
OnRetransmissionTimeout(packetsRetransmitted bool)
OnConnectionMigration()