diff --git a/frame.go b/frame.go index 70f810f1c..43dd22b2a 100644 --- a/frame.go +++ b/frame.go @@ -6,6 +6,7 @@ import ( ) // A StreamFrame of QUIC +// TODO: Maybe remove unneeded stuff, e.g. lengths? type StreamFrame struct { FinBit bool DataLengthPresent bool @@ -33,26 +34,22 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { } frame.StreamIDLength = typeByte&0x03 + 1 - sid, err := readUint64(r, frame.StreamIDLength) + sid, err := readUintN(r, frame.StreamIDLength) if err != nil { return nil, err } frame.StreamID = uint32(sid) - frame.Offset, err = readUint64(r, frame.OffsetLength) + frame.Offset, err = readUintN(r, frame.OffsetLength) if err != nil { return nil, err } if frame.DataLengthPresent { - var b1, b2 byte - if b1, err = r.ReadByte(); err != nil { + frame.DataLength, err = readUint16(r) + if err != nil { return nil, err } - if b2, err = r.ReadByte(); err != nil { - return nil, err - } - frame.DataLength = uint16(b1) + uint16(b2)<<8 } if frame.DataLength == 0 { diff --git a/public_header.go b/public_header.go index 88eea0d30..b56619923 100644 --- a/public_header.go +++ b/public_header.go @@ -47,7 +47,7 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { } // Connection ID - header.ConnectionID, err = readUint64(b, header.ConnectionIDLength) + header.ConnectionID, err = readUintN(b, header.ConnectionIDLength) if err != nil { return nil, err } @@ -55,7 +55,7 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { // Version (optional) if header.VersionFlag { var v uint64 - v, err = readUint64(b, 4) + v, err = readUintN(b, 4) if err != nil { return nil, err } @@ -63,7 +63,7 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { } // Packet number - header.PacketNumber, err = readUint64(b, header.PacketNumberLength) + header.PacketNumber, err = readUintN(b, header.PacketNumberLength) if err != nil { return nil, err } diff --git a/utils.go b/utils.go index 6f74a9c51..c8293a3d1 100644 --- a/utils.go +++ b/utils.go @@ -2,7 +2,7 @@ package quic import "io" -func readUint64(b io.ByteReader, length uint8) (uint64, error) { +func readUintN(b io.ByteReader, length uint8) (uint64, error) { var res uint64 for i := uint8(0); i < length; i++ { bt, err := b.ReadByte() @@ -13,3 +13,33 @@ func readUint64(b io.ByteReader, length uint8) (uint64, error) { } return res, nil } + +func readUint32(b io.ByteReader) (uint32, error) { + var b1, b2, b3, b4 uint8 + var err error + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b4, err = b.ReadByte(); err != nil { + return 0, err + } + return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil +} + +func readUint16(b io.ByteReader) (uint16, error) { + var b1, b2 uint8 + var err error + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + return uint16(b1) + uint16(b2)<<8, nil +}