diff --git a/public_header.go b/public_header.go index 2b20a08a9..0dc4baefb 100644 --- a/public_header.go +++ b/public_header.go @@ -89,6 +89,10 @@ func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pe return nil } + if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { + return errPacketNumberLenNotSet + } + switch h.PacketNumberLen { case protocol.PacketNumberLen1: b.WriteByte(uint8(h.PacketNumber)) diff --git a/public_header_test.go b/public_header_test.go index aee201dfb..8a49eca70 100644 --- a/public_header_test.go +++ b/public_header_test.go @@ -94,7 +94,8 @@ var _ = Describe("Public Header", func() { PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } - hdr.Write(b, protocol.Version35, protocol.PerspectiveServer) + err := hdr.Write(b, protocol.Version35, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x38, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 2, 0, 0, 0, 0, 0})) }) @@ -105,10 +106,21 @@ var _ = Describe("Public Header", func() { PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen6, } - hdr.Write(b, protocol.Version35, protocol.PerspectiveClient) + err := hdr.Write(b, protocol.Version35, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x38, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x37, 0x13, 0, 0, 0, 0})) }) + It("refuses to write a Public Header if the PacketNumberLen is not set", func() { + hdr := PublicHeader{ + ConnectionID: 1, + PacketNumber: 2, + } + b := &bytes.Buffer{} + err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + Expect(err).To(MatchError(errPacketNumberLenNotSet)) + }) + It("truncates the connection ID", func() { b := &bytes.Buffer{} hdr := PublicHeader{ @@ -173,7 +185,8 @@ var _ = Describe("Public Header", func() { PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } - hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) // must be the first assertion Expect(b.Len()).To(Equal(1 + 8)) // 1 FlagByte + 8 ConnectionID firstByte, _ := b.ReadByte() @@ -190,7 +203,8 @@ var _ = Describe("Public Header", func() { PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen6, } - hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveClient) + err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) // must be the first assertion Expect(b.Len()).To(Equal(1 + 8 + 4 + 6)) // 1 FlagByte + 8 ConnectionID + 4 version number + 6 PacketNumber firstByte, _ := b.ReadByte() @@ -210,7 +224,8 @@ var _ = Describe("Public Header", func() { PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } - hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) // must be the first assertion Expect(b.Len()).To(Equal(1 + 8)) // 1 FlagByte + 8 ConnectionID firstByte, _ := b.ReadByte() @@ -225,7 +240,8 @@ var _ = Describe("Public Header", func() { PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } - hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveClient) + err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) // must be the first assertion Expect(b.Len()).To(Equal(1 + 8)) // 1 FlagByte + 8 ConnectionID })