From 5d8293716e9575f1b8ea1b2d69bcb6f175ac212d Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Tue, 26 Apr 2016 19:04:42 +0200 Subject: [PATCH] add includeStreamFrames option to PacketPacker --- packet_packer.go | 18 ++++++++++------- packet_packer_test.go | 45 ++++++++++++++++++++++++++++++------------- session.go | 4 ++-- 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index d0e44b4b..2ed308c0 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -34,25 +34,25 @@ func (p *packetPacker) AddStreamFrame(f frames.StreamFrame) { p.mutex.Unlock() } -func (p *packetPacker) PackPacket(controlFrames []frames.Frame) (*packedPacket, error) { +func (p *packetPacker) PackPacket(controlFrames []frames.Frame, includeStreamFrames bool) (*packedPacket, error) { // TODO: save controlFrames as a member variable, makes it easier to handle in the unlikely event that there are more controlFrames than you can put into on packet p.mutex.Lock() defer p.mutex.Unlock() // TODO: Split up? - if len(p.queuedStreamFrames) == 0 { - return nil, nil - } - currentPacketNumber := protocol.PacketNumber(atomic.AddUint64( (*uint64)(&p.lastPacketNumber), 1, )) - payloadFrames, err := p.composeNextPacket(controlFrames) + payloadFrames, err := p.composeNextPacket(controlFrames, includeStreamFrames) if err != nil { return nil, err } + if len(payloadFrames) == 0 { + return nil, nil + } + payload, err := p.getPayload(payloadFrames, currentPacketNumber) if err != nil { return nil, err @@ -99,7 +99,7 @@ func (p *packetPacker) getPayload(frames []frames.Frame, currentPacketNumber pro return payload.Bytes(), nil } -func (p *packetPacker) composeNextPacket(controlFrames []frames.Frame) ([]frames.Frame, error) { +func (p *packetPacker) composeNextPacket(controlFrames []frames.Frame, includeStreamFrames bool) ([]frames.Frame, error) { payloadLength := 0 var payloadFrames []frames.Frame @@ -111,6 +111,10 @@ func (p *packetPacker) composeNextPacket(controlFrames []frames.Frame) ([]frames controlFrames = controlFrames[1:] } + if !includeStreamFrames { + return payloadFrames, nil + } + for len(p.queuedStreamFrames) > 0 { frame := p.queuedStreamFrames[0] diff --git a/packet_packer_test.go b/packet_packer_test.go index be0725ec..d4e15dd4 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -21,7 +21,7 @@ var _ = Describe("Packet packer", func() { }) It("returns nil when no packet is queued", func() { - p, err := packer.PackPacket([]frames.Frame{}) + p, err := packer.PackPacket([]frames.Frame{}, true) Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) @@ -32,7 +32,7 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, } packer.AddStreamFrame(f) - p, err := packer.PackPacket([]frames.Frame{}) + p, err := packer.PackPacket([]frames.Frame{}, true) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) b := &bytes.Buffer{} @@ -41,6 +41,25 @@ var _ = Describe("Packet packer", func() { Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) }) + It("does not pack stream frames if includeStreamFrames=false", func() { + f := frames.StreamFrame{ + StreamID: 5, + Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, + } + packer.AddStreamFrame(f) + p, err := packer.PackPacket([]frames.Frame{}, false) + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + }) + + It("packs only control frames", func() { + p, err := packer.PackPacket([]frames.Frame{&frames.ConnectionCloseFrame{}}, false) + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(len(p.frames)).To(Equal(1)) + Expect(p.raw).NotTo(HaveLen(0)) + }) + It("packs multiple stream frames into single packet", func() { f1 := frames.StreamFrame{ StreamID: 5, @@ -52,7 +71,7 @@ var _ = Describe("Packet packer", func() { } packer.AddStreamFrame(f1) packer.AddStreamFrame(f2) - p, err := packer.PackPacket([]frames.Frame{}) + p, err := packer.PackPacket([]frames.Frame{}, true) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) b := &bytes.Buffer{} @@ -72,10 +91,10 @@ var _ = Describe("Packet packer", func() { // packer.AddFrame(f) // counter++ // } - // payloadFrames, err := packer.composeNextPacket([]frames.Frame{}) + // payloadFrames, err := packer.composeNextPacket([]frames.Frame{}, true) // Expect(err).ToNot(HaveOccurred()) // Expect(len(payloadFrames)).To(Equal(maxFramesPerPacket)) - // payloadFrames, err = packer.composeNextPacket([]frames.Frame{}) + // payloadFrames, err = packer.composeNextPacket([]frames.Frame{}, true) // Expect(err).ToNot(HaveOccurred()) // Expect(len(payloadFrames)).To(Equal(counter - maxFramesPerPacket)) // }) @@ -88,10 +107,10 @@ var _ = Describe("Packet packer", func() { Offset: 1, } packer.AddStreamFrame(f) - payloadFrames, err := packer.composeNextPacket([]frames.Frame{}) + payloadFrames, err := packer.composeNextPacket([]frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) Expect(len(payloadFrames)).To(Equal(1)) - payloadFrames, err = packer.composeNextPacket([]frames.Frame{}) + payloadFrames, err = packer.composeNextPacket([]frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) Expect(len(payloadFrames)).To(Equal(0)) }) @@ -108,13 +127,13 @@ var _ = Describe("Packet packer", func() { } packer.AddStreamFrame(f1) packer.AddStreamFrame(f2) - p, err := packer.PackPacket([]frames.Frame{}) + p, err := packer.PackPacket([]frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) Expect(len(p.raw)).To(Equal(protocol.MaxPacketSize)) - p, err = packer.PackPacket([]frames.Frame{}) + p, err = packer.PackPacket([]frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) Expect(len(p.raw)).To(Equal(protocol.MaxPacketSize)) - p, err = packer.PackPacket([]frames.Frame{}) + p, err = packer.PackPacket([]frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) }) @@ -125,7 +144,7 @@ var _ = Describe("Packet packer", func() { Offset: 1, } packer.AddStreamFrame(f) - p, err := packer.PackPacket([]frames.Frame{}) + p, err := packer.PackPacket([]frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(len(p.raw)).To(Equal(protocol.MaxPacketSize)) @@ -137,10 +156,10 @@ var _ = Describe("Packet packer", func() { Offset: 1, } packer.AddStreamFrame(f) - payloadFrames, err := packer.composeNextPacket([]frames.Frame{}) + payloadFrames, err := packer.composeNextPacket([]frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) Expect(len(payloadFrames)).To(Equal(1)) - payloadFrames, err = packer.composeNextPacket([]frames.Frame{}) + payloadFrames, err = packer.composeNextPacket([]frames.Frame{}, true) Expect(err).ToNot(HaveOccurred()) Expect(len(payloadFrames)).To(Equal(1)) }) diff --git a/session.go b/session.go index eae77601..29013368 100644 --- a/session.go +++ b/session.go @@ -231,7 +231,7 @@ func (s *Session) sendPacket() error { if ack != nil { controlFrames = append(controlFrames, ack) } - packet, err := s.packer.PackPacket(controlFrames) + packet, err := s.packer.PackPacket(controlFrames, true) if err != nil { return err @@ -252,7 +252,7 @@ func (s *Session) sendPacket() error { return nil } -// QueueFrame queues a frame for sending to the client +// QueueStreamFrame queues a frame for sending to the client func (s *Session) QueueStreamFrame(frame *frames.StreamFrame) error { s.packer.AddStreamFrame(*frame) return nil