forked from quic-go/quic-go
refactor how padding is added in the packet packer
This commit is contained in:
@@ -419,61 +419,49 @@ func (p *packetPacker) writeAndSealPacket(
|
||||
payload payload,
|
||||
encLevel protocol.EncryptionLevel,
|
||||
sealer handshake.Sealer,
|
||||
) (*packedPacket, error) {
|
||||
var paddingLen protocol.ByteCount
|
||||
pnLen := protocol.ByteCount(header.PacketNumberLen)
|
||||
|
||||
if encLevel != protocol.Encryption1RTT {
|
||||
if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
|
||||
header.Token = p.token
|
||||
headerLen := header.GetLength(p.version)
|
||||
header.Length = pnLen + protocol.MinInitialPacketSize - headerLen
|
||||
paddingLen = protocol.ByteCount(protocol.MinInitialPacketSize-sealer.Overhead()) - headerLen - payload.length
|
||||
} else {
|
||||
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length
|
||||
}
|
||||
} else if payload.length < 4-pnLen {
|
||||
paddingLen = 4 - pnLen - payload.length
|
||||
}
|
||||
return p.writeAndSealPacketWithPadding(header, payload, paddingLen, encLevel, sealer)
|
||||
}
|
||||
|
||||
func (p *packetPacker) writeAndSealPacketWithPadding(
|
||||
header *wire.ExtendedHeader,
|
||||
payload payload,
|
||||
paddingLen protocol.ByteCount,
|
||||
encLevel protocol.EncryptionLevel,
|
||||
sealer handshake.Sealer,
|
||||
) (*packedPacket, error) {
|
||||
packetBuffer := getPacketBuffer()
|
||||
buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
|
||||
frames := payload.frames
|
||||
|
||||
addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
|
||||
|
||||
if header.IsLongHeader {
|
||||
if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
|
||||
header.Token = p.token
|
||||
}
|
||||
if addPaddingForInitial {
|
||||
headerLen := header.GetLength(p.version)
|
||||
header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen
|
||||
} else {
|
||||
// long header packets always use 4 byte packet number, so we never need to pad short payloads
|
||||
header.Length = protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen) + payload.length
|
||||
}
|
||||
}
|
||||
|
||||
if err := header.Write(buffer, p.version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadOffset := buffer.Len()
|
||||
|
||||
// write all frames but the last one
|
||||
for _, frame := range frames[:len(frames)-1] {
|
||||
if paddingLen > 0 {
|
||||
buffer.Write(bytes.Repeat([]byte{0}, int(paddingLen)))
|
||||
}
|
||||
for _, frame := range frames {
|
||||
if err := frame.Write(buffer, p.version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
lastFrame := frames[len(frames)-1]
|
||||
if addPaddingForInitial {
|
||||
// when appending padding, we need to make sure that the last STREAM frames has the data length set
|
||||
if sf, ok := lastFrame.(*wire.StreamFrame); ok {
|
||||
sf.DataLenPresent = true
|
||||
}
|
||||
} else {
|
||||
payloadLen := buffer.Len() - payloadOffset + int(lastFrame.Length(p.version))
|
||||
if paddingLen := 4 - int(header.PacketNumberLen) - payloadLen; paddingLen > 0 {
|
||||
// Pad the packet such that packet number length + payload length is 4 bytes.
|
||||
// This is needed to enable the peer to get a 16 byte sample for header protection.
|
||||
buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
|
||||
}
|
||||
}
|
||||
if err := lastFrame.Write(buffer, p.version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if addPaddingForInitial {
|
||||
paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len()
|
||||
if paddingLen > 0 {
|
||||
buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
|
||||
}
|
||||
}
|
||||
|
||||
if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize {
|
||||
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
|
||||
|
||||
@@ -34,7 +34,7 @@ var _ = Describe("Packet packer", func() {
|
||||
r := bytes.NewReader(data)
|
||||
extHdr, err := hdr.ParseExtended(r, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(0, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen)))
|
||||
ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen)))
|
||||
}
|
||||
|
||||
appendFrames := func(fs, frames []wire.Frame) ([]wire.Frame, protocol.ByteCount) {
|
||||
@@ -823,7 +823,7 @@ var _ = Describe("Packet packer", func() {
|
||||
firstPayloadByte, err := r.ReadByte()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(firstPayloadByte).To(Equal(byte(0)))
|
||||
// ... followed by the stream frame
|
||||
// ... followed by the STREAM frame
|
||||
frameParser := wire.NewFrameParser(packer.version)
|
||||
frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -865,23 +865,22 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
Context("retransmitions", func() {
|
||||
sf := &wire.StreamFrame{Data: []byte("foobar")}
|
||||
cf := &wire.CryptoFrame{Data: []byte("foo")}
|
||||
|
||||
It("packs a retransmission with the right encryption level", func() {
|
||||
f := &wire.CryptoFrame{Data: []byte("foo")}
|
||||
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil)
|
||||
packet := &ackhandler.Packet{
|
||||
PacketType: protocol.PacketTypeHandshake,
|
||||
EncryptionLevel: protocol.EncryptionInitial,
|
||||
Frames: []wire.Frame{f},
|
||||
Frames: []wire.Frame{cf},
|
||||
}
|
||||
p, err := packer.PackRetransmission(packet)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).To(HaveLen(1))
|
||||
Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(p[0].frames).To(Equal([]wire.Frame{f}))
|
||||
Expect(p[0].frames).To(Equal([]wire.Frame{cf}))
|
||||
Expect(p[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
|
||||
})
|
||||
|
||||
@@ -895,13 +894,13 @@ var _ = Describe("Packet packer", func() {
|
||||
packet := &ackhandler.Packet{
|
||||
PacketType: protocol.PacketTypeInitial,
|
||||
EncryptionLevel: protocol.EncryptionInitial,
|
||||
Frames: []wire.Frame{sf},
|
||||
Frames: []wire.Frame{cf},
|
||||
}
|
||||
packets, err := packer.PackRetransmission(packet)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packets).To(HaveLen(1))
|
||||
p := packets[0]
|
||||
Expect(p.frames).To(Equal([]wire.Frame{sf}))
|
||||
Expect(p.frames).To(Equal([]wire.Frame{cf}))
|
||||
Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
|
||||
Expect(p.header.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(p.header.Token).To(Equal(token))
|
||||
|
||||
Reference in New Issue
Block a user