diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index b22cafd97..2332dcf73 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -103,6 +103,7 @@ func newSentPacketHandler( congestion.DefaultClock{}, rttStats, true, // use Reno + tracer, ) return &sentPacketHandler{ diff --git a/internal/congestion/cubic_sender.go b/internal/congestion/cubic_sender.go index 2279c3312..07f6cbd34 100644 --- a/internal/congestion/cubic_sender.go +++ b/internal/congestion/cubic_sender.go @@ -5,6 +5,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/logging" ) const ( @@ -57,17 +58,20 @@ type cubicSender struct { initialCongestionWindow protocol.ByteCount initialMaxCongestionWindow protocol.ByteCount + + lastState logging.CongestionState + tracer logging.ConnectionTracer } var _ SendAlgorithm = &cubicSender{} var _ SendAlgorithmWithDebugInfos = &cubicSender{} // NewCubicSender makes a new cubic sender -func NewCubicSender(clock Clock, rttStats *utils.RTTStats, reno bool) *cubicSender { - return newCubicSender(clock, rttStats, reno, initialCongestionWindow, maxCongestionWindow) +func NewCubicSender(clock Clock, rttStats *utils.RTTStats, reno bool, tracer logging.ConnectionTracer) *cubicSender { + return newCubicSender(clock, rttStats, reno, initialCongestionWindow, maxCongestionWindow, tracer) } -func newCubicSender(clock Clock, rttStats *utils.RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) *cubicSender { +func newCubicSender(clock Clock, rttStats *utils.RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount, tracer logging.ConnectionTracer) *cubicSender { c := &cubicSender{ rttStats: rttStats, largestSentPacketNumber: protocol.InvalidPacketNumber, @@ -82,8 +86,13 @@ func newCubicSender(clock Clock, rttStats *utils.RTTStats, reno bool, initialCon cubic: NewCubic(clock), clock: clock, reno: reno, + tracer: tracer, } c.pacer = newPacer(c.BandwidthEstimate) + if c.tracer != nil { + c.lastState = logging.CongestionStateSlowStart + c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) + } return c } @@ -131,6 +140,7 @@ func (c *cubicSender) MaybeExitSlowStart() { if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/maxDatagramSize) { // exit slow start c.slowStartThreshold = c.congestionWindow + c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) } } @@ -161,6 +171,7 @@ func (c *cubicSender) OnPacketLost( return } c.lastCutbackExitedSlowstart = c.InSlowStart() + c.maybeTraceStateChange(logging.CongestionStateRecovery) if c.reno { c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta) @@ -189,6 +200,7 @@ func (c *cubicSender) maybeIncreaseCwnd( // the current window. if !c.isCwndLimited(priorInFlight) { c.cubic.OnApplicationLimited() + c.maybeTraceStateChange(logging.CongestionStateApplicationLimited) return } if c.congestionWindow >= c.maxCongestionWindow { @@ -197,9 +209,11 @@ func (c *cubicSender) maybeIncreaseCwnd( if c.InSlowStart() { // TCP slow start, exponential growth, increase by one for each ACK. c.congestionWindow += maxDatagramSize + c.maybeTraceStateChange(logging.CongestionStateSlowStart) return } // Congestion avoidance + c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) if c.reno { // Classic Reno congestion avoidance. c.numAckedPackets++ @@ -257,3 +271,10 @@ func (c *cubicSender) OnConnectionMigration() { c.slowStartThreshold = c.initialMaxCongestionWindow c.maxCongestionWindow = c.initialMaxCongestionWindow } + +func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { + if c.tracer == nil || new == c.lastState { + return + } + c.tracer.UpdatedCongestionState(new) +} diff --git a/internal/congestion/cubic_sender_test.go b/internal/congestion/cubic_sender_test.go index 6490a3d87..b341ee7a3 100644 --- a/internal/congestion/cubic_sender_test.go +++ b/internal/congestion/cubic_sender_test.go @@ -40,7 +40,7 @@ var _ = Describe("Cubic Sender", func() { ackedPacketNumber = 0 clock = mockClock{} rttStats = utils.NewRTTStats() - sender = newCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow) + sender = newCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow, nil) }) SendAvailableSendWindowLen := func(packetLength protocol.ByteCount) int { @@ -308,7 +308,7 @@ var _ = Describe("Cubic Sender", func() { It("tcp cubic reset epoch on quiescence", func() { const maxCongestionWindow = 50 const maxCongestionWindowBytes = maxCongestionWindow * maxDatagramSize - sender = newCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*maxDatagramSize, maxCongestionWindowBytes) + sender = newCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*maxDatagramSize, maxCongestionWindowBytes, nil) numSent := SendAvailableSendWindow() @@ -448,7 +448,7 @@ var _ = Describe("Cubic Sender", func() { }) It("default max cwnd", func() { - sender = newCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*maxDatagramSize, maxCongestionWindow) + sender = newCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*maxDatagramSize, maxCongestionWindow, nil) defaultMaxCongestionWindowPackets := maxCongestionWindow / maxDatagramSize for i := 1; i < int(defaultMaxCongestionWindowPackets); i++ { @@ -460,7 +460,7 @@ var _ = Describe("Cubic Sender", func() { It("limit cwnd increase in congestion avoidance", func() { // Enable Cubic. - sender = newCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow) + sender = newCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow, nil) numSent := SendAvailableSendWindow() // Make sure we fall out of slow start. diff --git a/internal/mocks/connection_tracer.go b/internal/mocks/connection_tracer.go index 5581db743..aac749ba6 100644 --- a/internal/mocks/connection_tracer.go +++ b/internal/mocks/connection_tracer.go @@ -5,16 +5,15 @@ package mocks import ( - "net" - "reflect" - "time" + net "net" + reflect "reflect" + time "time" - "github.com/lucas-clemente/quic-go/internal/utils" - - "github.com/golang/mock/gomock" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/logging" + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + utils "github.com/lucas-clemente/quic-go/internal/utils" + wire "github.com/lucas-clemente/quic-go/internal/wire" + logging "github.com/lucas-clemente/quic-go/logging" ) // MockConnectionTracer is a mock of ConnectionTracer interface @@ -232,6 +231,18 @@ func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3, arg4) } +// UpdatedCongestionState mocks base method +func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionState) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedCongestionState", arg0) +} + +// UpdatedCongestionState indicates an expected call of UpdatedCongestionState +func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) +} + // UpdatedKey mocks base method func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { m.ctrl.T.Helper() diff --git a/logging/interface.go b/logging/interface.go index 8db05b903..e9d4774e5 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -106,6 +106,7 @@ type ConnectionTracer interface { DroppedPacket(PacketType, ByteCount, PacketDropReason) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) + UpdatedCongestionState(CongestionState) UpdatedPTOCount(value uint32) UpdatedKeyFromTLS(EncryptionLevel, Perspective) UpdatedKey(generation KeyPhase, remote bool) diff --git a/logging/mock_connection_tracer_test.go b/logging/mock_connection_tracer_test.go index 8cba7d4cd..bf00e18cc 100644 --- a/logging/mock_connection_tracer_test.go +++ b/logging/mock_connection_tracer_test.go @@ -5,14 +5,13 @@ package logging import ( - "net" - "reflect" - "time" - - "github.com/golang/mock/gomock" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + utils "github.com/lucas-clemente/quic-go/internal/utils" + wire "github.com/lucas-clemente/quic-go/internal/wire" + net "net" + reflect "reflect" + time "time" ) // MockConnectionTracer is a mock of ConnectionTracer interface @@ -230,6 +229,18 @@ func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3, arg4) } +// UpdatedCongestionState mocks base method +func (m *MockConnectionTracer) UpdatedCongestionState(arg0 CongestionState) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedCongestionState", arg0) +} + +// UpdatedCongestionState indicates an expected call of UpdatedCongestionState +func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) +} + // UpdatedKey mocks base method func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { m.ctrl.T.Helper() diff --git a/logging/multiplex.go b/logging/multiplex.go index 1d20f3480..73e0c1286 100644 --- a/logging/multiplex.go +++ b/logging/multiplex.go @@ -120,6 +120,12 @@ func (m *connTracerMultiplexer) DroppedPacket(typ PacketType, size ByteCount, re } } +func (m *connTracerMultiplexer) UpdatedCongestionState(state CongestionState) { + for _, t := range m.tracers { + t.UpdatedCongestionState(state) + } +} + func (m *connTracerMultiplexer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFLight ByteCount, packetsInFlight int) { for _, t := range m.tracers { t.UpdatedMetrics(rttStats, cwnd, bytesInFLight, packetsInFlight) diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index 807a27750..dddb946f0 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -169,6 +169,12 @@ var _ = Describe("Tracing", func() { tracer.DroppedPacket(PacketTypeInitial, 1337, PacketDropHeaderParseError) }) + It("traces the UpdatedCongestionState event", func() { + tr1.EXPECT().UpdatedCongestionState(CongestionStateRecovery) + tr2.EXPECT().UpdatedCongestionState(CongestionStateRecovery) + tracer.UpdatedCongestionState(CongestionStateRecovery) + }) + It("traces the UpdatedMetrics event", func() { rttStats := &RTTStats{} rttStats.UpdateRTT(time.Second, 0, time.Now()) diff --git a/logging/types.go b/logging/types.go index 3e0bfb4b1..41a776c89 100644 --- a/logging/types.go +++ b/logging/types.go @@ -81,3 +81,16 @@ const ( // This reason is not defined in the qlog draft, but very useful for debugging. TimeoutReasonIdle ) + +type CongestionState uint8 + +const ( + // CongestionStateSlowStart is the slow start phase of Reno / Cubic + CongestionStateSlowStart CongestionState = iota + // CongestionStateCongestionAvoidance is the slow start phase of Reno / Cubic + CongestionStateCongestionAvoidance + // CongestionStateCongestionAvoidance is the recovery phase of Reno / Cubic + CongestionStateRecovery + // CongestionStateApplicationLimited means that the congestion controller is application limited + CongestionStateApplicationLimited +) diff --git a/metrics/metrics.go b/metrics/metrics.go index b0342557e..e151726f6 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -172,6 +172,7 @@ func (t *connTracer) ReceivedPacket(*logging.ExtendedHeader, logging.ByteCount, } func (t *connTracer) BufferedPacket(logging.PacketType) {} func (t *connTracer) DroppedPacket(logging.PacketType, logging.ByteCount, logging.PacketDropReason) {} +func (t *connTracer) UpdatedCongestionState(logging.CongestionState) {} func (t *connTracer) UpdatedMetrics(*logging.RTTStats, logging.ByteCount, logging.ByteCount, int) {} func (t *connTracer) LostPacket(encLevel logging.EncryptionLevel, _ logging.PacketNumber, reason logging.PacketLossReason) { stats.RecordWithTags( diff --git a/qlog/qlog.go b/qlog/qlog.go index 6ec2eff18..03107a775 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -316,6 +316,8 @@ func (t *connectionTracer) LostPacket(encLevel protocol.EncryptionLevel, pn prot t.mutex.Unlock() } +func (t *connectionTracer) UpdatedCongestionState(logging.CongestionState) {} + func (t *connectionTracer) UpdatedPTOCount(value uint32) { t.mutex.Lock() t.recordEvent(time.Now(), &eventUpdatedPTO{Value: value}) diff --git a/session_test.go b/session_test.go index 7a987a59f..bbadd556c 100644 --- a/session_test.go +++ b/session_test.go @@ -91,6 +91,7 @@ var _ = Describe("Session", func() { tracer = mocks.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() + tracer.EXPECT().UpdatedCongestionState(gomock.Any()) sess = newSession( mconn, sessionRunner, @@ -2154,6 +2155,7 @@ var _ = Describe("Client Session", func() { tracer = mocks.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() + tracer.EXPECT().UpdatedCongestionState(gomock.Any()) sess = newClientSession( mconn, sessionRunner,