From 756a42106506b30e0fea5c538bef804946a53f7f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 8 Nov 2016 09:21:27 +0700 Subject: [PATCH] calculate PublicHeader length for packets with VersionFlag and ResetFlag --- packet_packer.go | 2 +- public_header.go | 37 +++++++++++++++++++++---------- public_header_test.go | 51 +++++++++++++++++++++++++++++++++---------- 3 files changed, 66 insertions(+), 24 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index 6ad92cfb..046c659c 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -74,7 +74,7 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea DiversificationNonce: p.cryptoSetup.DiversificationNonce(), } - publicHeaderLength, err := responsePublicHeader.GetLength() + publicHeaderLength, err := responsePublicHeader.GetLength(p.perspective) if err != nil { return nil, err } diff --git a/public_header.go b/public_header.go index ff62cb68..39f631ea 100644 --- a/public_header.go +++ b/public_header.go @@ -11,11 +11,11 @@ import ( ) var ( - errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set") - errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") - errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported") - errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") - errGetLengthOnlyForRegularPackets = errors.New("PublicHeader: GetLength can only be called for regular packets") + errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set") + errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") + errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported") + errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") + errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets") ) // The PublicHeader of a QUIC packet @@ -192,20 +192,35 @@ func ParsePublicHeader(b io.ByteReader, packetSentBy protocol.Perspective) (*Pub // GetLength gets the length of the publicHeader in bytes // can only be called for regular packets -func (h *PublicHeader) GetLength() (protocol.ByteCount, error) { - if h.VersionFlag || h.ResetFlag { - return 0, errGetLengthOnlyForRegularPackets +func (h *PublicHeader) GetLength(pers protocol.Perspective) (protocol.ByteCount, error) { + if h.VersionFlag && h.ResetFlag { + return 0, errResetAndVersionFlagSet + } + + if h.VersionFlag && pers == protocol.PerspectiveServer { + return 0, errGetLengthNotForVersionNegotiation } length := protocol.ByteCount(1) // 1 byte for public flags - if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { - return 0, errPacketNumberLenNotSet + + if h.hasPacketNumber(pers) { + if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { + return 0, errPacketNumberLenNotSet + } + length += protocol.ByteCount(h.PacketNumberLen) } + if !h.TruncateConnectionID { length += 8 // 8 bytes for the connection ID } + + // Version Number in packets sent by the client + if h.VersionFlag { + length += 4 + } + length += protocol.ByteCount(len(h.DiversificationNonce)) - length += protocol.ByteCount(h.PacketNumberLen) + return length, nil } diff --git a/public_header_test.go b/public_header_test.go index 004597bb..b90a3bb2 100644 --- a/public_header_test.go +++ b/public_header_test.go @@ -269,14 +269,17 @@ var _ = Describe("Public Header", func() { Context("GetLength", func() { It("errors when calling GetLength for Version Negotiation packets", func() { hdr := PublicHeader{VersionFlag: true} - _, err := hdr.GetLength() - Expect(err).To(MatchError(errGetLengthOnlyForRegularPackets)) + _, err := hdr.GetLength(protocol.PerspectiveServer) + Expect(err).To(MatchError(errGetLengthNotForVersionNegotiation)) }) - It("errors when calling GetLength for Public Reset packets", func() { - hdr := PublicHeader{ResetFlag: true} - _, err := hdr.GetLength() - Expect(err).To(MatchError(errGetLengthOnlyForRegularPackets)) + It("errors when calling GetLength for packets that have the VersionFlag and the ResetFlag set", func() { + hdr := PublicHeader{ + ResetFlag: true, + VersionFlag: true, + } + _, err := hdr.GetLength(protocol.PerspectiveServer) + Expect(err).To(MatchError(errResetAndVersionFlagSet)) }) It("errors when PacketNumberLen is not set", func() { @@ -284,7 +287,7 @@ var _ = Describe("Public Header", func() { ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xDECAFBAD, } - _, err := hdr.GetLength() + _, err := hdr.GetLength(protocol.PerspectiveServer) Expect(err).To(MatchError(errPacketNumberLenNotSet)) }) @@ -294,11 +297,25 @@ var _ = Describe("Public Header", func() { PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen6, } - length, err := hdr.GetLength() + length, err := hdr.GetLength(protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 6))) // 1 byte public flag, 8 bytes connectionID, and packet number }) + It("gets the lengths of a packet sent by the client with the VersionFlag set", func() { + hdr := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + TruncateConnectionID: true, + PacketNumber: 0xDECAFBAD, + PacketNumberLen: protocol.PacketNumberLen6, + VersionFlag: true, + VersionNumber: protocol.Version36, + } + length, err := hdr.GetLength(protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + Expect(length).To(Equal(protocol.ByteCount(1 + 4 + 6))) // 1 byte public flag, 4 version number, and packet number + }) + It("gets the length of a packet with longest packet number length and truncated connectionID", func() { hdr := PublicHeader{ ConnectionID: 0x4cfa9f9b668619f6, @@ -306,7 +323,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen6, } - length, err := hdr.GetLength() + length, err := hdr.GetLength(protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(length).To(Equal(protocol.ByteCount(1 + 6))) // 1 byte public flag, and packet number }) @@ -317,7 +334,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen2, } - length, err := hdr.GetLength() + length, err := hdr.GetLength(protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 2))) // 1 byte public flag, 8 byte connectionID, and packet number }) @@ -327,9 +344,19 @@ var _ = Describe("Public Header", func() { DiversificationNonce: []byte("foo"), PacketNumberLen: protocol.PacketNumberLen1, } - length, err := hdr.GetLength() + length, err := hdr.GetLength(protocol.PerspectiveServer) Expect(err).NotTo(HaveOccurred()) - Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 3 + 1))) + Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 3 + 1))) // 1 byte public flag, 8 byte connectionID, 3 byte DiversificationNonce, 1 byte PacketNumber + }) + + It("gets the length of a PublicReset", func() { + hdr := PublicHeader{ + ResetFlag: true, + ConnectionID: 0x4cfa9f9b668619f6, + } + length, err := hdr.GetLength(protocol.PerspectiveServer) + Expect(err).NotTo(HaveOccurred()) + Expect(length).To(Equal(protocol.ByteCount(1 + 8))) // 1 byte public flag, 8 byte connectionID }) })