diff --git a/ackhandler/received_packet_handler.go b/ackhandler/received_packet_handler.go index 909e7854..3ca1b8d1 100644 --- a/ackhandler/received_packet_handler.go +++ b/ackhandler/received_packet_handler.go @@ -28,7 +28,8 @@ type receivedPacketHandler struct { currentAckFrame *frames.AckFrame stateChanged bool // has an ACK for this state already been sent? Will be set to false every time a new packet arrives, and to false every time an ACK is sent - packetHistory *receivedPacketHistory + packetHistory *receivedPacketHistory + receivedTimes map[protocol.PacketNumber]time.Time lowestInReceivedTimes protocol.PacketNumber } @@ -72,8 +73,6 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe h.receivedTimes[packetNumber] = time.Now() - h.garbageCollect() - if uint32(len(h.receivedTimes)) > protocol.MaxTrackedReceivedPackets { return errTooManyOutstandingReceivedPackets } @@ -88,6 +87,7 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame) } h.ignorePacketsBelow = f.LeastUnacked - 1 + h.garbageCollectReceivedTimes() // ignore if StopWaiting is unneeded, since all packets below have already been received if h.largestInOrderObserved >= f.LeastUnacked { @@ -99,8 +99,6 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame) h.packetHistory.DeleteBelow(f.LeastUnacked) - h.garbageCollect() - return nil } @@ -136,9 +134,9 @@ func (h *receivedPacketHandler) GetAckFrame(dequeue bool) (*frames.AckFrame, err return h.currentAckFrame, nil } -func (h *receivedPacketHandler) garbageCollect() { - for i := h.lowestInReceivedTimes; i < h.largestInOrderObserved; i++ { +func (h *receivedPacketHandler) garbageCollectReceivedTimes() { + for i := h.lowestInReceivedTimes; i <= h.ignorePacketsBelow; i++ { delete(h.receivedTimes, i) } - h.lowestInReceivedTimes = h.largestInOrderObserved + h.lowestInReceivedTimes = h.ignorePacketsBelow } diff --git a/ackhandler/received_packet_handler_test.go b/ackhandler/received_packet_handler_test.go index 8e3b3736..a84e5a25 100644 --- a/ackhandler/received_packet_handler_test.go +++ b/ackhandler/received_packet_handler_test.go @@ -249,21 +249,42 @@ var _ = Describe("receivedPacketHandler", func() { }) Context("Garbage Collector", func() { - PIt("only keeps packets with packet numbers higher than the highestInOrderObserved in packetHistory", func() { - handler.ReceivedPacket(1, false) - handler.ReceivedPacket(2, false) - handler.ReceivedPacket(4, false) + It("garbage collects receivedTimes after receiving a StopWaiting, if there are no missing packets", func() { + for i := 1; i <= 4; i++ { + err := handler.ReceivedPacket(protocol.PacketNumber(i), false) + Expect(err).ToNot(HaveOccurred()) + } + err := handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 3}) + Expect(err).ToNot(HaveOccurred()) Expect(handler.receivedTimes).ToNot(HaveKey(protocol.PacketNumber(1))) - Expect(handler.receivedTimes).To(HaveKey(protocol.PacketNumber(2))) + Expect(handler.receivedTimes).ToNot(HaveKey(protocol.PacketNumber(2))) + Expect(handler.receivedTimes).To(HaveKey(protocol.PacketNumber(3))) Expect(handler.receivedTimes).To(HaveKey(protocol.PacketNumber(4))) }) - It("garbage collects packetHistory after receiving a StopWaiting", func() { - handler.ReceivedPacket(1, false) - handler.ReceivedPacket(2, false) - handler.ReceivedPacket(4, false) - swf := frames.StopWaitingFrame{LeastUnacked: 4} - handler.ReceivedStopWaiting(&swf) + It("garbage collects the receivedTimes after receiving multiple StopWaitings", func() { + for i := 1; i <= 9; i++ { + err := handler.ReceivedPacket(protocol.PacketNumber(i), false) + Expect(err).ToNot(HaveOccurred()) + } + err := handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 4}) + Expect(err).ToNot(HaveOccurred()) + err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 8}) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.receivedTimes).To(HaveLen(2)) // packets 8 and 9 + Expect(handler.receivedTimes).To(HaveKey(protocol.PacketNumber(8))) + Expect(handler.receivedTimes).To(HaveKey(protocol.PacketNumber(9))) + }) + + It("garbage collects receivedTimes after receiving a StopWaiting, if there are missing packets", func() { + err := handler.ReceivedPacket(1, false) + Expect(err).ToNot(HaveOccurred()) + err = handler.ReceivedPacket(2, false) + Expect(err).ToNot(HaveOccurred()) + err = handler.ReceivedPacket(4, false) + Expect(err).ToNot(HaveOccurred()) + err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 4}) + Expect(err).ToNot(HaveOccurred()) Expect(handler.receivedTimes).ToNot(HaveKey(protocol.PacketNumber(1))) Expect(handler.receivedTimes).ToNot(HaveKey(protocol.PacketNumber(2))) Expect(handler.receivedTimes).To(HaveKey(protocol.PacketNumber(4)))