diff --git a/ackhandler/interfaces.go b/ackhandler/interfaces.go index 1f956003..dbcdcb4c 100644 --- a/ackhandler/interfaces.go +++ b/ackhandler/interfaces.go @@ -24,7 +24,7 @@ type ReceivedPacketHandler interface { ReceivedPacket(packetNumber protocol.PacketNumber, entropyBit bool) error ReceivedStopWaiting(*frames.StopWaitingFrame) error - DequeueAckFrame() (*frames.AckFrame, error) + GetAckFrame(dequeue bool) (*frames.AckFrame, error) } // StopWaitingManager manages StopWaitings for sent packets diff --git a/ackhandler/received_packet_handler.go b/ackhandler/received_packet_handler.go index d0a411d4..f5efe839 100644 --- a/ackhandler/received_packet_handler.go +++ b/ackhandler/received_packet_handler.go @@ -23,6 +23,7 @@ type receivedPacketHandler struct { highestInOrderObservedEntropy EntropyAccumulator largestObserved protocol.PacketNumber packetHistory map[protocol.PacketNumber]packetHistoryEntry + currentAckFrame *frames.AckFrame stateChanged bool // has an ACK for this state already been sent? Will be set to false every time a new packet arrives, and to false every time an ACK is sent } @@ -43,6 +44,7 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe } h.stateChanged = true + h.currentAckFrame = nil if packetNumber > h.largestObserved { h.largestObserved = packetNumber @@ -100,25 +102,31 @@ func (h *receivedPacketHandler) getNackRanges() ([]frames.NackRange, EntropyAccu return ranges, entropy } -func (h *receivedPacketHandler) DequeueAckFrame() (*frames.AckFrame, error) { +func (h *receivedPacketHandler) GetAckFrame(dequeue bool) (*frames.AckFrame, error) { if !h.stateChanged { return nil, nil } - h.stateChanged = false + if dequeue { + h.stateChanged = false + } + + if h.currentAckFrame != nil { + return h.currentAckFrame, nil + } p, ok := h.packetHistory[h.largestObserved] if !ok { - return nil, errors.New("bla") + return nil, ErrMapAccess } packetReceivedTime := p.TimeReceived nackRanges, entropy := h.getNackRanges() - ack := frames.AckFrame{ + h.currentAckFrame = &frames.AckFrame{ LargestObserved: h.largestObserved, Entropy: byte(entropy), NackRanges: nackRanges, PacketReceivedTime: packetReceivedTime, } - return &ack, nil + return h.currentAckFrame, nil } diff --git a/ackhandler/received_packet_handler_test.go b/ackhandler/received_packet_handler_test.go index 883daad2..b8913914 100644 --- a/ackhandler/received_packet_handler_test.go +++ b/ackhandler/received_packet_handler_test.go @@ -249,7 +249,7 @@ var _ = Describe("receivedPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) err = handler.ReceivedPacket(protocol.PacketNumber(2), true) Expect(err).ToNot(HaveOccurred()) - ack, err := handler.DequeueAckFrame() + ack, err := handler.GetAckFrame(true) Expect(err).ToNot(HaveOccurred()) Expect(ack.LargestObserved).To(Equal(protocol.PacketNumber(2))) Expect(ack.Entropy).To(Equal(byte(entropy))) @@ -263,7 +263,7 @@ var _ = Describe("receivedPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) err = handler.ReceivedPacket(protocol.PacketNumber(4), true) Expect(err).ToNot(HaveOccurred()) - ack, err := handler.DequeueAckFrame() + ack, err := handler.GetAckFrame(true) Expect(err).ToNot(HaveOccurred()) Expect(ack.LargestObserved).To(Equal(protocol.PacketNumber(4))) Expect(ack.Entropy).To(Equal(byte(entropy))) @@ -275,19 +275,66 @@ var _ = Describe("receivedPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) err = handler.ReceivedPacket(protocol.PacketNumber(2), false) Expect(err).ToNot(HaveOccurred()) - Expect(handler.DequeueAckFrame()).ToNot(BeNil()) - Expect(handler.DequeueAckFrame()).To(BeNil()) + ack, err := handler.GetAckFrame(true) + Expect(err).ToNot(HaveOccurred()) + Expect(ack).ToNot(BeNil()) + ack, err = handler.GetAckFrame(true) + Expect(err).ToNot(HaveOccurred()) + Expect(ack).To(BeNil()) }) - It("generates an ACK when an out-of-order packet arrives", func() { + It("does not dequeue an ACK frame if told so", func() { + err := handler.ReceivedPacket(protocol.PacketNumber(2), false) + Expect(err).ToNot(HaveOccurred()) + ack, err := handler.GetAckFrame(false) + Expect(err).ToNot(HaveOccurred()) + Expect(ack).ToNot(BeNil()) + ack, err = handler.GetAckFrame(false) + Expect(err).ToNot(HaveOccurred()) + Expect(ack).ToNot(BeNil()) + ack, err = handler.GetAckFrame(false) + Expect(err).ToNot(HaveOccurred()) + Expect(ack).ToNot(BeNil()) + }) + + It("returns a cached ACK frame if the ACK was not dequeued", func() { + err := handler.ReceivedPacket(protocol.PacketNumber(2), false) + Expect(err).ToNot(HaveOccurred()) + ack, err := handler.GetAckFrame(false) + Expect(err).ToNot(HaveOccurred()) + Expect(ack).ToNot(BeNil()) + ack2, err := handler.GetAckFrame(false) + Expect(err).ToNot(HaveOccurred()) + Expect(ack2).ToNot(BeNil()) + Expect(&ack).To(Equal(&ack2)) + }) + + It("generates a new ACK (and deletes the cached one) when a new packet arrives", func() { + err := handler.ReceivedPacket(protocol.PacketNumber(1), false) + Expect(err).ToNot(HaveOccurred()) + ack, _ := handler.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestObserved).To(Equal(protocol.PacketNumber(1))) + err = handler.ReceivedPacket(protocol.PacketNumber(3), false) + Expect(err).ToNot(HaveOccurred()) + ack, _ = handler.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestObserved).To(Equal(protocol.PacketNumber(3))) + }) + + It("generates a new ACK when an out-of-order packet arrives", func() { err := handler.ReceivedPacket(protocol.PacketNumber(1), false) Expect(err).ToNot(HaveOccurred()) err = handler.ReceivedPacket(protocol.PacketNumber(3), false) Expect(err).ToNot(HaveOccurred()) - Expect(handler.DequeueAckFrame()).ToNot(BeNil()) + ack, _ := handler.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(len(ack.NackRanges)).To(Equal(1)) err = handler.ReceivedPacket(protocol.PacketNumber(2), false) Expect(err).ToNot(HaveOccurred()) - Expect(handler.DequeueAckFrame()).ToNot(BeNil()) + ack, _ = handler.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(len(ack.NackRanges)).To(Equal(0)) }) }) }) diff --git a/session.go b/session.go index 62dcc34c..8948df49 100644 --- a/session.go +++ b/session.go @@ -403,7 +403,7 @@ func (s *Session) sendPacket() error { stopWaitingFrame := s.stopWaitingManager.GetStopWaitingFrame() - ack, err := s.receivedPacketHandler.DequeueAckFrame() + ack, err := s.receivedPacketHandler.GetAckFrame(true) if err != nil { return err }