diff --git a/streams_map.go b/streams_map.go index 5c26cff0f..262b0fdb2 100644 --- a/streams_map.go +++ b/streams_map.go @@ -13,6 +13,8 @@ type streamsMap struct { mutex sync.RWMutex } +type streamLambda func(*stream) (bool, error) + func newStreamsMap() *streamsMap { return &streamsMap{ streams: map[protocol.StreamID]*stream{}, @@ -29,6 +31,22 @@ func (m *streamsMap) GetStream(id protocol.StreamID) (*stream, bool) { return s, true } +func (m *streamsMap) Iterate(fn streamLambda) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + for _, str := range m.streams { + cont, err := fn(str) + if err != nil { + return err + } + if !cont { + break + } + } + return nil +} + func (m *streamsMap) PutStream(s *stream) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -40,9 +58,8 @@ func (m *streamsMap) PutStream(s *stream) error { return nil } +// Attention: this function must only be called if a mutex has been acquired previously func (m *streamsMap) RemoveStream(id protocol.StreamID) error { - m.mutex.Lock() - defer m.mutex.Unlock() s, ok := m.streams[id] if !ok || s == nil { return fmt.Errorf("attempted to remove non-existing stream: %d", id) diff --git a/streams_map_test.go b/streams_map_test.go index 01ff9302c..f55816ab0 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -1,6 +1,8 @@ package quic import ( + "errors" + "github.com/lucas-clemente/quic-go/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -57,4 +59,55 @@ var _ = Describe("Streams Map", func() { m.RemoveStream(5) Expect(m.NumberOfStreams()).To(Equal(0)) }) + + Context("Lambda", func() { + // create 5 streams, ids 1 to 3 + BeforeEach(func() { + for i := 1; i <= 3; i++ { + err := m.PutStream(&stream{streamID: protocol.StreamID(i)}) + Expect(err).NotTo(HaveOccurred()) + } + }) + + It("executes the lambda exactly once for every stream", func() { + var numIterations int + callbackCalled := make(map[protocol.StreamID]bool) + fn := func(str *stream) (bool, error) { + callbackCalled[str.StreamID()] = true + numIterations++ + return true, nil + } + err := m.Iterate(fn) + Expect(err).ToNot(HaveOccurred()) + Expect(callbackCalled).To(HaveKey(protocol.StreamID(1))) + Expect(callbackCalled).To(HaveKey(protocol.StreamID(2))) + Expect(callbackCalled).To(HaveKey(protocol.StreamID(3))) + Expect(numIterations).To(Equal(3)) + }) + + It("stops iterating when the callback returns false", func() { + var numIterations int + fn := func(str *stream) (bool, error) { + numIterations++ + return false, nil + } + err := m.Iterate(fn) + Expect(err).ToNot(HaveOccurred()) + // due to map access randomization, we don't know for which stream the callback was executed + // but it must only be executed once + Expect(numIterations).To(Equal(1)) + }) + + It("returns the error, if the lambda returns one", func() { + var numIterations int + expectedError := errors.New("test") + fn := func(str *stream) (bool, error) { + numIterations++ + return true, expectedError + } + err := m.Iterate(fn) + Expect(err).To(MatchError(expectedError)) + Expect(numIterations).To(Equal(1)) + }) + }) })