use a slice instead of a map to track waiting OpenStreamSync calls (#5036)

Using a slice is faster than using a map, unless thousands
of OpenStreamSync calls are cancelled at the same time.
This commit is contained in:
Marten Seemann
2025-04-19 10:17:53 +08:00
committed by GitHub
parent cda52a1e36
commit b645ce35a2

View File

@@ -2,6 +2,7 @@ package quic
import (
"context"
"slices"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
@@ -19,9 +20,7 @@ type outgoingStreamsMap[T outgoingStream] struct {
streamType protocol.StreamType
streams map[protocol.StreamNum]T
openQueue map[uint64]chan struct{}
lowestInQueue uint64
highestInQueue uint64
openQueue []chan struct{}
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
@@ -41,7 +40,6 @@ func newOutgoingStreamsMap[T outgoingStream](
return &outgoingStreamsMap[T]{
streamType: streamType,
streams: make(map[protocol.StreamNum]T),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
newStream: newStream,
@@ -72,22 +70,15 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
if m.closeErr != nil {
return *new(T), m.closeErr
}
if err := ctx.Err(); err != nil {
return *new(T), err
}
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil
}
waitChan := make(chan struct{}, 1)
queuePos := m.highestInQueue
m.highestInQueue++
if len(m.openQueue) == 0 {
m.lowestInQueue = queuePos
}
m.openQueue[queuePos] = waitChan
m.openQueue = append(m.openQueue, waitChan)
m.maybeSendBlockedFrame()
for {
@@ -95,12 +86,14 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
select {
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
m.openQueue = slices.DeleteFunc(m.openQueue, func(c chan struct{}) bool {
return c == waitChan
})
return *new(T), ctx.Err()
case <-waitChan:
}
m.mutex.Lock()
m.mutex.Lock()
if m.closeErr != nil {
return *new(T), m.closeErr
}
@@ -109,8 +102,7 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
continue
}
str := m.openStream()
delete(m.openQueue, queuePos)
m.lowestInQueue = queuePos + 1
m.openQueue = m.openQueue[1:]
m.unblockOpenSync()
return str, nil
}
@@ -200,19 +192,15 @@ func (m *outgoingStreamsMap[T]) unblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
c, ok := m.openQueue[qp]
if !ok { // entry was deleted because the context was canceled
continue
}
// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
// It's sufficient to only unblock OpenStreamSync once.
select {
case c <- struct{}{}:
default:
}
if m.nextStream > m.maxStream {
return
}
// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
// It's sufficient to only unblock OpenStreamSync once.
select {
case m.openQueue[0] <- struct{}{}:
default:
}
}
func (m *outgoingStreamsMap[T]) CloseWithError(err error) {