diff --git a/internal/ackhandler/packet_number_generator.go b/internal/ackhandler/packet_number_generator.go index 9cf20a0b..e84171e3 100644 --- a/internal/ackhandler/packet_number_generator.go +++ b/internal/ackhandler/packet_number_generator.go @@ -7,7 +7,10 @@ import ( type packetNumberGenerator interface { Peek() protocol.PacketNumber - Pop() protocol.PacketNumber + // Pop pops the packet number. + // It reports if the packet number (before the one just popped) was skipped. + // It never skips more than one packet number in a row. + Pop() (skipped bool, _ protocol.PacketNumber) } type sequentialPacketNumberGenerator struct { @@ -24,10 +27,10 @@ func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber { return p.next } -func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber { +func (p *sequentialPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) { next := p.next p.next++ - return next + return false, next } // The skippingPacketNumberGenerator generates the packet number for the next packet @@ -56,21 +59,26 @@ func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol } func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber { + if p.next == p.nextToSkip { + return p.next + 1 + } return p.next } -func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber { +func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) { next := p.next - p.next++ // generate a new packet number for the next packet if p.next == p.nextToSkip { - p.next++ + next++ + p.next += 2 p.generateNewSkip() + return true, next } - return next + p.next++ // generate a new packet number for the next packet + return false, next } func (p *skippingPacketNumberGenerator) generateNewSkip() { // make sure that there are never two consecutive packet numbers that are skipped - p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) + p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) p.period = utils.Min(2*p.period, p.maxPeriod) } diff --git a/internal/ackhandler/packet_number_generator_test.go b/internal/ackhandler/packet_number_generator_test.go index b24228ec..4384c8f8 100644 --- a/internal/ackhandler/packet_number_generator_test.go +++ b/internal/ackhandler/packet_number_generator_test.go @@ -18,7 +18,9 @@ var _ = Describe("Sequential Packet Number Generator", func() { for i := initialPN; i < initialPN+1000; i++ { Expect(png.Peek()).To(Equal(i)) Expect(png.Peek()).To(Equal(i)) - Expect(png.Pop()).To(Equal(i)) + skipNext, pn := png.Pop() + Expect(skipNext).To(BeFalse()) + Expect(pn).To(Equal(i)) } }) }) @@ -34,29 +36,39 @@ var _ = Describe("Skipping Packet Number Generator", func() { It("can be initialized to return any first packet number", func() { png := newSkippingPacketNumberGenerator(12345, initialPeriod, maxPeriod) - Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345))) + _, pn := png.Pop() + Expect(pn).To(Equal(protocol.PacketNumber(12345))) }) It("allows peeking", func() { png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod).(*skippingPacketNumberGenerator) - png.nextToSkip = 1000 Expect(png.Peek()).To(Equal(initialPN)) Expect(png.Peek()).To(Equal(initialPN)) - Expect(png.Pop()).To(Equal(initialPN)) - Expect(png.Peek()).To(Equal(initialPN + 1)) - Expect(png.Peek()).To(Equal(initialPN + 1)) + skipped, pn := png.Pop() + Expect(pn).To(Equal(initialPN)) + next := initialPN + 1 + if skipped { + next++ + } + Expect(png.Peek()).To(Equal(next)) + Expect(png.Peek()).To(Equal(next)) }) It("skips a packet number", func() { png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod) var last protocol.PacketNumber var skipped bool - for i := 0; i < 1000; i++ { - num := png.Pop() - if num > last+1 { + for i := 0; i < int(maxPeriod); i++ { + didSkip, num := png.Pop() + if didSkip { skipped = true + _, nextNum := png.Pop() + Expect(nextNum).To(Equal(num + 1)) break } + if i != 0 { + Expect(num).To(Equal(last + 1)) + } last = num } Expect(skipped).To(BeTrue()) @@ -69,17 +81,15 @@ var _ = Describe("Skipping Packet Number Generator", func() { for i := 0; i < rep; i++ { png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod) - last := initialPN lastSkip := initialPN for len(periods[i]) < len(expectedPeriods) { - next := png.Pop() - if next > last+1 { - skipped := next - 1 + skipNext, next := png.Pop() + if skipNext { + skipped := next + 1 Expect(skipped).To(BeNumerically(">", lastSkip+1)) periods[i] = append(periods[i], skipped-lastSkip-1) lastSkip = skipped } - last = next } } diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 88378d6d..4451b0e5 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -183,7 +183,7 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { // When 0-RTT is rejected, all application data sent so far becomes invalid. // Delete the packets from the history and remove them from bytes_in_flight. h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { - if p.EncryptionLevel != protocol.Encryption0RTT { + if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket { return false, nil } h.removeFromBytesInFlight(p) @@ -236,12 +236,6 @@ func (h *sentPacketHandler) SentPacket(p *Packet) { } pnSpace := h.getPacketNumberSpace(p.EncryptionLevel) - if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() { - for pn := utils.Max(0, pnSpace.largestSent+1); pn < p.PacketNumber; pn++ { - h.logger.Debugf("Skipping packet number %d", pn) - } - } - pnSpace.largestSent = p.PacketNumber isAckEliciting := len(p.StreamFrames) > 0 || len(p.Frames) > 0 @@ -258,7 +252,7 @@ func (h *sentPacketHandler) SentPacket(p *Packet) { if isAckEliciting { pnSpace.history.SentAckElicitingPacket(p) } else { - pnSpace.history.SentNonAckElicitingPacket(p.PacketNumber, p.EncryptionLevel, p.SendTime) + pnSpace.history.SentNonAckElicitingPacket(p.PacketNumber) putPacket(p) p = nil //nolint:ineffassign // This is just to be on the safe side. } @@ -689,7 +683,8 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { h.ptoMode = SendPTOHandshake case protocol.Encryption1RTT: // skip a packet number in order to elicit an immediate ACK - _ = h.PopPacketNumber(protocol.Encryption1RTT) + pn := h.PopPacketNumber(protocol.Encryption1RTT) + h.getPacketNumberSpace(protocol.Encryption1RTT).history.SkippedPacket(pn) h.ptoMode = SendPTOAppData default: return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel) @@ -709,7 +704,16 @@ func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) } func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber { - return h.getPacketNumberSpace(encLevel).pns.Pop() + pnSpace := h.getPacketNumberSpace(encLevel) + skipped, pn := pnSpace.pns.Pop() + if skipped { + skippedPN := pn - 1 + pnSpace.history.SkippedPacket(skippedPN) + if h.logger.Debug() { + h.logger.Debugf("Skipping packet number %d", skippedPN) + } + } + return pn } func (h *sentPacketHandler) SendMode(now time.Time) SendMode { @@ -835,8 +839,8 @@ func (h *sentPacketHandler) ResetForRetry() error { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } } - h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), false) - h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), true) + h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false) + h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Peek(), true) oldAlarm := h.alarm h.alarm = time.Time{} if h.tracer != nil { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 3ee41661..e12c42fa 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -166,18 +166,18 @@ var _ = Describe("SentPacketHandler", func() { }) It("says if a 1-RTT packet was acknowledged", func() { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 100, EncryptionLevel: protocol.Encryption0RTT})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 101, EncryptionLevel: protocol.Encryption0RTT})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 102, EncryptionLevel: protocol.Encryption1RTT})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10, EncryptionLevel: protocol.Encryption0RTT})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11, EncryptionLevel: protocol.Encryption0RTT})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 12, EncryptionLevel: protocol.Encryption1RTT})) acked1RTT, err := handler.ReceivedAck( - &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 100, Largest: 101}}}, + &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}}, protocol.Encryption1RTT, time.Now(), ) Expect(err).ToNot(HaveOccurred()) Expect(acked1RTT).To(BeFalse()) acked1RTT, err = handler.ReceivedAck( - &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 101, Largest: 102}}}, + &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 11, Largest: 12}}}, protocol.Encryption1RTT, time.Now(), ) @@ -199,13 +199,14 @@ var _ = Describe("SentPacketHandler", func() { }) It("rejects ACKs that acknowledge a skipped packet number", func() { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 100})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 102})) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 100, Largest: 102}}} + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) + handler.appDataPackets.history.SkippedPacket(11) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 12})) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 12}}} _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) Expect(err).To(MatchError(&qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "received an ACK for skipped packet number: 101 (1-RTT)", + ErrorMessage: "received an ACK for skipped packet number: 11 (1-RTT)", })) }) @@ -277,7 +278,7 @@ var _ = Describe("SentPacketHandler", func() { var acked bool ping := &wire.PingFrame{} handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 13, + PacketNumber: 10, Frames: []Frame{{ Frame: ping, OnAcked: func(f wire.Frame) { Expect(f).To(Equal(ping)) @@ -285,7 +286,7 @@ var _ = Describe("SentPacketHandler", func() { }, }}, })) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 10}}} _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(acked).To(BeTrue()) @@ -428,21 +429,21 @@ var _ = Describe("SentPacketHandler", func() { JustBeforeEach(func() { morePackets := []*Packet{ { - PacketNumber: 13, + PacketNumber: 10, LargestAcked: 100, Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, Length: 1, EncryptionLevel: protocol.Encryption1RTT, }, { - PacketNumber: 14, + PacketNumber: 11, LargestAcked: 200, Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, Length: 1, EncryptionLevel: protocol.Encryption1RTT, }, { - PacketNumber: 15, + PacketNumber: 12, Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, Length: 1, EncryptionLevel: protocol.Encryption1RTT, @@ -454,15 +455,15 @@ var _ = Describe("SentPacketHandler", func() { }) It("determines which ACK we have received an ACK for", func() { - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 15}}} + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 12}}} _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) }) It("doesn't do anything when the acked packet didn't contain an ACK", func() { - ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} - ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 15, Largest: 15}}} + ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 10}}} + ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 12, Largest: 12}}} _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101))) @@ -472,8 +473,8 @@ var _ = Describe("SentPacketHandler", func() { }) It("doesn't decrease the value", func() { - ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 14, Largest: 14}}} - ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} + ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 11, Largest: 11}}} + ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 10}}} _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) @@ -701,6 +702,7 @@ var _ = Describe("SentPacketHandler", func() { handler.SetHandshakeConfirmed() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) + handler.appDataPackets.pns.(*skippingPacketNumberGenerator).next = 3 Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) @@ -751,7 +753,7 @@ var _ = Describe("SentPacketHandler", func() { handler.SetHandshakeConfirmed() var lostPackets []protocol.PacketNumber handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 1, + PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT), SendTime: time.Now().Add(-time.Hour), Frames: []Frame{ {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, 1) }}, @@ -759,51 +761,36 @@ var _ = Describe("SentPacketHandler", func() { })) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.SendMode(time.Now())).ToNot(Equal(SendPTOAppData)) }) - It("skips a packet number for 1-RTT PTOs", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() - var lostPackets []protocol.PacketNumber - pn := handler.PopPacketNumber(protocol.Encryption1RTT) - handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: pn, - SendTime: time.Now().Add(-time.Hour), - Frames: []Frame{ - {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, 1) }}, - }, - })) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) - // The packet number generator might have introduced another skipped a packet number. - Expect(handler.PopPacketNumber(protocol.Encryption1RTT)).To(BeNumerically(">=", pn+2)) - }) - It("only counts ack-eliciting packets as probe packets", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeConfirmed() - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT), + SendTime: time.Now().Add(-time.Hour), + })) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) - for p := protocol.PacketNumber(3); p < 30; p++ { - handler.SentPacket(nonAckElicitingPacket(&Packet{PacketNumber: p})) + for i := 0; i < 30; i++ { + handler.SentPacket(nonAckElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) } - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 30})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.SendMode(time.Now())).ToNot(Equal(SendPTOAppData)) }) It("gets two probe packets if PTO expires", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeConfirmed() - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) updateRTT(time.Hour) Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) @@ -811,16 +798,16 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP Expect(handler.ptoCount).To(BeEquivalentTo(1)) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 4})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // PTO Expect(handler.ptoCount).To(BeEquivalentTo(2)) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 6})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) }) @@ -844,7 +831,7 @@ var _ = Describe("SentPacketHandler", func() { It("doesn't send 1-RTT probe packets before the handshake completes", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) updateRTT(time.Hour) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) @@ -858,11 +845,12 @@ var _ = Describe("SentPacketHandler", func() { It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeConfirmed() - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) + pn := handler.PopPacketNumber(protocol.Encryption1RTT) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: pn, SendTime: time.Now().Add(-time.Hour)})) updateRTT(time.Second) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: pn, Largest: pn}}} _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) @@ -870,7 +858,10 @@ var _ = Describe("SentPacketHandler", func() { It("handles ACKs for the original packet", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)})) + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT), + SendTime: time.Now().Add(-time.Hour), + })) updateRTT(time.Second) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) @@ -1210,10 +1201,11 @@ var _ = Describe("SentPacketHandler", func() { BeforeEach(func() { perspective = protocol.PerspectiveClient }) It("deletes Initials, as a client", func() { - for i := protocol.PacketNumber(0); i < 6; i++ { + for i := 0; i < 6; i++ { handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: i, + PacketNumber: handler.PopPacketNumber(protocol.EncryptionInitial), EncryptionLevel: protocol.EncryptionInitial, + Length: 1, })) } Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) @@ -1221,13 +1213,14 @@ var _ = Describe("SentPacketHandler", func() { // DropPackets should be ignored for clients and the Initial packet number space. // It has to be possible to send another Initial packets after this function was called. handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 10, + PacketNumber: handler.PopPacketNumber(protocol.EncryptionInitial), EncryptionLevel: protocol.EncryptionInitial, + Length: 1, })) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(7))) // Sending a Handshake packet triggers dropping of Initials. handler.SentPacket(ackElicitingPacket(&Packet{ - PacketNumber: 1, + PacketNumber: handler.PopPacketNumber(protocol.EncryptionHandshake), EncryptionLevel: protocol.EncryptionHandshake, })) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1))) @@ -1260,6 +1253,7 @@ var _ = Describe("SentPacketHandler", func() { It("doesn't retransmit 0-RTT packets when 0-RTT keys are dropped", func() { for i := protocol.PacketNumber(0); i < 6; i++ { if i == 3 { + handler.appDataPackets.history.SkippedPacket(3) continue } handler.SentPacket(ackElicitingPacket(&Packet{ diff --git a/internal/ackhandler/sent_packet_history.go b/internal/ackhandler/sent_packet_history.go index f5a4cd83..af175e11 100644 --- a/internal/ackhandler/sent_packet_history.go +++ b/internal/ackhandler/sent_packet_history.go @@ -2,7 +2,6 @@ package ackhandler import ( "fmt" - "time" "github.com/quic-go/quic-go/internal/protocol" ) @@ -22,8 +21,25 @@ func newSentPacketHistory() *sentPacketHistory { } } -func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) { - h.maybeAddSkippedPacketsBefore(pn, encLevel, t) +func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) { + if h.highestPacketNumber != protocol.InvalidPacketNumber { + if pn != h.highestPacketNumber+1 { + panic("non-sequential packet number use") + } + } +} + +func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) { + h.checkSequentialPacketNumberUse(pn) + h.highestPacketNumber = pn + h.packets = append(h.packets, &Packet{ + PacketNumber: pn, + skippedPacket: true, + }) +} + +func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) { + h.checkSequentialPacketNumberUse(pn) h.highestPacketNumber = pn if len(h.packets) > 0 { h.packets = append(h.packets, nil) @@ -31,30 +47,12 @@ func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber, } func (h *sentPacketHistory) SentAckElicitingPacket(p *Packet) { - h.maybeAddSkippedPacketsBefore(p.PacketNumber, p.EncryptionLevel, p.SendTime) + h.checkSequentialPacketNumberUse(p.PacketNumber) + h.highestPacketNumber = p.PacketNumber h.packets = append(h.packets, p) if p.outstanding() { h.numOutstanding++ } - h.highestPacketNumber = p.PacketNumber -} - -func (h *sentPacketHistory) maybeAddSkippedPacketsBefore(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) { - if pn <= h.highestPacketNumber { - panic("non-sequential packet number use") - } - var start protocol.PacketNumber - if h.highestPacketNumber != protocol.InvalidPacketNumber { - start = h.highestPacketNumber + 1 - } - for p := start; p < pn; p++ { - h.packets = append(h.packets, &Packet{ - PacketNumber: p, - EncryptionLevel: encLevel, - SendTime: t, - skippedPacket: true, - }) - } } // Iterate iterates through all packets. diff --git a/internal/ackhandler/sent_packet_history_test.go b/internal/ackhandler/sent_packet_history_test.go index 68e52240..8e6454d4 100644 --- a/internal/ackhandler/sent_packet_history_test.go +++ b/internal/ackhandler/sent_packet_history_test.go @@ -55,15 +55,17 @@ var _ = Describe("SentPacketHistory", func() { It("saves non-ack-eliciting packets", func() { now := time.Now() - hist.SentNonAckElicitingPacket(0, protocol.Encryption1RTT, now) + hist.SentNonAckElicitingPacket(0) hist.SentAckElicitingPacket(&Packet{PacketNumber: 1, SendTime: now}) - hist.SentNonAckElicitingPacket(2, protocol.Encryption1RTT, now) + hist.SentNonAckElicitingPacket(2) hist.SentAckElicitingPacket(&Packet{PacketNumber: 3, SendTime: now}) expectInHistory([]protocol.PacketNumber{1, 3}) }) It("saves sent packets, with skipped packet number", func() { + hist.SkippedPacket(0) hist.SentAckElicitingPacket(&Packet{PacketNumber: 1}) + hist.SkippedPacket(2) hist.SentAckElicitingPacket(&Packet{PacketNumber: 3}) hist.SentAckElicitingPacket(&Packet{PacketNumber: 4}) expectInHistory([]protocol.PacketNumber{1, 3, 4}) @@ -72,7 +74,8 @@ var _ = Describe("SentPacketHistory", func() { It("doesn't save non-ack-eliciting packets", func() { hist.SentAckElicitingPacket(&Packet{PacketNumber: 1}) - hist.SentNonAckElicitingPacket(3, protocol.EncryptionLevel(0), time.Time{}) + hist.SkippedPacket(2) + hist.SentNonAckElicitingPacket(3) hist.SentAckElicitingPacket(&Packet{PacketNumber: 4}) expectInHistory([]protocol.PacketNumber{1, 4}) }) @@ -103,6 +106,7 @@ var _ = Describe("SentPacketHistory", func() { It("doesn't regard path MTU packets as outstanding", func() { hist.SentAckElicitingPacket(&Packet{PacketNumber: 2}) + hist.SkippedPacket(3) hist.SentAckElicitingPacket(&Packet{PacketNumber: 4, IsPathMTUProbePacket: true}) front := hist.FirstOutstanding() Expect(front).ToNot(BeNil()) @@ -119,8 +123,11 @@ var _ = Describe("SentPacketHistory", func() { expectInHistory([]protocol.PacketNumber{0, 1, 3}) }) - It("also remove skipped packets before the removed packet", func() { + It("also removes skipped packets before the removed packet", func() { + hist.SkippedPacket(0) hist.SentAckElicitingPacket(&Packet{PacketNumber: 1}) + hist.SkippedPacket(2) + hist.SkippedPacket(3) hist.SentAckElicitingPacket(&Packet{PacketNumber: 4}) expectSkippedInHistory([]protocol.PacketNumber{0, 2, 3}) Expect(hist.Remove(4)).To(Succeed()) @@ -131,15 +138,23 @@ var _ = Describe("SentPacketHistory", func() { expectSkippedInHistory(nil) }) + It("panics on non-sequential packet number use", func() { + hist.SentAckElicitingPacket(&Packet{PacketNumber: 100}) + Expect(func() { hist.SentAckElicitingPacket(&Packet{PacketNumber: 102}) }).To(Panic()) + }) + It("removes and adds packets", func() { hist.SentAckElicitingPacket(&Packet{PacketNumber: 0}) hist.SentAckElicitingPacket(&Packet{PacketNumber: 1}) + hist.SkippedPacket(2) + hist.SkippedPacket(3) hist.SentAckElicitingPacket(&Packet{PacketNumber: 4}) - hist.SentAckElicitingPacket(&Packet{PacketNumber: 8}) + hist.SkippedPacket(5) + hist.SentAckElicitingPacket(&Packet{PacketNumber: 6}) Expect(hist.Remove(0)).To(Succeed()) Expect(hist.Remove(1)).To(Succeed()) - hist.SentAckElicitingPacket(&Packet{PacketNumber: 9}) - expectInHistory([]protocol.PacketNumber{4, 8, 9}) + hist.SentAckElicitingPacket(&Packet{PacketNumber: 7}) + expectInHistory([]protocol.PacketNumber{4, 6, 7}) }) It("removes the last packet, then adds more", func() { @@ -161,8 +176,14 @@ var _ = Describe("SentPacketHistory", func() { Context("iterating", func() { BeforeEach(func() { + hist.SkippedPacket(0) hist.SentAckElicitingPacket(&Packet{PacketNumber: 1}) + hist.SkippedPacket(2) + hist.SkippedPacket(3) hist.SentAckElicitingPacket(&Packet{PacketNumber: 4}) + hist.SkippedPacket(5) + hist.SkippedPacket(6) + hist.SkippedPacket(7) hist.SentAckElicitingPacket(&Packet{PacketNumber: 8}) }) diff --git a/packet_packer.go b/packet_packer.go index ca925b65..8454c510 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -827,18 +827,16 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire } payloadOffset := protocol.ByteCount(len(raw)) - pn := p.pnManager.PopPacketNumber(encLevel) - if pn != header.PacketNumber { - return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") - } - raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) if err != nil { return nil, err } - raw = p.encryptPacket(raw, sealer, pn, payloadOffset, pnLen) + raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen) buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] + if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber { + return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber) + } return &longHeaderPacket{ header: header, ack: pl.ack, @@ -875,10 +873,6 @@ func (p *packetPacker) appendShortHeaderPacket( } payloadOffset := protocol.ByteCount(len(raw)) - if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) { - return nil, nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") - } - raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) if err != nil { return nil, nil, err @@ -913,6 +907,9 @@ func (p *packetPacker) appendShortHeaderPacket( ap.SendTime = now ap.IsPathMTUProbePacket = isMTUProbePacket + if newPN := p.pnManager.PopPacketNumber(protocol.Encryption1RTT); newPN != pn { + return nil, nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, newPN) + } return ap, pl.ack, nil }