diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 98a04dd7..cc6bbd7d 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -41,7 +41,7 @@ type sentPacketHandler struct { // once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101 lowestPacketNotConfirmedAcked protocol.PacketNumber - packetHistory *PacketList + packetHistory *sentPacketHistory stopWaitingManager stopWaitingManager retransmissionQueue []*Packet @@ -76,7 +76,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler { ) return &sentPacketHandler{ - packetHistory: NewPacketList(), + packetHistory: newSentPacketHistory(), stopWaitingManager: stopWaitingManager{}, rttStats: rttStats, congestion: congestion, @@ -84,8 +84,8 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler { } func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber { - if f := h.packetHistory.Front(); f != nil { - return f.Value.PacketNumber + if p := h.packetHistory.Front(); p != nil { + return p.PacketNumber } return h.largestAcked + 1 } @@ -97,12 +97,15 @@ func (h *sentPacketHandler) SetHandshakeComplete() { queue = append(queue, packet) } } - for el := h.packetHistory.Front(); el != nil; { - next := el.Next() - if el.Value.EncryptionLevel != protocol.EncryptionForwardSecure { - h.packetHistory.Remove(el) + var handshakePackets []*Packet + h.packetHistory.Iterate(func(p *Packet) (bool, error) { + if p.EncryptionLevel != protocol.EncryptionForwardSecure { + handshakePackets = append(handshakePackets, p) } - el = next + return true, nil + }) + for _, p := range handshakePackets { + h.packetHistory.Remove(p.PacketNumber) } h.retransmissionQueue = queue h.handshakeComplete = true @@ -133,7 +136,7 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) { packet.sendTime = now packet.largestAcked = largestAcked h.bytesInFlight += packet.Length - h.packetHistory.PushBack(*packet) + h.packetHistory.SentPacket(packet) } h.congestion.OnPacketSent( now, @@ -164,9 +167,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number") } - rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime) - - if rttUpdated { + if rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime); rttUpdated { h.congestion.MaybeExitSlowStart() } @@ -175,20 +176,18 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe return err } - if len(ackedPackets) > 0 { - for _, p := range ackedPackets { - if encLevel < p.Value.EncryptionLevel { - return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel) - } - // largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0 - // It is safe to ignore the corner case of packets that just acked packet 0, because - // the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send. - if p.Value.largestAcked != 0 { - h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1) - } - h.onPacketAcked(p) - h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight) + for _, p := range ackedPackets { + if encLevel < p.EncryptionLevel { + return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel) } + // largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0 + // It is safe to ignore the corner case of packets that just acked packet 0, because + // the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send. + if p.largestAcked != 0 { + h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1) + } + h.onPacketAcked(p) + h.congestion.OnPacketAcked(p.PacketNumber, p.Length, h.bytesInFlight) } h.detectLostPackets(rcvTime) @@ -204,56 +203,56 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu return h.lowestPacketNotConfirmedAcked } -func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) { - var ackedPackets []*PacketElement +func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) { + var ackedPackets []*Packet ackRangeIndex := 0 - for el := h.packetHistory.Front(); el != nil; el = el.Next() { - packet := el.Value - packetNumber := packet.PacketNumber - + err := h.packetHistory.Iterate(func(p *Packet) (bool, error) { // Ignore packets below the LowestAcked - if packetNumber < ackFrame.LowestAcked { - continue + if p.PacketNumber < ackFrame.LowestAcked { + return true, nil } // Break after LargestAcked is reached - if packetNumber > ackFrame.LargestAcked { - break + if p.PacketNumber > ackFrame.LargestAcked { + return false, nil } if ackFrame.HasMissingRanges() { ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] - for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 { + for p.PacketNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 { ackRangeIndex++ ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] } - if packetNumber >= ackRange.First { // packet i contained in ACK range - if packetNumber > ackRange.Last { - return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last) + if p.PacketNumber >= ackRange.First { // packet i contained in ACK range + if p.PacketNumber > ackRange.Last { + return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.First, ackRange.Last) } - ackedPackets = append(ackedPackets, el) + ackedPackets = append(ackedPackets, p) } } else { - ackedPackets = append(ackedPackets, el) + ackedPackets = append(ackedPackets, p) } - } - return ackedPackets, nil + return true, nil + }) + return ackedPackets, err } func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool { - for el := h.packetHistory.Front(); el != nil; el = el.Next() { - packet := el.Value - if packet.PacketNumber == largestAcked { - h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, rcvTime) - return true + var rttUpdated bool + h.packetHistory.Iterate(func(p *Packet) (bool, error) { + if p.PacketNumber == largestAcked { + h.rttStats.UpdateRTT(rcvTime.Sub(p.sendTime), ackDelay, rcvTime) + rttUpdated = true + return false, nil } // Packets are sorted by number, so we can stop searching - if packet.PacketNumber > largestAcked { - break + if p.PacketNumber > largestAcked { + return false, nil } - } - return false + return true, nil + }) + return rttUpdated } func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) { @@ -281,28 +280,25 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time) { maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT) - var lostPackets []*PacketElement - for el := h.packetHistory.Front(); el != nil; el = el.Next() { - packet := el.Value - + var lostPackets []*Packet + h.packetHistory.Iterate(func(packet *Packet) (bool, error) { if packet.PacketNumber > h.largestAcked { - break + return false, nil } timeSinceSent := now.Sub(packet.sendTime) if timeSinceSent > delayUntilLost { - lostPackets = append(lostPackets, el) + lostPackets = append(lostPackets, packet) } else if h.lossTime.IsZero() { // Note: This conditional is only entered once per call h.lossTime = now.Add(delayUntilLost - timeSinceSent) } - } + return true, nil + }) - if len(lostPackets) > 0 { - for _, p := range lostPackets { - h.queuePacketForRetransmission(p) - h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight) - } + for _, p := range lostPackets { + h.queuePacketForRetransmission(p) + h.congestion.OnPacketLost(p.PacketNumber, p.Length, h.bytesInFlight) } } @@ -329,12 +325,12 @@ func (h *sentPacketHandler) GetAlarmTimeout() time.Time { return h.alarm } -func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) { - h.bytesInFlight -= packetElement.Value.Length +func (h *sentPacketHandler) onPacketAcked(p *Packet) { + h.bytesInFlight -= p.Length h.rtoCount = 0 h.handshakeCount = 0 // TODO(#497): h.tlpCount = 0 - h.packetHistory.Remove(packetElement) + h.packetHistory.Remove(p.PacketNumber) } func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet { @@ -405,36 +401,31 @@ func (h *sentPacketHandler) retransmitOldestTwoPackets() { } } -func (h *sentPacketHandler) queueRTO(el *PacketElement) { - packet := &el.Value - utils.Debugf( - "\tQueueing packet 0x%x for retransmission (RTO), %d outstanding", - packet.PacketNumber, - h.packetHistory.Len(), - ) - h.queuePacketForRetransmission(el) - h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight) +func (h *sentPacketHandler) queueRTO(p *Packet) { + utils.Debugf("\tQueueing packet 0x%x for retransmission (RTO), %d outstanding", p.PacketNumber, h.packetHistory.Len()) + h.queuePacketForRetransmission(p) + h.congestion.OnPacketLost(p.PacketNumber, p.Length, h.bytesInFlight) h.congestion.OnRetransmissionTimeout(true) } func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() { - var handshakePackets []*PacketElement - for el := h.packetHistory.Front(); el != nil; el = el.Next() { - if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure { - handshakePackets = append(handshakePackets, el) + var handshakePackets []*Packet + h.packetHistory.Iterate(func(p *Packet) (bool, error) { + if p.EncryptionLevel < protocol.EncryptionForwardSecure { + handshakePackets = append(handshakePackets, p) } - } - for _, el := range handshakePackets { - h.queuePacketForRetransmission(el) + return true, nil + }) + for _, p := range handshakePackets { + h.queuePacketForRetransmission(p) } } -func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) { - packet := &packetElement.Value - h.bytesInFlight -= packet.Length - h.retransmissionQueue = append(h.retransmissionQueue, packet) - h.packetHistory.Remove(packetElement) - h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber) +func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) { + h.bytesInFlight -= p.Length + h.retransmissionQueue = append(h.retransmissionQueue, p) + h.packetHistory.Remove(p.PacketNumber) + h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber) } func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 00a63823..44a8875c 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -50,15 +50,20 @@ var _ = Describe("SentPacketHandler", func() { } }) - getPacketElement := func(p protocol.PacketNumber) *PacketElement { - for el := handler.packetHistory.Front(); el != nil; el = el.Next() { - if el.Value.PacketNumber == p { - return el - } + getPacket := func(pn protocol.PacketNumber) *Packet { + if el, ok := handler.packetHistory.packetMap[pn]; ok { + return &el.Value } return nil } + expectInPacketHistory := func(expected []protocol.PacketNumber) { + ExpectWithOffset(1, handler.packetHistory.Len()).To(Equal(len(expected))) + for _, p := range expected { + ExpectWithOffset(1, handler.packetHistory.packetMap).To(HaveKey(p)) + } + } + It("determines the packet number length", func() { handler.largestAcked = 0x1337 Expect(handler.GetPacketNumberLen(0x1338)).To(Equal(protocol.PacketNumberLen2)) @@ -72,8 +77,7 @@ var _ = Describe("SentPacketHandler", func() { handler.SentPacket(&packet1) handler.SentPacket(&packet2) Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(handler.packetHistory.Back().Value.PacketNumber).To(Equal(protocol.PacketNumber(2))) + expectInPacketHistory([]protocol.PacketNumber{1, 2}) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(3))) Expect(handler.skippedPackets).To(BeEmpty()) }) @@ -85,8 +89,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.lastSentPacketNumber).To(BeZero()) handler.SentPacket(&packet2) Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(0))) - Expect(handler.packetHistory.Back().Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) + expectInPacketHistory([]protocol.PacketNumber{0, 1}) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(3))) Expect(handler.skippedPackets).To(BeEmpty()) }) @@ -94,7 +97,7 @@ var _ = Describe("SentPacketHandler", func() { It("stores the sent time", func() { packet := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} handler.SentPacket(&packet) - Expect(handler.packetHistory.Front().Value.sendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1)) + Expect(getPacket(1).sendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1)) }) It("does not store non-retransmittable packets", func() { @@ -109,10 +112,7 @@ var _ = Describe("SentPacketHandler", func() { handler.SentPacket(&packet1) handler.SentPacket(&packet2) Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(3))) - el := handler.packetHistory.Front() - Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) - el = el.Next() - Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(3))) + expectInPacketHistory([]protocol.PacketNumber{1, 3}) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(3))) Expect(handler.skippedPackets).To(HaveLen(1)) Expect(handler.skippedPackets[0]).To(Equal(protocol.PacketNumber(2))) @@ -199,14 +199,6 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets)))) }) - expectInPacketHistory := func(expected []protocol.PacketNumber) { - var packets []protocol.PacketNumber - for el := handler.packetHistory.Front(); el != nil; el = el.Next() { - packets = append(packets, el.Value.PacketNumber) - } - ExpectWithOffset(1, packets).To(Equal(expected)) - } - Context("ACK validation", func() { It("accepts ACKs sent in packet 0", func() { ack := wire.AckFrame{ @@ -300,12 +292,11 @@ var _ = Describe("SentPacketHandler", func() { err := handler.ReceivedAck(&ack, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(5))) - el := handler.packetHistory.Front() - for i := 6; i <= 10; i++ { - Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(i))) - el = el.Next() + Expect(handler.packetHistory.Len()).To(Equal(6)) // 6, 7, 8, 9, 10, 12 + for i := protocol.PacketNumber(6); i <= 10; i++ { + Expect(getPacket(i)).ToNot(BeNil()) } - Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(12))) + Expect(getPacket(12)).ToNot(BeNil()) }) It("rejects an ACK that acks packets with a higher encryption level", func() { @@ -330,13 +321,11 @@ var _ = Describe("SentPacketHandler", func() { } err := handler.ReceivedAck(&ack, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(err).ToNot(HaveOccurred()) - el := handler.packetHistory.Front() - Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(0))) - el = el.Next() - Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(9))) - el = el.Next() - Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(10))) - Expect(el.Next().Value.PacketNumber).To(Equal(protocol.PacketNumber(12))) + Expect(handler.packetHistory.Len()).To(Equal(4)) // 0, 9, 10, 12 + Expect(getPacket(0)).ToNot(BeNil()) + Expect(getPacket(9)).ToNot(BeNil()) + Expect(getPacket(10)).ToNot(BeNil()) + Expect(getPacket(12)).ToNot(BeNil()) }) It("acks packet 0", func() { @@ -346,7 +335,8 @@ var _ = Describe("SentPacketHandler", func() { } err := handler.ReceivedAck(&ack, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(err).ToNot(HaveOccurred()) - expectInPacketHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12}) + Expect(getPacket(0)).To(BeNil()) + Expect(handler.packetHistory.Len()).To(Equal(11)) // 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12 }) It("handles an ACK frame with one missing packet range", func() { @@ -465,9 +455,9 @@ var _ = Describe("SentPacketHandler", func() { It("computes the RTT", func() { now := time.Now() // First, fake the sent times of the first, second and last packet - getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute) - getPacketElement(2).Value.sendTime = now.Add(-5 * time.Minute) - getPacketElement(6).Value.sendTime = now.Add(-1 * time.Minute) + getPacket(1).sendTime = now.Add(-10 * time.Minute) + getPacket(2).sendTime = now.Add(-5 * time.Minute) + getPacket(6).sendTime = now.Add(-1 * time.Minute) // Now, check that the proper times are used when calculating the deltas err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1}, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(err).NotTo(HaveOccurred()) @@ -484,7 +474,7 @@ var _ = Describe("SentPacketHandler", func() { now := time.Now() // make sure the rttStats have a min RTT, so that the delay is used handler.rttStats.UpdateRTT(5*time.Minute, 0, time.Now()) - getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute) + getPacket(1).sendTime = now.Add(-10 * time.Minute) err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, DelayTime: 5 * time.Minute}, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(err).NotTo(HaveOccurred()) Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) @@ -557,9 +547,9 @@ var _ = Describe("SentPacketHandler", func() { }) It("dequeues a packet for retransmission", func() { - getPacketElement(1).Value.sendTime = time.Now().Add(-time.Hour) + getPacket(1).sendTime = time.Now().Add(-time.Hour) handler.OnAlarm() - Expect(getPacketElement(1)).To(BeNil()) + Expect(getPacket(1)).To(BeNil()) Expect(handler.retransmissionQueue).To(HaveLen(1)) Expect(handler.retransmissionQueue[0].PacketNumber).To(Equal(protocol.PacketNumber(1))) packet := handler.DequeuePacketForRetransmission() @@ -573,7 +563,7 @@ var _ = Describe("SentPacketHandler", func() { if i == 2 { // packet 2 was already acked in BeforeEach continue } - handler.queuePacketForRetransmission(getPacketElement(i)) + handler.queuePacketForRetransmission(getPacket(i)) } Expect(handler.retransmissionQueue).To(HaveLen(6)) handler.SetHandshakeComplete() @@ -595,7 +585,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("gets a STOP_WAITING frame after queueing a retransmission", func() { - handler.queuePacketForRetransmission(getPacketElement(5)) + handler.queuePacketForRetransmission(getPacket(5)) Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6})) }) }) @@ -626,10 +616,8 @@ var _ = Describe("SentPacketHandler", func() { } handler.ReceivedAck(&ack, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) - - handler.packetHistory.Front().Value.sendTime = time.Now().Add(-time.Hour) + getPacket(2).sendTime = time.Now().Add(-time.Hour) handler.OnAlarm() - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(0))) }) @@ -804,7 +792,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(time.Until(handler.lossTime)).To(BeNumerically("~", time.Hour*9/8, time.Minute)) Expect(time.Until(handler.GetAlarmTimeout())).To(BeNumerically("~", time.Hour*9/8, time.Minute)) - handler.packetHistory.Front().Value.sendTime = time.Now().Add(-2 * time.Hour) + getPacket(1).sendTime = time.Now().Add(-2 * time.Hour) handler.OnAlarm() Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil()) }) @@ -856,7 +844,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(p).ToNot(BeNil()) Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(4))) Expect(handler.packetHistory.Len()).To(Equal(1)) - Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(3))) + Expect(getPacket(3)).ToNot(BeNil()) Expect(handler.handshakeCount).To(BeEquivalentTo(1)) // make sure the exponential backoff is used Expect(handler.computeHandshakeTimeout()).To(BeNumerically("~", 2*handshakeTimeout, time.Minute)) diff --git a/internal/ackhandler/sent_packet_history.go b/internal/ackhandler/sent_packet_history.go new file mode 100644 index 00000000..eeb24709 --- /dev/null +++ b/internal/ackhandler/sent_packet_history.go @@ -0,0 +1,59 @@ +package ackhandler + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type sentPacketHistory struct { + packetList *PacketList + packetMap map[protocol.PacketNumber]*PacketElement +} + +func newSentPacketHistory() *sentPacketHistory { + return &sentPacketHistory{ + packetList: NewPacketList(), + packetMap: make(map[protocol.PacketNumber]*PacketElement), + } +} + +func (h *sentPacketHistory) SentPacket(p *Packet) { + el := h.packetList.PushBack(*p) + h.packetMap[p.PacketNumber] = el +} + +// Iterate iterates through all packets. +// The callback must not modify the history. +func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error { + cont := true + for el := h.packetList.Front(); cont && el != nil; el = el.Next() { + var err error + cont, err = cb(&el.Value) + if err != nil { + return err + } + } + return nil +} + +func (h *sentPacketHistory) Front() *Packet { + if h.Len() == 0 { + return nil + } + return &h.packetList.Front().Value +} + +func (h *sentPacketHistory) Len() int { + return len(h.packetMap) +} + +func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { + el, ok := h.packetMap[p] + if !ok { + return fmt.Errorf("packet %d not found in sent packet history", p) + } + h.packetList.Remove(el) + delete(h.packetMap, p) + return nil +} diff --git a/internal/ackhandler/sent_packet_history_test.go b/internal/ackhandler/sent_packet_history_test.go new file mode 100644 index 00000000..a88101fa --- /dev/null +++ b/internal/ackhandler/sent_packet_history_test.go @@ -0,0 +1,112 @@ +package ackhandler + +import ( + "errors" + + "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("SentPacketHistory", func() { + var hist *sentPacketHistory + + expectInHistory := func(packetNumbers []protocol.PacketNumber) { + ExpectWithOffset(1, hist.packetMap).To(HaveLen(len(packetNumbers))) + ExpectWithOffset(1, hist.packetList.Len()).To(Equal(len(packetNumbers))) + i := 0 + hist.Iterate(func(p *Packet) (bool, error) { + pn := packetNumbers[i] + ExpectWithOffset(1, p.PacketNumber).To(Equal(pn)) + ExpectWithOffset(1, hist.packetMap[pn].Value.PacketNumber).To(Equal(pn)) + i++ + return true, nil + }) + } + + BeforeEach(func() { + hist = newSentPacketHistory() + }) + + It("saves sent packets", func() { + hist.SentPacket(&Packet{PacketNumber: 1}) + hist.SentPacket(&Packet{PacketNumber: 3}) + hist.SentPacket(&Packet{PacketNumber: 4}) + expectInHistory([]protocol.PacketNumber{1, 3, 4}) + }) + + It("gets the length", func() { + hist.SentPacket(&Packet{PacketNumber: 1}) + hist.SentPacket(&Packet{PacketNumber: 10}) + Expect(hist.Len()).To(Equal(2)) + }) + + It("gets nil, if there's no front packet", func() { + Expect(hist.Front()).To(BeNil()) + }) + + It("gets the front packet", func() { + hist.SentPacket(&Packet{PacketNumber: 2}) + hist.SentPacket(&Packet{PacketNumber: 3}) + front := hist.Front() + Expect(front).ToNot(BeNil()) + Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2))) + }) + + It("removes packets", func() { + hist.SentPacket(&Packet{PacketNumber: 1}) + hist.SentPacket(&Packet{PacketNumber: 4}) + hist.SentPacket(&Packet{PacketNumber: 8}) + err := hist.Remove(4) + Expect(err).ToNot(HaveOccurred()) + expectInHistory([]protocol.PacketNumber{1, 8}) + }) + + It("errors when trying to remove a non existing packet", func() { + hist.SentPacket(&Packet{PacketNumber: 1}) + err := hist.Remove(2) + Expect(err).To(MatchError("packet 2 not found in sent packet history")) + }) + + Context("iterating", func() { + BeforeEach(func() { + hist.SentPacket(&Packet{PacketNumber: 10}) + hist.SentPacket(&Packet{PacketNumber: 14}) + hist.SentPacket(&Packet{PacketNumber: 18}) + }) + + It("iterates over all packets", func() { + var iterations []protocol.PacketNumber + err := hist.Iterate(func(p *Packet) (bool, error) { + iterations = append(iterations, p.PacketNumber) + return true, nil + }) + Expect(err).ToNot(HaveOccurred()) + Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14, 18})) + }) + + It("stops iterating", func() { + var iterations []protocol.PacketNumber + err := hist.Iterate(func(p *Packet) (bool, error) { + iterations = append(iterations, p.PacketNumber) + return p.PacketNumber != 14, nil + }) + Expect(err).ToNot(HaveOccurred()) + Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14})) + }) + + It("returns the error", func() { + testErr := errors.New("test error") + var iterations []protocol.PacketNumber + err := hist.Iterate(func(p *Packet) (bool, error) { + iterations = append(iterations, p.PacketNumber) + if p.PacketNumber == 14 { + return false, testErr + } + return true, nil + }) + Expect(err).To(MatchError(testErr)) + Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14})) + }) + }) +})