diff --git a/streams_map.go b/streams_map.go index 8c58d67e..0f102b10 100644 --- a/streams_map.go +++ b/streams_map.go @@ -25,7 +25,7 @@ type streamsMap struct { nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() highestStreamOpenedByPeer protocol.StreamID nextStreamOrErrCond sync.Cond - openStreamCond sync.Cond + openStreamOrErrCond sync.Cond closeErr error nextStreamToAccept protocol.StreamID @@ -54,7 +54,7 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect connectionParameters: connectionParameters, } sm.nextStreamOrErrCond.L = &sm.mutex - sm.openStreamCond.L = &sm.mutex + sm.openStreamOrErrCond.L = &sm.mutex if pers == protocol.PerspectiveClient { sm.nextStream = 1 @@ -177,6 +177,9 @@ func (m *streamsMap) OpenStreamSync() (*stream, error) { defer m.mutex.Unlock() for { + if m.closeErr != nil { + return nil, m.closeErr + } str, err := m.openStreamImpl() if err == nil { return str, err @@ -184,7 +187,7 @@ func (m *streamsMap) OpenStreamSync() (*stream, error) { if err != nil && err != qerr.TooManyOpenStreams { return nil, err } - m.openStreamCond.Wait() + m.openStreamOrErrCond.Wait() } } @@ -313,7 +316,7 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error { } delete(m.streams, id) - m.openStreamCond.Signal() + m.openStreamOrErrCond.Signal() return nil } @@ -321,5 +324,6 @@ func (m *streamsMap) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err m.nextStreamOrErrCond.Broadcast() + m.openStreamOrErrCond.Broadcast() m.mutex.Unlock() } diff --git a/streams_map_test.go b/streams_map_test.go index af183f9d..92ec2799 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -259,6 +259,29 @@ var _ = Describe("Streams Map", func() { _, err := m.OpenStreamSync() Expect(err).To(MatchError(testErr)) }) + + It("stops waiting when an error is registered", func() { + openMaxNumStreams() + testErr := errors.New("test error") + var err error + var returned bool + go func() { + _, err = m.OpenStreamSync() + returned = true + }() + + Consistently(func() bool { return returned }).Should(BeFalse()) + m.CloseWithError(testErr) + Eventually(func() bool { return returned }).Should(BeTrue()) + Expect(err).To(MatchError(testErr)) + }) + + It("immediately returns when OpenStreamSync is called after an error was registered", func() { + testErr := errors.New("test error") + m.CloseWithError(testErr) + _, err := m.OpenStreamSync() + Expect(err).To(MatchError(testErr)) + }) }) })