fix race condition when stream is created while AcceptStream is waiting

This commit is contained in:
Marten Seemann
2020-11-13 17:46:34 +07:00
parent 629272c713
commit 64daf71e3a
3 changed files with 24 additions and 9 deletions

View File

@@ -22,13 +22,12 @@ type incomingBidiStreamsMap struct {
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamNum) streamI
queueMaxStreamID func(*wire.MaxStreamsFrame)
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
closeErr error
}
@@ -39,7 +38,7 @@ func newIncomingBidiStreamsMap(
queueControlFrame func(wire.Frame),
) *incomingBidiStreamsMap {
return &incomingBidiStreamsMap{
newStreamChan: make(chan struct{}),
newStreamChan: make(chan struct{}, 1),
streams: make(map[protocol.StreamNum]streamI),
streamsToDelete: make(map[protocol.StreamNum]struct{}),
maxStream: protocol.StreamNum(maxStreams),
@@ -52,6 +51,12 @@ func newIncomingBidiStreamsMap(
}
func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) {
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
select {
case <-m.newStreamChan:
default:
}
m.mutex.Lock()
var num protocol.StreamNum

View File

@@ -20,13 +20,12 @@ type incomingItemsMap struct {
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamNum) item
queueMaxStreamID func(*wire.MaxStreamsFrame)
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
closeErr error
}
@@ -37,7 +36,7 @@ func newIncomingItemsMap(
queueControlFrame func(wire.Frame),
) *incomingItemsMap {
return &incomingItemsMap{
newStreamChan: make(chan struct{}),
newStreamChan: make(chan struct{}, 1),
streams: make(map[protocol.StreamNum]item),
streamsToDelete: make(map[protocol.StreamNum]struct{}),
maxStream: protocol.StreamNum(maxStreams),
@@ -50,6 +49,12 @@ func newIncomingItemsMap(
}
func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) {
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
select {
case <-m.newStreamChan:
default:
}
m.mutex.Lock()
var num protocol.StreamNum

View File

@@ -22,13 +22,12 @@ type incomingUniStreamsMap struct {
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamNum) receiveStreamI
queueMaxStreamID func(*wire.MaxStreamsFrame)
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
closeErr error
}
@@ -39,7 +38,7 @@ func newIncomingUniStreamsMap(
queueControlFrame func(wire.Frame),
) *incomingUniStreamsMap {
return &incomingUniStreamsMap{
newStreamChan: make(chan struct{}),
newStreamChan: make(chan struct{}, 1),
streams: make(map[protocol.StreamNum]receiveStreamI),
streamsToDelete: make(map[protocol.StreamNum]struct{}),
maxStream: protocol.StreamNum(maxStreams),
@@ -52,6 +51,12 @@ func newIncomingUniStreamsMap(
}
func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) {
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
select {
case <-m.newStreamChan:
default:
}
m.mutex.Lock()
var num protocol.StreamNum