diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 59f4647d3..43027dcfb 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -33,7 +33,7 @@ type SentPacketHandler interface { GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen GetAlarmTimeout() time.Time - OnAlarm() + OnAlarm() error } // ReceivedPacketHandler handles ACKs needed to send for incoming packets diff --git a/internal/ackhandler/packet.go b/internal/ackhandler/packet.go index fb0a5a84a..14c7ea47e 100644 --- a/internal/ackhandler/packet.go +++ b/internal/ackhandler/packet.go @@ -18,6 +18,11 @@ type Packet struct { largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK sendTime time.Time + + queuedForRetransmission bool + retransmittedAs []protocol.PacketNumber + isRetransmission bool // we need a separate bool here because 0 is a valid packet number + retransmissionOf protocol.PacketNumber } // GetFramesForRetransmission gets all the frames for retransmission diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index f4fcb9edb..47c3ddee7 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -84,7 +84,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler { } func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber { - if p := h.packetHistory.Front(); p != nil { + if p := h.packetHistory.FirstOutstanding(); p != nil { return p.PacketNumber } return h.largestAcked + 1 @@ -112,6 +112,27 @@ func (h *sentPacketHandler) SetHandshakeComplete() { } func (h *sentPacketHandler) SentPacket(packet *Packet) { + packet.sendTime = time.Now() + if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable { + h.packetHistory.SentPacket(packet) + } + h.updateLossDetectionAlarm(packet.sendTime) +} + +func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) { + now := time.Now() + var p []*Packet + for _, packet := range packets { + packet.sendTime = now + if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable { + p = append(p, packet) + } + } + h.packetHistory.SentPacketsAsRetransmission(p, retransmissionOf) + h.updateLossDetectionAlarm(now) +} + +func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ { for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { h.skippedPackets = append(h.skippedPackets, p) if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets { @@ -119,7 +140,6 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) { } } - now := time.Now() h.lastSentPacketNumber = packet.PacketNumber var largestAcked protocol.PacketNumber @@ -133,27 +153,13 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) { isRetransmittable := len(packet.Frames) != 0 if isRetransmittable { - packet.sendTime = now packet.largestAcked = largestAcked h.bytesInFlight += packet.Length - h.packetHistory.SentPacket(packet) } - h.congestion.OnPacketSent( - now, - h.bytesInFlight, - packet.PacketNumber, - packet.Length, - isRetransmittable, - ) + h.congestion.OnPacketSent(packet.sendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable) - h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, now).Add(h.congestion.TimeUntilSend(h.bytesInFlight)) - h.updateLossDetectionAlarm(now) -} - -func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) { - for _, packet := range packets { - h.SentPacket(packet) - } + h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, packet.sendTime).Add(h.congestion.TimeUntilSend(h.bytesInFlight)) + return isRetransmittable } func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { @@ -192,11 +198,17 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe if p.largestAcked != 0 { h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1) } - h.onPacketAcked(p) - h.congestion.OnPacketAcked(p.PacketNumber, p.Length, h.bytesInFlight) + if err := h.onPacketAcked(p); err != nil { + return err + } + if len(p.retransmittedAs) == 0 { + h.congestion.OnPacketAcked(p.PacketNumber, p.Length, h.bytesInFlight) + } } - h.detectLostPackets(rcvTime) + if err := h.detectLostPackets(rcvTime); err != nil { + return err + } h.updateLossDetectionAlarm(rcvTime) h.garbageCollectSkippedPackets() @@ -271,7 +283,7 @@ func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) { } } -func (h *sentPacketHandler) detectLostPackets(now time.Time) { +func (h *sentPacketHandler) detectLostPackets(now time.Time) error { h.lossTime = time.Time{} maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) @@ -282,6 +294,9 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time) { if packet.PacketNumber > h.largestAcked { return false, nil } + if packet.queuedForRetransmission { // don't retransmit packets twice + return true, nil + } timeSinceSent := now.Sub(packet.sendTime) if timeSinceSent > delayUntilLost { @@ -294,40 +309,99 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time) { }) for _, p := range lostPackets { - h.queuePacketForRetransmission(p) + if err := h.queuePacketForRetransmission(p); err != nil { + return err + } h.congestion.OnPacketLost(p.PacketNumber, p.Length, h.bytesInFlight) } + return nil } -func (h *sentPacketHandler) OnAlarm() { +func (h *sentPacketHandler) OnAlarm() error { now := time.Now() // TODO(#497): TLP + var err error if !h.handshakeComplete { - h.queueHandshakePacketsForRetransmission() h.handshakeCount++ + err = h.queueHandshakePacketsForRetransmission() } else if !h.lossTime.IsZero() { // Early retransmit or time loss detection - h.detectLostPackets(now) + err = h.detectLostPackets(now) } else { // RTO - h.retransmitOldestTwoPackets() h.rtoCount++ + err = h.queueRTOs() + } + if err != nil { + return err } - h.updateLossDetectionAlarm(now) + return nil } func (h *sentPacketHandler) GetAlarmTimeout() time.Time { return h.alarm } -func (h *sentPacketHandler) onPacketAcked(p *Packet) { - h.bytesInFlight -= p.Length +func (h *sentPacketHandler) onPacketAcked(p *Packet) error { + // This happens if a packet and its retransmissions is acked in the same ACK. + // As soon as we process the first one, this will remove all the retransmissions, + // so we won't find the retransmitted packet number later. + if packet := h.packetHistory.GetPacket(p.PacketNumber); packet == nil { + return nil + } h.rtoCount = 0 h.handshakeCount = 0 // TODO(#497): h.tlpCount = 0 - h.packetHistory.Remove(p.PacketNumber) + + // find the first packet, from which on we can delete all retransmissions + // example: packet 10 was retransmitted as packet 11 and 12, and + // packet 12 was then retransmitted as 13. + // When receiving an ACK for packet 13, we can remove packets 12 and 13, + // but still need to keep 10 and 11. + first := p + for first.isRetransmission { + previous := h.packetHistory.GetPacket(first.retransmissionOf) + if previous == nil { + return fmt.Errorf("sent packet handler BUG: retransmitted packet for %d not found (should have been %d)", first.PacketNumber, first.retransmissionOf) + } + // if the retransmission of a packet was split, we can't remove it yet + if len(previous.retransmittedAs) > 1 { + break + } + first = previous + } + if first.isRetransmission { + root := h.packetHistory.GetPacket(first.retransmissionOf) + retransmittedAs := make([]protocol.PacketNumber, 0, len(root.retransmittedAs)-1) + for _, pn := range root.retransmittedAs { + if pn != first.PacketNumber { + retransmittedAs = append(retransmittedAs, pn) + } + } + root.retransmittedAs = retransmittedAs + } + return h.removeAllRetransmissions(first) +} + +func (h *sentPacketHandler) removeAllRetransmissions(p *Packet) error { + if !p.queuedForRetransmission { + // The bytes in flight are reduced when a packet is queued for retransmission. + // When a packet is acked, we only need to reduce it for packets that were not retransmitted. + h.bytesInFlight -= p.Length + } else { + for _, r := range p.retransmittedAs { + packet := h.packetHistory.GetPacket(r) + if packet == nil { + return fmt.Errorf("sent packet handler BUG: removing packet %d (retransmission of %d) not found in history", r, p.PacketNumber) + } + if err := h.removeAllRetransmissions(packet); err != nil { + return err + } + } + } + return h.packetHistory.Remove(p.PacketNumber) } func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet { @@ -389,40 +463,45 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int { return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay))) } -func (h *sentPacketHandler) retransmitOldestTwoPackets() { - if p := h.packetHistory.Front(); p != nil { - h.queueRTO(p) - } - if p := h.packetHistory.Front(); p != nil { - h.queueRTO(p) +// retransmit the oldest two packets +func (h *sentPacketHandler) queueRTOs() error { + for i := 0; i < 2; i++ { + if p := h.packetHistory.FirstOutstanding(); p != nil { + utils.Debugf("\tQueueing packet %#x for retransmission (RTO), %d outstanding", p.PacketNumber, h.packetHistory.Len()) + if err := h.queuePacketForRetransmission(p); err != nil { + return err + } + h.congestion.OnPacketLost(p.PacketNumber, p.Length, h.bytesInFlight) + h.congestion.OnRetransmissionTimeout(true) + } } + return nil } -func (h *sentPacketHandler) queueRTO(p *Packet) { - utils.Debugf("\tQueueing packet 0x%x for retransmission (RTO), %d outstanding", p.PacketNumber, h.packetHistory.Len()) - h.queuePacketForRetransmission(p) - h.congestion.OnPacketLost(p.PacketNumber, p.Length, h.bytesInFlight) - h.congestion.OnRetransmissionTimeout(true) -} - -func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() { +func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error { var handshakePackets []*Packet h.packetHistory.Iterate(func(p *Packet) (bool, error) { - if p.EncryptionLevel < protocol.EncryptionForwardSecure { + if !p.queuedForRetransmission && p.EncryptionLevel < protocol.EncryptionForwardSecure { handshakePackets = append(handshakePackets, p) } return true, nil }) for _, p := range handshakePackets { - h.queuePacketForRetransmission(p) + if err := h.queuePacketForRetransmission(p); err != nil { + return err + } } + return nil } -func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) { +func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error { + if _, err := h.packetHistory.QueuePacketForRetransmission(p.PacketNumber); err != nil { + return err + } h.bytesInFlight -= p.Length h.retransmissionQueue = append(h.retransmissionQueue, p) - h.packetHistory.Remove(p.PacketNumber) h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber) + return nil } func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 033578953..cfa4484f3 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -13,10 +13,17 @@ import ( . "github.com/onsi/gomega" ) -func retransmittablePacket(num protocol.PacketNumber) *Packet { +func retransmittablePacket(num protocol.PacketNumber, args ...protocol.ByteCount) *Packet { + length := protocol.ByteCount(1) + if len(args) == 1 { + length = args[0] + } + if len(args) > 1 { + Fail("invalid function parameters") + } return &Packet{ PacketNumber: num, - Length: 1, + Length: length, Frames: []wire.Frame{&wire.PingFrame{}}, EncryptionLevel: protocol.EncryptionForwardSecure, } @@ -415,6 +422,96 @@ var _ = Describe("SentPacketHandler", func() { }) }) + Context("Ack processing, for retransmitted packets", func() { + losePacket := func(pn protocol.PacketNumber) { + p := getPacket(pn) + ExpectWithOffset(1, p).ToNot(BeNil()) + handler.queuePacketForRetransmission(p) + r := handler.DequeuePacketForRetransmission() + ExpectWithOffset(1, r).ToNot(BeNil()) + ExpectWithOffset(1, r.PacketNumber).To(Equal(pn)) + } + + It("sends a packet as retransmission", func() { + // packet 5 was retransmitted as packet 6 + handler.SentPacket(retransmittablePacket(5, 10)) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) + losePacket(5) + Expect(handler.bytesInFlight).To(BeZero()) + handler.SentPacketsAsRetransmission([]*Packet{retransmittablePacket(6, 11)}, 5) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(11))) + }) + + It("removes all retransmissions when the original packet is acked", func() { + // packet 5 was retransmitted as packet 6, which was then retransmitted as packet 8 + handler.SentPacket(retransmittablePacket(5, 10)) + losePacket(5) + handler.SentPacketsAsRetransmission([]*Packet{retransmittablePacket(6, 11)}, 5) + losePacket(6) + handler.SentPacketsAsRetransmission([]*Packet{retransmittablePacket(8, 12)}, 6) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(12))) + // ack 5 + err := handler.ReceivedAck(&wire.AckFrame{LowestAcked: 5, LargestAcked: 5}, 1, protocol.EncryptionForwardSecure, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.packetHistory.Len()).To(BeZero()) + Expect(handler.bytesInFlight).To(BeZero()) + }) + + It("removes all retransmissions when the retransmission is acked", func() { + // the retransmission for packet 5 was split into packets 6 and 8 + handler.SentPacket(retransmittablePacket(5, 10)) + losePacket(5) + handler.SentPacketsAsRetransmission([]*Packet{retransmittablePacket(6, 11)}, 5) + losePacket(6) + handler.SentPacketsAsRetransmission([]*Packet{retransmittablePacket(8, 12)}, 6) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(12))) + // ack 8 + err := handler.ReceivedAck(&wire.AckFrame{LowestAcked: 8, LargestAcked: 8}, 1, protocol.EncryptionForwardSecure, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.packetHistory.Len()).To(BeZero()) + Expect(handler.bytesInFlight).To(BeZero()) + }) + + It("removes split retransmissions when the original packet is acked", func() { + // the retransmission for packet 5 was split into 8 and 9 + handler.SentPacket(retransmittablePacket(5, 10)) + losePacket(5) + handler.SentPacketsAsRetransmission([]*Packet{retransmittablePacket(8, 6), retransmittablePacket(9, 7)}, 5) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6 + 7))) + // ack 5 + err := handler.ReceivedAck(&wire.AckFrame{LowestAcked: 5, LargestAcked: 5}, 1, protocol.EncryptionForwardSecure, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.packetHistory.Len()).To(BeZero()) + Expect(handler.bytesInFlight).To(BeZero()) + }) + + It("doesn't remove the original packet if a split retransmission is acked", func() { + // the retransmission for packet 5 was split into 10 and 12 + handler.SentPacket(retransmittablePacket(5, 10)) + losePacket(5) + handler.SentPacketsAsRetransmission([]*Packet{retransmittablePacket(10, 6), retransmittablePacket(12, 7)}, 5) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6 + 7))) + // ack 10 + err := handler.ReceivedAck(&wire.AckFrame{LowestAcked: 10, LargestAcked: 10}, 1, protocol.EncryptionForwardSecure, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(7))) + expectInPacketHistory([]protocol.PacketNumber{5, 12}) + }) + + It("handles ACKs that ack the original packet as well as the retransmission", func() { + // packet 5 was retransmitted as packet 7 + handler.SentPacket(retransmittablePacket(5, 10)) + losePacket(5) + handler.SentPacketsAsRetransmission([]*Packet{retransmittablePacket(7, 11)}, 5) + // ack 5 and 7 + ack := createAck([]wire.AckRange{{First: 5, Last: 5}, {First: 7, Last: 7}}) + err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.packetHistory.Len()).To(BeZero()) + Expect(handler.bytesInFlight).To(BeZero()) + }) + }) + Context("Retransmission handling", func() { It("does not dequeue a packet if no ack has been received", func() { handler.SentPacket(retransmittablePacket(1)) @@ -683,7 +780,6 @@ var _ = Describe("SentPacketHandler", func() { p = handler.DequeuePacketForRetransmission() Expect(p).ToNot(BeNil()) Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(4))) - Expect(handler.packetHistory.Len()).To(Equal(1)) Expect(getPacket(3)).ToNot(BeNil()) Expect(handler.handshakeCount).To(BeEquivalentTo(1)) // make sure the exponential backoff is used diff --git a/internal/ackhandler/sent_packet_history.go b/internal/ackhandler/sent_packet_history.go index 486daf7a2..f815b8542 100644 --- a/internal/ackhandler/sent_packet_history.go +++ b/internal/ackhandler/sent_packet_history.go @@ -9,6 +9,8 @@ import ( type sentPacketHistory struct { packetList *PacketList packetMap map[protocol.PacketNumber]*PacketElement + + firstOutstanding *PacketElement } func newSentPacketHistory() *sentPacketHistory { @@ -19,8 +21,37 @@ func newSentPacketHistory() *sentPacketHistory { } func (h *sentPacketHistory) SentPacket(p *Packet) { + h.sentPacketImpl(p) +} + +func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement { el := h.packetList.PushBack(*p) h.packetMap[p.PacketNumber] = el + if h.firstOutstanding == nil { + h.firstOutstanding = el + } + return el +} + +func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) { + retransmission, ok := h.packetMap[retransmissionOf] + // The retransmitted packet is not present anymore. + // This can happen if it was acked in between dequeueing of the retransmission and sending. + // Just treat the retransmissions as normal packets. + // TODO: This won't happen if we clear packets queued for retransmission on new ACKs. + if !ok { + for _, packet := range packets { + h.sentPacketImpl(packet) + } + return + } + retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets)) + for i, packet := range packets { + retransmission.Value.retransmittedAs[i] = packet.PacketNumber + el := h.sentPacketImpl(packet) + el.Value.isRetransmission = true + el.Value.retransmissionOf = retransmissionOf + } } func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet { @@ -44,11 +75,41 @@ func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) err return nil } -func (h *sentPacketHistory) Front() *Packet { - if h.Len() == 0 { +// FirstOutStanding returns the first outstanding packet. +// It must not be modified (e.g. retransmitted). +// Use DequeueFirstPacketForRetransmission() to retransmit it. +func (h *sentPacketHistory) FirstOutstanding() *Packet { + if h.firstOutstanding == nil { return nil } - return &h.packetList.Front().Value + return &h.firstOutstanding.Value +} + +// QueuePacketForRetransmission marks a packet for retransmission. +// A packet can only be queued once. +func (h *sentPacketHistory) QueuePacketForRetransmission(pn protocol.PacketNumber) (*Packet, error) { + el, ok := h.packetMap[pn] + if !ok { + return nil, fmt.Errorf("sent packet history: packet %d not found", pn) + } + if el.Value.queuedForRetransmission { + return nil, fmt.Errorf("sent packet history BUG: packet %d already queued for retransmission", pn) + } + el.Value.queuedForRetransmission = true + if el == h.firstOutstanding { + h.readjustFirstOutstanding() + } + return &el.Value, nil +} + +// readjustFirstOutstanding readjusts the pointer to the first outstanding packet. +// This is necessary every time the first outstanding packet is deleted or retransmitted. +func (h *sentPacketHistory) readjustFirstOutstanding() { + el := h.firstOutstanding.Next() + for el != nil && el.Value.queuedForRetransmission { + el = el.Next() + } + h.firstOutstanding = el } func (h *sentPacketHistory) Len() int { @@ -60,6 +121,9 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { if !ok { return fmt.Errorf("packet %d not found in sent packet history", p) } + if el == h.firstOutstanding { + h.readjustFirstOutstanding() + } h.packetList.Remove(el) delete(h.packetMap, p) return nil diff --git a/internal/ackhandler/sent_packet_history_test.go b/internal/ackhandler/sent_packet_history_test.go index 7026f3ea3..746bf70f2 100644 --- a/internal/ackhandler/sent_packet_history_test.go +++ b/internal/ackhandler/sent_packet_history_test.go @@ -41,16 +41,57 @@ var _ = Describe("SentPacketHistory", func() { Expect(hist.Len()).To(Equal(2)) }) - It("gets nil, if there's no front packet", func() { - Expect(hist.Front()).To(BeNil()) - }) + Context("getting the first outstanding packet", func() { + It("gets nil, if there are no packets", func() { + Expect(hist.FirstOutstanding()).To(BeNil()) + }) - It("gets the front packet", func() { - hist.SentPacket(&Packet{PacketNumber: 2}) - hist.SentPacket(&Packet{PacketNumber: 3}) - front := hist.Front() - Expect(front).ToNot(BeNil()) - Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2))) + It("gets the first outstanding packet", func() { + hist.SentPacket(&Packet{PacketNumber: 2}) + hist.SentPacket(&Packet{PacketNumber: 3}) + front := hist.FirstOutstanding() + Expect(front).ToNot(BeNil()) + Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2))) + }) + + It("gets the second packet if the first one is retransmitted", func() { + hist.SentPacket(&Packet{PacketNumber: 1}) + hist.SentPacket(&Packet{PacketNumber: 3}) + hist.SentPacket(&Packet{PacketNumber: 4}) + front := hist.FirstOutstanding() + Expect(front).ToNot(BeNil()) + Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1))) + // Queue the first packet for retransmission. + // The first outstanding packet should now be 3. + _, err := hist.QueuePacketForRetransmission(1) + Expect(err).ToNot(HaveOccurred()) + front = hist.FirstOutstanding() + Expect(front).ToNot(BeNil()) + Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(3))) + }) + + It("gets the third packet if the first two are retransmitted", func() { + hist.SentPacket(&Packet{PacketNumber: 1}) + hist.SentPacket(&Packet{PacketNumber: 3}) + hist.SentPacket(&Packet{PacketNumber: 4}) + front := hist.FirstOutstanding() + Expect(front).ToNot(BeNil()) + Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1))) + // Queue the second packet for retransmission. + // The first outstanding packet should still be 3. + _, err := hist.QueuePacketForRetransmission(3) + Expect(err).ToNot(HaveOccurred()) + front = hist.FirstOutstanding() + Expect(front).ToNot(BeNil()) + Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1))) + // Queue the first packet for retransmission. + // The first outstanding packet should still be 4. + _, err = hist.QueuePacketForRetransmission(1) + Expect(err).ToNot(HaveOccurred()) + front = hist.FirstOutstanding() + Expect(front).ToNot(BeNil()) + Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(4))) + }) }) It("gets a packet by packet number", func() { @@ -119,4 +160,59 @@ var _ = Describe("SentPacketHistory", func() { Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14})) }) }) + + Context("retransmissions", func() { + BeforeEach(func() { + for i := protocol.PacketNumber(1); i <= 5; i++ { + hist.SentPacket(&Packet{PacketNumber: i}) + } + }) + + It("gets packets for retransmission", func() { + p, err := hist.QueuePacketForRetransmission(3) + Expect(err).ToNot(HaveOccurred()) + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(3))) + Expect(p.queuedForRetransmission).To(BeTrue()) + Expect(p.retransmissionOf).To(BeZero()) + Expect(p.isRetransmission).To(BeFalse()) + Expect(p.retransmittedAs).To(BeNil()) + }) + + It("errors if the packet was already queued for retransmission", func() { + p, err := hist.QueuePacketForRetransmission(5) + Expect(err).ToNot(HaveOccurred()) + Expect(p).ToNot(BeNil()) + _, err = hist.QueuePacketForRetransmission(5) + Expect(err).To(MatchError("sent packet history BUG: packet 5 already queued for retransmission")) + }) + + It("errors if the packet doesn't exist", func() { + _, err := hist.QueuePacketForRetransmission(100) + Expect(err).To(MatchError("sent packet history: packet 100 not found")) + }) + + It("adds a sent packets as a retransmission", func() { + hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}}, 2) + expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13}) + Expect(hist.GetPacket(13).isRetransmission).To(BeTrue()) + Expect(hist.GetPacket(13).retransmissionOf).To(Equal(protocol.PacketNumber(2))) + Expect(hist.GetPacket(2).retransmittedAs).To(Equal([]protocol.PacketNumber{13})) + }) + + It("adds multiple packets sent as a retransmission", func() { + hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}, {PacketNumber: 15}}, 2) + expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13, 15}) + Expect(hist.GetPacket(13).isRetransmission).To(BeTrue()) + Expect(hist.GetPacket(13).retransmissionOf).To(Equal(protocol.PacketNumber(2))) + Expect(hist.GetPacket(15).retransmissionOf).To(Equal(protocol.PacketNumber(2))) + Expect(hist.GetPacket(2).retransmittedAs).To(Equal([]protocol.PacketNumber{13, 15})) + }) + + It("adds a packet as a normal packet if the retransmitted packet doesn't exist", func() { + hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}}, 7) + expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13}) + Expect(hist.GetPacket(13).isRetransmission).To(BeFalse()) + Expect(hist.GetPacket(13).retransmissionOf).To(BeZero()) + }) + }) }) diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 5e22dddd1..b11d999a2 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -98,8 +98,10 @@ func (mr *MockSentPacketHandlerMockRecorder) GetStopWaitingFrame(arg0 interface{ } // OnAlarm mocks base method -func (m *MockSentPacketHandler) OnAlarm() { - m.ctrl.Call(m, "OnAlarm") +func (m *MockSentPacketHandler) OnAlarm() error { + ret := m.ctrl.Call(m, "OnAlarm") + ret0, _ := ret[0].(error) + return ret0 } // OnAlarm indicates an expected call of OnAlarm diff --git a/session.go b/session.go index 6976a3fed..622b668e1 100644 --- a/session.go +++ b/session.go @@ -418,9 +418,11 @@ runLoop: now := time.Now() if timeout := s.sentPacketHandler.GetAlarmTimeout(); !timeout.IsZero() && timeout.Before(now) { - // This could cause packets to be retransmitted, so check it before trying - // to send packets. - s.sentPacketHandler.OnAlarm() + // This could cause packets to be retransmitted. + // Check it before trying to send packets. + if err := s.sentPacketHandler.OnAlarm(); err != nil { + s.closeLocal(err) + } } var pacingDeadline time.Time