forked from quic-go/quic-go
add an Iterate function to the StreamsMap
This commit is contained in:
@@ -13,6 +13,8 @@ type streamsMap struct {
|
|||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type streamLambda func(*stream) (bool, error)
|
||||||
|
|
||||||
func newStreamsMap() *streamsMap {
|
func newStreamsMap() *streamsMap {
|
||||||
return &streamsMap{
|
return &streamsMap{
|
||||||
streams: map[protocol.StreamID]*stream{},
|
streams: map[protocol.StreamID]*stream{},
|
||||||
@@ -29,6 +31,22 @@ func (m *streamsMap) GetStream(id protocol.StreamID) (*stream, bool) {
|
|||||||
return s, true
|
return s, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *streamsMap) Iterate(fn streamLambda) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
for _, str := range m.streams {
|
||||||
|
cont, err := fn(str)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !cont {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *streamsMap) PutStream(s *stream) error {
|
func (m *streamsMap) PutStream(s *stream) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -40,9 +58,8 @@ func (m *streamsMap) PutStream(s *stream) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attention: this function must only be called if a mutex has been acquired previously
|
||||||
func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
|
func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
s, ok := m.streams[id]
|
s, ok := m.streams[id]
|
||||||
if !ok || s == nil {
|
if !ok || s == nil {
|
||||||
return fmt.Errorf("attempted to remove non-existing stream: %d", id)
|
return fmt.Errorf("attempted to remove non-existing stream: %d", id)
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/protocol"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
@@ -57,4 +59,55 @@ var _ = Describe("Streams Map", func() {
|
|||||||
m.RemoveStream(5)
|
m.RemoveStream(5)
|
||||||
Expect(m.NumberOfStreams()).To(Equal(0))
|
Expect(m.NumberOfStreams()).To(Equal(0))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("Lambda", func() {
|
||||||
|
// create 5 streams, ids 1 to 3
|
||||||
|
BeforeEach(func() {
|
||||||
|
for i := 1; i <= 3; i++ {
|
||||||
|
err := m.PutStream(&stream{streamID: protocol.StreamID(i)})
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("executes the lambda exactly once for every stream", func() {
|
||||||
|
var numIterations int
|
||||||
|
callbackCalled := make(map[protocol.StreamID]bool)
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
callbackCalled[str.StreamID()] = true
|
||||||
|
numIterations++
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
err := m.Iterate(fn)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(callbackCalled).To(HaveKey(protocol.StreamID(1)))
|
||||||
|
Expect(callbackCalled).To(HaveKey(protocol.StreamID(2)))
|
||||||
|
Expect(callbackCalled).To(HaveKey(protocol.StreamID(3)))
|
||||||
|
Expect(numIterations).To(Equal(3))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("stops iterating when the callback returns false", func() {
|
||||||
|
var numIterations int
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
numIterations++
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
err := m.Iterate(fn)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
// due to map access randomization, we don't know for which stream the callback was executed
|
||||||
|
// but it must only be executed once
|
||||||
|
Expect(numIterations).To(Equal(1))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the error, if the lambda returns one", func() {
|
||||||
|
var numIterations int
|
||||||
|
expectedError := errors.New("test")
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
numIterations++
|
||||||
|
return true, expectedError
|
||||||
|
}
|
||||||
|
err := m.Iterate(fn)
|
||||||
|
Expect(err).To(MatchError(expectedError))
|
||||||
|
Expect(numIterations).To(Equal(1))
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user