refactor how padding is added in the packet packer

This commit is contained in:
Marten Seemann
2019-04-22 11:50:19 +09:00
parent 109bb3fe62
commit 3d22d56ed8
2 changed files with 36 additions and 49 deletions

View File

@@ -419,61 +419,49 @@ func (p *packetPacker) writeAndSealPacket(
payload payload,
encLevel protocol.EncryptionLevel,
sealer handshake.Sealer,
) (*packedPacket, error) {
var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(header.PacketNumberLen)
if encLevel != protocol.Encryption1RTT {
if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
header.Token = p.token
headerLen := header.GetLength(p.version)
header.Length = pnLen + protocol.MinInitialPacketSize - headerLen
paddingLen = protocol.ByteCount(protocol.MinInitialPacketSize-sealer.Overhead()) - headerLen - payload.length
} else {
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length
}
} else if payload.length < 4-pnLen {
paddingLen = 4 - pnLen - payload.length
}
return p.writeAndSealPacketWithPadding(header, payload, paddingLen, encLevel, sealer)
}
func (p *packetPacker) writeAndSealPacketWithPadding(
header *wire.ExtendedHeader,
payload payload,
paddingLen protocol.ByteCount,
encLevel protocol.EncryptionLevel,
sealer handshake.Sealer,
) (*packedPacket, error) {
packetBuffer := getPacketBuffer()
buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
frames := payload.frames
addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
if header.IsLongHeader {
if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
header.Token = p.token
}
if addPaddingForInitial {
headerLen := header.GetLength(p.version)
header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen
} else {
// long header packets always use 4 byte packet number, so we never need to pad short payloads
header.Length = protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen) + payload.length
}
}
if err := header.Write(buffer, p.version); err != nil {
return nil, err
}
payloadOffset := buffer.Len()
// write all frames but the last one
for _, frame := range frames[:len(frames)-1] {
if paddingLen > 0 {
buffer.Write(bytes.Repeat([]byte{0}, int(paddingLen)))
}
for _, frame := range frames {
if err := frame.Write(buffer, p.version); err != nil {
return nil, err
}
}
lastFrame := frames[len(frames)-1]
if addPaddingForInitial {
// when appending padding, we need to make sure that the last STREAM frames has the data length set
if sf, ok := lastFrame.(*wire.StreamFrame); ok {
sf.DataLenPresent = true
}
} else {
payloadLen := buffer.Len() - payloadOffset + int(lastFrame.Length(p.version))
if paddingLen := 4 - int(header.PacketNumberLen) - payloadLen; paddingLen > 0 {
// Pad the packet such that packet number length + payload length is 4 bytes.
// This is needed to enable the peer to get a 16 byte sample for header protection.
buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
}
}
if err := lastFrame.Write(buffer, p.version); err != nil {
return nil, err
}
if addPaddingForInitial {
paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len()
if paddingLen > 0 {
buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
}
}
if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize {
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)