From 6c394acde703828361cb50e756600dcb18e9d63c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 22 Oct 2017 00:15:09 +0700 Subject: [PATCH] never use a 6 byte packet number According to the IETF draft, the only packet number lengths are 1, 2 and 4 bytes. With the given formula for the packet number derivation, we would have only sent a 6 byte packet number if the difference between the largest unacked and the current packet number exceeded 2^31, so this would never have happened anyway. --- internal/protocol/packet_number.go | 12 ++++-------- internal/protocol/packet_number_test.go | 18 ++++++++---------- packet_packer.go | 2 +- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/internal/protocol/packet_number.go b/internal/protocol/packet_number.go index c4f468ad..4bc8bfc9 100644 --- a/internal/protocol/packet_number.go +++ b/internal/protocol/packet_number.go @@ -27,18 +27,14 @@ func delta(a, b PacketNumber) PacketNumber { return a - b } -// GetPacketNumberLengthForPublicHeader gets the length of the packet number for the public header +// GetPacketNumberLengthForHeader gets the length of the packet number for the public header // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances -func GetPacketNumberLengthForPublicHeader(packetNumber PacketNumber, leastUnacked PacketNumber) PacketNumberLen { +func GetPacketNumberLengthForHeader(packetNumber PacketNumber, leastUnacked PacketNumber) PacketNumberLen { diff := uint64(packetNumber - leastUnacked) - if diff < (2 << (uint8(PacketNumberLen2)*8 - 2)) { + if diff < (1 << (uint8(PacketNumberLen2)*8 - 1)) { return PacketNumberLen2 } - if diff < (2 << (uint8(PacketNumberLen4)*8 - 2)) { - return PacketNumberLen4 - } - // we do not check if there are less than 2^46 packets in flight, since flow control and congestion control will limit this number *a lot* sooner - return PacketNumberLen6 + return PacketNumberLen4 } // GetPacketNumberLength gets the minimum length needed to fully represent the packet number diff --git a/internal/protocol/packet_number_test.go b/internal/protocol/packet_number_test.go index aa1886a2..5f84c0ee 100644 --- a/internal/protocol/packet_number_test.go +++ b/internal/protocol/packet_number_test.go @@ -129,17 +129,17 @@ var _ = Describe("packet number calculation", func() { Context("shortening a packet number for the publicHeader", func() { Context("shortening", func() { It("sends out low packet numbers as 2 byte", func() { - length := GetPacketNumberLengthForPublicHeader(4, 2) + length := GetPacketNumberLengthForHeader(4, 2) Expect(length).To(Equal(PacketNumberLen2)) }) It("sends out high packet numbers as 2 byte, if all ACKs are received", func() { - length := GetPacketNumberLengthForPublicHeader(0xDEADBEEF, 0xDEADBEEF-1) + length := GetPacketNumberLengthForHeader(0xDEADBEEF, 0xDEADBEEF-1) Expect(length).To(Equal(PacketNumberLen2)) }) It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() { - length := GetPacketNumberLengthForPublicHeader(40000, 2) + length := GetPacketNumberLengthForHeader(40000, 2) Expect(length).To(Equal(PacketNumberLen4)) }) }) @@ -149,7 +149,7 @@ var _ = Describe("packet number calculation", func() { for i := uint64(1); i < 10000; i++ { packetNumber := PacketNumber(i) leastUnacked := PacketNumber(1) - length := GetPacketNumberLengthForPublicHeader(packetNumber, leastUnacked) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) @@ -161,7 +161,7 @@ var _ = Describe("packet number calculation", func() { for i := uint64(1); i < 10000; i++ { packetNumber := PacketNumber(i) leastUnacked := PacketNumber(i / 2) - length := GetPacketNumberLengthForPublicHeader(packetNumber, leastUnacked) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) @@ -171,10 +171,10 @@ var _ = Describe("packet number calculation", func() { It("also works for larger packet numbers", func() { increment := uint64(1 << (8 - 3)) - for i := uint64(1); i < (2 << 46); i += increment { + for i := uint64(1); i < (2 << 31); i += increment { packetNumber := PacketNumber(i) leastUnacked := PacketNumber(1) - length := GetPacketNumberLengthForPublicHeader(packetNumber, leastUnacked) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) @@ -185,8 +185,6 @@ var _ = Describe("packet number calculation", func() { increment = 1 << (2*8 - 3) case PacketNumberLen4: increment = 1 << (4*8 - 3) - case PacketNumberLen6: - increment = 1 << (6*8 - 3) } } }) @@ -195,7 +193,7 @@ var _ = Describe("packet number calculation", func() { for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) { packetNumber := PacketNumber(i) leastUnacked := PacketNumber(i - 1000) - length := GetPacketNumberLengthForPublicHeader(packetNumber, leastUnacked) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) diff --git a/packet_packer.go b/packet_packer.go index efc102d6..57ff3532 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -264,7 +264,7 @@ func (p *packetPacker) QueueControlFrame(frame wire.Frame) { func (p *packetPacker) getPublicHeader(encLevel protocol.EncryptionLevel) *wire.PublicHeader { pnum := p.packetNumberGenerator.Peek() - packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(pnum, p.leastUnacked) + packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked) publicHeader := &wire.PublicHeader{ ConnectionID: p.connectionID, PacketNumber: pnum,