forked from quic-go/quic-go
fix potential deadlock when OpenStreamSync context is cancelled (#5037)
* fix potential deadlock when OpenStreamSync context is cancelled * fix race condition when closing
This commit is contained in:
@@ -89,6 +89,9 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
|
|||||||
m.openQueue = slices.DeleteFunc(m.openQueue, func(c chan struct{}) bool {
|
m.openQueue = slices.DeleteFunc(m.openQueue, func(c chan struct{}) bool {
|
||||||
return c == waitChan
|
return c == waitChan
|
||||||
})
|
})
|
||||||
|
// If we just received a MAX_STREAMS frame, this might have been the next stream
|
||||||
|
// that could be opened. Make sure we unblock the next OpenStreamSync call.
|
||||||
|
m.maybeUnblockOpenSync()
|
||||||
return *new(T), ctx.Err()
|
return *new(T), ctx.Err()
|
||||||
case <-waitChan:
|
case <-waitChan:
|
||||||
}
|
}
|
||||||
@@ -103,7 +106,7 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
|
|||||||
}
|
}
|
||||||
str := m.openStream()
|
str := m.openStream()
|
||||||
m.openQueue = m.openQueue[1:]
|
m.openQueue = m.openQueue[1:]
|
||||||
m.unblockOpenSync()
|
m.maybeUnblockOpenSync()
|
||||||
return str, nil
|
return str, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -173,7 +176,7 @@ func (m *outgoingStreamsMap[T]) SetMaxStream(num protocol.StreamNum) {
|
|||||||
if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) {
|
if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) {
|
||||||
m.maybeSendBlockedFrame()
|
m.maybeSendBlockedFrame()
|
||||||
}
|
}
|
||||||
m.unblockOpenSync()
|
m.maybeUnblockOpenSync()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSendWindow is called when the peer's transport parameters are received.
|
// UpdateSendWindow is called when the peer's transport parameters are received.
|
||||||
@@ -188,7 +191,7 @@ func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream
|
// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream
|
||||||
func (m *outgoingStreamsMap[T]) unblockOpenSync() {
|
func (m *outgoingStreamsMap[T]) maybeUnblockOpenSync() {
|
||||||
if len(m.openQueue) == 0 {
|
if len(m.openQueue) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -205,6 +208,8 @@ func (m *outgoingStreamsMap[T]) unblockOpenSync() {
|
|||||||
|
|
||||||
func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
|
func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.closeErr = err
|
m.closeErr = err
|
||||||
for _, str := range m.streams {
|
for _, str := range m.streams {
|
||||||
str.closeForShutdown(err)
|
str.closeForShutdown(err)
|
||||||
@@ -214,5 +219,5 @@ func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
|
|||||||
close(c)
|
close(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.openQueue = nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
@@ -13,6 +14,9 @@ import (
|
|||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/wire"
|
"github.com/quic-go/quic-go/internal/wire"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
@@ -534,3 +538,47 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
func TestStreamsMapConcurrent(t *testing.T) {
|
||||||
|
for i := range 5 {
|
||||||
|
t.Run(fmt.Sprintf("iteration %d", i+1), func(t *testing.T) {
|
||||||
|
testStreamsMapConcurrent(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStreamsMapConcurrent(t *testing.T) {
|
||||||
|
m := newOutgoingStreamsMap(
|
||||||
|
protocol.StreamTypeBidi,
|
||||||
|
func(n protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: n} },
|
||||||
|
func(f wire.Frame) {},
|
||||||
|
)
|
||||||
|
|
||||||
|
const num = 100
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
errChan := make(chan error, num)
|
||||||
|
for range num {
|
||||||
|
go func() {
|
||||||
|
_, err := m.OpenStreamSync(ctx)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(scaleDuration(5 * time.Millisecond))
|
||||||
|
go m.CloseWithError(assert.AnError)
|
||||||
|
go cancel()
|
||||||
|
go m.SetMaxStream(protocol.StreamNum(num / 2))
|
||||||
|
|
||||||
|
for range num {
|
||||||
|
select {
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil {
|
||||||
|
require.True(t, errors.Is(err, assert.AnError) || errors.Is(err, context.Canceled))
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting for error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user