forked from quic-go/quic-go
implement GetOrOpenStream in streamsMap
This commit is contained in:
@@ -23,7 +23,7 @@ var _ = Describe("Packet packer", func() {
|
||||
fcm.sendWindowSizes[5] = protocol.MaxByteCount
|
||||
fcm.sendWindowSizes[7] = protocol.MaxByteCount
|
||||
|
||||
streamFramer = newStreamFramer(newStreamsMap(), fcm)
|
||||
streamFramer = newStreamFramer(newStreamsMap(nil), fcm)
|
||||
|
||||
packer = &packetPacker{
|
||||
cryptoSetup: &handshake.CryptoSetup{},
|
||||
|
||||
18
session.go
18
session.go
@@ -108,13 +108,11 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
}
|
||||
|
||||
session := &Session{
|
||||
connectionID: connectionID,
|
||||
version: v,
|
||||
conn: conn,
|
||||
streamCallback: streamCallback,
|
||||
closeCallback: closeCallback,
|
||||
// streams: make(map[protocol.StreamID]*stream),
|
||||
streamsMap: newStreamsMap(),
|
||||
connectionID: connectionID,
|
||||
version: v,
|
||||
conn: conn,
|
||||
streamCallback: streamCallback,
|
||||
closeCallback: closeCallback,
|
||||
sentPacketHandler: sentPacketHandler,
|
||||
receivedPacketHandler: receivedPacketHandler,
|
||||
stopWaitingManager: stopWaitingManager,
|
||||
@@ -129,6 +127,8 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
lastNetworkActivityTime: time.Now(),
|
||||
}
|
||||
|
||||
session.streamsMap = newStreamsMap(session.newStream)
|
||||
|
||||
cryptoStream, _ := session.OpenStream(1)
|
||||
var err error
|
||||
session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.IP(), v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged)
|
||||
@@ -637,6 +637,10 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) {
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (s *Session) newStream(id protocol.StreamID) (*stream, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// garbageCollectStreams goes through all streams and removes EOF'ed streams
|
||||
// from the streams map.
|
||||
func (s *Session) garbageCollectStreams() {
|
||||
|
||||
@@ -31,7 +31,7 @@ var _ = Describe("Stream Framer", func() {
|
||||
stream1 = &stream{streamID: 10}
|
||||
stream2 = &stream{streamID: 11}
|
||||
|
||||
streamsMap = newStreamsMap()
|
||||
streamsMap = newStreamsMap(nil)
|
||||
streamsMap.PutStream(stream1)
|
||||
streamsMap.PutStream(stream2)
|
||||
|
||||
|
||||
@@ -6,27 +6,34 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
const (
|
||||
maxNumStreams = int(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
|
||||
)
|
||||
|
||||
type streamsMap struct {
|
||||
streams map[protocol.StreamID]*stream
|
||||
openStreams []protocol.StreamID
|
||||
mutex sync.RWMutex
|
||||
newStream newStreamLambda
|
||||
|
||||
roundRobinIndex int
|
||||
}
|
||||
|
||||
type streamLambda func(*stream) (bool, error)
|
||||
type newStreamLambda func(protocol.StreamID) (*stream, error)
|
||||
|
||||
var (
|
||||
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
|
||||
)
|
||||
|
||||
func newStreamsMap() *streamsMap {
|
||||
maxNumStreams := uint32(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
|
||||
func newStreamsMap(newStream newStreamLambda) *streamsMap {
|
||||
return &streamsMap{
|
||||
streams: map[protocol.StreamID]*stream{},
|
||||
openStreams: make([]protocol.StreamID, 0, maxNumStreams),
|
||||
newStream: newStream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +47,33 @@ func (m *streamsMap) GetStream(id protocol.StreamID) (*stream, bool) {
|
||||
return s, true
|
||||
}
|
||||
|
||||
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
|
||||
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||
m.mutex.RLock()
|
||||
s, ok := m.streams[id]
|
||||
m.mutex.RUnlock()
|
||||
if ok {
|
||||
return s, nil // s may be nil
|
||||
}
|
||||
// ... we don't have an existing stream, try opening a new one
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
|
||||
s, ok = m.streams[id]
|
||||
if ok {
|
||||
return s, nil
|
||||
}
|
||||
if len(m.openStreams) == maxNumStreams {
|
||||
return nil, qerr.TooManyOpenStreams
|
||||
}
|
||||
s, err := m.newStream(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.putStreamImpl(s)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) Iterate(fn streamLambda) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -94,7 +128,10 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
|
||||
func (m *streamsMap) PutStream(s *stream) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return m.putStreamImpl(s)
|
||||
}
|
||||
|
||||
func (m *streamsMap) putStreamImpl(s *stream) error {
|
||||
id := s.StreamID()
|
||||
if _, ok := m.streams[id]; ok {
|
||||
return fmt.Errorf("a stream with ID %d already exists", id)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@@ -14,7 +15,7 @@ var _ = Describe("Streams Map", func() {
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
m = newStreamsMap()
|
||||
m = newStreamsMap(nil)
|
||||
})
|
||||
|
||||
It("returns an error for non-existent streams", func() {
|
||||
@@ -52,6 +53,57 @@ var _ = Describe("Streams Map", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("getting and creating streams", func() {
|
||||
BeforeEach(func() {
|
||||
m.newStream = func(id protocol.StreamID) (*stream, error) {
|
||||
return &stream{streamID: id}, nil
|
||||
}
|
||||
})
|
||||
|
||||
It("gets new streams", func() {
|
||||
s, err := m.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
||||
})
|
||||
|
||||
It("gets existing streams", func() {
|
||||
s, err := m.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
s, err = m.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
||||
})
|
||||
|
||||
It("returns nil for closed streams", func() {
|
||||
s, err := m.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = m.RemoveStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
s, err = m.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s).To(BeNil())
|
||||
})
|
||||
|
||||
Context("counting streams", func() {
|
||||
It("errors when too many streams are opened", func() {
|
||||
for i := 0; i < maxNumStreams; i++ {
|
||||
_, err := m.GetOrOpenStream(protocol.StreamID(i))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
_, err := m.GetOrOpenStream(protocol.StreamID(maxNumStreams))
|
||||
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||
})
|
||||
|
||||
It("does not error when many streams are opened and closed", func() {
|
||||
for i := 2; i < 10*maxNumStreams; i++ {
|
||||
_, err := m.GetOrOpenStream(protocol.StreamID(i))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
m.RemoveStream(protocol.StreamID(i))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("deleting streams", func() {
|
||||
BeforeEach(func() {
|
||||
for i := 1; i <= 5; i++ {
|
||||
|
||||
Reference in New Issue
Block a user