forked from quic-go/quic-go
unblock OpenStreamSync in streamsMap when an error is registered
This commit is contained in:
@@ -25,7 +25,7 @@ type streamsMap struct {
|
|||||||
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
||||||
highestStreamOpenedByPeer protocol.StreamID
|
highestStreamOpenedByPeer protocol.StreamID
|
||||||
nextStreamOrErrCond sync.Cond
|
nextStreamOrErrCond sync.Cond
|
||||||
openStreamCond sync.Cond
|
openStreamOrErrCond sync.Cond
|
||||||
|
|
||||||
closeErr error
|
closeErr error
|
||||||
nextStreamToAccept protocol.StreamID
|
nextStreamToAccept protocol.StreamID
|
||||||
@@ -54,7 +54,7 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect
|
|||||||
connectionParameters: connectionParameters,
|
connectionParameters: connectionParameters,
|
||||||
}
|
}
|
||||||
sm.nextStreamOrErrCond.L = &sm.mutex
|
sm.nextStreamOrErrCond.L = &sm.mutex
|
||||||
sm.openStreamCond.L = &sm.mutex
|
sm.openStreamOrErrCond.L = &sm.mutex
|
||||||
|
|
||||||
if pers == protocol.PerspectiveClient {
|
if pers == protocol.PerspectiveClient {
|
||||||
sm.nextStream = 1
|
sm.nextStream = 1
|
||||||
@@ -177,6 +177,9 @@ func (m *streamsMap) OpenStreamSync() (*stream, error) {
|
|||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
if m.closeErr != nil {
|
||||||
|
return nil, m.closeErr
|
||||||
|
}
|
||||||
str, err := m.openStreamImpl()
|
str, err := m.openStreamImpl()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return str, err
|
return str, err
|
||||||
@@ -184,7 +187,7 @@ func (m *streamsMap) OpenStreamSync() (*stream, error) {
|
|||||||
if err != nil && err != qerr.TooManyOpenStreams {
|
if err != nil && err != qerr.TooManyOpenStreams {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m.openStreamCond.Wait()
|
m.openStreamOrErrCond.Wait()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -313,7 +316,7 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
delete(m.streams, id)
|
delete(m.streams, id)
|
||||||
m.openStreamCond.Signal()
|
m.openStreamOrErrCond.Signal()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -321,5 +324,6 @@ func (m *streamsMap) CloseWithError(err error) {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
m.closeErr = err
|
m.closeErr = err
|
||||||
m.nextStreamOrErrCond.Broadcast()
|
m.nextStreamOrErrCond.Broadcast()
|
||||||
|
m.openStreamOrErrCond.Broadcast()
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -259,6 +259,29 @@ var _ = Describe("Streams Map", func() {
|
|||||||
_, err := m.OpenStreamSync()
|
_, err := m.OpenStreamSync()
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("stops waiting when an error is registered", func() {
|
||||||
|
openMaxNumStreams()
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
var err error
|
||||||
|
var returned bool
|
||||||
|
go func() {
|
||||||
|
_, err = m.OpenStreamSync()
|
||||||
|
returned = true
|
||||||
|
}()
|
||||||
|
|
||||||
|
Consistently(func() bool { return returned }).Should(BeFalse())
|
||||||
|
m.CloseWithError(testErr)
|
||||||
|
Eventually(func() bool { return returned }).Should(BeTrue())
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("immediately returns when OpenStreamSync is called after an error was registered", func() {
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
m.CloseWithError(testErr)
|
||||||
|
_, err := m.OpenStreamSync()
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user