diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 38a9aa66..c586495e 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -29,8 +29,10 @@ type packetNumberSpace struct { func newPacketNumberSpace(initialPN protocol.PacketNumber) *packetNumberSpace { return &packetNumberSpace{ - history: newSentPacketHistory(), - pns: newPacketNumberGenerator(initialPN, protocol.SkipPacketAveragePeriodLength), + history: newSentPacketHistory(), + pns: newPacketNumberGenerator(initialPN, protocol.SkipPacketAveragePeriodLength), + largestSent: protocol.InvalidPacketNumber, + largestAcked: protocol.InvalidPacketNumber, } } @@ -161,14 +163,15 @@ func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLev func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-eliciting */ { pnSpace := h.getPacketNumberSpace(packet.EncryptionLevel) - if h.logger.Debug() && pnSpace.largestSent != 0 { - for p := pnSpace.largestSent + 1; p < packet.PacketNumber; p++ { + if h.logger.Debug() { + for p := utils.MaxPacketNumber(0, pnSpace.largestSent+1); p < packet.PacketNumber; p++ { h.logger.Debugf("Skipping packet number %#x", p) } } pnSpace.largestSent = packet.PacketNumber + packet.largestAcked = protocol.InvalidPacketNumber if packet.Ack != nil { packet.largestAcked = packet.Ack.LargestAcked() } @@ -232,10 +235,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe priorInFlight := h.bytesInFlight for _, p := range ackedPackets { - // largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0 - // It is safe to ignore the corner case of packets that just acked packet 0, because - // the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send. - if p.largestAcked != 0 && encLevel == protocol.Encryption1RTT { + if p.largestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.largestAcked+1) } if err := h.onPacketAcked(p, rcvTime); err != nil { diff --git a/internal/congestion/cubic_sender.go b/internal/congestion/cubic_sender.go index 017d824c..2a45b354 100644 --- a/internal/congestion/cubic_sender.go +++ b/internal/congestion/cubic_sender.go @@ -68,6 +68,9 @@ var _ SendAlgorithm = &cubicSender{} func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) *cubicSender { return &cubicSender{ rttStats: rttStats, + largestSentPacketNumber: protocol.InvalidPacketNumber, + largestAckedPacketNumber: protocol.InvalidPacketNumber, + largestSentAtLastCutback: protocol.InvalidPacketNumber, initialCongestionWindow: initialCongestionWindow, initialMaxCongestionWindow: initialMaxCongestionWindow, congestionWindow: initialCongestionWindow, @@ -110,7 +113,7 @@ func (c *cubicSender) OnPacketSent( } func (c *cubicSender) InRecovery() bool { - return c.largestAckedPacketNumber <= c.largestSentAtLastCutback && c.largestAckedPacketNumber != 0 + return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback } func (c *cubicSender) InSlowStart() bool { @@ -282,7 +285,7 @@ func (c *cubicSender) SetNumEmulatedConnections(n int) { // OnRetransmissionTimeout is called on an retransmission timeout func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { - c.largestSentAtLastCutback = 0 + c.largestSentAtLastCutback = protocol.InvalidPacketNumber if !packetsRetransmitted { return } @@ -296,9 +299,9 @@ func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { func (c *cubicSender) OnConnectionMigration() { c.hybridSlowStart.Restart() c.prr = PrrSender{} - c.largestSentPacketNumber = 0 - c.largestAckedPacketNumber = 0 - c.largestSentAtLastCutback = 0 + c.largestSentPacketNumber = protocol.InvalidPacketNumber + c.largestAckedPacketNumber = protocol.InvalidPacketNumber + c.largestSentAtLastCutback = protocol.InvalidPacketNumber c.lastCutbackExitedSlowstart = false c.cubic.Reset() c.numAckedPackets = 0 diff --git a/internal/protocol/packet_number.go b/internal/protocol/packet_number.go index 405a07ac..1fbed569 100644 --- a/internal/protocol/packet_number.go +++ b/internal/protocol/packet_number.go @@ -1,5 +1,12 @@ package protocol +// A PacketNumber in QUIC +type PacketNumber int64 + +// InvalidPacketNumber is a packet number that is never sent. +// In QUIC, 0 is a valid packet number. +const InvalidPacketNumber = -1 + // PacketNumberLen is the length of the packet number in bytes type PacketNumberLen uint8 @@ -34,7 +41,10 @@ func DecodePacketNumber( epochDelta = PacketNumber(1) << 32 } epoch := lastPacketNumber & ^(epochDelta - 1) - prevEpochBegin := epoch - epochDelta + var prevEpochBegin PacketNumber + if epoch > epochDelta { + prevEpochBegin = epoch - epochDelta + } nextEpochBegin := epoch + epochDelta return closestTo( lastPacketNumber+1, diff --git a/internal/protocol/packet_number_test.go b/internal/protocol/packet_number_test.go index 861bbb6c..c5891441 100644 --- a/internal/protocol/packet_number_test.go +++ b/internal/protocol/packet_number_test.go @@ -2,7 +2,6 @@ package protocol import ( "fmt" - "math" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -10,6 +9,10 @@ import ( // Tests taken and extended from chrome var _ = Describe("packet number calculation", func() { + It("InvalidPacketNumber is smaller than all valid packet numbers", func() { + Expect(InvalidPacketNumber).To(BeNumerically("<", 0)) + }) + It("works with the example from the draft", func() { Expect(DecodePacketNumber(PacketNumberLen2, 0xa82f30ea, 0x9b32)).To(Equal(PacketNumber(0xa82f9b32))) }) @@ -25,10 +28,10 @@ var _ = Describe("packet number calculation", func() { epoch := getEpoch(length) epochMask := epoch - 1 wirePacketNumber := expected & epochMask - Expect(DecodePacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected))) + ExpectWithOffset(1, DecodePacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected))) } - for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} { + for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen3, PacketNumberLen4} { length := l Context(fmt.Sprintf("with %d bytes", length), func() { @@ -113,29 +116,6 @@ var _ = Describe("packet number calculation", func() { } }) - It("works near next max", func() { - maxNumber := uint64(math.MaxUint64) - maxEpoch := maxNumber & ^epochMask - - // Cases where the last number was close to the end of the range - for i := uint64(0); i < 10; i++ { - // Subtract 1, because the expected next packet number is 1 more than the - // last packet number. - last := maxNumber - i - 1 - - // Small numbers should not wrap, because they have nowhere to go. - for j := uint64(0); j < 10; j++ { - check(length, maxEpoch+j, last) - } - - // Large numbers should not wrap either. - for j := uint64(0); j < 10; j++ { - num := epoch - 1 - j - check(length, maxEpoch+num, last) - } - } - }) - Context("shortening a packet number for the header", func() { Context("shortening", func() { It("sends out low packet numbers as 2 byte", func() { diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index c3e48fd9..e97b569d 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -5,9 +5,6 @@ import ( "time" ) -// A PacketNumber in QUIC -type PacketNumber uint64 - // The PacketType is the Long Header Type type PacketType uint8