diff --git a/packet_packer.go b/packet_packer.go index 69684556..6f0dca02 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -419,61 +419,49 @@ func (p *packetPacker) writeAndSealPacket( payload payload, encLevel protocol.EncryptionLevel, sealer handshake.Sealer, +) (*packedPacket, error) { + var paddingLen protocol.ByteCount + pnLen := protocol.ByteCount(header.PacketNumberLen) + + if encLevel != protocol.Encryption1RTT { + if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { + header.Token = p.token + headerLen := header.GetLength(p.version) + header.Length = pnLen + protocol.MinInitialPacketSize - headerLen + paddingLen = protocol.ByteCount(protocol.MinInitialPacketSize-sealer.Overhead()) - headerLen - payload.length + } else { + header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + } + } else if payload.length < 4-pnLen { + paddingLen = 4 - pnLen - payload.length + } + return p.writeAndSealPacketWithPadding(header, payload, paddingLen, encLevel, sealer) +} + +func (p *packetPacker) writeAndSealPacketWithPadding( + header *wire.ExtendedHeader, + payload payload, + paddingLen protocol.ByteCount, + encLevel protocol.EncryptionLevel, + sealer handshake.Sealer, ) (*packedPacket, error) { packetBuffer := getPacketBuffer() buffer := bytes.NewBuffer(packetBuffer.Slice[:0]) frames := payload.frames - addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial - - if header.IsLongHeader { - if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { - header.Token = p.token - } - if addPaddingForInitial { - headerLen := header.GetLength(p.version) - header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen - } else { - // long header packets always use 4 byte packet number, so we never need to pad short payloads - header.Length = protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen) + payload.length - } - } - if err := header.Write(buffer, p.version); err != nil { return nil, err } payloadOffset := buffer.Len() - // write all frames but the last one - for _, frame := range frames[:len(frames)-1] { + if paddingLen > 0 { + buffer.Write(bytes.Repeat([]byte{0}, int(paddingLen))) + } + for _, frame := range frames { if err := frame.Write(buffer, p.version); err != nil { return nil, err } } - lastFrame := frames[len(frames)-1] - if addPaddingForInitial { - // when appending padding, we need to make sure that the last STREAM frames has the data length set - if sf, ok := lastFrame.(*wire.StreamFrame); ok { - sf.DataLenPresent = true - } - } else { - payloadLen := buffer.Len() - payloadOffset + int(lastFrame.Length(p.version)) - if paddingLen := 4 - int(header.PacketNumberLen) - payloadLen; paddingLen > 0 { - // Pad the packet such that packet number length + payload length is 4 bytes. - // This is needed to enable the peer to get a 16 byte sample for header protection. - buffer.Write(bytes.Repeat([]byte{0}, paddingLen)) - } - } - if err := lastFrame.Write(buffer, p.version); err != nil { - return nil, err - } - - if addPaddingForInitial { - paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len() - if paddingLen > 0 { - buffer.Write(bytes.Repeat([]byte{0}, paddingLen)) - } - } if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize { return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) diff --git a/packet_packer_test.go b/packet_packer_test.go index 222f4ccb..59fa2fb8 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -34,7 +34,7 @@ var _ = Describe("Packet packer", func() { r := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(r, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(0, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen))) + ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen))) } appendFrames := func(fs, frames []wire.Frame) ([]wire.Frame, protocol.ByteCount) { @@ -823,7 +823,7 @@ var _ = Describe("Packet packer", func() { firstPayloadByte, err := r.ReadByte() Expect(err).ToNot(HaveOccurred()) Expect(firstPayloadByte).To(Equal(byte(0))) - // ... followed by the stream frame + // ... followed by the STREAM frame frameParser := wire.NewFrameParser(packer.version) frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -865,23 +865,22 @@ var _ = Describe("Packet packer", func() { }) Context("retransmitions", func() { - sf := &wire.StreamFrame{Data: []byte("foobar")} + cf := &wire.CryptoFrame{Data: []byte("foo")} It("packs a retransmission with the right encryption level", func() { - f := &wire.CryptoFrame{Data: []byte("foo")} pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) packet := &ackhandler.Packet{ PacketType: protocol.PacketTypeHandshake, EncryptionLevel: protocol.EncryptionInitial, - Frames: []wire.Frame{f}, + Frames: []wire.Frame{cf}, } p, err := packer.PackRetransmission(packet) Expect(err).ToNot(HaveOccurred()) Expect(p).To(HaveLen(1)) Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(p[0].frames).To(Equal([]wire.Frame{f})) + Expect(p[0].frames).To(Equal([]wire.Frame{cf})) Expect(p[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) }) @@ -895,13 +894,13 @@ var _ = Describe("Packet packer", func() { packet := &ackhandler.Packet{ PacketType: protocol.PacketTypeInitial, EncryptionLevel: protocol.EncryptionInitial, - Frames: []wire.Frame{sf}, + Frames: []wire.Frame{cf}, } packets, err := packer.PackRetransmission(packet) Expect(err).ToNot(HaveOccurred()) Expect(packets).To(HaveLen(1)) p := packets[0] - Expect(p.frames).To(Equal([]wire.Frame{sf})) + Expect(p.frames).To(Equal([]wire.Frame{cf})) Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.header.Type).To(Equal(protocol.PacketTypeInitial)) Expect(p.header.Token).To(Equal(token))