diff --git a/streams_map.go b/streams_map.go index 14b74c69..bf36f4f5 100644 --- a/streams_map.go +++ b/streams_map.go @@ -98,14 +98,7 @@ func (m *streamsMap) Iterate(fn streamLambda) error { defer m.mutex.Unlock() for _, streamID := range m.openStreams { - str, ok := m.streams[streamID] - if !ok { - return errMapAccess - } - if str == nil { - return fmt.Errorf("BUG: Stream %d is closed, but still in openStreams map", streamID) - } - cont, err := fn(str) + cont, err := m.iterateFunc(streamID, fn) if err != nil { return err } @@ -116,6 +109,9 @@ func (m *streamsMap) Iterate(fn streamLambda) error { return nil } +// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false +// It uses a round-robin-like scheduling to ensure that every stream is considered fairly +// It prioritizes the crypto- and the header-stream (StreamIDs 1 and 3) func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -123,16 +119,24 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { numStreams := len(m.openStreams) startIndex := m.roundRobinIndex + for _, i := range []protocol.StreamID{1, 3} { + cont, err := m.iterateFunc(i, fn) + if err != nil && err != errMapAccess { + return err + } + if !cont { + return nil + } + } + for i := 0; i < numStreams; i++ { streamID := m.openStreams[(i+startIndex)%numStreams] - str, ok := m.streams[streamID] - if !ok { - return errMapAccess + + if streamID == 1 || streamID == 3 { + continue } - if str == nil { - return fmt.Errorf("BUG: Stream %d is closed, but still in openStreams map", streamID) - } - cont, err := fn(str) + + cont, err := m.iterateFunc(streamID, fn) if err != nil { return err } @@ -144,6 +148,17 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { return nil } +func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) { + str, ok := m.streams[streamID] + if !ok { + return true, errMapAccess + } + if str == nil { + return false, fmt.Errorf("BUG: Stream %d is closed, but still in openStreams map", streamID) + } + return fn(str) +} + func (m *streamsMap) putStream(s *stream) error { id := s.StreamID() if _, ok := m.streams[id]; ok { diff --git a/streams_map_test.go b/streams_map_test.go index 59268c74..9edd56b4 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -281,14 +281,14 @@ var _ = Describe("Streams Map", func() { }) Context("RoundRobinIterate", func() { - // create 5 streams, ids 1 to 5 + // create 5 streams, ids 4 to 8 var lambdaCalledForStream []protocol.StreamID var numIterations int BeforeEach(func() { lambdaCalledForStream = lambdaCalledForStream[:0] numIterations = 0 - for i := 1; i <= 5; i++ { + for i := 4; i <= 8; i++ { err := m.putStream(&stream{streamID: protocol.StreamID(i)}) Expect(err).NotTo(HaveOccurred()) } @@ -303,7 +303,7 @@ var _ = Describe("Streams Map", func() { err := m.RoundRobinIterate(fn) Expect(err).ToNot(HaveOccurred()) Expect(numIterations).To(Equal(5)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5})) + Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8})) Expect(m.roundRobinIndex).To(BeZero()) }) @@ -313,11 +313,11 @@ var _ = Describe("Streams Map", func() { numIterations++ return true, nil } - m.roundRobinIndex = 3 + m.roundRobinIndex = 3 // pointing to stream 7 err := m.RoundRobinIterate(fn) Expect(err).ToNot(HaveOccurred()) Expect(numIterations).To(Equal(5)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 1, 2, 3})) + Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{7, 8, 4, 5, 6})) Expect(m.roundRobinIndex).To(Equal(3)) }) @@ -325,7 +325,7 @@ var _ = Describe("Streams Map", func() { fn := func(str *stream) (bool, error) { lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) numIterations++ - if str.StreamID() == 2 { + if str.StreamID() == 5 { return false, nil } return true, nil @@ -333,14 +333,14 @@ var _ = Describe("Streams Map", func() { err := m.RoundRobinIterate(fn) Expect(err).ToNot(HaveOccurred()) Expect(numIterations).To(Equal(2)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 2})) + Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5})) Expect(m.roundRobinIndex).To(Equal(1)) numIterations = 0 lambdaCalledForStream = lambdaCalledForStream[:0] fn2 := func(str *stream) (bool, error) { lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) numIterations++ - if str.StreamID() == 4 { + if str.StreamID() == 7 { return false, nil } return true, nil @@ -348,25 +348,50 @@ var _ = Describe("Streams Map", func() { err = m.RoundRobinIterate(fn2) Expect(err).ToNot(HaveOccurred()) Expect(numIterations).To(Equal(3)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{2, 3, 4})) + Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{5, 6, 7})) }) It("adjust the RoundRobinIndex when deleting an element in front", func() { - m.roundRobinIndex = 3 // stream 4 - m.RemoveStream(2) + m.roundRobinIndex = 3 // stream 7 + m.RemoveStream(5) Expect(m.roundRobinIndex).To(Equal(2)) }) It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() { - m.roundRobinIndex = 1 // stream 2 - m.RemoveStream(4) + m.roundRobinIndex = 1 // stream 5 + m.RemoveStream(7) Expect(m.roundRobinIndex).To(Equal(1)) }) It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() { - m.roundRobinIndex = 3 // stream 4 - m.RemoveStream(4) + m.roundRobinIndex = 3 // stream 7 + m.RemoveStream(7) Expect(m.roundRobinIndex).To(Equal(3)) }) + + Context("Prioritizing crypto- and header streams", func() { + BeforeEach(func() { + err := m.putStream(&stream{streamID: 1}) + Expect(err).NotTo(HaveOccurred()) + err = m.putStream(&stream{streamID: 3}) + Expect(err).NotTo(HaveOccurred()) + }) + + It("gets crypto- and header stream first, then picks up at the round-robin position", func() { + m.roundRobinIndex = 3 // stream 7 + fn := func(str *stream) (bool, error) { + if numIterations >= 3 { + return false, nil + } + lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) + numIterations++ + return true, nil + } + err := m.RoundRobinIterate(fn) + Expect(err).ToNot(HaveOccurred()) + Expect(numIterations).To(Equal(3)) + Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 3, 7})) + }) + }) }) })