From d1562e5f8731ee1a9e8918674f58277334dc6656 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 20 Feb 2017 14:45:44 +0700 Subject: [PATCH] add an OpenStreamSync method to the streamsMap --- streams_map.go | 38 +++++++++++++++++++++++++++------- streams_map_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/streams_map.go b/streams_map.go index b894fc4b..8c58d67e 100644 --- a/streams_map.go +++ b/streams_map.go @@ -25,8 +25,10 @@ type streamsMap struct { nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() highestStreamOpenedByPeer protocol.StreamID nextStreamOrErrCond sync.Cond - closeErr error - nextStreamToAccept protocol.StreamID + openStreamCond sync.Cond + + closeErr error + nextStreamToAccept protocol.StreamID newStream newStreamLambda @@ -52,6 +54,7 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect connectionParameters: connectionParameters, } sm.nextStreamOrErrCond.L = &sm.mutex + sm.openStreamCond.L = &sm.mutex if pers == protocol.PerspectiveClient { sm.nextStream = 1 @@ -139,11 +142,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { return s, nil } -// OpenStream opens the next available stream -func (m *streamsMap) OpenStream() (*stream, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - +func (m *streamsMap) openStreamImpl() (*stream, error) { id := m.nextStream if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() { return nil, qerr.TooManyOpenStreams @@ -165,6 +164,30 @@ func (m *streamsMap) OpenStream() (*stream, error) { return s, nil } +// OpenStream opens the next available stream +func (m *streamsMap) OpenStream() (*stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.openStreamImpl() +} + +func (m *streamsMap) OpenStreamSync() (*stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.openStreamCond.Wait() + } +} + // AcceptStream returns the next stream opened by the peer // it blocks until a new stream is opened func (m *streamsMap) AcceptStream() (utils.Stream, error) { @@ -290,6 +313,7 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error { } delete(m.streams, id) + m.openStreamCond.Signal() return nil } diff --git a/streams_map_test.go b/streams_map_test.go index 025b7b11..af183f9d 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -175,6 +175,13 @@ var _ = Describe("Streams Map", func() { Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) }) + It("errors if the stream can't be created", func() { + testErr := errors.New("test error") + m.newStream = func(protocol.StreamID) (*stream, error) { return nil, testErr } + _, err := m.OpenStream() + Expect(err).To(MatchError(testErr)) + }) + Context("counting streams", func() { var maxNumStreams int @@ -210,6 +217,49 @@ var _ = Describe("Streams Map", func() { } }) }) + + Context("opening streams synchronously", func() { + var maxNumStreams int + + BeforeEach(func() { + maxNumStreams = int(cpm.GetMaxOutgoingStreams()) + }) + + openMaxNumStreams := func() { + for i := 1; i <= maxNumStreams; i++ { + _, err := m.OpenStream() + Expect(err).NotTo(HaveOccurred()) + } + _, err := m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + } + + It("waits until another stream is closed", func() { + openMaxNumStreams() + var returned bool + var str *stream + go func() { + defer GinkgoRecover() + var err error + str, err = m.OpenStreamSync() + Expect(err).ToNot(HaveOccurred()) + returned = true + }() + + Consistently(func() bool { return returned }).Should(BeFalse()) + err := m.RemoveStream(6) + Expect(err).ToNot(HaveOccurred()) + Eventually(func() bool { return returned }).Should(BeTrue()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxNumStreams + 2))) + }) + + It("errors if the stream can't be created", func() { + testErr := errors.New("test error") + m.newStream = func(protocol.StreamID) (*stream, error) { return nil, testErr } + _, err := m.OpenStreamSync() + Expect(err).To(MatchError(testErr)) + }) + }) }) Context("accepting streams", func() {