From 0df44e46e544fd31eea602bbfdd7cabcbae7aca5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 12 May 2016 20:00:54 +0700 Subject: [PATCH] shorten LargestObserved in ACK frames fixes #82 --- frames/ack_frame.go | 24 +++++- frames/ack_frame_test.go | 57 ++++++++++++- protocol/packet_number.go | 14 ++++ protocol/packet_number_test.go | 144 +++++++++++++++++++-------------- session_test.go | 11 ++- 5 files changed, 178 insertions(+), 72 deletions(-) diff --git a/frames/ack_frame.go b/frames/ack_frame.go index f9c43fb41..ab02f077b 100644 --- a/frames/ack_frame.go +++ b/frames/ack_frame.go @@ -24,7 +24,13 @@ type AckFrame struct { // Write writes an ACK frame. func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - typeByte := uint8(0x40 | 0x0C) + largestObservedLen := protocol.GetPacketNumberLength(f.LargestObserved) + + typeByte := uint8(0x40) + + if largestObservedLen != protocol.PacketNumberLen1 { + typeByte ^= (uint8(largestObservedLen / 2)) << 2 + } if f.HasNACK() { typeByte |= (0x20 | 0x03) @@ -34,7 +40,18 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error b.WriteByte(typeByte) b.WriteByte(f.Entropy) - utils.WriteUint48(b, uint64(f.LargestObserved)) // TODO: send the correct length + + switch largestObservedLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(f.LargestObserved)) + case protocol.PacketNumberLen2: + utils.WriteUint16(b, uint16(f.LargestObserved)) + case protocol.PacketNumberLen4: + utils.WriteUint32(b, uint32(f.LargestObserved)) + case protocol.PacketNumberLen6: + utils.WriteUint48(b, uint64(f.LargestObserved)) + } + utils.WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond)) b.WriteByte(0x01) // Just one timestamp b.WriteByte(0x00) // Delta Largest observed @@ -99,7 +116,8 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error // MinLength of a written frame func (f *AckFrame) MinLength() (protocol.ByteCount, error) { - l := 1 + 1 + 6 + 2 + 1 + 1 + 4 + l := 1 + 1 + 2 + 1 + 1 + 4 // 1 TypeByte, 1 Entropy, 2 ACK delay time, 1 Num Timestamp, 1 Delta Largest Observed, 4 FirstTimestamp + l += int(protocol.GetPacketNumberLength(f.LargestObserved)) l += (1 + 2) * 0 /* TODO: num_timestamps */ if f.HasNACK() { l += 1 + (6+1)*len(f.NackRanges) diff --git a/frames/ack_frame_test.go b/frames/ack_frame_test.go index 2cef07ac5..9cdb71121 100644 --- a/frames/ack_frame_test.go +++ b/frames/ack_frame_test.go @@ -253,8 +253,8 @@ var _ = Describe("AckFrame", func() { err := frame.Write(b, 32) Expect(err).ToNot(HaveOccurred()) // check all values except the DelayTime - Expect(b.Bytes()[0:8]).To(Equal([]byte{0x4c, 0x02, 0x01, 0, 0, 0, 0, 0})) - Expect(b.Bytes()[10:]).To(Equal([]byte{1, 0, 0, 0, 0, 0})) + Expect(b.Bytes()[0:3]).To(Equal([]byte{0x40, 0x02, 0x01})) + Expect(b.Bytes()[5:]).To(Equal([]byte{1, 0, 0, 0, 0, 0})) }) It("calculates the DelayTime", func() { @@ -267,7 +267,7 @@ var _ = Describe("AckFrame", func() { delayTime := frame.DelayTime var b2 bytes.Buffer utils.WriteUfloat16(&b2, uint64(delayTime/time.Microsecond)) - Expect(b.Bytes()[8:10]).To(Equal(b2.Bytes())) + Expect(b.Bytes()[3:5]).To(Equal(b2.Bytes())) }) It("writes a frame with one NACK range", func() { @@ -411,6 +411,48 @@ var _ = Describe("AckFrame", func() { Expect(missingPacketBytes[22:28]).To(Equal([]byte{0, 0, 0, 0, 0, 0})) // missingPacketSequenceNumberDelta #4 Expect(missingPacketBytes[28]).To(Equal(uint8(0xFF))) // rangeLength #4 }) + + Context("LargestObserved length", func() { + It("writes a 1 byte LargestObserved value", func() { + frame := AckFrame{ + LargestObserved: 7, + } + err := frame.Write(b, 32) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x4C).To(Equal(uint8(0x40))) + Expect(b.Bytes()[2]).To(Equal(uint8(7))) + }) + + It("writes a 2 byte LargestObserved value", func() { + frame := AckFrame{ + LargestObserved: 0x1337, + } + err := frame.Write(b, 32) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x4C).To(Equal(uint8(0x44))) + Expect(b.Bytes()[2:4]).To(Equal([]byte{0x37, 0x13})) + }) + + It("writes a 4 byte LargestObserved value", func() { + frame := AckFrame{ + LargestObserved: 0xDECAFBAD, + } + err := frame.Write(b, 32) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x4C).To(Equal(uint8(0x48))) + Expect(b.Bytes()[2:6]).To(Equal([]byte{0xAD, 0xFB, 0xCA, 0xDE})) + }) + + It("writes a 6 byte LargestObserved value", func() { + frame := AckFrame{ + LargestObserved: 0xDEADBEEFCAFE, + } + err := frame.Write(b, 32) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x4C).To(Equal(uint8(0x4C))) + Expect(b.Bytes()[2:8]).To(Equal([]byte{0xFE, 0xCA, 0xEF, 0xBE, 0xAD, 0xDE})) + }) + }) }) Context("min length", func() { @@ -423,6 +465,15 @@ var _ = Describe("AckFrame", func() { Expect(f.MinLength()).To(Equal(protocol.ByteCount(b.Len()))) }) + It("has proper min length with a large LargestObserved", func() { + f := &AckFrame{ + Entropy: 2, + LargestObserved: 0xDEADBEEFCAFE, + } + f.Write(b, 2) + Expect(f.MinLength()).To(Equal(protocol.ByteCount(b.Len()))) + }) + It("has proper min length with nack ranges", func() { f := &AckFrame{ Entropy: 2, diff --git a/protocol/packet_number.go b/protocol/packet_number.go index 4ca004323..fb9e55883 100644 --- a/protocol/packet_number.go +++ b/protocol/packet_number.go @@ -42,3 +42,17 @@ func GetPacketNumberLengthForPublicHeader(packetNumber PacketNumber, highestAcke // we do not check if there are less than 2^46 packets in flight, since flow control and congestion control will limit this number *a lot* sooner return PacketNumberLen6 } + +// GetPacketNumberLength gets the minimum length needed to fully represent the packet number +func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen { + if packetNumber < (1 << (uint8(PacketNumberLen1) * 8)) { + return PacketNumberLen1 + } + if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) { + return PacketNumberLen2 + } + if packetNumber < (1 << (uint8(PacketNumberLen4) * 8)) { + return PacketNumberLen4 + } + return PacketNumberLen6 +} diff --git a/protocol/packet_number_test.go b/protocol/packet_number_test.go index 1313ad4e7..5330d5495 100644 --- a/protocol/packet_number_test.go +++ b/protocol/packet_number_test.go @@ -126,80 +126,100 @@ var _ = Describe("packet number calculation", func() { } }) - Context("shortening a packet number", func() { - It("sends out low packet numbers as 1 byte", func() { - length := GetPacketNumberLengthForPublicHeader(4, 2) - Expect(length).To(Equal(PacketNumberLen1)) + Context("shortening a packet number for the PublicHeader", func() { + Context("shortening", func() { + It("sends out low packet numbers as 1 byte", func() { + length := GetPacketNumberLengthForPublicHeader(4, 2) + Expect(length).To(Equal(PacketNumberLen1)) + }) + + It("sends out high packet numbers as 1 byte, if all ACKs are received", func() { + length := GetPacketNumberLengthForPublicHeader(0xDEADBEEF, 0xDEADBEEF-1) + Expect(length).To(Equal(PacketNumberLen1)) + }) + + It("sends out higher packet numbers as 2 bytes, if a lot of ACKs are missing", func() { + length := GetPacketNumberLengthForPublicHeader(200, 2) + Expect(length).To(Equal(PacketNumberLen2)) + }) }) - It("sends out high packet numbers as 1 byte, if all ACKs are received", func() { - length := GetPacketNumberLengthForPublicHeader(0xDEADBEEF, 0xDEADBEEF-1) - Expect(length).To(Equal(PacketNumberLen1)) - }) + Context("self-consistency", func() { + It("works for small packet numbers", func() { + for i := uint64(1); i < 10000; i++ { + packetNumber := PacketNumber(i) + highestAcked := PacketNumber(1) + length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked) + wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - It("sends out higher packet numbers as 2 bytes, if a lot of ACKs are missing", func() { - length := GetPacketNumberLengthForPublicHeader(200, 2) - Expect(length).To(Equal(PacketNumberLen2)) + inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber)) + Expect(inferedPacketNumber).To(Equal(packetNumber)) + } + }) + + It("works for small packet numbers and increasing ACKed packets", func() { + for i := uint64(1); i < 10000; i++ { + packetNumber := PacketNumber(i) + highestAcked := PacketNumber(i / 2) + length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked) + wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) + + inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber)) + Expect(inferedPacketNumber).To(Equal(packetNumber)) + } + }) + + It("also works for larger packet numbers", func() { + increment := uint64(1) + for i := uint64(1); i < (2 << 46); i += increment { + packetNumber := PacketNumber(i) + highestAcked := PacketNumber(1) + length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked) + wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) + + inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber)) + Expect(inferedPacketNumber).To(Equal(packetNumber)) + + switch length { + case PacketNumberLen2: + increment = 100 + case PacketNumberLen4: + increment = 50000 + case PacketNumberLen6: + increment = 100000000 + } + } + }) + + It("works for packet numbers larger than 2^48", func() { + for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 45) { + packetNumber := PacketNumber(i) + highestAcked := PacketNumber(i - 1000) + length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked) + wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) + + inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber)) + Expect(inferedPacketNumber).To(Equal(packetNumber)) + } + }) }) }) - Context("self-consistency", func() { - It("works for small packet numbers", func() { - for i := uint64(1); i < 10000; i++ { - packetNumber := PacketNumber(i) - highestAcked := PacketNumber(1) - length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked) - wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - - inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber)) - Expect(inferedPacketNumber).To(Equal(packetNumber)) - } + Context("determining the minimum length of a packet number", func() { + It("1 byte", func() { + Expect(GetPacketNumberLength(0xFF)).To(Equal(PacketNumberLen1)) }) - It("works for small packet numbers and increasing ACKed packets", func() { - for i := uint64(1); i < 10000; i++ { - packetNumber := PacketNumber(i) - highestAcked := PacketNumber(i / 2) - length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked) - wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - - inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber)) - Expect(inferedPacketNumber).To(Equal(packetNumber)) - } + It("2 byte", func() { + Expect(GetPacketNumberLength(0xFFFF)).To(Equal(PacketNumberLen2)) }) - It("also works for larger packet numbers", func() { - increment := uint64(1) - for i := uint64(1); i < (2 << 46); i += increment { - packetNumber := PacketNumber(i) - highestAcked := PacketNumber(1) - length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked) - wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - - inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber)) - Expect(inferedPacketNumber).To(Equal(packetNumber)) - - switch length { - case PacketNumberLen2: - increment = 100 - case PacketNumberLen4: - increment = 50000 - case PacketNumberLen6: - increment = 100000000 - } - } + It("4 byte", func() { + Expect(GetPacketNumberLength(0xFFFFFFFF)).To(Equal(PacketNumberLen4)) }) - It("works for packet numbers larger than 2^48", func() { - for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 45) { - packetNumber := PacketNumber(i) - highestAcked := PacketNumber(i - 1000) - length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked) - wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - - inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber)) - Expect(inferedPacketNumber).To(Equal(packetNumber)) - } + It("6 byte", func() { + Expect(GetPacketNumberLength(0xFFFFFFFFFFFF)).To(Equal(PacketNumberLen6)) }) }) }) diff --git a/session_test.go b/session_test.go index 73694ace2..b48bf60ed 100644 --- a/session_test.go +++ b/session_test.go @@ -317,12 +317,15 @@ var _ = Describe("Session", func() { }) It("sends ack frames", func() { - session.receivedPacketHandler.ReceivedPacket(1, true) + packetNumber := protocol.PacketNumber(0x0135) + var entropy ackhandler.EntropyAccumulator + session.receivedPacketHandler.ReceivedPacket(packetNumber, true) + entropy.Add(packetNumber, true) err := session.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(conn.written).To(HaveLen(1)) - // test for the beginning of an ACK frame: TypeByte until LargestObserved - Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x4c, 0x2, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0}))) + // test for the beginning of an ACK frame: Entropy until LargestObserved + Expect(conn.written[0]).To(ContainSubstring(string([]byte{byte(entropy), 0x35, 0x01}))) }) It("sends queued stream frames", func() { @@ -335,7 +338,7 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) Expect(conn.written).To(HaveLen(1)) // test for the beginning of an ACK frame: TypeByte until LargestObserved - Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x4c, 0x2, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0}))) + Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x40, 0x2, 0x1}))) Expect(conn.written[0]).To(ContainSubstring(string("foobar"))) })