From 495399ede69176f111e355568fa92ccaa063e5e6 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Mon, 19 Jun 2017 17:26:15 +0200 Subject: [PATCH] Remove PackPacket's control frames parameter With this change it would theoretically be possible for outdated control frames to be sent, but this is quite unlikely in practice. --- packet_packer.go | 3 +- packet_packer_test.go | 73 ++++++++++++++++++++++++------------------- session.go | 12 +++---- 3 files changed, 47 insertions(+), 41 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index 88e27d18..05c1c443 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -71,8 +71,7 @@ func (p *packetPacker) RetransmitNonForwardSecurePacket(stopWaitingFrame *frames // PackPacket packs a new packet // the stopWaitingFrame is *guaranteed* to be included in the next packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise -func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { - p.controlFrames = append(p.controlFrames, controlFrames...) +func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { return p.packPacket(stopWaitingFrame, leastUnacked, nil) } diff --git a/packet_packer_test.go b/packet_packer_test.go index 39c9709a..b9ed2bd6 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -79,7 +79,7 @@ var _ = Describe("Packet packer", func() { }) It("returns nil when no packet is queued", func() { - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + p, err := packer.PackPacket(nil, 0) Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) @@ -90,7 +90,7 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, } streamFramer.AddFrameForRetransmission(f) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b := &bytes.Buffer{} @@ -106,7 +106,7 @@ var _ = Describe("Packet packer", func() { Data: []byte("foobar"), } streamFramer.AddFrameForRetransmission(f) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) }) @@ -169,7 +169,9 @@ var _ = Describe("Packet packer", func() { }) It("packs only control frames", func() { - p, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}, &frames.WindowUpdateFrame{}}, 0) + packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{}) + packer.QueueControlFrameForNextPacket(&frames.WindowUpdateFrame{}) + p, err := packer.PackPacket(nil, 0) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(2)) @@ -177,10 +179,12 @@ var _ = Describe("Packet packer", func() { }) It("increases the packet number", func() { - p1, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}}, 0) + packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{}) + p1, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p1).ToNot(BeNil()) - p2, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}}, 0) + packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{}) + p2, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p2).ToNot(BeNil()) Expect(p2.number).To(BeNumerically(">", p1.number)) @@ -189,7 +193,8 @@ var _ = Describe("Packet packer", func() { It("packs a StopWaitingFrame first", func() { packer.packetNumberGenerator.next = 15 swf := &frames.StopWaitingFrame{LeastUnacked: 10} - p, err := packer.PackPacket(swf, []frames.Frame{&frames.RstStreamFrame{}}, 0) + packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{}) + p, err := packer.PackPacket(swf, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.frames).To(HaveLen(2)) @@ -200,21 +205,22 @@ var _ = Describe("Packet packer", func() { packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number packer.packetNumberGenerator.next = packetNumber swf := &frames.StopWaitingFrame{LeastUnacked: packetNumber - 0x100} - p, err := packer.PackPacket(swf, []frames.Frame{&frames.RstStreamFrame{}}, 0) + packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{}) + p, err := packer.PackPacket(swf, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.frames[0].(*frames.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) }) It("does not pack a packet containing only a StopWaitingFrame", func() { swf := &frames.StopWaitingFrame{LeastUnacked: 10} - p, err := packer.PackPacket(swf, []frames.Frame{}, 0) + p, err := packer.PackPacket(swf, 0) Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) It("packs a packet if it has queued control frames, but no new control frames", func() { packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}} - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) }) @@ -225,7 +231,7 @@ var _ = Describe("Packet packer", func() { packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}} packer.connectionID = 0x1337 packer.version = 123 - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) hdr, err := ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient) @@ -239,7 +245,7 @@ var _ = Describe("Packet packer", func() { packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}} packer.connectionID = 0x1337 - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) hdr, err := ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient) @@ -286,7 +292,7 @@ var _ = Describe("Packet packer", func() { It("only increases the packet number when there is an actual packet to send", func() { packer.packetNumberGenerator.nextToSkip = 1000 - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + p, err := packer.PackPacket(nil, 0) Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) @@ -295,7 +301,7 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, } streamFramer.AddFrameForRetransmission(f) - p, err = packer.PackPacket(nil, []frames.Frame{}, 0) + p, err = packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.number).To(Equal(protocol.PacketNumber(1))) @@ -336,12 +342,12 @@ var _ = Describe("Packet packer", func() { } streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f2) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1))) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) - p, err = packer.PackPacket(nil, []frames.Frame{}, 0) + p, err = packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) @@ -363,7 +369,7 @@ var _ = Describe("Packet packer", func() { streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f2) streamFramer.AddFrameForRetransmission(f3) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + p, err := packer.PackPacket(nil, 0) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) b := &bytes.Buffer{} @@ -417,23 +423,23 @@ var _ = Describe("Packet packer", func() { } streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f2) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) - p, err = packer.PackPacket(nil, []frames.Frame{}, 0) + p, err = packer.PackPacket(nil, 0) Expect(p.frames).To(HaveLen(2)) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeTrue()) Expect(p.frames[1].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(err).ToNot(HaveOccurred()) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) - p, err = packer.PackPacket(nil, []frames.Frame{}, 0) + p, err = packer.PackPacket(nil, 0) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - p, err = packer.PackPacket(nil, []frames.Frame{}, 0) + p, err = packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -446,7 +452,7 @@ var _ = Describe("Packet packer", func() { minLength, _ := f.MinLength(0) f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header streamFramer.AddFrameForRetransmission(f) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) @@ -476,7 +482,7 @@ var _ = Describe("Packet packer", func() { Data: []byte("foobar"), } streamFramer.AddFrameForRetransmission(f) - p, err := packer.PackPacket(nil, nil, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).NotTo(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -489,7 +495,7 @@ var _ = Describe("Packet packer", func() { Data: []byte("foobar"), } streamFramer.AddFrameForRetransmission(f) - p, err := packer.PackPacket(nil, nil, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) Expect(p.frames[0]).To(Equal(f)) @@ -502,7 +508,7 @@ var _ = Describe("Packet packer", func() { Data: []byte("foobar"), } streamFramer.AddFrameForRetransmission(f) - p, err := packer.PackPacket(nil, nil, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -510,7 +516,7 @@ var _ = Describe("Packet packer", func() { It("sends unencrypted stream data on the crypto stream", func() { packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted cryptoStream.dataForWriting = []byte("foobar") - p, err := packer.PackPacket(nil, nil, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) Expect(p.frames).To(HaveLen(1)) @@ -520,7 +526,7 @@ var _ = Describe("Packet packer", func() { It("sends encrypted stream data on the crypto stream", func() { packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure cryptoStream.dataForWriting = []byte("foobar") - p, err := packer.PackPacket(nil, nil, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) Expect(p.frames).To(HaveLen(1)) @@ -531,7 +537,7 @@ var _ = Describe("Packet packer", func() { packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted packer.QueueControlFrameForNextPacket(&frames.AckFrame{}) streamFramer.AddFrameForRetransmission(&frames.StreamFrame{StreamID: 3, Data: []byte("foobar")}) - p, err := packer.PackPacket(nil, nil, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) Expect(func() { _ = p.frames[0].(*frames.AckFrame) }).NotTo(Panic()) @@ -580,21 +586,24 @@ var _ = Describe("Packet packer", func() { }) It("returns nil if we only have a single STOP_WAITING", func() { - p, err := packer.PackPacket(&frames.StopWaitingFrame{}, nil, 0) + p, err := packer.PackPacket(&frames.StopWaitingFrame{}, 0) Expect(err).NotTo(HaveOccurred()) Expect(p).To(BeNil()) }) It("packs a single ACK", func() { ack := &frames.AckFrame{LargestAcked: 42} - p, err := packer.PackPacket(nil, []frames.Frame{ack}, 0) + packer.QueueControlFrameForNextPacket(ack) + p, err := packer.PackPacket(nil, 0) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.frames[0]).To(Equal(ack)) }) It("does not return nil if we only have a single ACK but request it to be sent", func() { - p, err := packer.PackPacket(nil, []frames.Frame{&frames.AckFrame{}}, 0) + ack := &frames.AckFrame{} + packer.QueueControlFrameForNextPacket(ack) + p, err := packer.PackPacket(nil, 0) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) }) @@ -602,7 +611,7 @@ var _ = Describe("Packet packer", func() { It("queues a control frame to be sent in the next packet", func() { wuf := &frames.WindowUpdateFrame{StreamID: 5} packer.QueueControlFrameForNextPacket(wuf) - p, err := packer.PackPacket(nil, nil, 0) + p, err := packer.PackPacket(nil, 0) Expect(err).NotTo(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0]).To(Equal(wuf)) diff --git a/session.go b/session.go index cee4912b..6222cca1 100644 --- a/session.go +++ b/session.go @@ -568,13 +568,11 @@ func (s *session) sendPacket() error { return nil } - var controlFrames []frames.Frame - // get WindowUpdate frames // this call triggers the flow controller to increase the flow control windows, if necessary windowUpdateFrames := s.getWindowUpdateFrames() for _, wuf := range windowUpdateFrames { - controlFrames = append(controlFrames, wuf) + s.packer.QueueControlFrameForNextPacket(wuf) } // check for retransmissions first @@ -617,10 +615,10 @@ func (s *session) sendPacket() error { f := frame.(*frames.WindowUpdateFrame) currentOffset, err := s.flowControlManager.GetReceiveWindow(f.StreamID) if err == nil && f.ByteOffset >= currentOffset { - controlFrames = append(controlFrames, frame) + s.packer.QueueControlFrameForNextPacket(f) } default: - controlFrames = append(controlFrames, frame) + s.packer.QueueControlFrameForNextPacket(frame) } } } @@ -628,14 +626,14 @@ func (s *session) sendPacket() error { ack := s.receivedPacketHandler.GetAckFrame() if ack != nil { - controlFrames = append(controlFrames, ack) + s.packer.QueueControlFrameForNextPacket(ack) } hasRetransmission := s.streamFramer.HasFramesForRetransmission() var stopWaitingFrame *frames.StopWaitingFrame if ack != nil || hasRetransmission { stopWaitingFrame = s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) } - packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, s.sentPacketHandler.GetLeastUnacked()) + packet, err := s.packer.PackPacket(stopWaitingFrame, s.sentPacketHandler.GetLeastUnacked()) if err != nil { return err }