fix potential deadlock when OpenStreamSync context is cancelled (#5037)

* fix potential deadlock when OpenStreamSync context is cancelled

* fix race condition when closing
This commit is contained in:
Marten Seemann
2025-04-19 13:48:18 +08:00
committed by GitHub
parent b645ce35a2
commit f4c478b627
2 changed files with 57 additions and 4 deletions

View File

@@ -89,6 +89,9 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
m.openQueue = slices.DeleteFunc(m.openQueue, func(c chan struct{}) bool {
return c == waitChan
})
// If we just received a MAX_STREAMS frame, this might have been the next stream
// that could be opened. Make sure we unblock the next OpenStreamSync call.
m.maybeUnblockOpenSync()
return *new(T), ctx.Err()
case <-waitChan:
}
@@ -103,7 +106,7 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
}
str := m.openStream()
m.openQueue = m.openQueue[1:]
m.unblockOpenSync()
m.maybeUnblockOpenSync()
return str, nil
}
}
@@ -173,7 +176,7 @@ func (m *outgoingStreamsMap[T]) SetMaxStream(num protocol.StreamNum) {
if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) {
m.maybeSendBlockedFrame()
}
m.unblockOpenSync()
m.maybeUnblockOpenSync()
}
// UpdateSendWindow is called when the peer's transport parameters are received.
@@ -188,7 +191,7 @@ func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) {
}
// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream
func (m *outgoingStreamsMap[T]) unblockOpenSync() {
func (m *outgoingStreamsMap[T]) maybeUnblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
@@ -205,6 +208,8 @@ func (m *outgoingStreamsMap[T]) unblockOpenSync() {
func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.closeErr = err
for _, str := range m.streams {
str.closeForShutdown(err)
@@ -214,5 +219,5 @@ func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
close(c)
}
}
m.mutex.Unlock()
m.openQueue = nil
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"sort"
"sync"
"testing"
"time"
"golang.org/x/exp/rand"
@@ -13,6 +14,9 @@ import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
@@ -534,3 +538,47 @@ var _ = Describe("Streams Map (outgoing)", func() {
})
})
})
func TestStreamsMapConcurrent(t *testing.T) {
for i := range 5 {
t.Run(fmt.Sprintf("iteration %d", i+1), func(t *testing.T) {
testStreamsMapConcurrent(t)
})
}
}
func testStreamsMapConcurrent(t *testing.T) {
m := newOutgoingStreamsMap(
protocol.StreamTypeBidi,
func(n protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: n} },
func(f wire.Frame) {},
)
const num = 100
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errChan := make(chan error, num)
for range num {
go func() {
_, err := m.OpenStreamSync(ctx)
errChan <- err
}()
}
time.Sleep(scaleDuration(5 * time.Millisecond))
go m.CloseWithError(assert.AnError)
go cancel()
go m.SetMaxStream(protocol.StreamNum(num / 2))
for range num {
select {
case err := <-errChan:
if err != nil {
require.True(t, errors.Is(err, assert.AnError) || errors.Is(err, context.Canceled))
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for error")
}
}
}