From 64daf71e3a84f958588540ae8cfbd0f8fc79baee Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 13 Nov 2020 17:46:34 +0700 Subject: [PATCH] fix race condition when stream is created while AcceptStream is waiting --- streams_map_incoming_bidi.go | 11 ++++++++--- streams_map_incoming_generic.go | 11 ++++++++--- streams_map_incoming_uni.go | 11 ++++++++--- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go index 0825b6d4a..ff5b6c445 100644 --- a/streams_map_incoming_bidi.go +++ b/streams_map_incoming_bidi.go @@ -22,13 +22,12 @@ type incomingBidiStreamsMap struct { streamsToDelete map[protocol.StreamNum]struct{} // used as a set nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened maxStream protocol.StreamNum // the highest stream that the peer is allowed to open maxNumStreams uint64 // maximum number of streams newStream func(protocol.StreamNum) streamI queueMaxStreamID func(*wire.MaxStreamsFrame) - // streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors closeErr error } @@ -39,7 +38,7 @@ func newIncomingBidiStreamsMap( queueControlFrame func(wire.Frame), ) *incomingBidiStreamsMap { return &incomingBidiStreamsMap{ - newStreamChan: make(chan struct{}), + newStreamChan: make(chan struct{}, 1), streams: make(map[protocol.StreamNum]streamI), streamsToDelete: make(map[protocol.StreamNum]struct{}), maxStream: protocol.StreamNum(maxStreams), @@ -52,6 +51,12 @@ func newIncomingBidiStreamsMap( } func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + m.mutex.Lock() var num protocol.StreamNum diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go index f9ae12b72..e7ced406b 100644 --- a/streams_map_incoming_generic.go +++ b/streams_map_incoming_generic.go @@ -20,13 +20,12 @@ type incomingItemsMap struct { streamsToDelete map[protocol.StreamNum]struct{} // used as a set nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened maxStream protocol.StreamNum // the highest stream that the peer is allowed to open maxNumStreams uint64 // maximum number of streams newStream func(protocol.StreamNum) item queueMaxStreamID func(*wire.MaxStreamsFrame) - // streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors closeErr error } @@ -37,7 +36,7 @@ func newIncomingItemsMap( queueControlFrame func(wire.Frame), ) *incomingItemsMap { return &incomingItemsMap{ - newStreamChan: make(chan struct{}), + newStreamChan: make(chan struct{}, 1), streams: make(map[protocol.StreamNum]item), streamsToDelete: make(map[protocol.StreamNum]struct{}), maxStream: protocol.StreamNum(maxStreams), @@ -50,6 +49,12 @@ func newIncomingItemsMap( } func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + m.mutex.Lock() var num protocol.StreamNum diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go index 70467a76b..df0da8166 100644 --- a/streams_map_incoming_uni.go +++ b/streams_map_incoming_uni.go @@ -22,13 +22,12 @@ type incomingUniStreamsMap struct { streamsToDelete map[protocol.StreamNum]struct{} // used as a set nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened maxStream protocol.StreamNum // the highest stream that the peer is allowed to open maxNumStreams uint64 // maximum number of streams newStream func(protocol.StreamNum) receiveStreamI queueMaxStreamID func(*wire.MaxStreamsFrame) - // streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors closeErr error } @@ -39,7 +38,7 @@ func newIncomingUniStreamsMap( queueControlFrame func(wire.Frame), ) *incomingUniStreamsMap { return &incomingUniStreamsMap{ - newStreamChan: make(chan struct{}), + newStreamChan: make(chan struct{}, 1), streams: make(map[protocol.StreamNum]receiveStreamI), streamsToDelete: make(map[protocol.StreamNum]struct{}), maxStream: protocol.StreamNum(maxStreams), @@ -52,6 +51,12 @@ func newIncomingUniStreamsMap( } func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + m.mutex.Lock() var num protocol.StreamNum