From 64daf71e3a84f958588540ae8cfbd0f8fc79baee Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 13 Nov 2020 17:46:34 +0700 Subject: [PATCH 1/2] fix race condition when stream is created while AcceptStream is waiting --- streams_map_incoming_bidi.go | 11 ++++++++--- streams_map_incoming_generic.go | 11 ++++++++--- streams_map_incoming_uni.go | 11 ++++++++--- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go index 0825b6d4a..ff5b6c445 100644 --- a/streams_map_incoming_bidi.go +++ b/streams_map_incoming_bidi.go @@ -22,13 +22,12 @@ type incomingBidiStreamsMap struct { streamsToDelete map[protocol.StreamNum]struct{} // used as a set nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened maxStream protocol.StreamNum // the highest stream that the peer is allowed to open maxNumStreams uint64 // maximum number of streams newStream func(protocol.StreamNum) streamI queueMaxStreamID func(*wire.MaxStreamsFrame) - // streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors closeErr error } @@ -39,7 +38,7 @@ func newIncomingBidiStreamsMap( queueControlFrame func(wire.Frame), ) *incomingBidiStreamsMap { return &incomingBidiStreamsMap{ - newStreamChan: make(chan struct{}), + newStreamChan: make(chan struct{}, 1), streams: make(map[protocol.StreamNum]streamI), streamsToDelete: make(map[protocol.StreamNum]struct{}), maxStream: protocol.StreamNum(maxStreams), @@ -52,6 +51,12 @@ func newIncomingBidiStreamsMap( } func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + m.mutex.Lock() var num protocol.StreamNum diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go index f9ae12b72..e7ced406b 100644 --- a/streams_map_incoming_generic.go +++ b/streams_map_incoming_generic.go @@ -20,13 +20,12 @@ type incomingItemsMap struct { streamsToDelete map[protocol.StreamNum]struct{} // used as a set nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened maxStream protocol.StreamNum // the highest stream that the peer is allowed to open maxNumStreams uint64 // maximum number of streams newStream func(protocol.StreamNum) item queueMaxStreamID func(*wire.MaxStreamsFrame) - // streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors closeErr error } @@ -37,7 +36,7 @@ func newIncomingItemsMap( queueControlFrame func(wire.Frame), ) *incomingItemsMap { return &incomingItemsMap{ - newStreamChan: make(chan struct{}), + newStreamChan: make(chan struct{}, 1), streams: make(map[protocol.StreamNum]item), streamsToDelete: make(map[protocol.StreamNum]struct{}), maxStream: protocol.StreamNum(maxStreams), @@ -50,6 +49,12 @@ func newIncomingItemsMap( } func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + m.mutex.Lock() var num protocol.StreamNum diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go index 70467a76b..df0da8166 100644 --- a/streams_map_incoming_uni.go +++ b/streams_map_incoming_uni.go @@ -22,13 +22,12 @@ type incomingUniStreamsMap struct { streamsToDelete map[protocol.StreamNum]struct{} // used as a set nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened maxStream protocol.StreamNum // the highest stream that the peer is allowed to open maxNumStreams uint64 // maximum number of streams newStream func(protocol.StreamNum) receiveStreamI queueMaxStreamID func(*wire.MaxStreamsFrame) - // streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors closeErr error } @@ -39,7 +38,7 @@ func newIncomingUniStreamsMap( queueControlFrame func(wire.Frame), ) *incomingUniStreamsMap { return &incomingUniStreamsMap{ - newStreamChan: make(chan struct{}), + newStreamChan: make(chan struct{}, 1), streams: make(map[protocol.StreamNum]receiveStreamI), streamsToDelete: make(map[protocol.StreamNum]struct{}), maxStream: protocol.StreamNum(maxStreams), @@ -52,6 +51,12 @@ func newIncomingUniStreamsMap( } func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + m.mutex.Lock() var num protocol.StreamNum From 46991ae0ec1b7c510a62da79e8063cdce9c5a0f1 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 13 Nov 2020 17:39:19 +0700 Subject: [PATCH 2/2] add a randomized test for accepting streams --- streams_map_incoming_generic_test.go | 42 ++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index 0bf4a3d1d..376e1811d 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "errors" + "math/rand" + "time" "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -257,4 +259,44 @@ var _ = Describe("Streams Map (incoming)", func() { Expect(m.DeleteStream(1)).To(Succeed()) }) }) + + Context("randomized tests", func() { + const num = 1000 + + BeforeEach(func() { maxNumStreams = num }) + + It("opens and accepts streams", func() { + rand.Seed(GinkgoRandomSeed()) + ids := make([]protocol.StreamNum, num) + for i := 0; i < num; i++ { + ids[i] = protocol.StreamNum(i + 1) + } + rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] }) + + const timeout = 5 * time.Second + done := make(chan struct{}, 2) + go func() { + defer GinkgoRecover() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + for i := 0; i < num; i++ { + _, err := m.AcceptStream(ctx) + Expect(err).ToNot(HaveOccurred()) + } + done <- struct{}{} + }() + + go func() { + defer GinkgoRecover() + for i := 0; i < num; i++ { + _, err := m.GetOrOpenStream(ids[i]) + Expect(err).ToNot(HaveOccurred()) + } + done <- struct{}{} + }() + + Eventually(done, timeout*3/2).Should(Receive()) + Eventually(done, timeout*3/2).Should(Receive()) + }) + }) })