replace WriteUintN and ReadUintN by explicit functions

This commit is contained in:
Marten Seemann
2019-08-20 15:59:39 +07:00
parent 2133d01956
commit ccb197939d
4 changed files with 92 additions and 70 deletions

View File

@@ -26,6 +26,23 @@ var _ = Describe("Big Endian encoding / decoding", func() {
})
})
Context("ReadUint24", func() {
It("reads a big endian", func() {
b := []byte{0x13, 0xbe, 0xef}
val, err := BigEndian.ReadUint24(bytes.NewReader(b))
Expect(err).ToNot(HaveOccurred())
Expect(val).To(Equal(uint32(0x13beef)))
})
It("throws an error if less than 3 bytes are passed", func() {
b := []byte{0x13, 0xbe, 0xef}
for i := 0; i < len(b); i++ {
_, err := BigEndian.ReadUint24(bytes.NewReader(b[:i]))
Expect(err).To(MatchError(io.EOF))
}
})
})
Context("ReadUint32", func() {
It("reads a big endian", func() {
b := []byte{0x12, 0x35, 0xAB, 0xFF}
@@ -58,6 +75,21 @@ var _ = Describe("Big Endian encoding / decoding", func() {
})
})
Context("WriteUint24", func() {
It("outputs 3 bytes", func() {
b := &bytes.Buffer{}
BigEndian.WriteUint24(b, uint32(1))
Expect(b.Len()).To(Equal(3))
})
It("outputs a big endian", func() {
num := uint32(0xff11aa)
b := &bytes.Buffer{}
BigEndian.WriteUint24(b, num)
Expect(b.Bytes()).To(Equal([]byte{0xff, 0x11, 0xaa}))
})
})
Context("WriteUint32", func() {
It("outputs 4 bytes", func() {
b := &bytes.Buffer{}
@@ -72,60 +104,4 @@ var _ = Describe("Big Endian encoding / decoding", func() {
Expect(b.Bytes()).To(Equal([]byte{0xEF, 0xAC, 0x35, 0x12}))
})
})
Context("WriteUintN", func() {
It("writes n bytes", func() {
expected := []byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}
m := map[uint8]uint64{
0: 0x0,
1: 0x01,
2: 0x0102,
3: 0x010203,
4: 0x01020304,
5: 0x0102030405,
6: 0x010203040506,
7: 0x01020304050607,
8: 0x0102030405060708,
}
for n, val := range m {
b := &bytes.Buffer{}
BigEndian.WriteUintN(b, n, val)
Expect(b.Bytes()).To(Equal(expected[:n]))
}
})
It("cuts off the higher order bytes", func() {
b := &bytes.Buffer{}
BigEndian.WriteUintN(b, 2, 0xdeadbeef)
Expect(b.Bytes()).To(Equal([]byte{0xbe, 0xef}))
})
})
Context("ReadUintN", func() {
It("reads n bytes", func() {
m := map[uint8]uint64{
0: 0x0,
1: 0x01,
2: 0x0102,
3: 0x010203,
4: 0x01020304,
5: 0x0102030405,
6: 0x010203040506,
7: 0x01020304050607,
8: 0x0102030405060708,
}
for n, expected := range m {
b := bytes.NewReader([]byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})
i, err := BigEndian.ReadUintN(b, n)
Expect(err).ToNot(HaveOccurred())
Expect(i).To(Equal(expected))
}
})
It("errors", func() {
b := bytes.NewReader([]byte{0x1, 0x2})
_, err := BigEndian.ReadUintN(b, 3)
Expect(err).To(HaveOccurred())
})
})
})

View File

@@ -7,11 +7,11 @@ import (
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
type ByteOrder interface {
ReadUintN(b io.ByteReader, length uint8) (uint64, error)
ReadUint32(io.ByteReader) (uint32, error)
ReadUint24(io.ByteReader) (uint32, error)
ReadUint16(io.ByteReader) (uint16, error)
WriteUintN(b *bytes.Buffer, length uint8, value uint64)
WriteUint32(*bytes.Buffer, uint32)
WriteUint24(*bytes.Buffer, uint32)
WriteUint16(*bytes.Buffer, uint16)
}

View File

@@ -44,6 +44,22 @@ func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
}
// ReadUint24 reads a uint24
func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) {
var b1, b2, b3 uint8
var err error
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil
}
// ReadUint16 reads a uint16
func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
var b1, b2 uint8
@@ -57,17 +73,16 @@ func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
return uint16(b1) + uint16(b2)<<8, nil
}
func (bigEndian) WriteUintN(b *bytes.Buffer, length uint8, i uint64) {
for j := length; j > 0; j-- {
b.WriteByte(uint8(i >> (8 * (j - 1))))
}
}
// WriteUint32 writes a uint32
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint24 writes a uint24
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint16 writes a uint16
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i >> 8), uint8(i)})

View File

@@ -73,11 +73,34 @@ func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNum
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
n, err := b.ReadByte()
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(pn)
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen2:
n, err := utils.BigEndian.ReadUint16(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen3:
n, err := utils.BigEndian.ReadUint24(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen4:
n, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
return nil
}
@@ -151,10 +174,18 @@ func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, v protocol.VersionNum
}
func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
if h.PacketNumberLen == protocol.PacketNumberLenInvalid || h.PacketNumberLen > protocol.PacketNumberLen4 {
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen3:
utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
utils.BigEndian.WriteUintN(b, uint8(h.PacketNumberLen), uint64(h.PacketNumber))
return nil
}