diff --git a/public_header.go b/public_header.go index dcc823b7..0cd11c05 100644 --- a/public_header.go +++ b/public_header.go @@ -55,21 +55,12 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { } // Version (optional) + if header.VersionFlag { - var b1, b2, b3, b4 uint8 - if b1, err = b.ReadByte(); err != nil { + header.QuicVersion, err = utils.ReadUint32BigEndian(b) + if err != nil { return nil, err } - if b2, err = b.ReadByte(); err != nil { - return nil, err - } - if b3, err = b.ReadByte(); err != nil { - return nil, err - } - if b4, err = b.ReadByte(); err != nil { - return nil, err - } - header.QuicVersion = uint32(b4) + uint32(b3)<<8 + uint32(b2)<<16 + uint32(b1)<<24 } // Packet number diff --git a/utils/utils.go b/utils/utils.go index 5438dbfe..1c2418ce 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -2,6 +2,7 @@ package utils import ( "bytes" + "encoding/binary" "io" ) @@ -21,34 +22,43 @@ func ReadUintN(b io.ByteReader, length uint8) (uint64, error) { // ReadUint32 reads a uint32 func ReadUint32(b io.ByteReader) (uint32, error) { - var b1, b2, b3, b4 uint8 - var err error - if b1, err = b.ReadByte(); err != nil { + slice, err := readNBytes(b, 4) + if err != nil { return 0, err } - if b2, err = b.ReadByte(); err != nil { + return binary.LittleEndian.Uint32(slice), nil +} + +// ReadUint32BigEndian reads a uint32 Big Endian +func ReadUint32BigEndian(b io.ByteReader) (uint32, error) { + slice, err := readNBytes(b, 4) + if err != nil { return 0, err } - if b3, err = b.ReadByte(); err != nil { - return 0, err - } - if b4, err = b.ReadByte(); err != nil { - return 0, err - } - return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil + return binary.BigEndian.Uint32(slice), nil } // ReadUint16 reads a uint16 func ReadUint16(b io.ByteReader) (uint16, error) { - var b1, b2 uint8 + slice, err := readNBytes(b, 2) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint16(slice), nil +} + +func readNBytes(b io.ByteReader, n int) ([]byte, error) { + slice := make([]byte, n, n) + var val uint8 var err error - if b1, err = b.ReadByte(); err != nil { - return 0, err + for i := 0; i < n; i++ { + val, err = b.ReadByte() + if err != nil { + return []byte{}, err + } + slice[i] = val } - if b2, err = b.ReadByte(); err != nil { - return 0, err - } - return uint16(b1) + uint16(b2)<<8, nil + return slice, nil } // WriteUint64 writes a uint64 @@ -71,6 +81,14 @@ func WriteUint32(b *bytes.Buffer, i uint32) { b.WriteByte(uint8((i >> 24) & 0xff)) } +// WriteUint32BigEndian writes a uint32 +func WriteUint32BigEndian(b *bytes.Buffer, i uint32) { + b.WriteByte(uint8((i >> 24) & 0xff)) + b.WriteByte(uint8((i >> 16) & 0xff)) + b.WriteByte(uint8((i >> 8) & 0xff)) + b.WriteByte(uint8(i & 0xff)) +} + // WriteUint16 writes a uint16 func WriteUint16(b *bytes.Buffer, i uint16) { b.WriteByte(uint8(i & 0xff)) diff --git a/utils/utils_test.go b/utils/utils_test.go index 3291c76e..d45550f4 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -23,21 +23,6 @@ var _ = Describe("Utils", func() { }) }) - Context("WriteUint32", func() { - It("outputs 4 bytes", func() { - b := &bytes.Buffer{} - WriteUint32(b, uint32(1)) - Expect(b.Len()).To(Equal(4)) - }) - - It("outputs a little endian", func() { - num := uint32(0xEFAC3512) - b := &bytes.Buffer{} - WriteUint32(b, num) - Expect(b.Bytes()).To(Equal([]byte{0x12, 0x35, 0xAC, 0xEF})) - }) - }) - Context("ReadUint16", func() { It("reads a little endian", func() { b := []byte{0x13, 0xEF} @@ -68,6 +53,21 @@ var _ = Describe("Utils", func() { }) }) + Context("ReadUint32BigEndian", func() { + It("reads a big endian", func() { + b := []byte{0x12, 0x35, 0xAB, 0xFF} + val, err := ReadUint32BigEndian(bytes.NewReader(b)) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint32(0x1235ABFF))) + }) + + It("throws an error if less than 4 bytes are passed", func() { + b := []byte{0x13, 0x34, 0xEA} + _, err := ReadUint32(bytes.NewReader(b)) + Expect(err).To(HaveOccurred()) + }) + }) + Context("WriteUint16", func() { It("outputs 2 bytes", func() { b := &bytes.Buffer{} @@ -82,4 +82,34 @@ var _ = Describe("Utils", func() { Expect(b.Bytes()).To(Equal([]byte{0x11, 0xFF})) }) }) + + Context("WriteUint32", func() { + It("outputs 4 bytes", func() { + b := &bytes.Buffer{} + WriteUint32(b, uint32(1)) + Expect(b.Len()).To(Equal(4)) + }) + + It("outputs a little endian", func() { + num := uint32(0xEFAC3512) + b := &bytes.Buffer{} + WriteUint32(b, num) + Expect(b.Bytes()).To(Equal([]byte{0x12, 0x35, 0xAC, 0xEF})) + }) + }) + + Context("WriteUint32BigEndian", func() { + It("outputs 4 bytes", func() { + b := &bytes.Buffer{} + WriteUint32BigEndian(b, uint32(1)) + Expect(b.Len()).To(Equal(4)) + }) + + It("outputs a big endian", func() { + num := uint32(0xEFAC3512) + b := &bytes.Buffer{} + WriteUint32BigEndian(b, num) + Expect(b.Bytes()).To(Equal([]byte{0xEF, 0xAC, 0x35, 0x12})) + }) + }) })