improve util functions

This commit is contained in:
Lucas Clemente
2016-04-08 19:28:14 +02:00
parent 671557542b
commit 3492497230
3 changed files with 39 additions and 12 deletions

View File

@@ -6,6 +6,7 @@ import (
)
// A StreamFrame of QUIC
// TODO: Maybe remove unneeded stuff, e.g. lengths?
type StreamFrame struct {
FinBit bool
DataLengthPresent bool
@@ -33,26 +34,22 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) {
}
frame.StreamIDLength = typeByte&0x03 + 1
sid, err := readUint64(r, frame.StreamIDLength)
sid, err := readUintN(r, frame.StreamIDLength)
if err != nil {
return nil, err
}
frame.StreamID = uint32(sid)
frame.Offset, err = readUint64(r, frame.OffsetLength)
frame.Offset, err = readUintN(r, frame.OffsetLength)
if err != nil {
return nil, err
}
if frame.DataLengthPresent {
var b1, b2 byte
if b1, err = r.ReadByte(); err != nil {
frame.DataLength, err = readUint16(r)
if err != nil {
return nil, err
}
if b2, err = r.ReadByte(); err != nil {
return nil, err
}
frame.DataLength = uint16(b1) + uint16(b2)<<8
}
if frame.DataLength == 0 {

View File

@@ -47,7 +47,7 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) {
}
// Connection ID
header.ConnectionID, err = readUint64(b, header.ConnectionIDLength)
header.ConnectionID, err = readUintN(b, header.ConnectionIDLength)
if err != nil {
return nil, err
}
@@ -55,7 +55,7 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) {
// Version (optional)
if header.VersionFlag {
var v uint64
v, err = readUint64(b, 4)
v, err = readUintN(b, 4)
if err != nil {
return nil, err
}
@@ -63,7 +63,7 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) {
}
// Packet number
header.PacketNumber, err = readUint64(b, header.PacketNumberLength)
header.PacketNumber, err = readUintN(b, header.PacketNumberLength)
if err != nil {
return nil, err
}

View File

@@ -2,7 +2,7 @@ package quic
import "io"
func readUint64(b io.ByteReader, length uint8) (uint64, error) {
func readUintN(b io.ByteReader, length uint8) (uint64, error) {
var res uint64
for i := uint8(0); i < length; i++ {
bt, err := b.ReadByte()
@@ -13,3 +13,33 @@ func readUint64(b io.ByteReader, length uint8) (uint64, error) {
}
return res, nil
}
func readUint32(b io.ByteReader) (uint32, error) {
var b1, b2, b3, b4 uint8
var err error
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
}
func readUint16(b io.ByteReader) (uint16, error) {
var b1, b2 uint8
var err error
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
return uint16(b1) + uint16(b2)<<8, nil
}