diff --git a/public_header.go b/public_header.go index 970dac009..ee214388a 100644 --- a/public_header.go +++ b/public_header.go @@ -10,10 +10,11 @@ import ( ) var ( - errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set") - errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") - errReceivedTruncatedConnectionID = errors.New("PublicHeader: Receiving packets with truncated ConnectionID is not supported") - errInvalidConnectionID = errors.New("PublicHeader: connection ID cannot be 0") + errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set") + errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") + errReceivedTruncatedConnectionID = errors.New("PublicHeader: Receiving packets with truncated ConnectionID is not supported") + errInvalidConnectionID = errors.New("PublicHeader: connection ID cannot be 0") + errGetLengthOnlyForRegularPackets = errors.New("PublicHeader: GetLength can only be called for regular packets") ) // The PublicHeader of a QUIC packet @@ -143,3 +144,21 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { return header, nil } + +// GetLength gets the length of the PublicHeader in bytes +// can only be called for regular packets +func (h *PublicHeader) GetLength() (uint8, error) { + if h.VersionFlag || h.ResetFlag { + return 0, errGetLengthOnlyForRegularPackets + } + + length := uint8(1) // 1 byte for public flags + if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { + return 0, errPacketNumberLenNotSet + } + if !h.TruncateConnectionID { + length += 8 // 8 bytes for the connection ID + } + length += uint8(h.PacketNumberLen) + return length, nil +} diff --git a/public_header_test.go b/public_header_test.go index b0ecdcd73..808907043 100644 --- a/public_header_test.go +++ b/public_header_test.go @@ -148,6 +148,66 @@ var _ = Describe("Public Header", func() { Expect(b.Bytes()).ToNot(ContainSubstring(string([]byte{0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}))) }) + Context("GetLength", func() { + It("errors when calling GetLength for Version Negotiation packets", func() { + publicHeader := PublicHeader{VersionFlag: true} + _, err := publicHeader.GetLength() + Expect(err).To(HaveOccurred()) + Expect(err).To(Equal(errGetLengthOnlyForRegularPackets)) + }) + + It("errors when calling GetLength for Public Reset packets", func() { + publicHeader := PublicHeader{ResetFlag: true} + _, err := publicHeader.GetLength() + Expect(err).To(HaveOccurred()) + Expect(err).To(Equal(errGetLengthOnlyForRegularPackets)) + }) + + It("errors when PacketNumberLen is not set", func() { + publicHeader := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 0xDECAFBAD, + } + _, err := publicHeader.GetLength() + Expect(err).To(HaveOccurred()) + Expect(err).To(Equal(errPacketNumberLenNotSet)) + }) + + It("gets the length of a packet with longest packet number length and connectionID", func() { + publicHeader := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 0xDECAFBAD, + PacketNumberLen: protocol.PacketNumberLen6, + } + length, err := publicHeader.GetLength() + Expect(err).ToNot(HaveOccurred()) + Expect(length).To(Equal(uint8(1 + 8 + 6))) // 1 byte public flag, 8 bytes connectionID, and packet number + }) + + It("gets the length of a packet with longest packet number length and truncated connectionID", func() { + publicHeader := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + TruncateConnectionID: true, + PacketNumber: 0xDECAFBAD, + PacketNumberLen: protocol.PacketNumberLen6, + } + length, err := publicHeader.GetLength() + Expect(err).ToNot(HaveOccurred()) + Expect(length).To(Equal(uint8(1 + 6))) // 1 byte public flag, and packet number + }) + + It("gets the length of a packet 2 byte packet number length ", func() { + publicHeader := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 0xDECAFBAD, + PacketNumberLen: protocol.PacketNumberLen2, + } + length, err := publicHeader.GetLength() + Expect(err).ToNot(HaveOccurred()) + Expect(length).To(Equal(uint8(1 + 8 + 2))) // 1 byte public flag, 8 byte connectionID, and packet number + }) + }) + Context("packet number length", func() { It("doesn't write a header if the packet number length is not set", func() { b := &bytes.Buffer{}