From 74ed3f7037bd7bb0efdfd74196e1e4e09168bc50 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 30 Jun 2018 12:48:42 +0700 Subject: [PATCH] remove writing of special Public Header packets Version Negotiation Packets and Public Resets are written separately, so we don't need to have code for that in the Public Header. --- internal/wire/public_header.go | 49 ++++++--------------- internal/wire/public_header_test.go | 67 +---------------------------- 2 files changed, 15 insertions(+), 101 deletions(-) diff --git a/internal/wire/public_header.go b/internal/wire/public_header.go index 1ece84c7c..176f7e439 100644 --- a/internal/wire/public_header.go +++ b/internal/wire/public_header.go @@ -12,19 +12,14 @@ import ( ) var ( - errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") - errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") - errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets") - errInvalidPacketNumberLen6 = errors.New("invalid packet number length: 6 bytes") + errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") + errInvalidPacketNumberLen6 = errors.New("invalid packet number length: 6 bytes") ) // writePublicHeader writes a Public Header. func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error { - if h.VersionFlag && pers == protocol.PerspectiveServer { - return errors.New("PublicHeader: Writing of Version Negotiation Packets not supported") - } - if h.VersionFlag && h.ResetFlag { - return errResetAndVersionFlagSet + if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) { + return errors.New("PublicHeader: Can only write regular packets") } if h.SrcConnectionID.Len() != 0 { return errors.New("PublicHeader: SrcConnectionID must not be set") @@ -49,16 +44,13 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ } publicFlagByte |= 0x04 } - // only set PacketNumberLen bits if a packet number will be written - if h.hasPacketNumber(pers) { - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - publicFlagByte |= 0x00 - case protocol.PacketNumberLen2: - publicFlagByte |= 0x10 - case protocol.PacketNumberLen4: - publicFlagByte |= 0x20 - } + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + publicFlagByte |= 0x00 + case protocol.PacketNumberLen2: + publicFlagByte |= 0x10 + case protocol.PacketNumberLen4: + publicFlagByte |= 0x20 } b.WriteByte(publicFlagByte) @@ -71,10 +63,6 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ if len(h.DiversificationNonce) > 0 { b.Write(h.DiversificationNonce) } - // if we're a server, and the VersionFlag is set, we must not include anything else in the packet - if !h.hasPacketNumber(pers) { - return nil - } switch h.PacketNumberLen { case protocol.PacketNumberLen1: @@ -197,23 +185,14 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea // getPublicHeaderLength gets the length of the publicHeader in bytes. // It can only be called for regular packets. func (h *Header) getPublicHeaderLength(pers protocol.Perspective) (protocol.ByteCount, error) { - if h.VersionFlag && h.ResetFlag { - return 0, errResetAndVersionFlagSet - } - if h.VersionFlag && pers == protocol.PerspectiveServer { - return 0, errGetLengthNotForVersionNegotiation - } - length := protocol.ByteCount(1) // 1 byte for public flags if h.PacketNumberLen == protocol.PacketNumberLen6 { return 0, errInvalidPacketNumberLen6 } - if h.hasPacketNumber(pers) { - if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 { - return 0, errPacketNumberLenNotSet - } - length += protocol.ByteCount(h.PacketNumberLen) + if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 { + return 0, errPacketNumberLenNotSet } + length += protocol.ByteCount(h.PacketNumberLen) length += protocol.ByteCount(h.DestConnectionID.Len()) // if set, always 8 bytes // Version Number in packets sent by the client if h.VersionFlag { diff --git a/internal/wire/public_header_test.go b/internal/wire/public_header_test.go index 848096290..8d3855ee4 100644 --- a/internal/wire/public_header_test.go +++ b/internal/wire/public_header_test.go @@ -258,16 +258,6 @@ var _ = Describe("Public Header", func() { })) }) - It("throws an error if both Reset Flag and Version Flag are set", func() { - b := &bytes.Buffer{} - hdr := Header{ - VersionFlag: true, - ResetFlag: true, - } - err := hdr.writePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) - Expect(err).To(MatchError(errResetAndVersionFlagSet)) - }) - It("doesn't write Version Negotiation Packets", func() { b := &bytes.Buffer{} hdr := Header{ @@ -277,7 +267,7 @@ var _ = Describe("Public Header", func() { PacketNumberLen: protocol.PacketNumberLen6, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) - Expect(err).To(MatchError("PublicHeader: Writing of Version Negotiation Packets not supported")) + Expect(err).To(MatchError("PublicHeader: Can only write regular packets")) }) It("writes packets with Version Flag, as a client", func() { @@ -300,52 +290,7 @@ var _ = Describe("Public Header", func() { Expect(b.Bytes()[12:13]).To(Equal([]byte{0x42})) }) - Context("PublicReset packets", func() { - It("sets the Reset Flag", func() { - b := &bytes.Buffer{} - hdr := Header{ - ResetFlag: true, - DestConnectionID: protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}, - } - err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - // must be the first assertion - Expect(b.Len()).To(Equal(1 + 8)) // 1 FlagByte + 8 ConnectionID - firstByte, _ := b.ReadByte() - Expect((firstByte & 0x02) >> 1).To(Equal(uint8(1))) - }) - - It("doesn't add a packet number for headers with Reset Flag sent as a client", func() { - b := &bytes.Buffer{} - hdr := Header{ - ResetFlag: true, - DestConnectionID: protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}, - PacketNumber: 2, - PacketNumberLen: protocol.PacketNumberLen6, - } - err := hdr.writePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - // must be the first assertion - Expect(b.Len()).To(Equal(1 + 8)) // 1 FlagByte + 8 ConnectionID - }) - }) - Context("getting the length", func() { - It("errors when calling getPublicHeaderLength for Version Negotiation packets", func() { - hdr := Header{VersionFlag: true} - _, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) - Expect(err).To(MatchError(errGetLengthNotForVersionNegotiation)) - }) - - It("errors when calling getPublicHeaderLength for packets that have the VersionFlag and the ResetFlag set", func() { - hdr := Header{ - ResetFlag: true, - VersionFlag: true, - } - _, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) - Expect(err).To(MatchError(errResetAndVersionFlagSet)) - }) - It("errors when PacketNumberLen is not set", func() { hdr := Header{ DestConnectionID: connID, @@ -408,16 +353,6 @@ var _ = Describe("Public Header", func() { Expect(err).NotTo(HaveOccurred()) Expect(length).To(Equal(protocol.ByteCount(1 + 3 + 1))) // 1 byte public flag, 3 byte DiversificationNonce, 1 byte PacketNumber }) - - It("gets the length of a PublicReset", func() { - hdr := Header{ - ResetFlag: true, - DestConnectionID: connID, - } - length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) - Expect(err).NotTo(HaveOccurred()) - Expect(length).To(Equal(protocol.ByteCount(1 + 8))) // 1 byte public flag, 8 byte connectionID - }) }) Context("packet number length", func() {