unblock OpenStreamSync in streamsMap when an error is registered

This commit is contained in:
Marten Seemann
2017-02-20 15:25:33 +07:00
parent d1562e5f87
commit 9c46ee6b24
2 changed files with 31 additions and 4 deletions

View File

@@ -25,7 +25,7 @@ type streamsMap struct {
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID highestStreamOpenedByPeer protocol.StreamID
nextStreamOrErrCond sync.Cond nextStreamOrErrCond sync.Cond
openStreamCond sync.Cond openStreamOrErrCond sync.Cond
closeErr error closeErr error
nextStreamToAccept protocol.StreamID nextStreamToAccept protocol.StreamID
@@ -54,7 +54,7 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect
connectionParameters: connectionParameters, connectionParameters: connectionParameters,
} }
sm.nextStreamOrErrCond.L = &sm.mutex sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex
if pers == protocol.PerspectiveClient { if pers == protocol.PerspectiveClient {
sm.nextStream = 1 sm.nextStream = 1
@@ -177,6 +177,9 @@ func (m *streamsMap) OpenStreamSync() (*stream, error) {
defer m.mutex.Unlock() defer m.mutex.Unlock()
for { for {
if m.closeErr != nil {
return nil, m.closeErr
}
str, err := m.openStreamImpl() str, err := m.openStreamImpl()
if err == nil { if err == nil {
return str, err return str, err
@@ -184,7 +187,7 @@ func (m *streamsMap) OpenStreamSync() (*stream, error) {
if err != nil && err != qerr.TooManyOpenStreams { if err != nil && err != qerr.TooManyOpenStreams {
return nil, err return nil, err
} }
m.openStreamCond.Wait() m.openStreamOrErrCond.Wait()
} }
} }
@@ -313,7 +316,7 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
} }
delete(m.streams, id) delete(m.streams, id)
m.openStreamCond.Signal() m.openStreamOrErrCond.Signal()
return nil return nil
} }
@@ -321,5 +324,6 @@ func (m *streamsMap) CloseWithError(err error) {
m.mutex.Lock() m.mutex.Lock()
m.closeErr = err m.closeErr = err
m.nextStreamOrErrCond.Broadcast() m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast()
m.mutex.Unlock() m.mutex.Unlock()
} }

View File

@@ -259,6 +259,29 @@ var _ = Describe("Streams Map", func() {
_, err := m.OpenStreamSync() _, err := m.OpenStreamSync()
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
It("stops waiting when an error is registered", func() {
openMaxNumStreams()
testErr := errors.New("test error")
var err error
var returned bool
go func() {
_, err = m.OpenStreamSync()
returned = true
}()
Consistently(func() bool { return returned }).Should(BeFalse())
m.CloseWithError(testErr)
Eventually(func() bool { return returned }).Should(BeTrue())
Expect(err).To(MatchError(testErr))
})
It("immediately returns when OpenStreamSync is called after an error was registered", func() {
testErr := errors.New("test error")
m.CloseWithError(testErr)
_, err := m.OpenStreamSync()
Expect(err).To(MatchError(testErr))
})
}) })
}) })