diff --git a/streams_map_outgoing.go b/streams_map_outgoing.go index a8d04b04f..602d0f932 100644 --- a/streams_map_outgoing.go +++ b/streams_map_outgoing.go @@ -2,6 +2,7 @@ package quic import ( "context" + "slices" "sync" "github.com/quic-go/quic-go/internal/protocol" @@ -19,9 +20,7 @@ type outgoingStreamsMap[T outgoingStream] struct { streamType protocol.StreamType streams map[protocol.StreamNum]T - openQueue map[uint64]chan struct{} - lowestInQueue uint64 - highestInQueue uint64 + openQueue []chan struct{} nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) maxStream protocol.StreamNum // the maximum stream ID we're allowed to open @@ -41,7 +40,6 @@ func newOutgoingStreamsMap[T outgoingStream]( return &outgoingStreamsMap[T]{ streamType: streamType, streams: make(map[protocol.StreamNum]T), - openQueue: make(map[uint64]chan struct{}), maxStream: protocol.InvalidStreamNum, nextStream: 1, newStream: newStream, @@ -72,22 +70,15 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) { if m.closeErr != nil { return *new(T), m.closeErr } - if err := ctx.Err(); err != nil { return *new(T), err } - if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { return m.openStream(), nil } waitChan := make(chan struct{}, 1) - queuePos := m.highestInQueue - m.highestInQueue++ - if len(m.openQueue) == 0 { - m.lowestInQueue = queuePos - } - m.openQueue[queuePos] = waitChan + m.openQueue = append(m.openQueue, waitChan) m.maybeSendBlockedFrame() for { @@ -95,12 +86,14 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) { select { case <-ctx.Done(): m.mutex.Lock() - delete(m.openQueue, queuePos) + m.openQueue = slices.DeleteFunc(m.openQueue, func(c chan struct{}) bool { + return c == waitChan + }) return *new(T), ctx.Err() case <-waitChan: } - m.mutex.Lock() + m.mutex.Lock() if m.closeErr != nil { return *new(T), m.closeErr } @@ -109,8 +102,7 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) { continue } str := m.openStream() - delete(m.openQueue, queuePos) - m.lowestInQueue = queuePos + 1 + m.openQueue = m.openQueue[1:] m.unblockOpenSync() return str, nil } @@ -200,19 +192,15 @@ func (m *outgoingStreamsMap[T]) unblockOpenSync() { if len(m.openQueue) == 0 { return } - for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { - c, ok := m.openQueue[qp] - if !ok { // entry was deleted because the context was canceled - continue - } - // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. - // It's sufficient to only unblock OpenStreamSync once. - select { - case c <- struct{}{}: - default: - } + if m.nextStream > m.maxStream { return } + // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. + // It's sufficient to only unblock OpenStreamSync once. + select { + case m.openQueue[0] <- struct{}{}: + default: + } } func (m *outgoingStreamsMap[T]) CloseWithError(err error) {