forked from quic-go/quic-go
calculate PublicHeader length for packets with VersionFlag and ResetFlag
This commit is contained in:
@@ -74,7 +74,7 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea
|
||||
DiversificationNonce: p.cryptoSetup.DiversificationNonce(),
|
||||
}
|
||||
|
||||
publicHeaderLength, err := responsePublicHeader.GetLength()
|
||||
publicHeaderLength, err := responsePublicHeader.GetLength(p.perspective)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -11,11 +11,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set")
|
||||
errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time")
|
||||
errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported")
|
||||
errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0")
|
||||
errGetLengthOnlyForRegularPackets = errors.New("PublicHeader: GetLength can only be called for regular packets")
|
||||
errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set")
|
||||
errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time")
|
||||
errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported")
|
||||
errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0")
|
||||
errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets")
|
||||
)
|
||||
|
||||
// The PublicHeader of a QUIC packet
|
||||
@@ -192,20 +192,35 @@ func ParsePublicHeader(b io.ByteReader, packetSentBy protocol.Perspective) (*Pub
|
||||
|
||||
// GetLength gets the length of the publicHeader in bytes
|
||||
// can only be called for regular packets
|
||||
func (h *PublicHeader) GetLength() (protocol.ByteCount, error) {
|
||||
if h.VersionFlag || h.ResetFlag {
|
||||
return 0, errGetLengthOnlyForRegularPackets
|
||||
func (h *PublicHeader) GetLength(pers protocol.Perspective) (protocol.ByteCount, error) {
|
||||
if h.VersionFlag && h.ResetFlag {
|
||||
return 0, errResetAndVersionFlagSet
|
||||
}
|
||||
|
||||
if h.VersionFlag && pers == protocol.PerspectiveServer {
|
||||
return 0, errGetLengthNotForVersionNegotiation
|
||||
}
|
||||
|
||||
length := protocol.ByteCount(1) // 1 byte for public flags
|
||||
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 {
|
||||
return 0, errPacketNumberLenNotSet
|
||||
|
||||
if h.hasPacketNumber(pers) {
|
||||
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 {
|
||||
return 0, errPacketNumberLenNotSet
|
||||
}
|
||||
length += protocol.ByteCount(h.PacketNumberLen)
|
||||
}
|
||||
|
||||
if !h.TruncateConnectionID {
|
||||
length += 8 // 8 bytes for the connection ID
|
||||
}
|
||||
|
||||
// Version Number in packets sent by the client
|
||||
if h.VersionFlag {
|
||||
length += 4
|
||||
}
|
||||
|
||||
length += protocol.ByteCount(len(h.DiversificationNonce))
|
||||
length += protocol.ByteCount(h.PacketNumberLen)
|
||||
|
||||
return length, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -269,14 +269,17 @@ var _ = Describe("Public Header", func() {
|
||||
Context("GetLength", func() {
|
||||
It("errors when calling GetLength for Version Negotiation packets", func() {
|
||||
hdr := PublicHeader{VersionFlag: true}
|
||||
_, err := hdr.GetLength()
|
||||
Expect(err).To(MatchError(errGetLengthOnlyForRegularPackets))
|
||||
_, err := hdr.GetLength(protocol.PerspectiveServer)
|
||||
Expect(err).To(MatchError(errGetLengthNotForVersionNegotiation))
|
||||
})
|
||||
|
||||
It("errors when calling GetLength for Public Reset packets", func() {
|
||||
hdr := PublicHeader{ResetFlag: true}
|
||||
_, err := hdr.GetLength()
|
||||
Expect(err).To(MatchError(errGetLengthOnlyForRegularPackets))
|
||||
It("errors when calling GetLength for packets that have the VersionFlag and the ResetFlag set", func() {
|
||||
hdr := PublicHeader{
|
||||
ResetFlag: true,
|
||||
VersionFlag: true,
|
||||
}
|
||||
_, err := hdr.GetLength(protocol.PerspectiveServer)
|
||||
Expect(err).To(MatchError(errResetAndVersionFlagSet))
|
||||
})
|
||||
|
||||
It("errors when PacketNumberLen is not set", func() {
|
||||
@@ -284,7 +287,7 @@ var _ = Describe("Public Header", func() {
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 0xDECAFBAD,
|
||||
}
|
||||
_, err := hdr.GetLength()
|
||||
_, err := hdr.GetLength(protocol.PerspectiveServer)
|
||||
Expect(err).To(MatchError(errPacketNumberLenNotSet))
|
||||
})
|
||||
|
||||
@@ -294,11 +297,25 @@ var _ = Describe("Public Header", func() {
|
||||
PacketNumber: 0xDECAFBAD,
|
||||
PacketNumberLen: protocol.PacketNumberLen6,
|
||||
}
|
||||
length, err := hdr.GetLength()
|
||||
length, err := hdr.GetLength(protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 6))) // 1 byte public flag, 8 bytes connectionID, and packet number
|
||||
})
|
||||
|
||||
It("gets the lengths of a packet sent by the client with the VersionFlag set", func() {
|
||||
hdr := PublicHeader{
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
TruncateConnectionID: true,
|
||||
PacketNumber: 0xDECAFBAD,
|
||||
PacketNumberLen: protocol.PacketNumberLen6,
|
||||
VersionFlag: true,
|
||||
VersionNumber: protocol.Version36,
|
||||
}
|
||||
length, err := hdr.GetLength(protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(length).To(Equal(protocol.ByteCount(1 + 4 + 6))) // 1 byte public flag, 4 version number, and packet number
|
||||
})
|
||||
|
||||
It("gets the length of a packet with longest packet number length and truncated connectionID", func() {
|
||||
hdr := PublicHeader{
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
@@ -306,7 +323,7 @@ var _ = Describe("Public Header", func() {
|
||||
PacketNumber: 0xDECAFBAD,
|
||||
PacketNumberLen: protocol.PacketNumberLen6,
|
||||
}
|
||||
length, err := hdr.GetLength()
|
||||
length, err := hdr.GetLength(protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(length).To(Equal(protocol.ByteCount(1 + 6))) // 1 byte public flag, and packet number
|
||||
})
|
||||
@@ -317,7 +334,7 @@ var _ = Describe("Public Header", func() {
|
||||
PacketNumber: 0xDECAFBAD,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}
|
||||
length, err := hdr.GetLength()
|
||||
length, err := hdr.GetLength(protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 2))) // 1 byte public flag, 8 byte connectionID, and packet number
|
||||
})
|
||||
@@ -327,9 +344,19 @@ var _ = Describe("Public Header", func() {
|
||||
DiversificationNonce: []byte("foo"),
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
length, err := hdr.GetLength()
|
||||
length, err := hdr.GetLength(protocol.PerspectiveServer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 3 + 1)))
|
||||
Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 3 + 1))) // 1 byte public flag, 8 byte connectionID, 3 byte DiversificationNonce, 1 byte PacketNumber
|
||||
})
|
||||
|
||||
It("gets the length of a PublicReset", func() {
|
||||
hdr := PublicHeader{
|
||||
ResetFlag: true,
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
}
|
||||
length, err := hdr.GetLength(protocol.PerspectiveServer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(length).To(Equal(protocol.ByteCount(1 + 8))) // 1 byte public flag, 8 byte connectionID
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user