forked from quic-go/quic-go
add a context to Session.Open{Uni}StreamSync
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
@@ -12,10 +13,12 @@ import (
|
||||
type outgoingItemsMap struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
openQueue []chan struct{}
|
||||
|
||||
streams map[protocol.StreamNum]item
|
||||
|
||||
openQueue map[uint64]chan struct{}
|
||||
lowestInQueue uint64
|
||||
highestInQueue uint64
|
||||
|
||||
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
|
||||
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
|
||||
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
|
||||
@@ -32,6 +35,7 @@ func newOutgoingItemsMap(
|
||||
) *outgoingItemsMap {
|
||||
return &outgoingItemsMap{
|
||||
streams: make(map[protocol.StreamNum]item),
|
||||
openQueue: make(map[uint64]chan struct{}),
|
||||
maxStream: protocol.InvalidStreamNum,
|
||||
nextStream: 1,
|
||||
newStream: newStream,
|
||||
@@ -55,7 +59,7 @@ func (m *outgoingItemsMap) OpenStream() (item, error) {
|
||||
return m.openStream(), nil
|
||||
}
|
||||
|
||||
func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
|
||||
func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -63,17 +67,32 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
|
||||
return nil, m.closeErr
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
|
||||
return m.openStream(), nil
|
||||
}
|
||||
|
||||
waitChan := make(chan struct{}, 1)
|
||||
m.openQueue = append(m.openQueue, waitChan)
|
||||
queuePos := m.highestInQueue
|
||||
m.highestInQueue++
|
||||
if len(m.openQueue) == 0 {
|
||||
m.lowestInQueue = queuePos
|
||||
}
|
||||
m.openQueue[queuePos] = waitChan
|
||||
m.maybeSendBlockedFrame()
|
||||
|
||||
for {
|
||||
m.mutex.Unlock()
|
||||
<-waitChan
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.mutex.Lock()
|
||||
delete(m.openQueue, queuePos)
|
||||
return nil, ctx.Err()
|
||||
case <-waitChan:
|
||||
}
|
||||
m.mutex.Lock()
|
||||
|
||||
if m.closeErr != nil {
|
||||
@@ -84,7 +103,7 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
|
||||
continue
|
||||
}
|
||||
str := m.openStream()
|
||||
m.openQueue = m.openQueue[1:]
|
||||
delete(m.openQueue, queuePos)
|
||||
m.unblockOpenSync()
|
||||
return str, nil
|
||||
}
|
||||
@@ -157,9 +176,15 @@ func (m *outgoingItemsMap) unblockOpenSync() {
|
||||
if len(m.openQueue) == 0 {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case m.openQueue[0] <- struct{}{}:
|
||||
default:
|
||||
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
|
||||
c, ok := m.openQueue[qp]
|
||||
if !ok { // entry was deleted because the context was canceled
|
||||
continue
|
||||
}
|
||||
close(c)
|
||||
m.openQueue[qp] = nil
|
||||
m.lowestInQueue = qp + 1
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +195,9 @@ func (m *outgoingItemsMap) CloseWithError(err error) {
|
||||
str.closeForShutdown(err)
|
||||
}
|
||||
for _, c := range m.openQueue {
|
||||
close(c)
|
||||
if c != nil {
|
||||
close(c)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user