pad small packets, such that len(packet number) + len(payload) >= 4

This commit is contained in:
Marten Seemann
2018-11-28 13:42:47 +07:00
parent d981364ec6
commit cf957bb3d0
2 changed files with 63 additions and 14 deletions

View File

@@ -392,23 +392,24 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Extend
}
func (p *packetPacker) writeAndSealPacket(
header *wire.ExtendedHeader, frames []wire.Frame,
header *wire.ExtendedHeader,
frames []wire.Frame,
sealer handshake.Sealer,
) ([]byte, error) {
raw := *getPacketBuffer()
buffer := bytes.NewBuffer(raw[:0])
addPadding := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
// the length is only needed for Long Headers
if header.IsLongHeader {
if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
header.Token = p.token
}
if addPadding {
if addPaddingForInitial {
headerLen := header.GetLength(p.version)
header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen
} else {
// long header packets always use 4 byte packet number, so we never need to pad short payloads
length := protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen)
for _, frame := range frames {
length += frame.Length(p.version)
@@ -422,19 +423,31 @@ func (p *packetPacker) writeAndSealPacket(
}
payloadStartIndex := buffer.Len()
// the Initial packet needs to be padded, so the last STREAM frame must have the data length present
if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
lastFrame := frames[len(frames)-1]
if sf, ok := lastFrame.(*wire.StreamFrame); ok {
sf.DataLenPresent = true
}
}
for _, frame := range frames {
// write all frames but the last one
for _, frame := range frames[:len(frames)-1] {
if err := frame.Write(buffer, p.version); err != nil {
return nil, err
}
}
if addPadding {
lastFrame := frames[len(frames)-1]
if addPaddingForInitial {
// when appending padding, we need to make sure that the last STREAM frames has the data length set
if sf, ok := lastFrame.(*wire.StreamFrame); ok {
sf.DataLenPresent = true
}
} else {
payloadLen := buffer.Len() - payloadStartIndex + int(lastFrame.Length(p.version))
if paddingLen := 4 - int(header.PacketNumberLen) - payloadLen; paddingLen > 0 {
// Pad the packet such that packet number length + payload length is 4 bytes.
// This is needed to enable the peer to get a 16 byte sample for header protection.
buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
}
}
if err := lastFrame.Write(buffer, p.version); err != nil {
return nil, err
}
if addPaddingForInitial {
paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len()
if paddingLen > 0 {
buffer.Write(bytes.Repeat([]byte{0}, paddingLen))

View File

@@ -69,7 +69,7 @@ var _ = Describe("Packet packer", func() {
sealer = mocks.NewMockSealer(mockCtrl)
sealer.EXPECT().Overhead().Return(7).AnyTimes()
sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte {
return append(src, bytes.Repeat([]byte{0}, 7)...)
return append(src, bytes.Repeat([]byte{0}, sealer.Overhead())...)
}).AnyTimes()
token = []byte("initial token")
@@ -711,6 +711,42 @@ var _ = Describe("Packet packer", func() {
Expect(cf.Data).To(Equal([]byte("foobar")))
})
It("pads if payload length + packet number length is smaller than 4", func() {
f := &wire.StreamFrame{
StreamID: 0x10, // small stream ID, such that only a single byte is consumed
FinBit: true,
}
Expect(f.Length(packer.version)).To(BeEquivalentTo(2))
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1)
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
ackFramer.EXPECT().GetAckFrame()
initialStream.EXPECT().HasData()
handshakeStream.EXPECT().HasData()
framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any())
framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Return([]wire.Frame{f})
packet, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
// cut off the tag that the mock sealer added
packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()]
hdr, err := wire.ParseHeader(bytes.NewReader(packet.raw), len(packer.destConnID))
Expect(err).ToNot(HaveOccurred())
r := bytes.NewReader(packet.raw)
extHdr, err := hdr.ParseExtended(r, packer.version)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1))
Expect(r.Len()).To(Equal(4 - 1 /* packet number length */))
// the first byte of the payload should be a PADDING frame...
firstPayloadByte, err := r.ReadByte()
Expect(err).ToNot(HaveOccurred())
Expect(firstPayloadByte).To(Equal(byte(0)))
// ... followed by the stream frame
frame, err := wire.ParseNextFrame(r, packer.version)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
Expect(r.Len()).To(BeZero())
})
It("sets the correct length for an Initial packet", func() {
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))