diff --git a/stream_frame_queue.go b/stream_frame_queue.go index 2f7f58820..1070bed27 100644 --- a/stream_frame_queue.go +++ b/stream_frame_queue.go @@ -76,10 +76,17 @@ func (q *streamFrameQueue) Pop(maxLength protocol.ByteCount) (*frames.StreamFram var streamID protocol.StreamID var err error - if len(q.prioFrames) > 0 { + for len(q.prioFrames) > 0 { frame = q.prioFrames[0] + if frame == nil { + q.prioFrames = q.prioFrames[1:] + continue + } isPrioFrame = true - } else { + break + } + + if !isPrioFrame { streamID, err = q.getNextStream() if err != nil { return nil, err @@ -115,6 +122,34 @@ func (q *streamFrameQueue) Pop(maxLength protocol.ByteCount) (*frames.StreamFram return frame, nil } +func (q *streamFrameQueue) RemoveStream(streamID protocol.StreamID) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for i, frame := range q.prioFrames { + if frame.StreamID == streamID { + q.byteLen -= frame.DataLen() + q.len-- + q.prioFrames[i] = nil + } + } + + frameQueue, ok := q.frameMap[streamID] + if ok { + for _, frame := range frameQueue { + q.byteLen -= frame.DataLen() + q.len-- + } + delete(q.frameMap, streamID) + } + + for i, s := range q.activeStreams { + if s == streamID { + q.activeStreams[i] = 0 + } + } +} + // front returns the next element without modifying the queue // has to be called from a function that has already acquired the mutex func (q *streamFrameQueue) getNextStream() (protocol.StreamID, error) { @@ -124,17 +159,22 @@ func (q *streamFrameQueue) getNextStream() (protocol.StreamID, error) { var counter int for counter < len(q.activeStreams) { + counter++ streamID := q.activeStreams[q.activeStreamsPosition] + q.activeStreamsPosition = (q.activeStreamsPosition + 1) % len(q.activeStreams) + + if streamID == 0 { // this happens if the stream was deleted + continue + } + frameQueue, ok := q.frameMap[streamID] if !ok { return 0, errMapAccess } + if len(frameQueue) > 0 { - q.activeStreamsPosition = (q.activeStreamsPosition + 1) % len(q.activeStreams) return streamID, nil } - q.activeStreamsPosition = (q.activeStreamsPosition + 1) % len(q.activeStreams) - counter++ } return 0, nil diff --git a/stream_frame_queue_test.go b/stream_frame_queue_test.go index 33b728dcd..d34a1e8ae 100644 --- a/stream_frame_queue_test.go +++ b/stream_frame_queue_test.go @@ -65,6 +65,22 @@ var _ = Describe("streamFrameQueue", func() { queue.Pop(1000) Expect(queue.Len()).To(Equal(0)) }) + + It("reduces the length when deleting a stream for which a prio frame was queued", func() { + queue.Push(prioFrame1, true) + queue.Push(prioFrame2, true) + Expect(queue.Len()).To(Equal(2)) + queue.RemoveStream(prioFrame1.StreamID) + Expect(queue.Len()).To(Equal(1)) + }) + + It("reduces the length when deleting a stream for which a normal frame was queued", func() { + queue.Push(frame1, false) + queue.Push(frame2, false) + Expect(queue.Len()).To(Equal(2)) + queue.RemoveStream(frame1.StreamID) + Expect(queue.Len()).To(Equal(1)) + }) }) Context("Queue Byte Length", func() { @@ -88,6 +104,22 @@ var _ = Describe("streamFrameQueue", func() { queue.Pop(1000) Expect(queue.ByteLen()).To(Equal(protocol.ByteCount(0))) }) + + It("reduces the byte length when deleting a stream for which a prio frame was queued", func() { + queue.Push(prioFrame1, true) + queue.Push(prioFrame2, true) + Expect(queue.ByteLen()).To(Equal(prioFrame1.DataLen() + prioFrame2.DataLen())) + queue.RemoveStream(prioFrame1.StreamID) + Expect(queue.ByteLen()).To(Equal(prioFrame2.DataLen())) + }) + + It("reduces the byte length when deleting a stream for which a normal frame was queued", func() { + queue.Push(frame1, false) + queue.Push(frame2, false) + Expect(queue.ByteLen()).To(Equal(frame1.DataLen() + frame2.DataLen())) + queue.RemoveStream(frame1.StreamID) + Expect(queue.ByteLen()).To(Equal(frame2.DataLen())) + }) }) Context("Pushing", func() { @@ -147,6 +179,22 @@ var _ = Describe("streamFrameQueue", func() { streamID, err := queue.getNextStream() Expect(err).ToNot(HaveOccurred()) Expect(streamID).To(Equal(frame2.StreamID)) + streamID, err = queue.getNextStream() + Expect(err).ToNot(HaveOccurred()) + Expect(streamID).To(Equal(frame1.StreamID)) + }) + + It("gets the next frame if a stream was deleted", func() { + queue.Push(frame2, false) + queue.Push(frame1, false) + Expect(queue.activeStreams).To(ContainElement(frame1.StreamID)) + Expect(queue.activeStreams).To(ContainElement(frame2.StreamID)) + queue.RemoveStream(frame2.StreamID) + Expect(queue.activeStreams).To(ContainElement(frame1.StreamID)) + Expect(queue.activeStreams).ToNot(ContainElement(frame2.StreamID)) + streamID, err := queue.getNextStream() + Expect(err).ToNot(HaveOccurred()) + Expect(streamID).To(Equal(frame1.StreamID)) }) }) @@ -333,4 +381,47 @@ var _ = Describe("streamFrameQueue", func() { }) }) }) + + Context("deleting streams", func() { + It("deletes prioFrames", func() { + queue.Push(prioFrame1, true) + queue.Push(prioFrame2, true) + queue.RemoveStream(prioFrame1.StreamID) + frame, err := queue.Pop(1000) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(prioFrame2)) + frame, err = queue.Pop(1000) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeNil()) + }) + + It("deletes the map entry", func() { + queue.Push(frame1, false) + queue.Push(frame2, false) + Expect(queue.frameMap).To(HaveKey(frame1.StreamID)) + queue.RemoveStream(frame1.StreamID) + Expect(queue.frameMap).ToNot(HaveKey(frame1.StreamID)) + }) + + It("gets a normal frame, when the stream of the prio frame was deleted", func() { + queue.Push(prioFrame1, true) + queue.Push(frame1, true) + queue.RemoveStream(prioFrame1.StreamID) + frame, err := queue.Pop(1000) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(frame1)) + frame, err = queue.Pop(1000) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeNil()) + }) + + It("deletes frames", func() { + queue.Push(frame1, false) + queue.Push(frame2, false) + queue.RemoveStream(frame1.StreamID) + frame, err := queue.Pop(1000) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(frame2)) + }) + }) })