forked from quic-go/quic-go
use a slice instead of a map to track waiting OpenStreamSync calls (#5036)
Using a slice is faster than using a map, unless thousands of OpenStreamSync calls are cancelled at the same time.
This commit is contained in:
@@ -2,6 +2,7 @@ package quic
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
@@ -19,9 +20,7 @@ type outgoingStreamsMap[T outgoingStream] struct {
|
|||||||
streamType protocol.StreamType
|
streamType protocol.StreamType
|
||||||
streams map[protocol.StreamNum]T
|
streams map[protocol.StreamNum]T
|
||||||
|
|
||||||
openQueue map[uint64]chan struct{}
|
openQueue []chan struct{}
|
||||||
lowestInQueue uint64
|
|
||||||
highestInQueue uint64
|
|
||||||
|
|
||||||
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
|
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
|
||||||
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
|
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
|
||||||
@@ -41,7 +40,6 @@ func newOutgoingStreamsMap[T outgoingStream](
|
|||||||
return &outgoingStreamsMap[T]{
|
return &outgoingStreamsMap[T]{
|
||||||
streamType: streamType,
|
streamType: streamType,
|
||||||
streams: make(map[protocol.StreamNum]T),
|
streams: make(map[protocol.StreamNum]T),
|
||||||
openQueue: make(map[uint64]chan struct{}),
|
|
||||||
maxStream: protocol.InvalidStreamNum,
|
maxStream: protocol.InvalidStreamNum,
|
||||||
nextStream: 1,
|
nextStream: 1,
|
||||||
newStream: newStream,
|
newStream: newStream,
|
||||||
@@ -72,22 +70,15 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
|
|||||||
if m.closeErr != nil {
|
if m.closeErr != nil {
|
||||||
return *new(T), m.closeErr
|
return *new(T), m.closeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ctx.Err(); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
return *new(T), err
|
return *new(T), err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
|
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
|
||||||
return m.openStream(), nil
|
return m.openStream(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
waitChan := make(chan struct{}, 1)
|
waitChan := make(chan struct{}, 1)
|
||||||
queuePos := m.highestInQueue
|
m.openQueue = append(m.openQueue, waitChan)
|
||||||
m.highestInQueue++
|
|
||||||
if len(m.openQueue) == 0 {
|
|
||||||
m.lowestInQueue = queuePos
|
|
||||||
}
|
|
||||||
m.openQueue[queuePos] = waitChan
|
|
||||||
m.maybeSendBlockedFrame()
|
m.maybeSendBlockedFrame()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -95,12 +86,14 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
delete(m.openQueue, queuePos)
|
m.openQueue = slices.DeleteFunc(m.openQueue, func(c chan struct{}) bool {
|
||||||
|
return c == waitChan
|
||||||
|
})
|
||||||
return *new(T), ctx.Err()
|
return *new(T), ctx.Err()
|
||||||
case <-waitChan:
|
case <-waitChan:
|
||||||
}
|
}
|
||||||
m.mutex.Lock()
|
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
if m.closeErr != nil {
|
if m.closeErr != nil {
|
||||||
return *new(T), m.closeErr
|
return *new(T), m.closeErr
|
||||||
}
|
}
|
||||||
@@ -109,8 +102,7 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
str := m.openStream()
|
str := m.openStream()
|
||||||
delete(m.openQueue, queuePos)
|
m.openQueue = m.openQueue[1:]
|
||||||
m.lowestInQueue = queuePos + 1
|
|
||||||
m.unblockOpenSync()
|
m.unblockOpenSync()
|
||||||
return str, nil
|
return str, nil
|
||||||
}
|
}
|
||||||
@@ -200,19 +192,15 @@ func (m *outgoingStreamsMap[T]) unblockOpenSync() {
|
|||||||
if len(m.openQueue) == 0 {
|
if len(m.openQueue) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
|
if m.nextStream > m.maxStream {
|
||||||
c, ok := m.openQueue[qp]
|
|
||||||
if !ok { // entry was deleted because the context was canceled
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
|
|
||||||
// It's sufficient to only unblock OpenStreamSync once.
|
|
||||||
select {
|
|
||||||
case c <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
|
||||||
|
// It's sufficient to only unblock OpenStreamSync once.
|
||||||
|
select {
|
||||||
|
case m.openQueue[0] <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
|
func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user