add a context to Session.Open{Uni}StreamSync

This commit is contained in:
Marten Seemann
2019-06-07 16:19:56 +08:00
parent e63a991950
commit 2b8cece60a
20 changed files with 218 additions and 104 deletions

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -12,10 +13,12 @@ import (
type outgoingItemsMap struct {
mutex sync.RWMutex
openQueue []chan struct{}
streams map[protocol.StreamNum]item
openQueue map[uint64]chan struct{}
lowestInQueue uint64
highestInQueue uint64
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
@@ -32,6 +35,7 @@ func newOutgoingItemsMap(
) *outgoingItemsMap {
return &outgoingItemsMap{
streams: make(map[protocol.StreamNum]item),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
newStream: newStream,
@@ -55,7 +59,7 @@ func (m *outgoingItemsMap) OpenStream() (item, error) {
return m.openStream(), nil
}
func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -63,17 +67,32 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
return nil, m.closeErr
}
if err := ctx.Err(); err != nil {
return nil, err
}
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil
}
waitChan := make(chan struct{}, 1)
m.openQueue = append(m.openQueue, waitChan)
queuePos := m.highestInQueue
m.highestInQueue++
if len(m.openQueue) == 0 {
m.lowestInQueue = queuePos
}
m.openQueue[queuePos] = waitChan
m.maybeSendBlockedFrame()
for {
m.mutex.Unlock()
<-waitChan
select {
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
return nil, ctx.Err()
case <-waitChan:
}
m.mutex.Lock()
if m.closeErr != nil {
@@ -84,7 +103,7 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
continue
}
str := m.openStream()
m.openQueue = m.openQueue[1:]
delete(m.openQueue, queuePos)
m.unblockOpenSync()
return str, nil
}
@@ -157,9 +176,15 @@ func (m *outgoingItemsMap) unblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
select {
case m.openQueue[0] <- struct{}{}:
default:
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
c, ok := m.openQueue[qp]
if !ok { // entry was deleted because the context was canceled
continue
}
close(c)
m.openQueue[qp] = nil
m.lowestInQueue = qp + 1
return
}
}
@@ -170,7 +195,9 @@ func (m *outgoingItemsMap) CloseWithError(err error) {
str.closeForShutdown(err)
}
for _, c := range m.openQueue {
close(c)
if c != nil {
close(c)
}
}
m.mutex.Unlock()
}