diff --git a/packet_packer.go b/packet_packer.go index 7c6785a7..291b64ac 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -392,23 +392,24 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Extend } func (p *packetPacker) writeAndSealPacket( - header *wire.ExtendedHeader, frames []wire.Frame, + header *wire.ExtendedHeader, + frames []wire.Frame, sealer handshake.Sealer, ) ([]byte, error) { raw := *getPacketBuffer() buffer := bytes.NewBuffer(raw[:0]) - addPadding := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial + addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial - // the length is only needed for Long Headers if header.IsLongHeader { if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { header.Token = p.token } - if addPadding { + if addPaddingForInitial { headerLen := header.GetLength(p.version) header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen } else { + // long header packets always use 4 byte packet number, so we never need to pad short payloads length := protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen) for _, frame := range frames { length += frame.Length(p.version) @@ -422,19 +423,31 @@ func (p *packetPacker) writeAndSealPacket( } payloadStartIndex := buffer.Len() - // the Initial packet needs to be padded, so the last STREAM frame must have the data length present - if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { - lastFrame := frames[len(frames)-1] - if sf, ok := lastFrame.(*wire.StreamFrame); ok { - sf.DataLenPresent = true - } - } - for _, frame := range frames { + // write all frames but the last one + for _, frame := range frames[:len(frames)-1] { if err := frame.Write(buffer, p.version); err != nil { return nil, err } } - if addPadding { + lastFrame := frames[len(frames)-1] + if addPaddingForInitial { + // when appending padding, we need to make sure that the last STREAM frames has the data length set + if sf, ok := lastFrame.(*wire.StreamFrame); ok { + sf.DataLenPresent = true + } + } else { + payloadLen := buffer.Len() - payloadStartIndex + int(lastFrame.Length(p.version)) + if paddingLen := 4 - int(header.PacketNumberLen) - payloadLen; paddingLen > 0 { + // Pad the packet such that packet number length + payload length is 4 bytes. + // This is needed to enable the peer to get a 16 byte sample for header protection. + buffer.Write(bytes.Repeat([]byte{0}, paddingLen)) + } + } + if err := lastFrame.Write(buffer, p.version); err != nil { + return nil, err + } + + if addPaddingForInitial { paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len() if paddingLen > 0 { buffer.Write(bytes.Repeat([]byte{0}, paddingLen)) diff --git a/packet_packer_test.go b/packet_packer_test.go index 141edcec..267eb319 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -69,7 +69,7 @@ var _ = Describe("Packet packer", func() { sealer = mocks.NewMockSealer(mockCtrl) sealer.EXPECT().Overhead().Return(7).AnyTimes() sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte { - return append(src, bytes.Repeat([]byte{0}, 7)...) + return append(src, bytes.Repeat([]byte{0}, sealer.Overhead())...) }).AnyTimes() token = []byte("initial token") @@ -711,6 +711,42 @@ var _ = Describe("Packet packer", func() { Expect(cf.Data).To(Equal([]byte("foobar"))) }) + It("pads if payload length + packet number length is smaller than 4", func() { + f := &wire.StreamFrame{ + StreamID: 0x10, // small stream ID, such that only a single byte is consumed + FinBit: true, + } + Expect(f.Length(packer.version)).To(BeEquivalentTo(2)) + pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ackFramer.EXPECT().GetAckFrame() + initialStream.EXPECT().HasData() + handshakeStream.EXPECT().HasData() + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()) + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Return([]wire.Frame{f}) + packet, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + // cut off the tag that the mock sealer added + packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()] + hdr, err := wire.ParseHeader(bytes.NewReader(packet.raw), len(packer.destConnID)) + Expect(err).ToNot(HaveOccurred()) + r := bytes.NewReader(packet.raw) + extHdr, err := hdr.ParseExtended(r, packer.version) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) + Expect(r.Len()).To(Equal(4 - 1 /* packet number length */)) + // the first byte of the payload should be a PADDING frame... + firstPayloadByte, err := r.ReadByte() + Expect(err).ToNot(HaveOccurred()) + Expect(firstPayloadByte).To(Equal(byte(0))) + // ... followed by the stream frame + frame, err := wire.ParseNextFrame(r, packer.version) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + Expect(r.Len()).To(BeZero()) + }) + It("sets the correct length for an Initial packet", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))