From 5d999f39277ddd1637a858e11b8b399fe13bef15 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 21 Apr 2019 20:37:18 +0900 Subject: [PATCH 1/4] handle ACKs separately in the sent packet handler packet struct --- internal/ackhandler/ack_eliciting.go | 11 ----------- internal/ackhandler/ack_eliciting_test.go | 9 --------- internal/ackhandler/packet.go | 1 + internal/ackhandler/sent_packet_handler.go | 10 ++++------ .../ackhandler/sent_packet_handler_test.go | 9 ++++----- packet_packer.go | 18 ++++++++++++++++-- 6 files changed, 25 insertions(+), 33 deletions(-) diff --git a/internal/ackhandler/ack_eliciting.go b/internal/ackhandler/ack_eliciting.go index 0e8af8616..23940d982 100644 --- a/internal/ackhandler/ack_eliciting.go +++ b/internal/ackhandler/ack_eliciting.go @@ -2,17 +2,6 @@ package ackhandler import "github.com/lucas-clemente/quic-go/internal/wire" -// Returns a new slice with all non-ack-eliciting frames deleted. -func stripNonAckElicitingFrames(fs []wire.Frame) []wire.Frame { - res := make([]wire.Frame, 0, len(fs)) - for _, f := range fs { - if IsFrameAckEliciting(f) { - res = append(res, f) - } - } - return res -} - // IsFrameAckEliciting returns true if the frame is ack-eliciting. func IsFrameAckEliciting(f wire.Frame) bool { switch f.(type) { diff --git a/internal/ackhandler/ack_eliciting_test.go b/internal/ackhandler/ack_eliciting_test.go index 3236a496a..65263af68 100644 --- a/internal/ackhandler/ack_eliciting_test.go +++ b/internal/ackhandler/ack_eliciting_test.go @@ -27,15 +27,6 @@ var _ = Describe("ack-eliciting frames", func() { Expect(IsFrameAckEliciting(f)).To(Equal(e)) }) - It("stripping non-ack-elicinting frames works for "+fName, func() { - s := []wire.Frame{f} - if e { - Expect(stripNonAckElicitingFrames(s)).To(Equal([]wire.Frame{f})) - } else { - Expect(stripNonAckElicitingFrames(s)).To(BeEmpty()) - } - }) - It("HasAckElicitingFrames works for "+fName, func() { Expect(HasAckElicitingFrames([]wire.Frame{f})).To(Equal(e)) }) diff --git a/internal/ackhandler/packet.go b/internal/ackhandler/packet.go index 9673a85c7..5dfa608ba 100644 --- a/internal/ackhandler/packet.go +++ b/internal/ackhandler/packet.go @@ -11,6 +11,7 @@ import ( type Packet struct { PacketNumber protocol.PacketNumber PacketType protocol.PacketType + Ack *wire.AckFrame Frames []wire.Frame Length protocol.ByteCount EncryptionLevel protocol.EncryptionLevel diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 9ca04c72e..7c3e21538 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -166,14 +166,12 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-elicit pnSpace.largestSent = packet.PacketNumber - if len(packet.Frames) > 0 { - if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok { - packet.largestAcked = ackFrame.LargestAcked() - } + if packet.Ack != nil { + packet.largestAcked = packet.Ack.LargestAcked() } + packet.Ack = nil // no need to save the ACK - packet.Frames = stripNonAckElicitingFrames(packet.Frames) - isAckEliciting := len(packet.Frames) != 0 + isAckEliciting := len(packet.Frames) > 0 if isAckEliciting { if packet.EncryptionLevel != protocol.Encryption1RTT { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index bb0c7fc83..1c2aeb5ad 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -29,9 +29,8 @@ func ackElicitingPacket(p *Packet) *Packet { func nonAckElicitingPacket(p *Packet) *Packet { p = ackElicitingPacket(p) - p.Frames = []wire.Frame{ - &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, - } + p.Frames = nil + p.Ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} return p } @@ -313,8 +312,8 @@ var _ = Describe("SentPacketHandler", func() { ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 80, Largest: 100}}} ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 50, Largest: 200}}} morePackets := []*Packet{ - {PacketNumber: 13, Frames: []wire.Frame{ack1, &streamFrame}, Length: 1, EncryptionLevel: protocol.Encryption1RTT}, - {PacketNumber: 14, Frames: []wire.Frame{ack2, &streamFrame}, Length: 1, EncryptionLevel: protocol.Encryption1RTT}, + {PacketNumber: 13, Ack: ack1, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.Encryption1RTT}, + {PacketNumber: 14, Ack: ack2, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.Encryption1RTT}, {PacketNumber: 15, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.Encryption1RTT}, } for _, packet := range morePackets { diff --git a/packet_packer.go b/packet_packer.go index 71a515778..bc994b540 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -52,10 +52,24 @@ func (p *packedPacket) IsAckEliciting() bool { } func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { + var frames []wire.Frame + var ack *wire.AckFrame + if len(p.frames) > 0 { + var ok bool + ack, ok = p.frames[0].(*wire.AckFrame) + if ok { + // make a copy, so that the ACK can be garbage collected + frames = make([]wire.Frame, len(p.frames)-1) + copy(frames, p.frames[1:]) + } else { + frames = p.frames + } + } return &ackhandler.Packet{ PacketNumber: p.header.PacketNumber, PacketType: p.header.Type, - Frames: p.frames, + Ack: ack, + Frames: frames, Length: protocol.ByteCount(len(p.raw)), EncryptionLevel: p.EncryptionLevel(), SendTime: time.Now(), @@ -330,7 +344,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wir var length protocol.ByteCount var frames []wire.Frame - // ACKs need to go first, so that the sentPacketHandler will recognize them + // ACKs need to go first, so we recognize them in packedPacket.ToAckHandlerPacket() if ack := p.acks.GetAckFrame(protocol.Encryption1RTT); ack != nil { frames = append(frames, ack) length += ack.Length(p.version) From 109bb3fe628e64bb2b62ad631190bc84a7c873ae Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 22 Apr 2019 10:04:36 +0900 Subject: [PATCH 2/4] pass the length of the packet being packet around in the packet packer --- framer.go | 6 ++--- framer_test.go | 46 ++++++++++++++++++++++----------- mock_frame_source_test.go | 5 ++-- packet_packer.go | 54 ++++++++++++++++++++++++--------------- packet_packer_test.go | 24 +++++++++-------- 5 files changed, 84 insertions(+), 51 deletions(-) diff --git a/framer.go b/framer.go index fbfe9bb76..d5be6fc9a 100644 --- a/framer.go +++ b/framer.go @@ -12,7 +12,7 @@ type framer interface { AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) AddActiveStream(protocol.StreamID) - AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame + AppendStreamFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) } type framerI struct { @@ -73,7 +73,7 @@ func (f *framerI) AddActiveStream(id protocol.StreamID) { f.mutex.Unlock() } -func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { +func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { var length protocol.ByteCount f.mutex.Lock() // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet @@ -105,5 +105,5 @@ func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCo length += frame.Length(f.version) } f.mutex.Unlock() - return frames + return frames, length } diff --git a/framer_test.go b/framer_test.go index 214528d72..075bb6d26 100644 --- a/framer_test.go +++ b/framer_test.go @@ -85,8 +85,9 @@ var _ = Describe("Stream Framer", func() { } stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) - fs := framer.AppendStreamFrames(nil, 1000) + fs, length := framer.AppendStreamFrames(nil, 1000) Expect(fs).To(Equal([]wire.Frame{f})) + Expect(length).To(Equal(f.Length(version))) }) It("appends to a frame slice", func() { @@ -99,8 +100,9 @@ var _ = Describe("Stream Framer", func() { framer.AddActiveStream(id1) mdf := &wire.MaxDataFrame{ByteOffset: 1337} frames := []wire.Frame{mdf} - fs := framer.AppendStreamFrames(frames, 1000) + fs, length := framer.AppendStreamFrames(frames, 1000) Expect(fs).To(Equal([]wire.Frame{mdf, f})) + Expect(length).To(Equal(f.Length(version))) }) It("skips a stream that was reported active, but was completed shortly after", func() { @@ -113,7 +115,8 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) framer.AddActiveStream(id2) - Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f})) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(Equal([]wire.Frame{f})) }) It("skips a stream that was reported active, but doesn't have any data", func() { @@ -127,7 +130,8 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) framer.AddActiveStream(id2) - Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f})) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(Equal([]wire.Frame{f})) }) It("pops from a stream multiple times, if it has enough data", func() { @@ -137,10 +141,13 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true) stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) framer.AddActiveStream(id1) // only add it once - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f1})) - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f2})) + frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(Equal([]wire.Frame{f1})) + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(Equal([]wire.Frame{f2})) // no further calls to popStreamFrame, after popStreamFrame said there's no more data - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(BeNil()) + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(BeNil()) }) It("re-queues a stream at the end, if it has enough data", func() { @@ -155,11 +162,14 @@ var _ = Describe("Stream Framer", func() { framer.AddActiveStream(id1) // only add it once framer.AddActiveStream(id2) // first a frame from stream 1 - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f11})) + frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(Equal([]wire.Frame{f11})) // then a frame from stream 2 - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f2})) + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(Equal([]wire.Frame{f2})) // then another frame from stream 1 - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f12})) + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(Equal([]wire.Frame{f12})) }) It("only dequeues data from each stream once per packet", func() { @@ -172,7 +182,9 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, true) framer.AddActiveStream(id1) framer.AddActiveStream(id2) - Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f1, f2})) + frames, length := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(Equal([]wire.Frame{f1, f2})) + Expect(length).To(Equal(f1.Length(version) + f2.Length(version))) }) It("returns multiple normal frames in the order they were reported active", func() { @@ -184,7 +196,8 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) framer.AddActiveStream(id2) framer.AddActiveStream(id1) - Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f2, f1})) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(Equal([]wire.Frame{f2, f1})) }) It("only asks a stream for data once, even if it was reported active multiple times", func() { @@ -193,12 +206,14 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) // only one call to this function framer.AddActiveStream(id1) framer.AddActiveStream(id1) - Expect(framer.AppendStreamFrames(nil, 1000)).To(HaveLen(1)) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(HaveLen(1)) }) It("does not pop empty frames", func() { - fs := framer.AppendStreamFrames(nil, 500) + fs, length := framer.AppendStreamFrames(nil, 500) Expect(fs).To(BeEmpty()) + Expect(length).To(BeZero()) }) It("pops frames that have the minimum size", func() { @@ -222,8 +237,9 @@ var _ = Describe("Stream Framer", func() { } stream1.EXPECT().popStreamFrame(protocol.ByteCount(500)).Return(f, false) framer.AddActiveStream(id1) - fs := framer.AppendStreamFrames(nil, 500) + fs, length := framer.AppendStreamFrames(nil, 500) Expect(fs).To(Equal([]wire.Frame{f})) + Expect(length).To(Equal(f.Length(version))) }) }) }) diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go index e2f682a8a..676da023a 100644 --- a/mock_frame_source_test.go +++ b/mock_frame_source_test.go @@ -51,11 +51,12 @@ func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1 interface{ } // AppendStreamFrames mocks base method -func (m *MockFrameSource) AppendStreamFrames(arg0 []wire.Frame, arg1 protocol.ByteCount) []wire.Frame { +func (m *MockFrameSource) AppendStreamFrames(arg0 []wire.Frame, arg1 protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AppendStreamFrames", arg0, arg1) ret0, _ := ret[0].([]wire.Frame) - return ret0 + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 } // AppendStreamFrames indicates an expected call of AppendStreamFrames diff --git a/packet_packer.go b/packet_packer.go index bc994b540..69684556a 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -25,6 +25,11 @@ type packer interface { ChangeDestConnectionID(protocol.ConnectionID) } +type payload struct { + frames []wire.Frame + length protocol.ByteCount +} + type packedPacket struct { header *wire.ExtendedHeader raw []byte @@ -104,7 +109,7 @@ type sealingManager interface { } type frameSource interface { - AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame + AppendStreamFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) } @@ -165,10 +170,13 @@ func newPacketPacker( // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) { - frames := []wire.Frame{ccf} + payload := payload{ + frames: []wire.Frame{ccf}, + length: ccf.Length(p.version), + } encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) - return p.writeAndSealPacket(header, frames, encLevel, sealer) + return p.writeAndSealPacket(header, payload, encLevel, sealer) } func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { @@ -176,11 +184,14 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { if ack == nil { return nil, nil } + payload := payload{ + frames: []wire.Frame{ack}, + length: ack.Length(p.version), + } // TODO(#1534): only pack ACKs with the right encryption level encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) - frames := []wire.Frame{ack} - return p.writeAndSealPacket(header, frames, encLevel, sealer) + return p.writeAndSealPacket(header, payload, encLevel, sealer) } // PackRetransmission packs a retransmission @@ -247,7 +258,7 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok { sf.DataLenPresent = false } - p, err := p.writeAndSealPacket(header, frames, encLevel, sealer) + p, err := p.writeAndSealPacket(header, payload{frames: frames, length: length}, encLevel, sealer) if err != nil { return nil, err } @@ -275,19 +286,21 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { } maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen - frames, err := p.composeNextPacket(maxSize) + payload, err := p.composeNextPacket(maxSize) if err != nil { return nil, err } // Check if we have enough frames to send - if len(frames) == 0 { + if len(payload.frames) == 0 { return nil, nil } // check if this packet only contains an ACK - if !ackhandler.HasAckElicitingFrames(frames) { + if !ackhandler.HasAckElicitingFrames(payload.frames) { if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { - frames = append(frames, &wire.PingFrame{}) + ping := &wire.PingFrame{} + payload.frames = append(payload.frames, ping) + payload.length += ping.Length(p.version) p.numNonAckElicitingAcks = 0 } else { p.numNonAckElicitingAcks++ @@ -296,7 +309,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { p.numNonAckElicitingAcks = 0 } - return p.writeAndSealPacket(header, frames, encLevel, sealer) + return p.writeAndSealPacket(header, payload, encLevel, sealer) } func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { @@ -336,11 +349,12 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { if hasData { cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length) frames = append(frames, cf) + length += cf.Length(p.version) } - return p.writeAndSealPacket(hdr, frames, encLevel, sealer) + return p.writeAndSealPacket(hdr, payload{frames: frames, length: length}, encLevel, sealer) } -func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wire.Frame, error) { +func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (payload, error) { var length protocol.ByteCount var frames []wire.Frame @@ -360,14 +374,15 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wir // the length is encoded to either 1 or 2 bytes maxFrameSize++ - frames = p.framer.AppendStreamFrames(frames, maxFrameSize-length) + frames, lengthAdded = p.framer.AppendStreamFrames(frames, maxFrameSize-length) if len(frames) > 0 { lastFrame := frames[len(frames)-1] if sf, ok := lastFrame.(*wire.StreamFrame); ok { sf.DataLenPresent = false } + length += lengthAdded } - return frames, nil + return payload{frames: frames, length: length}, nil } func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { @@ -401,12 +416,13 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Extend func (p *packetPacker) writeAndSealPacket( header *wire.ExtendedHeader, - frames []wire.Frame, + payload payload, encLevel protocol.EncryptionLevel, sealer handshake.Sealer, ) (*packedPacket, error) { packetBuffer := getPacketBuffer() buffer := bytes.NewBuffer(packetBuffer.Slice[:0]) + frames := payload.frames addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial @@ -419,11 +435,7 @@ func (p *packetPacker) writeAndSealPacket( header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen } else { // long header packets always use 4 byte packet number, so we never need to pad short payloads - length := protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen) - for _, frame := range frames { - length += frame.Length(p.version) - } - header.Length = length + header.Length = protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen) + payload.length } } diff --git a/packet_packer_test.go b/packet_packer_test.go index fae6389a3..222f4ccbf 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -37,19 +37,23 @@ var _ = Describe("Packet packer", func() { ExpectWithOffset(0, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen))) } + appendFrames := func(fs, frames []wire.Frame) ([]wire.Frame, protocol.ByteCount) { + var length protocol.ByteCount + for _, f := range frames { + length += f.Length(packer.version) + } + return append(fs, frames...), length + } + expectAppendStreamFrames := func(frames ...wire.Frame) { - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) []wire.Frame { - return append(fs, frames...) + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + return appendFrames(fs, frames) }) } expectAppendControlFrames := func(frames ...wire.Frame) { framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { - var length protocol.ByteCount - for _, f := range frames { - length += f.Length(packer.version) - } - return append(fs, frames...), length + return appendFrames(fs, frames) }) } @@ -311,9 +315,9 @@ var _ = Describe("Packet packer", func() { maxSize = maxLen return fs, 444 }), - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Do(func(fs []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { Expect(maxLen).To(Equal(maxSize - 444 + 1 /* data length of the STREAM frame */)) - return nil + return fs, 0 }), ) _, err := packer.PackPacket() @@ -803,7 +807,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().HasData() handshakeStream.EXPECT().HasData() framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()) - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Return([]wire.Frame{f}) + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Return([]wire.Frame{f}, f.Length(packer.version)) packet, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added From 3d22d56ed806fc904a47267be3617420deba2604 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 22 Apr 2019 11:50:19 +0900 Subject: [PATCH 3/4] refactor how padding is added in the packet packer --- packet_packer.go | 70 ++++++++++++++++++------------------------- packet_packer_test.go | 15 +++++----- 2 files changed, 36 insertions(+), 49 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index 69684556a..6f0dca02b 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -419,61 +419,49 @@ func (p *packetPacker) writeAndSealPacket( payload payload, encLevel protocol.EncryptionLevel, sealer handshake.Sealer, +) (*packedPacket, error) { + var paddingLen protocol.ByteCount + pnLen := protocol.ByteCount(header.PacketNumberLen) + + if encLevel != protocol.Encryption1RTT { + if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { + header.Token = p.token + headerLen := header.GetLength(p.version) + header.Length = pnLen + protocol.MinInitialPacketSize - headerLen + paddingLen = protocol.ByteCount(protocol.MinInitialPacketSize-sealer.Overhead()) - headerLen - payload.length + } else { + header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + } + } else if payload.length < 4-pnLen { + paddingLen = 4 - pnLen - payload.length + } + return p.writeAndSealPacketWithPadding(header, payload, paddingLen, encLevel, sealer) +} + +func (p *packetPacker) writeAndSealPacketWithPadding( + header *wire.ExtendedHeader, + payload payload, + paddingLen protocol.ByteCount, + encLevel protocol.EncryptionLevel, + sealer handshake.Sealer, ) (*packedPacket, error) { packetBuffer := getPacketBuffer() buffer := bytes.NewBuffer(packetBuffer.Slice[:0]) frames := payload.frames - addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial - - if header.IsLongHeader { - if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { - header.Token = p.token - } - if addPaddingForInitial { - headerLen := header.GetLength(p.version) - header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen - } else { - // long header packets always use 4 byte packet number, so we never need to pad short payloads - header.Length = protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen) + payload.length - } - } - if err := header.Write(buffer, p.version); err != nil { return nil, err } payloadOffset := buffer.Len() - // write all frames but the last one - for _, frame := range frames[:len(frames)-1] { + if paddingLen > 0 { + buffer.Write(bytes.Repeat([]byte{0}, int(paddingLen))) + } + for _, frame := range frames { if err := frame.Write(buffer, p.version); err != nil { return nil, err } } - lastFrame := frames[len(frames)-1] - if addPaddingForInitial { - // when appending padding, we need to make sure that the last STREAM frames has the data length set - if sf, ok := lastFrame.(*wire.StreamFrame); ok { - sf.DataLenPresent = true - } - } else { - payloadLen := buffer.Len() - payloadOffset + int(lastFrame.Length(p.version)) - if paddingLen := 4 - int(header.PacketNumberLen) - payloadLen; paddingLen > 0 { - // Pad the packet such that packet number length + payload length is 4 bytes. - // This is needed to enable the peer to get a 16 byte sample for header protection. - buffer.Write(bytes.Repeat([]byte{0}, paddingLen)) - } - } - if err := lastFrame.Write(buffer, p.version); err != nil { - return nil, err - } - - if addPaddingForInitial { - paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len() - if paddingLen > 0 { - buffer.Write(bytes.Repeat([]byte{0}, paddingLen)) - } - } if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize { return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) diff --git a/packet_packer_test.go b/packet_packer_test.go index 222f4ccbf..59fa2fb8e 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -34,7 +34,7 @@ var _ = Describe("Packet packer", func() { r := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(r, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(0, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen))) + ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen))) } appendFrames := func(fs, frames []wire.Frame) ([]wire.Frame, protocol.ByteCount) { @@ -823,7 +823,7 @@ var _ = Describe("Packet packer", func() { firstPayloadByte, err := r.ReadByte() Expect(err).ToNot(HaveOccurred()) Expect(firstPayloadByte).To(Equal(byte(0))) - // ... followed by the stream frame + // ... followed by the STREAM frame frameParser := wire.NewFrameParser(packer.version) frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -865,23 +865,22 @@ var _ = Describe("Packet packer", func() { }) Context("retransmitions", func() { - sf := &wire.StreamFrame{Data: []byte("foobar")} + cf := &wire.CryptoFrame{Data: []byte("foo")} It("packs a retransmission with the right encryption level", func() { - f := &wire.CryptoFrame{Data: []byte("foo")} pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) packet := &ackhandler.Packet{ PacketType: protocol.PacketTypeHandshake, EncryptionLevel: protocol.EncryptionInitial, - Frames: []wire.Frame{f}, + Frames: []wire.Frame{cf}, } p, err := packer.PackRetransmission(packet) Expect(err).ToNot(HaveOccurred()) Expect(p).To(HaveLen(1)) Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(p[0].frames).To(Equal([]wire.Frame{f})) + Expect(p[0].frames).To(Equal([]wire.Frame{cf})) Expect(p[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) }) @@ -895,13 +894,13 @@ var _ = Describe("Packet packer", func() { packet := &ackhandler.Packet{ PacketType: protocol.PacketTypeInitial, EncryptionLevel: protocol.EncryptionInitial, - Frames: []wire.Frame{sf}, + Frames: []wire.Frame{cf}, } packets, err := packer.PackRetransmission(packet) Expect(err).ToNot(HaveOccurred()) Expect(packets).To(HaveLen(1)) p := packets[0] - Expect(p.frames).To(Equal([]wire.Frame{sf})) + Expect(p.frames).To(Equal([]wire.Frame{cf})) Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.header.Type).To(Equal(protocol.PacketTypeInitial)) Expect(p.header.Token).To(Equal(token)) From f112edd894a3244cd94a4c132bc2780e3a67caf8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 22 Apr 2019 12:45:53 +0900 Subject: [PATCH 4/4] handle ACK frames separately when packing packets --- packet_packer.go | 77 +++++++++++++++++++------------------------ packet_packer_test.go | 29 ++++++++-------- 2 files changed, 50 insertions(+), 56 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index 6f0dca02b..b52a29c4f 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -27,12 +27,14 @@ type packer interface { type payload struct { frames []wire.Frame + ack *wire.AckFrame length protocol.ByteCount } type packedPacket struct { header *wire.ExtendedHeader raw []byte + ack *wire.AckFrame frames []wire.Frame buffer *packetBuffer @@ -57,24 +59,11 @@ func (p *packedPacket) IsAckEliciting() bool { } func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { - var frames []wire.Frame - var ack *wire.AckFrame - if len(p.frames) > 0 { - var ok bool - ack, ok = p.frames[0].(*wire.AckFrame) - if ok { - // make a copy, so that the ACK can be garbage collected - frames = make([]wire.Frame, len(p.frames)-1) - copy(frames, p.frames[1:]) - } else { - frames = p.frames - } - } return &ackhandler.Packet{ PacketNumber: p.header.PacketNumber, PacketType: p.header.Type, - Ack: ack, - Frames: frames, + Ack: p.ack, + Frames: p.frames, Length: protocol.ByteCount(len(p.raw)), EncryptionLevel: p.EncryptionLevel(), SendTime: time.Now(), @@ -185,7 +174,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { return nil, nil } payload := payload{ - frames: []wire.Frame{ack}, + ack: ack, length: ack.Length(p.version), } // TODO(#1534): only pack ACKs with the right encryption level @@ -291,12 +280,11 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { return nil, err } - // Check if we have enough frames to send - if len(payload.frames) == 0 { + // check if we have anything to send + if len(payload.frames) == 0 && payload.ack == nil { return nil, nil } - // check if this packet only contains an ACK - if !ackhandler.HasAckElicitingFrames(payload.frames) { + if len(payload.frames) == 0 { // the packet only contains an ACK if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { ping := &wire.PingFrame{} payload.frames = append(payload.frames, ping) @@ -338,35 +326,32 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { return nil, err } + var payload payload + if ack != nil { + payload.ack = ack + payload.length = ack.Length(p.version) + } hdr := p.getHeader(encLevel) hdrLen := hdr.GetLength(p.version) - var length protocol.ByteCount - frames := make([]wire.Frame, 0, 2) - if ack != nil { - frames = append(frames, ack) - length += ack.Length(p.version) - } if hasData { - cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length) - frames = append(frames, cf) - length += cf.Length(p.version) + cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length) + payload.frames = []wire.Frame{cf} + payload.length += cf.Length(p.version) } - return p.writeAndSealPacket(hdr, payload{frames: frames, length: length}, encLevel, sealer) + return p.writeAndSealPacket(hdr, payload, encLevel, sealer) } func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (payload, error) { - var length protocol.ByteCount - var frames []wire.Frame + var payload payload // ACKs need to go first, so we recognize them in packedPacket.ToAckHandlerPacket() if ack := p.acks.GetAckFrame(protocol.Encryption1RTT); ack != nil { - frames = append(frames, ack) - length += ack.Length(p.version) + payload.ack = ack + payload.length += ack.Length(p.version) } - var lengthAdded protocol.ByteCount - frames, lengthAdded = p.framer.AppendControlFrames(frames, maxFrameSize-length) - length += lengthAdded + frames, lengthAdded := p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length) + payload.length += lengthAdded // temporarily increase the maxFrameSize by the (minimum) length of the DataLen field // this leads to a properly sized packet in all cases, since we do all the packet length calculations with STREAM frames that have the DataLen set @@ -374,15 +359,16 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (paylo // the length is encoded to either 1 or 2 bytes maxFrameSize++ - frames, lengthAdded = p.framer.AppendStreamFrames(frames, maxFrameSize-length) + frames, lengthAdded = p.framer.AppendStreamFrames(frames, maxFrameSize-payload.length) if len(frames) > 0 { lastFrame := frames[len(frames)-1] if sf, ok := lastFrame.(*wire.StreamFrame); ok { sf.DataLenPresent = false } - length += lengthAdded + payload.frames = append(payload.frames, frames...) + payload.length += lengthAdded } - return payload{frames: frames, length: length}, nil + return payload, nil } func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { @@ -447,17 +433,21 @@ func (p *packetPacker) writeAndSealPacketWithPadding( ) (*packedPacket, error) { packetBuffer := getPacketBuffer() buffer := bytes.NewBuffer(packetBuffer.Slice[:0]) - frames := payload.frames if err := header.Write(buffer, p.version); err != nil { return nil, err } payloadOffset := buffer.Len() + if payload.ack != nil { + if err := payload.ack.Write(buffer, p.version); err != nil { + return nil, err + } + } if paddingLen > 0 { buffer.Write(bytes.Repeat([]byte{0}, int(paddingLen))) } - for _, frame := range frames { + for _, frame := range payload.frames { if err := frame.Write(buffer, p.version); err != nil { return nil, err } @@ -485,7 +475,8 @@ func (p *packetPacker) writeAndSealPacketWithPadding( return &packedPacket{ header: header, raw: raw, - frames: frames, + ack: payload.ack, + frames: payload.frames, buffer: packetBuffer, }, nil } diff --git a/packet_packer_test.go b/packet_packer_test.go index 59fa2fb8e..b5a9f6152 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -272,7 +272,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.frames[0]).To(Equal(ack)) + Expect(p.ack).To(Equal(ack)) }) It("packs a CONNECTION_CLOSE", func() { @@ -340,7 +340,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) p, err := packer.MaybePackAckPacket() Expect(err).NotTo(HaveOccurred()) - Expect(p.frames).To(Equal([]wire.Frame{ack})) + Expect(p.ack).To(Equal(ack)) }) }) @@ -356,7 +356,8 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) + Expect(p.ack).ToNot(BeNil()) + Expect(p.frames).To(BeEmpty()) } } @@ -382,7 +383,8 @@ var _ = Describe("Packet packer", func() { p, err = packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) + Expect(p.ack).ToNot(BeNil()) + Expect(p.frames).To(BeEmpty()) }) It("waits until there's something to send before adding a PING frame", func() { @@ -402,14 +404,15 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames).To(ContainElement(&wire.PingFrame{})) + Expect(p.ack).To(Equal(ack)) + Expect(p.frames).To(Equal([]wire.Frame{&wire.PingFrame{}})) }) - It("doesn't send a PING if it already sent another ack-elicitng frame", func() { + It("doesn't send a PING if it already sent another ack-eliciting frame", func() { sendMaxNumNonAckElicitingAcks() pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) @@ -746,7 +749,7 @@ var _ = Describe("Packet packer", func() { checkLength(p.raw) }) - It("sends a Initial packet containing only an ACK", func() { + It("sends an Initial packet containing only an ACK", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) initialStream.EXPECT().HasData() @@ -755,7 +758,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(Equal([]wire.Frame{ack})) + Expect(p.ack).To(Equal(ack)) }) It("sends a Handshake packet containing only an ACK", func() { @@ -769,7 +772,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(Equal([]wire.Frame{ack})) + Expect(p.ack).To(Equal(ack)) }) It("pads Initial packets to the required minimum packet size", func() { @@ -860,8 +863,8 @@ var _ = Describe("Packet packer", func() { packet, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize)) - Expect(packet.frames).To(HaveLen(2)) - Expect(packet.frames[0]).To(Equal(ack)) + Expect(packet.ack).To(Equal(ack)) + Expect(packet.frames).To(HaveLen(1)) }) Context("retransmitions", func() {