use a slice instead of a map to track waiting OpenStreamSync calls (#5036)

Using a slice is faster than using a map, unless thousands
of OpenStreamSync calls are cancelled at the same time.
This commit is contained in:
Marten Seemann
2025-04-19 10:17:53 +08:00
committed by GitHub
parent cda52a1e36
commit b645ce35a2

View File

@@ -2,6 +2,7 @@ package quic
import ( import (
"context" "context"
"slices"
"sync" "sync"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
@@ -19,9 +20,7 @@ type outgoingStreamsMap[T outgoingStream] struct {
streamType protocol.StreamType streamType protocol.StreamType
streams map[protocol.StreamNum]T streams map[protocol.StreamNum]T
openQueue map[uint64]chan struct{} openQueue []chan struct{}
lowestInQueue uint64
highestInQueue uint64
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
@@ -41,7 +40,6 @@ func newOutgoingStreamsMap[T outgoingStream](
return &outgoingStreamsMap[T]{ return &outgoingStreamsMap[T]{
streamType: streamType, streamType: streamType,
streams: make(map[protocol.StreamNum]T), streams: make(map[protocol.StreamNum]T),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum, maxStream: protocol.InvalidStreamNum,
nextStream: 1, nextStream: 1,
newStream: newStream, newStream: newStream,
@@ -72,22 +70,15 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
if m.closeErr != nil { if m.closeErr != nil {
return *new(T), m.closeErr return *new(T), m.closeErr
} }
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return *new(T), err return *new(T), err
} }
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil return m.openStream(), nil
} }
waitChan := make(chan struct{}, 1) waitChan := make(chan struct{}, 1)
queuePos := m.highestInQueue m.openQueue = append(m.openQueue, waitChan)
m.highestInQueue++
if len(m.openQueue) == 0 {
m.lowestInQueue = queuePos
}
m.openQueue[queuePos] = waitChan
m.maybeSendBlockedFrame() m.maybeSendBlockedFrame()
for { for {
@@ -95,12 +86,14 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
m.mutex.Lock() m.mutex.Lock()
delete(m.openQueue, queuePos) m.openQueue = slices.DeleteFunc(m.openQueue, func(c chan struct{}) bool {
return c == waitChan
})
return *new(T), ctx.Err() return *new(T), ctx.Err()
case <-waitChan: case <-waitChan:
} }
m.mutex.Lock()
m.mutex.Lock()
if m.closeErr != nil { if m.closeErr != nil {
return *new(T), m.closeErr return *new(T), m.closeErr
} }
@@ -109,8 +102,7 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
continue continue
} }
str := m.openStream() str := m.openStream()
delete(m.openQueue, queuePos) m.openQueue = m.openQueue[1:]
m.lowestInQueue = queuePos + 1
m.unblockOpenSync() m.unblockOpenSync()
return str, nil return str, nil
} }
@@ -200,19 +192,15 @@ func (m *outgoingStreamsMap[T]) unblockOpenSync() {
if len(m.openQueue) == 0 { if len(m.openQueue) == 0 {
return return
} }
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { if m.nextStream > m.maxStream {
c, ok := m.openQueue[qp]
if !ok { // entry was deleted because the context was canceled
continue
}
// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
// It's sufficient to only unblock OpenStreamSync once.
select {
case c <- struct{}{}:
default:
}
return return
} }
// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
// It's sufficient to only unblock OpenStreamSync once.
select {
case m.openQueue[0] <- struct{}{}:
default:
}
} }
func (m *outgoingStreamsMap[T]) CloseWithError(err error) { func (m *outgoingStreamsMap[T]) CloseWithError(err error) {