diff --git a/internal/utils/byteoder_big_endian_test.go b/internal/utils/byteoder_big_endian_test.go index 9ea1f5ade..5d0873a95 100644 --- a/internal/utils/byteoder_big_endian_test.go +++ b/internal/utils/byteoder_big_endian_test.go @@ -26,6 +26,23 @@ var _ = Describe("Big Endian encoding / decoding", func() { }) }) + Context("ReadUint24", func() { + It("reads a big endian", func() { + b := []byte{0x13, 0xbe, 0xef} + val, err := BigEndian.ReadUint24(bytes.NewReader(b)) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint32(0x13beef))) + }) + + It("throws an error if less than 3 bytes are passed", func() { + b := []byte{0x13, 0xbe, 0xef} + for i := 0; i < len(b); i++ { + _, err := BigEndian.ReadUint24(bytes.NewReader(b[:i])) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + Context("ReadUint32", func() { It("reads a big endian", func() { b := []byte{0x12, 0x35, 0xAB, 0xFF} @@ -58,6 +75,21 @@ var _ = Describe("Big Endian encoding / decoding", func() { }) }) + Context("WriteUint24", func() { + It("outputs 3 bytes", func() { + b := &bytes.Buffer{} + BigEndian.WriteUint24(b, uint32(1)) + Expect(b.Len()).To(Equal(3)) + }) + + It("outputs a big endian", func() { + num := uint32(0xff11aa) + b := &bytes.Buffer{} + BigEndian.WriteUint24(b, num) + Expect(b.Bytes()).To(Equal([]byte{0xff, 0x11, 0xaa})) + }) + }) + Context("WriteUint32", func() { It("outputs 4 bytes", func() { b := &bytes.Buffer{} @@ -72,60 +104,4 @@ var _ = Describe("Big Endian encoding / decoding", func() { Expect(b.Bytes()).To(Equal([]byte{0xEF, 0xAC, 0x35, 0x12})) }) }) - - Context("WriteUintN", func() { - It("writes n bytes", func() { - expected := []byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8} - m := map[uint8]uint64{ - 0: 0x0, - 1: 0x01, - 2: 0x0102, - 3: 0x010203, - 4: 0x01020304, - 5: 0x0102030405, - 6: 0x010203040506, - 7: 0x01020304050607, - 8: 0x0102030405060708, - } - for n, val := range m { - b := &bytes.Buffer{} - BigEndian.WriteUintN(b, n, val) - Expect(b.Bytes()).To(Equal(expected[:n])) - } - }) - - It("cuts off the higher order bytes", func() { - b := &bytes.Buffer{} - BigEndian.WriteUintN(b, 2, 0xdeadbeef) - Expect(b.Bytes()).To(Equal([]byte{0xbe, 0xef})) - }) - }) - - Context("ReadUintN", func() { - It("reads n bytes", func() { - m := map[uint8]uint64{ - 0: 0x0, - 1: 0x01, - 2: 0x0102, - 3: 0x010203, - 4: 0x01020304, - 5: 0x0102030405, - 6: 0x010203040506, - 7: 0x01020304050607, - 8: 0x0102030405060708, - } - for n, expected := range m { - b := bytes.NewReader([]byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}) - i, err := BigEndian.ReadUintN(b, n) - Expect(err).ToNot(HaveOccurred()) - Expect(i).To(Equal(expected)) - } - }) - - It("errors", func() { - b := bytes.NewReader([]byte{0x1, 0x2}) - _, err := BigEndian.ReadUintN(b, 3) - Expect(err).To(HaveOccurred()) - }) - }) }) diff --git a/internal/utils/byteorder.go b/internal/utils/byteorder.go index 6b92cfa26..d1f528429 100644 --- a/internal/utils/byteorder.go +++ b/internal/utils/byteorder.go @@ -7,11 +7,11 @@ import ( // A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. type ByteOrder interface { - ReadUintN(b io.ByteReader, length uint8) (uint64, error) ReadUint32(io.ByteReader) (uint32, error) + ReadUint24(io.ByteReader) (uint32, error) ReadUint16(io.ByteReader) (uint16, error) - WriteUintN(b *bytes.Buffer, length uint8, value uint64) WriteUint32(*bytes.Buffer, uint32) + WriteUint24(*bytes.Buffer, uint32) WriteUint16(*bytes.Buffer, uint16) } diff --git a/internal/utils/byteorder_big_endian.go b/internal/utils/byteorder_big_endian.go index eede9cd72..d05542e1d 100644 --- a/internal/utils/byteorder_big_endian.go +++ b/internal/utils/byteorder_big_endian.go @@ -44,6 +44,22 @@ func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) { return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil } +// ReadUint24 reads a uint24 +func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) { + var b1, b2, b3 uint8 + var err error + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil +} + // ReadUint16 reads a uint16 func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) { var b1, b2 uint8 @@ -57,17 +73,16 @@ func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) { return uint16(b1) + uint16(b2)<<8, nil } -func (bigEndian) WriteUintN(b *bytes.Buffer, length uint8, i uint64) { - for j := length; j > 0; j-- { - b.WriteByte(uint8(i >> (8 * (j - 1)))) - } -} - // WriteUint32 writes a uint32 func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) { b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)}) } +// WriteUint24 writes a uint24 +func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) { + b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)}) +} + // WriteUint16 writes a uint16 func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) { b.Write([]byte{uint8(i >> 8), uint8(i)}) diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index 5b85e573f..deaf0e2ec 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -73,11 +73,34 @@ func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNum func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 - pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen)) - if err != nil { - return err + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + n, err := b.ReadByte() + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen2: + n, err := utils.BigEndian.ReadUint16(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen3: + n, err := utils.BigEndian.ReadUint24(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen4: + n, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + default: + return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) } - h.PacketNumber = protocol.PacketNumber(pn) return nil } @@ -151,10 +174,18 @@ func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, v protocol.VersionNum } func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { - if h.PacketNumberLen == protocol.PacketNumberLenInvalid || h.PacketNumberLen > protocol.PacketNumberLen4 { + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(h.PacketNumber)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) + case protocol.PacketNumberLen3: + utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + default: return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) } - utils.BigEndian.WriteUintN(b, uint8(h.PacketNumberLen), uint64(h.PacketNumber)) return nil }