diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index 48060cc2d..d188691da 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -266,7 +266,7 @@ func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber { } func (h *sentPacketHandler) GetStopWaitingFrame() *frames.StopWaitingFrame { - return h.stopWaitingManager.GetStopWaitingFrame() + return h.stopWaitingManager.GetStopWaitingFrame(false) } func (h *sentPacketHandler) CongestionAllowsSending() bool { diff --git a/ackhandler/stop_waiting_manager.go b/ackhandler/stop_waiting_manager.go index 76f68adfe..dfd79ae0f 100644 --- a/ackhandler/stop_waiting_manager.go +++ b/ackhandler/stop_waiting_manager.go @@ -9,17 +9,24 @@ import ( type stopWaitingManager struct { largestLeastUnackedSent protocol.PacketNumber nextLeastUnacked protocol.PacketNumber + + lastStopWaitingFrame *frames.StopWaitingFrame } -func (s *stopWaitingManager) GetStopWaitingFrame() *frames.StopWaitingFrame { +func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame { if s.nextLeastUnacked <= s.largestLeastUnackedSent { + if force { + return s.lastStopWaitingFrame + } return nil } s.largestLeastUnackedSent = s.nextLeastUnacked - return &frames.StopWaitingFrame{ + swf := &frames.StopWaitingFrame{ LeastUnacked: s.nextLeastUnacked, } + s.lastStopWaitingFrame = swf + return swf } func (s *stopWaitingManager) ReceivedAck(ack *frames.AckFrame) { diff --git a/ackhandler/stop_waiting_manager_test.go b/ackhandler/stop_waiting_manager_test.go index f5c914476..73fdf1d85 100644 --- a/ackhandler/stop_waiting_manager_test.go +++ b/ackhandler/stop_waiting_manager_test.go @@ -13,35 +13,43 @@ var _ = Describe("StopWaitingManager", func() { }) It("returns nil in the beginning", func() { - Expect(manager.GetStopWaitingFrame()).To(BeNil()) + Expect(manager.GetStopWaitingFrame(false)).To(BeNil()) + Expect(manager.GetStopWaitingFrame(true)).To(BeNil()) }) It("returns a StopWaitingFrame, when a new ACK arrives", func() { manager.ReceivedAck(&frames.AckFrame{LargestAcked: 10}) - Expect(manager.GetStopWaitingFrame()).To(Equal(&frames.StopWaitingFrame{LeastUnacked: 11})) + Expect(manager.GetStopWaitingFrame(false)).To(Equal(&frames.StopWaitingFrame{LeastUnacked: 11})) }) It("does not decrease the LeastUnacked", func() { manager.ReceivedAck(&frames.AckFrame{LargestAcked: 10}) manager.ReceivedAck(&frames.AckFrame{LargestAcked: 9}) - Expect(manager.GetStopWaitingFrame()).To(Equal(&frames.StopWaitingFrame{LeastUnacked: 11})) + Expect(manager.GetStopWaitingFrame(false)).To(Equal(&frames.StopWaitingFrame{LeastUnacked: 11})) }) It("does not send the same StopWaitingFrame twice", func() { manager.ReceivedAck(&frames.AckFrame{LargestAcked: 10}) - Expect(manager.GetStopWaitingFrame()).ToNot(BeNil()) - Expect(manager.GetStopWaitingFrame()).To(BeNil()) + Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil()) + Expect(manager.GetStopWaitingFrame(false)).To(BeNil()) + }) + + It("gets the same StopWaitingFrame twice, if forced", func() { + manager.ReceivedAck(&frames.AckFrame{LargestAcked: 10}) + Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil()) + Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil()) + Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil()) }) It("increases the LeastUnacked when a retransmission is queued", func() { manager.ReceivedAck(&frames.AckFrame{LargestAcked: 10}) manager.QueuedRetransmissionForPacketNumber(20) - Expect(manager.GetStopWaitingFrame()).To(Equal(&frames.StopWaitingFrame{LeastUnacked: 21})) + Expect(manager.GetStopWaitingFrame(false)).To(Equal(&frames.StopWaitingFrame{LeastUnacked: 21})) }) It("does not decrease the LeastUnacked when a retransmission is queued", func() { manager.ReceivedAck(&frames.AckFrame{LargestAcked: 10}) manager.QueuedRetransmissionForPacketNumber(9) - Expect(manager.GetStopWaitingFrame()).To(Equal(&frames.StopWaitingFrame{LeastUnacked: 11})) + Expect(manager.GetStopWaitingFrame(false)).To(Equal(&frames.StopWaitingFrame{LeastUnacked: 11})) }) })