forked from quic-go/quic-go
Merge pull request #2874 from lucas-clemente/fix-accept-stream-race
fix race condition when accepting streams
This commit is contained in:
@@ -22,13 +22,12 @@ type incomingBidiStreamsMap struct {
|
||||
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
|
||||
|
||||
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
|
||||
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
|
||||
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
|
||||
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
|
||||
maxNumStreams uint64 // maximum number of streams
|
||||
|
||||
newStream func(protocol.StreamNum) streamI
|
||||
queueMaxStreamID func(*wire.MaxStreamsFrame)
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
|
||||
|
||||
closeErr error
|
||||
}
|
||||
@@ -39,7 +38,7 @@ func newIncomingBidiStreamsMap(
|
||||
queueControlFrame func(wire.Frame),
|
||||
) *incomingBidiStreamsMap {
|
||||
return &incomingBidiStreamsMap{
|
||||
newStreamChan: make(chan struct{}),
|
||||
newStreamChan: make(chan struct{}, 1),
|
||||
streams: make(map[protocol.StreamNum]streamI),
|
||||
streamsToDelete: make(map[protocol.StreamNum]struct{}),
|
||||
maxStream: protocol.StreamNum(maxStreams),
|
||||
@@ -52,6 +51,12 @@ func newIncomingBidiStreamsMap(
|
||||
}
|
||||
|
||||
func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) {
|
||||
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
|
||||
select {
|
||||
case <-m.newStreamChan:
|
||||
default:
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
|
||||
var num protocol.StreamNum
|
||||
|
||||
@@ -20,13 +20,12 @@ type incomingItemsMap struct {
|
||||
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
|
||||
|
||||
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
|
||||
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
|
||||
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
|
||||
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
|
||||
maxNumStreams uint64 // maximum number of streams
|
||||
|
||||
newStream func(protocol.StreamNum) item
|
||||
queueMaxStreamID func(*wire.MaxStreamsFrame)
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
|
||||
|
||||
closeErr error
|
||||
}
|
||||
@@ -37,7 +36,7 @@ func newIncomingItemsMap(
|
||||
queueControlFrame func(wire.Frame),
|
||||
) *incomingItemsMap {
|
||||
return &incomingItemsMap{
|
||||
newStreamChan: make(chan struct{}),
|
||||
newStreamChan: make(chan struct{}, 1),
|
||||
streams: make(map[protocol.StreamNum]item),
|
||||
streamsToDelete: make(map[protocol.StreamNum]struct{}),
|
||||
maxStream: protocol.StreamNum(maxStreams),
|
||||
@@ -50,6 +49,12 @@ func newIncomingItemsMap(
|
||||
}
|
||||
|
||||
func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) {
|
||||
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
|
||||
select {
|
||||
case <-m.newStreamChan:
|
||||
default:
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
|
||||
var num protocol.StreamNum
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
@@ -257,4 +259,44 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||
Expect(m.DeleteStream(1)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("randomized tests", func() {
|
||||
const num = 1000
|
||||
|
||||
BeforeEach(func() { maxNumStreams = num })
|
||||
|
||||
It("opens and accepts streams", func() {
|
||||
rand.Seed(GinkgoRandomSeed())
|
||||
ids := make([]protocol.StreamNum, num)
|
||||
for i := 0; i < num; i++ {
|
||||
ids[i] = protocol.StreamNum(i + 1)
|
||||
}
|
||||
rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
|
||||
|
||||
const timeout = 5 * time.Second
|
||||
done := make(chan struct{}, 2)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
for i := 0; i < num; i++ {
|
||||
_, err := m.AcceptStream(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for i := 0; i < num; i++ {
|
||||
_, err := m.GetOrOpenStream(ids[i])
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
Eventually(done, timeout*3/2).Should(Receive())
|
||||
Eventually(done, timeout*3/2).Should(Receive())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -22,13 +22,12 @@ type incomingUniStreamsMap struct {
|
||||
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
|
||||
|
||||
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
|
||||
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
|
||||
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
|
||||
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
|
||||
maxNumStreams uint64 // maximum number of streams
|
||||
|
||||
newStream func(protocol.StreamNum) receiveStreamI
|
||||
queueMaxStreamID func(*wire.MaxStreamsFrame)
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
|
||||
|
||||
closeErr error
|
||||
}
|
||||
@@ -39,7 +38,7 @@ func newIncomingUniStreamsMap(
|
||||
queueControlFrame func(wire.Frame),
|
||||
) *incomingUniStreamsMap {
|
||||
return &incomingUniStreamsMap{
|
||||
newStreamChan: make(chan struct{}),
|
||||
newStreamChan: make(chan struct{}, 1),
|
||||
streams: make(map[protocol.StreamNum]receiveStreamI),
|
||||
streamsToDelete: make(map[protocol.StreamNum]struct{}),
|
||||
maxStream: protocol.StreamNum(maxStreams),
|
||||
@@ -52,6 +51,12 @@ func newIncomingUniStreamsMap(
|
||||
}
|
||||
|
||||
func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) {
|
||||
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
|
||||
select {
|
||||
case <-m.newStreamChan:
|
||||
default:
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
|
||||
var num protocol.StreamNum
|
||||
|
||||
Reference in New Issue
Block a user