diff --git a/ackhandler/interfaces.go b/ackhandler/interfaces.go index ee7ac5957..1d1b4be4e 100644 --- a/ackhandler/interfaces.go +++ b/ackhandler/interfaces.go @@ -28,8 +28,8 @@ type SentPacketHandler interface { // ReceivedPacketHandler handles ACKs needed to send for incoming packets type ReceivedPacketHandler interface { - ReceivedPacket(packetNumber protocol.PacketNumber) error + ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error ReceivedStopWaiting(*frames.StopWaitingFrame) error - GetAckFrame(dequeue bool) (*frames.AckFrame, error) + GetAckFrame() *frames.AckFrame } diff --git a/ackhandler/received_packet_handler.go b/ackhandler/received_packet_handler.go index 56ca3efd8..daebbfb2d 100644 --- a/ackhandler/received_packet_handler.go +++ b/ackhandler/received_packet_handler.go @@ -18,24 +18,36 @@ var ( var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number") type receivedPacketHandler struct { - largestObserved protocol.PacketNumber - ignorePacketsBelow protocol.PacketNumber - currentAckFrame *frames.AckFrame - stateChanged bool // has an ACK for this state already been sent? Will be set to false every time a new packet arrives, and to false every time an ACK is sent + largestObserved protocol.PacketNumber + ignorePacketsBelow protocol.PacketNumber + largestObservedReceivedTime time.Time packetHistory *receivedPacketHistory - largestObservedReceivedTime time.Time + ackSendDelay time.Duration + + packetsReceivedSinceLastAck int + retransmittablePacketsReceivedSinceLastAck int + ackQueued bool + ackAlarm time.Time + ackAlarmResetCallback func(time.Time) + lastAck *frames.AckFrame } // NewReceivedPacketHandler creates a new receivedPacketHandler -func NewReceivedPacketHandler() ReceivedPacketHandler { +func NewReceivedPacketHandler(ackAlarmResetCallback func(time.Time)) ReceivedPacketHandler { + // create a stopped timer, see https://github.com/golang/go/issues/12721#issuecomment-143010182 + timer := time.NewTimer(0) + <-timer.C + return &receivedPacketHandler{ - packetHistory: newReceivedPacketHistory(), + packetHistory: newReceivedPacketHistory(), + ackAlarmResetCallback: ackAlarmResetCallback, + ackSendDelay: protocol.AckSendDelay, } } -func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber) error { +func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error { if packetNumber == 0 { return errInvalidPacketNumber } @@ -55,14 +67,12 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe return err } - h.stateChanged = true - h.currentAckFrame = nil - if packetNumber > h.largestObserved { h.largestObserved = packetNumber h.largestObservedReceivedTime = time.Now() } + h.maybeQueueAck(packetNumber, shouldInstigateAck) return nil } @@ -78,29 +88,79 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame) return nil } -func (h *receivedPacketHandler) GetAckFrame(dequeue bool) (*frames.AckFrame, error) { - if !h.stateChanged { - return nil, nil +func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { + var ackAlarmSet bool + h.packetsReceivedSinceLastAck++ + + if shouldInstigateAck { + h.retransmittablePacketsReceivedSinceLastAck++ } - if dequeue { - h.stateChanged = false + // always ack the first packet + if h.lastAck == nil { + h.ackQueued = true } - if h.currentAckFrame != nil { - return h.currentAckFrame, nil + // Always send an ack every 20 packets in order to allow the peer to discard + // information from the SentPacketManager and provide an RTT measurement. + if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend { + h.ackQueued = true + } + + // if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK + // note that it cannot be a duplicate because they're already filtered out by ReceivedPacket() + if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked { + h.ackQueued = true + } + + // check if a new missing range above the previously was created + if h.lastAck != nil && h.packetHistory.GetHighestAckRange().FirstPacketNumber > h.lastAck.LargestAcked { + h.ackQueued = true + } + + if !h.ackQueued && shouldInstigateAck { + if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck { + h.ackQueued = true + } else { + if h.ackAlarm.IsZero() { + h.ackAlarm = time.Now().Add(h.ackSendDelay) + ackAlarmSet = true + } + } + } + + if h.ackQueued { + // cancel the ack alarm + h.ackAlarm = time.Time{} + ackAlarmSet = false + } + + if ackAlarmSet { + h.ackAlarmResetCallback(h.ackAlarm) + } +} + +func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame { + if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) { + return nil } ackRanges := h.packetHistory.GetAckRanges() - h.currentAckFrame = &frames.AckFrame{ + ack := &frames.AckFrame{ LargestAcked: h.largestObserved, LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber, PacketReceivedTime: h.largestObservedReceivedTime, } if len(ackRanges) > 1 { - h.currentAckFrame.AckRanges = ackRanges + ack.AckRanges = ackRanges } - return h.currentAckFrame, nil + h.lastAck = ack + h.ackAlarm = time.Time{} + h.ackQueued = false + h.packetsReceivedSinceLastAck = 0 + h.retransmittablePacketsReceivedSinceLastAck = 0 + + return ack } diff --git a/ackhandler/received_packet_handler_test.go b/ackhandler/received_packet_handler_test.go index d1d644f81..6e9d11340 100644 --- a/ackhandler/received_packet_handler_test.go +++ b/ackhandler/received_packet_handler_test.go @@ -12,57 +12,63 @@ import ( var _ = Describe("receivedPacketHandler", func() { var ( - handler *receivedPacketHandler + handler *receivedPacketHandler + ackAlarmCallbackCalled bool ) + ackAlarmCallback := func(time.Time) { + ackAlarmCallbackCalled = true + } + BeforeEach(func() { - handler = NewReceivedPacketHandler().(*receivedPacketHandler) + ackAlarmCallbackCalled = false + handler = NewReceivedPacketHandler(ackAlarmCallback).(*receivedPacketHandler) }) Context("accepting packets", func() { It("handles a packet that arrives late", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(1)) + err := handler.ReceivedPacket(protocol.PacketNumber(1), true) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(protocol.PacketNumber(3)) + err = handler.ReceivedPacket(protocol.PacketNumber(3), true) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(protocol.PacketNumber(2)) + err = handler.ReceivedPacket(protocol.PacketNumber(2), true) Expect(err).ToNot(HaveOccurred()) }) It("rejects packets with packet number 0", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(0)) + err := handler.ReceivedPacket(protocol.PacketNumber(0), true) Expect(err).To(MatchError(errInvalidPacketNumber)) }) It("rejects a duplicate package", func() { for i := 1; i < 5; i++ { - err := handler.ReceivedPacket(protocol.PacketNumber(i)) + err := handler.ReceivedPacket(protocol.PacketNumber(i), true) Expect(err).ToNot(HaveOccurred()) } - err := handler.ReceivedPacket(4) + err := handler.ReceivedPacket(4, true) Expect(err).To(MatchError(ErrDuplicatePacket)) }) It("ignores a packet with PacketNumber less than the LeastUnacked of a previously received StopWaiting", func() { - err := handler.ReceivedPacket(5) + err := handler.ReceivedPacket(5, true) Expect(err).ToNot(HaveOccurred()) err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10}) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(9) + err = handler.ReceivedPacket(9, true) Expect(err).To(MatchError(ErrPacketSmallerThanLastStopWaiting)) }) It("does not ignore a packet with PacketNumber equal to LeastUnacked of a previously received StopWaiting", func() { - err := handler.ReceivedPacket(5) + err := handler.ReceivedPacket(5, true) Expect(err).ToNot(HaveOccurred()) err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10}) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(10) + err = handler.ReceivedPacket(10, true) Expect(err).ToNot(HaveOccurred()) }) It("saves the time when each packet arrived", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(3)) + err := handler.ReceivedPacket(protocol.PacketNumber(3), true) Expect(err).ToNot(HaveOccurred()) Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) }) @@ -70,7 +76,7 @@ var _ = Describe("receivedPacketHandler", func() { It("updates the largestObserved and the largestObservedReceivedTime", func() { handler.largestObserved = 3 handler.largestObservedReceivedTime = time.Now().Add(-1 * time.Second) - err := handler.ReceivedPacket(5) + err := handler.ReceivedPacket(5, true) Expect(err).ToNot(HaveOccurred()) Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5))) Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) @@ -80,27 +86,27 @@ var _ = Describe("receivedPacketHandler", func() { timestamp := time.Now().Add(-1 * time.Second) handler.largestObserved = 5 handler.largestObservedReceivedTime = timestamp - err := handler.ReceivedPacket(4) + err := handler.ReceivedPacket(4, true) Expect(err).ToNot(HaveOccurred()) Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5))) Expect(handler.largestObservedReceivedTime).To(Equal(timestamp)) }) It("doesn't store more than MaxTrackedReceivedPackets packets", func() { - err := handler.ReceivedPacket(1) + err := handler.ReceivedPacket(1, true) Expect(err).ToNot(HaveOccurred()) for i := protocol.PacketNumber(3); i < 3+protocol.MaxTrackedReceivedPackets-1; i++ { - err := handler.ReceivedPacket(protocol.PacketNumber(i)) + err := handler.ReceivedPacket(protocol.PacketNumber(i), true) Expect(err).ToNot(HaveOccurred()) } - err = handler.ReceivedPacket(protocol.PacketNumber(protocol.MaxTrackedReceivedPackets) + 10) + err = handler.ReceivedPacket(protocol.PacketNumber(protocol.MaxTrackedReceivedPackets)+10, true) Expect(err).To(MatchError(errTooManyOutstandingReceivedPackets)) }) It("passes on errors from receivedPacketHistory", func() { var err error for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ { - err = handler.ReceivedPacket(2*i + 1) + err = handler.ReceivedPacket(2*i+1, true) // this will eventually return an error // details about when exactly the receivedPacketHistory errors are tested there if err != nil { @@ -120,7 +126,7 @@ var _ = Describe("receivedPacketHandler", func() { It("increase the ignorePacketsBelow number, even if all packets below the LeastUnacked were already acked", func() { for i := 1; i < 20; i++ { - err := handler.ReceivedPacket(protocol.PacketNumber(i)) + err := handler.ReceivedPacket(protocol.PacketNumber(i), true) Expect(err).ToNot(HaveOccurred()) } err := handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(12)}) @@ -138,133 +144,185 @@ var _ = Describe("receivedPacketHandler", func() { }) }) - Context("ACK package generation", func() { - It("generates a simple ACK frame", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(1)) - Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(protocol.PacketNumber(2)) - Expect(err).ToNot(HaveOccurred()) - ack, err := handler.GetAckFrame(true) - Expect(err).ToNot(HaveOccurred()) - Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(2))) - Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(1))) - Expect(ack.AckRanges).To(BeEmpty()) - }) - - It("generates an ACK frame with missing packets", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(1)) - Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(protocol.PacketNumber(4)) - Expect(err).ToNot(HaveOccurred()) - ack, err := handler.GetAckFrame(true) - Expect(err).ToNot(HaveOccurred()) - Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(4))) - Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(1))) - Expect(ack.AckRanges).To(HaveLen(2)) - Expect(ack.AckRanges[0]).To(Equal(frames.AckRange{FirstPacketNumber: 4, LastPacketNumber: 4})) - Expect(ack.AckRanges[1]).To(Equal(frames.AckRange{FirstPacketNumber: 1, LastPacketNumber: 1})) - }) - - It("does not generate an ACK if an ACK has already been sent for the largest Packet", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(1)) - Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(protocol.PacketNumber(2)) - Expect(err).ToNot(HaveOccurred()) - ack, err := handler.GetAckFrame(true) - Expect(err).ToNot(HaveOccurred()) - Expect(ack).ToNot(BeNil()) - ack, err = handler.GetAckFrame(true) - Expect(err).ToNot(HaveOccurred()) - Expect(ack).To(BeNil()) - }) - - It("does not dequeue an ACK frame if told so", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(2)) - Expect(err).ToNot(HaveOccurred()) - ack, err := handler.GetAckFrame(false) - Expect(err).ToNot(HaveOccurred()) - Expect(ack).ToNot(BeNil()) - ack, err = handler.GetAckFrame(false) - Expect(err).ToNot(HaveOccurred()) - Expect(ack).ToNot(BeNil()) - ack, err = handler.GetAckFrame(false) - Expect(err).ToNot(HaveOccurred()) - Expect(ack).ToNot(BeNil()) - }) - - It("returns a cached ACK frame if the ACK was not dequeued", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(2)) - Expect(err).ToNot(HaveOccurred()) - ack, err := handler.GetAckFrame(false) - Expect(err).ToNot(HaveOccurred()) - Expect(ack).ToNot(BeNil()) - ack2, err := handler.GetAckFrame(false) - Expect(err).ToNot(HaveOccurred()) - Expect(ack2).ToNot(BeNil()) - Expect(&ack).To(Equal(&ack2)) - }) - - It("generates a new ACK (and deletes the cached one) when a new packet arrives", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(1)) - Expect(err).ToNot(HaveOccurred()) - ack, _ := handler.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(1))) - err = handler.ReceivedPacket(protocol.PacketNumber(3)) - Expect(err).ToNot(HaveOccurred()) - ack, _ = handler.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(3))) - }) - - It("generates a new ACK when an out-of-order packet arrives", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(1)) - Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(protocol.PacketNumber(3)) - Expect(err).ToNot(HaveOccurred()) - ack, _ := handler.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.AckRanges).To(HaveLen(2)) - err = handler.ReceivedPacket(protocol.PacketNumber(2)) - Expect(err).ToNot(HaveOccurred()) - ack, _ = handler.GetAckFrame(true) - Expect(ack).ToNot(BeNil()) - Expect(ack.AckRanges).To(BeEmpty()) - }) - - It("doesn't send old ACK ranges after receiving a StopWaiting", func() { - err := handler.ReceivedPacket(5) - Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(10) - Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(11) - Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(12) - Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(11)}) - Expect(err).ToNot(HaveOccurred()) - ack, err := handler.GetAckFrame(true) - Expect(err).ToNot(HaveOccurred()) - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(12))) - Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(11))) - Expect(ack.HasMissingRanges()).To(BeFalse()) - }) - - It("deletes packets from the packetHistory after receiving a StopWaiting, after continuously received packets", func() { - for i := 1; i <= 12; i++ { - err := handler.ReceivedPacket(protocol.PacketNumber(i)) - Expect(err).ToNot(HaveOccurred()) + Context("ACKs", func() { + Context("queueing ACKs", func() { + receiveAndAck10Packets := func() { + for i := 1; i <= 10; i++ { + err := handler.ReceivedPacket(protocol.PacketNumber(i), true) + Expect(err).ToNot(HaveOccurred()) + } + Expect(handler.GetAckFrame()).ToNot(BeNil()) + Expect(handler.ackQueued).To(BeFalse()) + ackAlarmCallbackCalled = false } - err := handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(6)}) - Expect(err).ToNot(HaveOccurred()) - // check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame - ack, err := handler.GetAckFrame(true) - Expect(err).ToNot(HaveOccurred()) - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(12))) - Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(6))) - Expect(ack.HasMissingRanges()).To(BeFalse()) + + It("always queues an ACK for the first packet", func() { + err := handler.ReceivedPacket(1, false) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeTrue()) + Expect(ackAlarmCallbackCalled).To(BeFalse()) + }) + + It("only queues one ACK for many non-retransmittable packets", func() { + receiveAndAck10Packets() + for i := 11; i < 10+protocol.MaxPacketsReceivedBeforeAckSend; i++ { + err := handler.ReceivedPacket(protocol.PacketNumber(i), false) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeFalse()) + } + err := handler.ReceivedPacket(10+protocol.MaxPacketsReceivedBeforeAckSend, false) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeTrue()) + Expect(ackAlarmCallbackCalled).To(BeFalse()) + }) + + It("queues an ACK for every second retransmittable packet, if they are arriving fast", func() { + receiveAndAck10Packets() + err := handler.ReceivedPacket(11, true) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeFalse()) + Expect(ackAlarmCallbackCalled).To(BeTrue()) + ackAlarmCallbackCalled = false + err = handler.ReceivedPacket(12, true) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeTrue()) + Expect(ackAlarmCallbackCalled).To(BeFalse()) + }) + + It("only sets the timer when receiving a retransmittable packets", func() { + receiveAndAck10Packets() + err := handler.ReceivedPacket(11, false) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeFalse()) + Expect(handler.ackAlarm).To(BeZero()) + err = handler.ReceivedPacket(12, true) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeFalse()) + Expect(handler.ackAlarm).ToNot(BeZero()) + Expect(ackAlarmCallbackCalled).To(BeTrue()) + }) + + It("queues an ACK if it was reported missing before", func() { + receiveAndAck10Packets() + err := handler.ReceivedPacket(11, true) + Expect(err).ToNot(HaveOccurred()) + err = handler.ReceivedPacket(13, true) + Expect(err).ToNot(HaveOccurred()) + ack := handler.GetAckFrame() // ACK: 1 and 3, missing: 2 + Expect(ack).ToNot(BeNil()) + Expect(ack.HasMissingRanges()).To(BeTrue()) + Expect(handler.ackQueued).To(BeFalse()) + err = handler.ReceivedPacket(12, false) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeTrue()) + }) + + It("queues an ACK if it creates a new missing range", func() { + receiveAndAck10Packets() + for i := 11; i < 16; i++ { + err := handler.ReceivedPacket(protocol.PacketNumber(i), true) + Expect(err).ToNot(HaveOccurred()) + } + Expect(handler.GetAckFrame()).ToNot(BeNil()) + handler.ReceivedPacket(20, true) // we now know that packets 16 to 19 are missing + Expect(handler.ackQueued).To(BeTrue()) + }) + }) + + Context("ACK generation", func() { + BeforeEach(func() { + handler.ackQueued = true + }) + + It("generates a simple ACK frame", func() { + err := handler.ReceivedPacket(1, true) + Expect(err).ToNot(HaveOccurred()) + err = handler.ReceivedPacket(2, true) + Expect(err).ToNot(HaveOccurred()) + ack := handler.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(2))) + Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(1))) + Expect(ack.AckRanges).To(BeEmpty()) + }) + + It("saves the last sent ACK", func() { + err := handler.ReceivedPacket(1, true) + Expect(err).ToNot(HaveOccurred()) + ack := handler.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(handler.lastAck).To(Equal(ack)) + err = handler.ReceivedPacket(2, true) + Expect(err).ToNot(HaveOccurred()) + handler.ackQueued = true + ack = handler.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(handler.lastAck).To(Equal(ack)) + }) + + It("generates an ACK frame with missing packets", func() { + err := handler.ReceivedPacket(1, true) + Expect(err).ToNot(HaveOccurred()) + err = handler.ReceivedPacket(4, true) + Expect(err).ToNot(HaveOccurred()) + ack := handler.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(4))) + Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(1))) + Expect(ack.AckRanges).To(HaveLen(2)) + Expect(ack.AckRanges[0]).To(Equal(frames.AckRange{FirstPacketNumber: 4, LastPacketNumber: 4})) + Expect(ack.AckRanges[1]).To(Equal(frames.AckRange{FirstPacketNumber: 1, LastPacketNumber: 1})) + }) + + It("deletes packets from the packetHistory after receiving a StopWaiting, after continuously received packets", func() { + for i := 1; i <= 12; i++ { + err := handler.ReceivedPacket(protocol.PacketNumber(i), true) + Expect(err).ToNot(HaveOccurred()) + } + err := handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(6)}) + Expect(err).ToNot(HaveOccurred()) + // check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame + ack := handler.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(12))) + Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(6))) + Expect(ack.HasMissingRanges()).To(BeFalse()) + }) + + It("resets all counters needed for the ACK queueing decision when sending an ACK", func() { + err := handler.ReceivedPacket(1, true) + Expect(err).ToNot(HaveOccurred()) + handler.ackAlarm = time.Now().Add(-time.Minute) + Expect(handler.GetAckFrame()).ToNot(BeNil()) + Expect(handler.packetsReceivedSinceLastAck).To(BeZero()) + Expect(handler.ackAlarm).To(BeZero()) + Expect(handler.retransmittablePacketsReceivedSinceLastAck).To(BeZero()) + Expect(handler.ackQueued).To(BeFalse()) + }) + + It("doesn't generate an ACK when none is queued and the timer is not set", func() { + err := handler.ReceivedPacket(1, true) + Expect(err).ToNot(HaveOccurred()) + handler.ackQueued = false + handler.ackAlarm = time.Time{} + Expect(handler.GetAckFrame()).To(BeNil()) + }) + + It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() { + err := handler.ReceivedPacket(1, true) + Expect(err).ToNot(HaveOccurred()) + handler.ackQueued = false + handler.ackAlarm = time.Now().Add(time.Minute) + Expect(handler.GetAckFrame()).To(BeNil()) + }) + + It("generates an ACK when the timer has expired", func() { + err := handler.ReceivedPacket(1, true) + Expect(err).ToNot(HaveOccurred()) + handler.ackQueued = false + handler.ackAlarm = time.Now().Add(-time.Minute) + Expect(handler.GetAckFrame()).ToNot(BeNil()) + }) }) }) }) diff --git a/ackhandler/received_packet_history.go b/ackhandler/received_packet_history.go index 2c3b2cbd7..d45fe6fd8 100644 --- a/ackhandler/received_packet_history.go +++ b/ackhandler/received_packet_history.go @@ -133,3 +133,13 @@ func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange { return ackRanges } + +func (h *receivedPacketHistory) GetHighestAckRange() frames.AckRange { + ackRange := frames.AckRange{} + if h.ranges.Len() > 0 { + r := h.ranges.Back().Value + ackRange.FirstPacketNumber = r.Start + ackRange.LastPacketNumber = r.End + } + return ackRange +} diff --git a/ackhandler/received_packet_history_test.go b/ackhandler/received_packet_history_test.go index ed421e250..0f794e1e5 100644 --- a/ackhandler/received_packet_history_test.go +++ b/ackhandler/received_packet_history_test.go @@ -315,4 +315,23 @@ var _ = Describe("receivedPacketHistory", func() { Expect(ackRanges[2]).To(Equal(frames.AckRange{FirstPacketNumber: 1, LastPacketNumber: 2})) }) }) + + Context("Getting the highest ACK range", func() { + It("returns the zero value if there are no ranges", func() { + Expect(hist.GetHighestAckRange()).To(BeZero()) + }) + + It("gets a single ACK range", func() { + hist.ReceivedPacket(4) + hist.ReceivedPacket(5) + Expect(hist.GetHighestAckRange()).To(Equal(frames.AckRange{FirstPacketNumber: 4, LastPacketNumber: 5})) + }) + + It("gets the highest of multiple ACK ranges", func() { + hist.ReceivedPacket(3) + hist.ReceivedPacket(6) + hist.ReceivedPacket(7) + Expect(hist.GetHighestAckRange()).To(Equal(frames.AckRange{FirstPacketNumber: 6, LastPacketNumber: 7})) + }) + }) }) diff --git a/packet_packer.go b/packet_packer.go index dd9c3fa84..ef4080159 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -41,14 +41,14 @@ func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup *handshake. } func (p *packetPacker) PackConnectionClose(frame *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { - return p.packPacket(nil, []frames.Frame{frame}, leastUnacked, true, false) + return p.packPacket(nil, []frames.Frame{frame}, leastUnacked, true) } -func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, maySendOnlyAck bool) (*packedPacket, error) { - return p.packPacket(stopWaitingFrame, controlFrames, leastUnacked, false, maySendOnlyAck) +func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { + return p.packPacket(stopWaitingFrame, controlFrames, leastUnacked, false) } -func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, onlySendOneControlFrame, maySendOnlyAck bool) (*packedPacket, error) { +func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, onlySendOneControlFrame bool) (*packedPacket, error) { if len(controlFrames) > 0 { p.controlFrames = append(p.controlFrames, controlFrames...) } @@ -97,18 +97,6 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con if !onlySendOneControlFrame && len(payloadFrames) == 1 && stopWaitingFrame != nil { return nil, nil } - // Don't send out packets that only contain an ACK (plus optional STOP_WAITING), if requested - if !maySendOnlyAck { - if len(payloadFrames) == 1 { - if _, ok := payloadFrames[0].(*frames.AckFrame); ok { - return nil, nil - } - } else if len(payloadFrames) == 2 && stopWaitingFrame != nil { - if _, ok := payloadFrames[1].(*frames.AckFrame); ok { - return nil, nil - } - } - } raw := getPacketBuffer() buffer := bytes.NewBuffer(raw) diff --git a/packet_packer_test.go b/packet_packer_test.go index 9b6f52a60..05b5111b6 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -37,7 +37,7 @@ var _ = Describe("Packet packer", func() { }) It("returns nil when no packet is queued", func() { - p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) @@ -48,7 +48,7 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, } streamFramer.AddFrameForRetransmission(f) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b := &bytes.Buffer{} @@ -73,14 +73,14 @@ var _ = Describe("Packet packer", func() { ErrorCode: 0x1337, ReasonPhrase: "foobar", } - p, err := packer.packPacket(&frames.StopWaitingFrame{LeastUnacked: 13}, []frames.Frame{&ccf, &frames.WindowUpdateFrame{StreamID: 37}}, 0, true, true) + p, err := packer.packPacket(&frames.StopWaitingFrame{LeastUnacked: 13}, []frames.Frame{&ccf, &frames.WindowUpdateFrame{StreamID: 37}}, 0, true) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0]).To(Equal(&ccf)) }) It("packs only control frames", func() { - p, err := packer.PackPacket(nil, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true) + p, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}}, 0) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) @@ -88,10 +88,10 @@ var _ = Describe("Packet packer", func() { }) It("increases the packet number", func() { - p1, err := packer.PackPacket(nil, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true) + p1, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p1).ToNot(BeNil()) - p2, err := packer.PackPacket(nil, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true) + p2, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p2).ToNot(BeNil()) Expect(p2.number).To(BeNumerically(">", p1.number)) @@ -100,7 +100,7 @@ var _ = Describe("Packet packer", func() { It("packs a StopWaitingFrame first", func() { packer.packetNumberGenerator.next = 15 swf := &frames.StopWaitingFrame{LeastUnacked: 10} - p, err := packer.PackPacket(swf, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true) + p, err := packer.PackPacket(swf, []frames.Frame{&frames.RstStreamFrame{}}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.frames).To(HaveLen(2)) @@ -111,21 +111,21 @@ var _ = Describe("Packet packer", func() { packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number packer.packetNumberGenerator.next = packetNumber swf := &frames.StopWaitingFrame{LeastUnacked: packetNumber - 0x100} - p, err := packer.PackPacket(swf, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true) + p, err := packer.PackPacket(swf, []frames.Frame{&frames.RstStreamFrame{}}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.frames[0].(*frames.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) }) It("does not pack a packet containing only a StopWaitingFrame", func() { swf := &frames.StopWaitingFrame{LeastUnacked: 10} - p, err := packer.PackPacket(swf, []frames.Frame{}, 0, true) + p, err := packer.PackPacket(swf, []frames.Frame{}, 0) Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) It("packs a packet if it has queued control frames, but no new control frames", func() { packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}} - p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) }) @@ -169,7 +169,7 @@ var _ = Describe("Packet packer", func() { It("only increases the packet number when there is an actual packet to send", func() { packer.packetNumberGenerator.nextToSkip = 1000 - p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) @@ -178,7 +178,7 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, } streamFramer.AddFrameForRetransmission(f) - p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err = packer.PackPacket(nil, []frames.Frame{}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.number).To(Equal(protocol.PacketNumber(1))) @@ -219,12 +219,12 @@ var _ = Describe("Packet packer", func() { } streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f2) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1))) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) - p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err = packer.PackPacket(nil, []frames.Frame{}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) @@ -246,7 +246,7 @@ var _ = Describe("Packet packer", func() { streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f2) streamFramer.AddFrameForRetransmission(f3) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) b := &bytes.Buffer{} @@ -300,23 +300,23 @@ var _ = Describe("Packet packer", func() { } streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f2) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) - p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err = packer.PackPacket(nil, []frames.Frame{}, 0) Expect(p.frames).To(HaveLen(2)) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeTrue()) Expect(p.frames[1].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(err).ToNot(HaveOccurred()) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) - p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err = packer.PackPacket(nil, []frames.Frame{}, 0) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err = packer.PackPacket(nil, []frames.Frame{}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -329,7 +329,7 @@ var _ = Describe("Packet packer", func() { minLength, _ := f.MinLength(0) f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header streamFramer.AddFrameForRetransmission(f) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) @@ -395,25 +395,21 @@ var _ = Describe("Packet packer", func() { }) It("returns nil if we only have a single STOP_WAITING", func() { - p, err := packer.PackPacket(&frames.StopWaitingFrame{}, nil, 0, false) + p, err := packer.PackPacket(&frames.StopWaitingFrame{}, nil, 0) Expect(err).NotTo(HaveOccurred()) Expect(p).To(BeNil()) }) - It("returns nil if we only have a single STOP_WAITING and an ACK", func() { - p, err := packer.PackPacket(&frames.StopWaitingFrame{}, []frames.Frame{&frames.AckFrame{}}, 0, false) + It("packs a single ACK", func() { + ack := &frames.AckFrame{LargestAcked: 42} + p, err := packer.PackPacket(nil, []frames.Frame{ack}, 0) Expect(err).NotTo(HaveOccurred()) - Expect(p).To(BeNil()) - }) - - It("returns nil if we only have a single ACK", func() { - p, err := packer.PackPacket(nil, []frames.Frame{&frames.AckFrame{}}, 0, false) - Expect(err).NotTo(HaveOccurred()) - Expect(p).To(BeNil()) + Expect(p).ToNot(BeNil()) + Expect(p.frames[0]).To(Equal(ack)) }) It("does not return nil if we only have a single ACK but request it to be sent", func() { - p, err := packer.PackPacket(nil, []frames.Frame{&frames.AckFrame{}}, 0, true) + p, err := packer.PackPacket(nil, []frames.Frame{&frames.AckFrame{}}, 0) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) }) diff --git a/packet_unpacker.go b/packet_unpacker.go index 3434751b1..cec85ed7d 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -11,10 +11,6 @@ import ( "github.com/lucas-clemente/quic-go/qerr" ) -type unpackedPacket struct { - frames []frames.Frame -} - type packetUnpacker struct { version protocol.VersionNumber aead crypto.AEAD diff --git a/protocol/server_parameters.go b/protocol/server_parameters.go index f2ccce781..391f06231 100644 --- a/protocol/server_parameters.go +++ b/protocol/server_parameters.go @@ -12,8 +12,9 @@ const InitialCongestionWindow = 32 // session queues for later until it sends a public reset. const MaxUndecryptablePackets = 10 -// AckSendDelay is the maximal time delay applied to packets containing only ACKs -const AckSendDelay = 5 * time.Millisecond +// AckSendDelay is the maximum delay that can be applied to an ACK for a retransmittable packet +// This is the value Chromium is using +const AckSendDelay = 25 * time.Millisecond // ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data // This is the value that Google servers are using @@ -71,6 +72,12 @@ const MaxTrackedReceivedPackets = 2 * DefaultMaxCongestionWindow // MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow +// MaxPacketsReceivedBeforeAckSend is the number of packets that can be received before an ACK frame is sent +const MaxPacketsReceivedBeforeAckSend = 20 + +// RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for +const RetransmittablePacketsBeforeAck = 2 + // MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames // prevents DoS attacks against the streamFrameSorter const MaxStreamFrameSorterGaps = 1000 diff --git a/session.go b/session.go index 9c39fbf9e..21295e5be 100644 --- a/session.go +++ b/session.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net" - "runtime" "sync/atomic" "time" @@ -77,7 +76,7 @@ type Session struct { undecryptablePackets []*receivedPacket aeadChanged chan struct{} - delayedAckOriginTime time.Time + nextAckScheduledTime time.Time connectionParameters handshake.ConnectionParametersManager @@ -99,12 +98,9 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol connectionParameters := handshake.NewConnectionParamatersManager(v) var sentPacketHandler ackhandler.SentPacketHandler - var receivedPacketHandler ackhandler.ReceivedPacketHandler - rttStats := &congestion.RTTStats{} sentPacketHandler = ackhandler.NewSentPacketHandler(rttStats) - receivedPacketHandler = ackhandler.NewReceivedPacketHandler() flowControlManager := flowcontrol.NewFlowControlManager(connectionParameters, rttStats) now := time.Now() @@ -116,10 +112,9 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol streamCallback: streamCallback, closeCallback: closeCallback, - connectionParameters: connectionParameters, - sentPacketHandler: sentPacketHandler, - receivedPacketHandler: receivedPacketHandler, - flowControlManager: flowControlManager, + connectionParameters: connectionParameters, + sentPacketHandler: sentPacketHandler, + flowControlManager: flowControlManager, receivedPackets: make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets), closeChan: make(chan *qerr.QuicError, 1), @@ -133,6 +128,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol sessionCreationTime: now, } + session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged) session.streamsMap = newStreamsMap(session.newStream, session.connectionParameters) cryptoStream, _ := session.GetOrOpenStream(1) @@ -195,9 +191,6 @@ runLoop: // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. putPacketBuffer(p.publicHeader.Raw) - if s.delayedAckOriginTime.IsZero() { - s.delayedAckOriginTime = p.rcvTime - } case <-s.aeadChanged: s.tryDecryptingQueuedPackets() } @@ -225,8 +218,8 @@ runLoop: func (s *Session) maybeResetTimer() { nextDeadline := s.lastNetworkActivityTime.Add(s.idleTimeout()) - if !s.delayedAckOriginTime.IsZero() { - nextDeadline = utils.MinTime(nextDeadline, s.delayedAckOriginTime.Add(protocol.AckSendDelay)) + if !s.nextAckScheduledTime.IsZero() { + nextDeadline = utils.MinTime(nextDeadline, s.nextAckScheduledTime) } if rtoTime := s.sentPacketHandler.TimeOfFirstRTO(); !rtoTime.IsZero() { nextDeadline = utils.MinTime(nextDeadline, rtoTime) @@ -291,7 +284,7 @@ func (s *Session) handlePacketImpl(p *receivedPacket) error { // Only do this after decrypting, so we are sure the packet is not attacker-controlled s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) - err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber) + err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, packet.IsRetransmittable()) // ignore duplicate packets if err == ackhandler.ErrDuplicatePacket { utils.Infof("Ignoring packet 0x%x due to ErrDuplicatePacket", hdr.PacketNumber) @@ -514,45 +507,26 @@ func (s *Session) sendPacket() error { if err != nil { return err } - for _, wuf := range windowUpdateFrames { controlFrames = append(controlFrames, wuf) } - - ack, err := s.receivedPacketHandler.GetAckFrame(false) - if err != nil { - return err - } + ack := s.receivedPacketHandler.GetAckFrame() if ack != nil { controlFrames = append(controlFrames, ack) } - - // Check whether we are allowed to send a packet containing only an ACK - maySendOnlyAck := time.Now().Sub(s.delayedAckOriginTime) > protocol.AckSendDelay - if runtime.GOOS == "windows" { - maySendOnlyAck = true - } - hasRetransmission := s.streamFramer.HasFramesForRetransmission() - var stopWaitingFrame *frames.StopWaitingFrame if ack != nil || hasRetransmission { stopWaitingFrame = s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) } - packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, s.sentPacketHandler.GetLeastUnacked(), maySendOnlyAck) + packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, s.sentPacketHandler.GetLeastUnacked()) if err != nil { return err } if packet == nil { return nil } - - // Pop the ACK frame now that we are sure we're gonna send it - _, err = s.receivedPacketHandler.GetAckFrame(true) - if err != nil { - return err - } - + // send every window update twice for _, f := range windowUpdateFrames { s.packer.QueueControlFrameForNextPacket(f) } @@ -567,13 +541,13 @@ func (s *Session) sendPacket() error { } s.logPacket(packet) - s.delayedAckOriginTime = time.Time{} err = s.conn.write(packet.raw) putPacketBuffer(packet.raw) if err != nil { return err } + s.nextAckScheduledTime = time.Time{} } } @@ -695,6 +669,11 @@ func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) { return res, nil } +func (s *Session) ackAlarmChanged(t time.Time) { + s.nextAckScheduledTime = t + s.maybeResetTimer() +} + // RemoteAddr returns the net.UDPAddr of the client func (s *Session) RemoteAddr() *net.UDPAddr { return s.conn.RemoteAddr() diff --git a/session_test.go b/session_test.go index 9f781521d..30b04f988 100644 --- a/session_test.go +++ b/session_test.go @@ -88,6 +88,22 @@ func newMockSentPacketHandler() ackhandler.SentPacketHandler { return &mockSentPacketHandler{} } +var _ ackhandler.SentPacketHandler = &mockSentPacketHandler{} + +type mockReceivedPacketHandler struct { + nextAckFrame *frames.AckFrame +} + +func (m *mockReceivedPacketHandler) GetAckFrame() *frames.AckFrame { return m.nextAckFrame } +func (m *mockReceivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error { + panic("not implemented") +} +func (m *mockReceivedPacketHandler) ReceivedStopWaiting(*frames.StopWaitingFrame) error { + panic("not implemented") +} + +var _ ackhandler.ReceivedPacketHandler = &mockReceivedPacketHandler{} + var _ = Describe("Session", func() { var ( session *Session @@ -602,7 +618,7 @@ var _ = Describe("Session", func() { Context("sending packets", func() { It("sends ack frames", func() { packetNumber := protocol.PacketNumber(0x035E) - session.receivedPacketHandler.ReceivedPacket(packetNumber) + session.receivedPacketHandler.ReceivedPacket(packetNumber, true) err := session.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(conn.written).To(HaveLen(1)) @@ -734,6 +750,17 @@ var _ = Describe("Session", func() { s.(*stream).getDataForWriting(1000) // unblock }) + It("sets the timer to the ack timer", func() { + rph := &mockReceivedPacketHandler{} + rph.nextAckFrame = &frames.AckFrame{LargestAcked: 0x1337} + session.receivedPacketHandler = rph + go session.run() + session.ackAlarmChanged(time.Now().Add(10 * time.Millisecond)) + time.Sleep(10 * time.Millisecond) + Eventually(func() int { return len(conn.written) }).ShouldNot(BeZero()) + Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13}))) + }) + Context("bundling of small packets", func() { It("bundles two small frames of different streams into one packet", func() { s1, err := session.GetOrOpenStream(5) @@ -783,7 +810,7 @@ var _ = Describe("Session", func() { It("sends a queued ACK frame only once", func() { packetNumber := protocol.PacketNumber(0x1337) - session.receivedPacketHandler.ReceivedPacket(packetNumber) + session.receivedPacketHandler.ReceivedPacket(packetNumber, true) s, err := session.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) diff --git a/unpacked_packet.go b/unpacked_packet.go new file mode 100644 index 000000000..807920476 --- /dev/null +++ b/unpacked_packet.go @@ -0,0 +1,27 @@ +package quic + +import "github.com/lucas-clemente/quic-go/frames" + +type unpackedPacket struct { + frames []frames.Frame +} + +func (u *unpackedPacket) IsRetransmittable() bool { + for _, f := range u.frames { + switch f.(type) { + case *frames.StreamFrame: + return true + case *frames.RstStreamFrame: + return true + case *frames.WindowUpdateFrame: + return true + case *frames.BlockedFrame: + return true + case *frames.PingFrame: + return true + case *frames.GoawayFrame: + return true + } + } + return false +} diff --git a/unpacked_packet_test.go b/unpacked_packet_test.go new file mode 100644 index 000000000..82112a260 --- /dev/null +++ b/unpacked_packet_test.go @@ -0,0 +1,46 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/frames" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Unpacked packet", func() { + var packet *unpackedPacket + BeforeEach(func() { + packet = &unpackedPacket{} + }) + + It("says that an empty packet is not retransmittable", func() { + Expect(packet.IsRetransmittable()).To(BeFalse()) + }) + + It("detects the frame types", func() { + packet.frames = []frames.Frame{&frames.AckFrame{}} + Expect(packet.IsRetransmittable()).To(BeFalse()) + packet.frames = []frames.Frame{&frames.BlockedFrame{}} + Expect(packet.IsRetransmittable()).To(BeTrue()) + packet.frames = []frames.Frame{&frames.GoawayFrame{}} + Expect(packet.IsRetransmittable()).To(BeTrue()) + packet.frames = []frames.Frame{&frames.PingFrame{}} + Expect(packet.IsRetransmittable()).To(BeTrue()) + packet.frames = []frames.Frame{&frames.StreamFrame{}} + Expect(packet.IsRetransmittable()).To(BeTrue()) + packet.frames = []frames.Frame{&frames.RstStreamFrame{}} + Expect(packet.IsRetransmittable()).To(BeTrue()) + packet.frames = []frames.Frame{&frames.StopWaitingFrame{}} + Expect(packet.IsRetransmittable()).To(BeFalse()) + packet.frames = []frames.Frame{&frames.WindowUpdateFrame{}} + Expect(packet.IsRetransmittable()).To(BeTrue()) + }) + + It("says that a packet is retransmittable if it contains one retransmittable frame", func() { + packet.frames = []frames.Frame{ + &frames.AckFrame{}, + &frames.WindowUpdateFrame{}, + &frames.StopWaitingFrame{}, + } + Expect(packet.IsRetransmittable()).To(BeTrue()) + }) +})