forked from quic-go/quic-go
@@ -98,14 +98,7 @@ func (m *streamsMap) Iterate(fn streamLambda) error {
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
for _, streamID := range m.openStreams {
|
||||
str, ok := m.streams[streamID]
|
||||
if !ok {
|
||||
return errMapAccess
|
||||
}
|
||||
if str == nil {
|
||||
return fmt.Errorf("BUG: Stream %d is closed, but still in openStreams map", streamID)
|
||||
}
|
||||
cont, err := fn(str)
|
||||
cont, err := m.iterateFunc(streamID, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -116,6 +109,9 @@ func (m *streamsMap) Iterate(fn streamLambda) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false
|
||||
// It uses a round-robin-like scheduling to ensure that every stream is considered fairly
|
||||
// It prioritizes the crypto- and the header-stream (StreamIDs 1 and 3)
|
||||
func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -123,16 +119,24 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
|
||||
numStreams := len(m.openStreams)
|
||||
startIndex := m.roundRobinIndex
|
||||
|
||||
for _, i := range []protocol.StreamID{1, 3} {
|
||||
cont, err := m.iterateFunc(i, fn)
|
||||
if err != nil && err != errMapAccess {
|
||||
return err
|
||||
}
|
||||
if !cont {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < numStreams; i++ {
|
||||
streamID := m.openStreams[(i+startIndex)%numStreams]
|
||||
str, ok := m.streams[streamID]
|
||||
if !ok {
|
||||
return errMapAccess
|
||||
|
||||
if streamID == 1 || streamID == 3 {
|
||||
continue
|
||||
}
|
||||
if str == nil {
|
||||
return fmt.Errorf("BUG: Stream %d is closed, but still in openStreams map", streamID)
|
||||
}
|
||||
cont, err := fn(str)
|
||||
|
||||
cont, err := m.iterateFunc(streamID, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -144,6 +148,17 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) {
|
||||
str, ok := m.streams[streamID]
|
||||
if !ok {
|
||||
return true, errMapAccess
|
||||
}
|
||||
if str == nil {
|
||||
return false, fmt.Errorf("BUG: Stream %d is closed, but still in openStreams map", streamID)
|
||||
}
|
||||
return fn(str)
|
||||
}
|
||||
|
||||
func (m *streamsMap) putStream(s *stream) error {
|
||||
id := s.StreamID()
|
||||
if _, ok := m.streams[id]; ok {
|
||||
|
||||
@@ -281,14 +281,14 @@ var _ = Describe("Streams Map", func() {
|
||||
})
|
||||
|
||||
Context("RoundRobinIterate", func() {
|
||||
// create 5 streams, ids 1 to 5
|
||||
// create 5 streams, ids 4 to 8
|
||||
var lambdaCalledForStream []protocol.StreamID
|
||||
var numIterations int
|
||||
|
||||
BeforeEach(func() {
|
||||
lambdaCalledForStream = lambdaCalledForStream[:0]
|
||||
numIterations = 0
|
||||
for i := 1; i <= 5; i++ {
|
||||
for i := 4; i <= 8; i++ {
|
||||
err := m.putStream(&stream{streamID: protocol.StreamID(i)})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
@@ -303,7 +303,7 @@ var _ = Describe("Streams Map", func() {
|
||||
err := m.RoundRobinIterate(fn)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(numIterations).To(Equal(5))
|
||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5}))
|
||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
|
||||
Expect(m.roundRobinIndex).To(BeZero())
|
||||
})
|
||||
|
||||
@@ -313,11 +313,11 @@ var _ = Describe("Streams Map", func() {
|
||||
numIterations++
|
||||
return true, nil
|
||||
}
|
||||
m.roundRobinIndex = 3
|
||||
m.roundRobinIndex = 3 // pointing to stream 7
|
||||
err := m.RoundRobinIterate(fn)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(numIterations).To(Equal(5))
|
||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 1, 2, 3}))
|
||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{7, 8, 4, 5, 6}))
|
||||
Expect(m.roundRobinIndex).To(Equal(3))
|
||||
})
|
||||
|
||||
@@ -325,7 +325,7 @@ var _ = Describe("Streams Map", func() {
|
||||
fn := func(str *stream) (bool, error) {
|
||||
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
||||
numIterations++
|
||||
if str.StreamID() == 2 {
|
||||
if str.StreamID() == 5 {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
@@ -333,14 +333,14 @@ var _ = Describe("Streams Map", func() {
|
||||
err := m.RoundRobinIterate(fn)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(numIterations).To(Equal(2))
|
||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 2}))
|
||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5}))
|
||||
Expect(m.roundRobinIndex).To(Equal(1))
|
||||
numIterations = 0
|
||||
lambdaCalledForStream = lambdaCalledForStream[:0]
|
||||
fn2 := func(str *stream) (bool, error) {
|
||||
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
||||
numIterations++
|
||||
if str.StreamID() == 4 {
|
||||
if str.StreamID() == 7 {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
@@ -348,25 +348,50 @@ var _ = Describe("Streams Map", func() {
|
||||
err = m.RoundRobinIterate(fn2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(numIterations).To(Equal(3))
|
||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{2, 3, 4}))
|
||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{5, 6, 7}))
|
||||
})
|
||||
|
||||
It("adjust the RoundRobinIndex when deleting an element in front", func() {
|
||||
m.roundRobinIndex = 3 // stream 4
|
||||
m.RemoveStream(2)
|
||||
m.roundRobinIndex = 3 // stream 7
|
||||
m.RemoveStream(5)
|
||||
Expect(m.roundRobinIndex).To(Equal(2))
|
||||
})
|
||||
|
||||
It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() {
|
||||
m.roundRobinIndex = 1 // stream 2
|
||||
m.RemoveStream(4)
|
||||
m.roundRobinIndex = 1 // stream 5
|
||||
m.RemoveStream(7)
|
||||
Expect(m.roundRobinIndex).To(Equal(1))
|
||||
})
|
||||
|
||||
It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() {
|
||||
m.roundRobinIndex = 3 // stream 4
|
||||
m.RemoveStream(4)
|
||||
m.roundRobinIndex = 3 // stream 7
|
||||
m.RemoveStream(7)
|
||||
Expect(m.roundRobinIndex).To(Equal(3))
|
||||
})
|
||||
|
||||
Context("Prioritizing crypto- and header streams", func() {
|
||||
BeforeEach(func() {
|
||||
err := m.putStream(&stream{streamID: 1})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = m.putStream(&stream{streamID: 3})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("gets crypto- and header stream first, then picks up at the round-robin position", func() {
|
||||
m.roundRobinIndex = 3 // stream 7
|
||||
fn := func(str *stream) (bool, error) {
|
||||
if numIterations >= 3 {
|
||||
return false, nil
|
||||
}
|
||||
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
||||
numIterations++
|
||||
return true, nil
|
||||
}
|
||||
err := m.RoundRobinIterate(fn)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(numIterations).To(Equal(3))
|
||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 3, 7}))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user