diff --git a/ackhandler/interfaces.go b/ackhandler/interfaces.go index 7dec882a..e100264a 100644 --- a/ackhandler/interfaces.go +++ b/ackhandler/interfaces.go @@ -27,5 +27,6 @@ type ReceivedPacketHandler interface { ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error ReceivedStopWaiting(*frames.StopWaitingFrame) error + GetAlarmTimeout() time.Time GetAckFrame() *frames.AckFrame } diff --git a/ackhandler/received_packet_handler.go b/ackhandler/received_packet_handler.go index 88002575..c5e9dc29 100644 --- a/ackhandler/received_packet_handler.go +++ b/ackhandler/received_packet_handler.go @@ -23,16 +23,14 @@ type receivedPacketHandler struct { retransmittablePacketsReceivedSinceLastAck int ackQueued bool ackAlarm time.Time - ackAlarmResetCallback func(time.Time) lastAck *frames.AckFrame } // NewReceivedPacketHandler creates a new receivedPacketHandler -func NewReceivedPacketHandler(ackAlarmResetCallback func(time.Time)) ReceivedPacketHandler { +func NewReceivedPacketHandler() ReceivedPacketHandler { return &receivedPacketHandler{ - packetHistory: newReceivedPacketHistory(), - ackAlarmResetCallback: ackAlarmResetCallback, - ackSendDelay: protocol.AckSendDelay, + packetHistory: newReceivedPacketHistory(), + ackSendDelay: protocol.AckSendDelay, } } @@ -69,7 +67,6 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame) } func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { - var ackAlarmSet bool h.packetsReceivedSinceLastAck++ if shouldInstigateAck { @@ -104,7 +101,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber } else { if h.ackAlarm.IsZero() { h.ackAlarm = time.Now().Add(h.ackSendDelay) - ackAlarmSet = true } } } @@ -112,11 +108,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber if h.ackQueued { // cancel the ack alarm h.ackAlarm = time.Time{} - ackAlarmSet = false - } - - if ackAlarmSet { - h.ackAlarmResetCallback(h.ackAlarm) } } @@ -144,3 +135,5 @@ func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame { return ack } + +func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm } diff --git a/ackhandler/received_packet_handler_test.go b/ackhandler/received_packet_handler_test.go index cece2df3..3a78be3d 100644 --- a/ackhandler/received_packet_handler_test.go +++ b/ackhandler/received_packet_handler_test.go @@ -12,17 +12,11 @@ import ( var _ = Describe("receivedPacketHandler", func() { var ( - handler *receivedPacketHandler - ackAlarmCallbackCalled bool + handler *receivedPacketHandler ) - ackAlarmCallback := func(time.Time) { - ackAlarmCallbackCalled = true - } - BeforeEach(func() { - ackAlarmCallbackCalled = false - handler = NewReceivedPacketHandler(ackAlarmCallback).(*receivedPacketHandler) + handler = NewReceivedPacketHandler().(*receivedPacketHandler) }) Context("accepting packets", func() { @@ -135,14 +129,13 @@ var _ = Describe("receivedPacketHandler", func() { } Expect(handler.GetAckFrame()).ToNot(BeNil()) Expect(handler.ackQueued).To(BeFalse()) - ackAlarmCallbackCalled = false } It("always queues an ACK for the first packet", func() { err := handler.ReceivedPacket(1, false) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeTrue()) - Expect(ackAlarmCallbackCalled).To(BeFalse()) + Expect(handler.GetAlarmTimeout()).To(BeZero()) }) It("only queues one ACK for many non-retransmittable packets", func() { @@ -155,7 +148,7 @@ var _ = Describe("receivedPacketHandler", func() { err := handler.ReceivedPacket(10+protocol.MaxPacketsReceivedBeforeAckSend, false) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeTrue()) - Expect(ackAlarmCallbackCalled).To(BeFalse()) + Expect(handler.GetAlarmTimeout()).To(BeZero()) }) It("queues an ACK for every second retransmittable packet, if they are arriving fast", func() { @@ -163,12 +156,11 @@ var _ = Describe("receivedPacketHandler", func() { err := handler.ReceivedPacket(11, true) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeFalse()) - Expect(ackAlarmCallbackCalled).To(BeTrue()) - ackAlarmCallbackCalled = false + Expect(handler.GetAlarmTimeout()).NotTo(BeZero()) err = handler.ReceivedPacket(12, true) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeTrue()) - Expect(ackAlarmCallbackCalled).To(BeFalse()) + Expect(handler.GetAlarmTimeout()).To(BeZero()) }) It("only sets the timer when receiving a retransmittable packets", func() { @@ -181,7 +173,7 @@ var _ = Describe("receivedPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeFalse()) Expect(handler.ackAlarm).ToNot(BeZero()) - Expect(ackAlarmCallbackCalled).To(BeTrue()) + Expect(handler.GetAlarmTimeout()).NotTo(BeZero()) }) It("queues an ACK if it was reported missing before", func() { diff --git a/session.go b/session.go index be2247e7..c2416017 100644 --- a/session.go +++ b/session.go @@ -98,8 +98,6 @@ type session struct { // it receives at most 3 handshake events: 2 when the encryption level changes, and one error handshakeChan chan<- handshakeEvent - nextAckScheduledTime time.Time - connectionParameters handshake.ConnectionParametersManager lastRcvdPacketNumber protocol.PacketNumber @@ -178,7 +176,7 @@ func (s *session) setup( s.config.MaxReceiveStreamFlowControlWindow, s.config.MaxReceiveConnectionFlowControlWindow) s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.flowControlManager = flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats) - s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) + s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler() s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters) s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) @@ -332,8 +330,8 @@ func (s *session) WaitUntilClosed() { func (s *session) maybeResetTimer() { deadline := s.lastNetworkActivityTime.Add(s.idleTimeout()) - if !s.nextAckScheduledTime.IsZero() { - deadline = utils.MinTime(deadline, s.nextAckScheduledTime) + if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() { + deadline = utils.MinTime(deadline, ackAlarm) } if lossTime := s.sentPacketHandler.GetAlarmTimeout(); !lossTime.IsZero() { deadline = utils.MinTime(deadline, lossTime) @@ -656,7 +654,6 @@ func (s *session) sendPacket() error { } windowUpdateFrames = nil ack = nil - s.nextAckScheduledTime = time.Time{} } } @@ -817,11 +814,6 @@ func (s *session) getWindowUpdateFrames() []*frames.WindowUpdateFrame { return res } -func (s *session) ackAlarmChanged(t time.Time) { - s.nextAckScheduledTime = t - s.maybeResetTimer() -} - func (s *session) LocalAddr() net.Addr { return s.conn.LocalAddr() } diff --git a/session_test.go b/session_test.go index c66379e4..6c07b986 100644 --- a/session_test.go +++ b/session_test.go @@ -101,6 +101,7 @@ var _ ackhandler.SentPacketHandler = &mockSentPacketHandler{} type mockReceivedPacketHandler struct { nextAckFrame *frames.AckFrame + ackAlarm time.Time } func (m *mockReceivedPacketHandler) GetAckFrame() *frames.AckFrame { @@ -114,6 +115,7 @@ func (m *mockReceivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketN func (m *mockReceivedPacketHandler) ReceivedStopWaiting(*frames.StopWaitingFrame) error { panic("not implemented") } +func (m *mockReceivedPacketHandler) GetAlarmTimeout() time.Time { return m.ackAlarm } var _ ackhandler.ReceivedPacketHandler = &mockReceivedPacketHandler{} @@ -1187,12 +1189,11 @@ var _ = Describe("Session", func() { }) It("sets the timer to the ack timer", func() { - rph := &mockReceivedPacketHandler{} + rph := &mockReceivedPacketHandler{ackAlarm: time.Now().Add(10 * time.Millisecond)} rph.nextAckFrame = &frames.AckFrame{LargestAcked: 0x1337} sess.receivedPacketHandler = rph go sess.run() defer sess.Close(nil) - sess.ackAlarmChanged(time.Now().Add(10 * time.Millisecond)) time.Sleep(10 * time.Millisecond) Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero()) Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13})))