Merge pull request #2874 from lucas-clemente/fix-accept-stream-race

fix race condition when accepting streams
This commit is contained in:
Marten Seemann
2020-11-13 18:33:34 +07:00
committed by GitHub
4 changed files with 66 additions and 9 deletions

View File

@@ -22,13 +22,12 @@ type incomingBidiStreamsMap struct {
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamNum) streamI
queueMaxStreamID func(*wire.MaxStreamsFrame)
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
closeErr error
}
@@ -39,7 +38,7 @@ func newIncomingBidiStreamsMap(
queueControlFrame func(wire.Frame),
) *incomingBidiStreamsMap {
return &incomingBidiStreamsMap{
newStreamChan: make(chan struct{}),
newStreamChan: make(chan struct{}, 1),
streams: make(map[protocol.StreamNum]streamI),
streamsToDelete: make(map[protocol.StreamNum]struct{}),
maxStream: protocol.StreamNum(maxStreams),
@@ -52,6 +51,12 @@ func newIncomingBidiStreamsMap(
}
func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) {
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
select {
case <-m.newStreamChan:
default:
}
m.mutex.Lock()
var num protocol.StreamNum

View File

@@ -20,13 +20,12 @@ type incomingItemsMap struct {
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamNum) item
queueMaxStreamID func(*wire.MaxStreamsFrame)
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
closeErr error
}
@@ -37,7 +36,7 @@ func newIncomingItemsMap(
queueControlFrame func(wire.Frame),
) *incomingItemsMap {
return &incomingItemsMap{
newStreamChan: make(chan struct{}),
newStreamChan: make(chan struct{}, 1),
streams: make(map[protocol.StreamNum]item),
streamsToDelete: make(map[protocol.StreamNum]struct{}),
maxStream: protocol.StreamNum(maxStreams),
@@ -50,6 +49,12 @@ func newIncomingItemsMap(
}
func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) {
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
select {
case <-m.newStreamChan:
default:
}
m.mutex.Lock()
var num protocol.StreamNum

View File

@@ -4,6 +4,8 @@ import (
"bytes"
"context"
"errors"
"math/rand"
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -257,4 +259,44 @@ var _ = Describe("Streams Map (incoming)", func() {
Expect(m.DeleteStream(1)).To(Succeed())
})
})
Context("randomized tests", func() {
const num = 1000
BeforeEach(func() { maxNumStreams = num })
It("opens and accepts streams", func() {
rand.Seed(GinkgoRandomSeed())
ids := make([]protocol.StreamNum, num)
for i := 0; i < num; i++ {
ids[i] = protocol.StreamNum(i + 1)
}
rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
const timeout = 5 * time.Second
done := make(chan struct{}, 2)
go func() {
defer GinkgoRecover()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for i := 0; i < num; i++ {
_, err := m.AcceptStream(ctx)
Expect(err).ToNot(HaveOccurred())
}
done <- struct{}{}
}()
go func() {
defer GinkgoRecover()
for i := 0; i < num; i++ {
_, err := m.GetOrOpenStream(ids[i])
Expect(err).ToNot(HaveOccurred())
}
done <- struct{}{}
}()
Eventually(done, timeout*3/2).Should(Receive())
Eventually(done, timeout*3/2).Should(Receive())
})
})
})

View File

@@ -22,13 +22,12 @@ type incomingUniStreamsMap struct {
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamNum) receiveStreamI
queueMaxStreamID func(*wire.MaxStreamsFrame)
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
closeErr error
}
@@ -39,7 +38,7 @@ func newIncomingUniStreamsMap(
queueControlFrame func(wire.Frame),
) *incomingUniStreamsMap {
return &incomingUniStreamsMap{
newStreamChan: make(chan struct{}),
newStreamChan: make(chan struct{}, 1),
streams: make(map[protocol.StreamNum]receiveStreamI),
streamsToDelete: make(map[protocol.StreamNum]struct{}),
maxStream: protocol.StreamNum(maxStreams),
@@ -52,6 +51,12 @@ func newIncomingUniStreamsMap(
}
func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) {
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
select {
case <-m.newStreamChan:
default:
}
m.mutex.Lock()
var num protocol.StreamNum