diff --git a/ackhandler/received_packet_handler.go b/ackhandler/received_packet_handler.go index b54e15a02..97410f375 100644 --- a/ackhandler/received_packet_handler.go +++ b/ackhandler/received_packet_handler.go @@ -1,15 +1,12 @@ package ackhandler import ( - "errors" "time" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" ) -var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number") - type receivedPacketHandler struct { largestObserved protocol.PacketNumber ignoreBelow protocol.PacketNumber @@ -38,10 +35,6 @@ func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHand } func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error { - if packetNumber == 0 { - return errInvalidPacketNumber - } - if packetNumber > h.largestObserved { h.largestObserved = packetNumber h.largestObservedReceivedTime = time.Now() diff --git a/ackhandler/received_packet_handler_test.go b/ackhandler/received_packet_handler_test.go index c1d7ecd59..4a87c72d9 100644 --- a/ackhandler/received_packet_handler_test.go +++ b/ackhandler/received_packet_handler_test.go @@ -29,11 +29,6 @@ var _ = Describe("receivedPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) }) - It("rejects packets with packet number 0", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(0), true) - Expect(err).To(MatchError(errInvalidPacketNumber)) - }) - It("saves the time when each packet arrived", func() { err := handler.ReceivedPacket(protocol.PacketNumber(3), true) Expect(err).ToNot(HaveOccurred()) @@ -91,14 +86,11 @@ var _ = Describe("receivedPacketHandler", func() { Expect(handler.GetAlarmTimeout()).To(BeZero()) }) - It("doesn't queue an ACK for non-retransmittable packets", func() { - receiveAndAck10Packets() - handler.version = protocol.Version39 - for i := 11; i < 1000; i++ { - err := handler.ReceivedPacket(protocol.PacketNumber(i), false) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.ackQueued).To(BeFalse()) - } + It("works with packet number 0", func() { + err := handler.ReceivedPacket(0, false) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeTrue()) + Expect(handler.GetAlarmTimeout()).To(BeZero()) }) It("queues an ACK for every second retransmittable packet, if they are arriving fast", func() { @@ -173,6 +165,16 @@ var _ = Describe("receivedPacketHandler", func() { Expect(ack.AckRanges).To(BeEmpty()) }) + It("generates an ACK for packet number 0", func() { + err := handler.ReceivedPacket(0, true) + Expect(err).ToNot(HaveOccurred()) + ack := handler.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(0))) + Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(0))) + Expect(ack.AckRanges).To(BeEmpty()) + }) + It("saves the last sent ACK", func() { err := handler.ReceivedPacket(1, true) Expect(err).ToNot(HaveOccurred()) @@ -196,9 +198,27 @@ var _ = Describe("receivedPacketHandler", func() { Expect(ack).ToNot(BeNil()) Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(4))) Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(1))) - Expect(ack.AckRanges).To(HaveLen(2)) - Expect(ack.AckRanges[0]).To(Equal(wire.AckRange{First: 4, Last: 4})) - Expect(ack.AckRanges[1]).To(Equal(wire.AckRange{First: 1, Last: 1})) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{ + wire.AckRange{First: 4, Last: 4}, + wire.AckRange{First: 1, Last: 1}, + })) + }) + + It("generates an ACK for packet number 0 and other packets", func() { + err := handler.ReceivedPacket(0, true) + Expect(err).ToNot(HaveOccurred()) + err = handler.ReceivedPacket(1, true) + Expect(err).ToNot(HaveOccurred()) + err = handler.ReceivedPacket(3, true) + Expect(err).ToNot(HaveOccurred()) + ack := handler.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(3))) + Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(0))) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{ + wire.AckRange{First: 3, Last: 3}, + wire.AckRange{First: 0, Last: 1}, + })) }) It("accepts packets below the lower limit", func() {