forked from quic-go/quic-go
implement GetOrOpenStream in streamsMap
This commit is contained in:
@@ -23,7 +23,7 @@ var _ = Describe("Packet packer", func() {
|
|||||||
fcm.sendWindowSizes[5] = protocol.MaxByteCount
|
fcm.sendWindowSizes[5] = protocol.MaxByteCount
|
||||||
fcm.sendWindowSizes[7] = protocol.MaxByteCount
|
fcm.sendWindowSizes[7] = protocol.MaxByteCount
|
||||||
|
|
||||||
streamFramer = newStreamFramer(newStreamsMap(), fcm)
|
streamFramer = newStreamFramer(newStreamsMap(nil), fcm)
|
||||||
|
|
||||||
packer = &packetPacker{
|
packer = &packetPacker{
|
||||||
cryptoSetup: &handshake.CryptoSetup{},
|
cryptoSetup: &handshake.CryptoSetup{},
|
||||||
|
|||||||
18
session.go
18
session.go
@@ -108,13 +108,11 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
|||||||
}
|
}
|
||||||
|
|
||||||
session := &Session{
|
session := &Session{
|
||||||
connectionID: connectionID,
|
connectionID: connectionID,
|
||||||
version: v,
|
version: v,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
streamCallback: streamCallback,
|
streamCallback: streamCallback,
|
||||||
closeCallback: closeCallback,
|
closeCallback: closeCallback,
|
||||||
// streams: make(map[protocol.StreamID]*stream),
|
|
||||||
streamsMap: newStreamsMap(),
|
|
||||||
sentPacketHandler: sentPacketHandler,
|
sentPacketHandler: sentPacketHandler,
|
||||||
receivedPacketHandler: receivedPacketHandler,
|
receivedPacketHandler: receivedPacketHandler,
|
||||||
stopWaitingManager: stopWaitingManager,
|
stopWaitingManager: stopWaitingManager,
|
||||||
@@ -129,6 +127,8 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
|||||||
lastNetworkActivityTime: time.Now(),
|
lastNetworkActivityTime: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
session.streamsMap = newStreamsMap(session.newStream)
|
||||||
|
|
||||||
cryptoStream, _ := session.OpenStream(1)
|
cryptoStream, _ := session.OpenStream(1)
|
||||||
var err error
|
var err error
|
||||||
session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.IP(), v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged)
|
session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.IP(), v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged)
|
||||||
@@ -637,6 +637,10 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) {
|
|||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Session) newStream(id protocol.StreamID) (*stream, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// garbageCollectStreams goes through all streams and removes EOF'ed streams
|
// garbageCollectStreams goes through all streams and removes EOF'ed streams
|
||||||
// from the streams map.
|
// from the streams map.
|
||||||
func (s *Session) garbageCollectStreams() {
|
func (s *Session) garbageCollectStreams() {
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ var _ = Describe("Stream Framer", func() {
|
|||||||
stream1 = &stream{streamID: 10}
|
stream1 = &stream{streamID: 10}
|
||||||
stream2 = &stream{streamID: 11}
|
stream2 = &stream{streamID: 11}
|
||||||
|
|
||||||
streamsMap = newStreamsMap()
|
streamsMap = newStreamsMap(nil)
|
||||||
streamsMap.PutStream(stream1)
|
streamsMap.PutStream(stream1)
|
||||||
streamsMap.PutStream(stream2)
|
streamsMap.PutStream(stream2)
|
||||||
|
|
||||||
|
|||||||
@@ -6,27 +6,34 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxNumStreams = int(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
|
||||||
)
|
)
|
||||||
|
|
||||||
type streamsMap struct {
|
type streamsMap struct {
|
||||||
streams map[protocol.StreamID]*stream
|
streams map[protocol.StreamID]*stream
|
||||||
openStreams []protocol.StreamID
|
openStreams []protocol.StreamID
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
newStream newStreamLambda
|
||||||
|
|
||||||
roundRobinIndex int
|
roundRobinIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamLambda func(*stream) (bool, error)
|
type streamLambda func(*stream) (bool, error)
|
||||||
|
type newStreamLambda func(protocol.StreamID) (*stream, error)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
|
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
|
||||||
)
|
)
|
||||||
|
|
||||||
func newStreamsMap() *streamsMap {
|
func newStreamsMap(newStream newStreamLambda) *streamsMap {
|
||||||
maxNumStreams := uint32(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
|
|
||||||
return &streamsMap{
|
return &streamsMap{
|
||||||
streams: map[protocol.StreamID]*stream{},
|
streams: map[protocol.StreamID]*stream{},
|
||||||
openStreams: make([]protocol.StreamID, 0, maxNumStreams),
|
openStreams: make([]protocol.StreamID, 0, maxNumStreams),
|
||||||
|
newStream: newStream,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,6 +47,33 @@ func (m *streamsMap) GetStream(id protocol.StreamID) (*stream, bool) {
|
|||||||
return s, true
|
return s, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
|
||||||
|
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||||
|
m.mutex.RLock()
|
||||||
|
s, ok := m.streams[id]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
if ok {
|
||||||
|
return s, nil // s may be nil
|
||||||
|
}
|
||||||
|
// ... we don't have an existing stream, try opening a new one
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
|
||||||
|
s, ok = m.streams[id]
|
||||||
|
if ok {
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
if len(m.openStreams) == maxNumStreams {
|
||||||
|
return nil, qerr.TooManyOpenStreams
|
||||||
|
}
|
||||||
|
s, err := m.newStream(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.putStreamImpl(s)
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *streamsMap) Iterate(fn streamLambda) error {
|
func (m *streamsMap) Iterate(fn streamLambda) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -94,7 +128,10 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
|
|||||||
func (m *streamsMap) PutStream(s *stream) error {
|
func (m *streamsMap) PutStream(s *stream) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
return m.putStreamImpl(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *streamsMap) putStreamImpl(s *stream) error {
|
||||||
id := s.StreamID()
|
id := s.StreamID()
|
||||||
if _, ok := m.streams[id]; ok {
|
if _, ok := m.streams[id]; ok {
|
||||||
return fmt.Errorf("a stream with ID %d already exists", id)
|
return fmt.Errorf("a stream with ID %d already exists", id)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
@@ -14,7 +15,7 @@ var _ = Describe("Streams Map", func() {
|
|||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
m = newStreamsMap()
|
m = newStreamsMap(nil)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns an error for non-existent streams", func() {
|
It("returns an error for non-existent streams", func() {
|
||||||
@@ -52,6 +53,57 @@ var _ = Describe("Streams Map", func() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("getting and creating streams", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
m.newStream = func(id protocol.StreamID) (*stream, error) {
|
||||||
|
return &stream{streamID: id}, nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("gets new streams", func() {
|
||||||
|
s, err := m.GetOrOpenStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("gets existing streams", func() {
|
||||||
|
s, err := m.GetOrOpenStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
s, err = m.GetOrOpenStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns nil for closed streams", func() {
|
||||||
|
s, err := m.GetOrOpenStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
err = m.RemoveStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
s, err = m.GetOrOpenStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(s).To(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("counting streams", func() {
|
||||||
|
It("errors when too many streams are opened", func() {
|
||||||
|
for i := 0; i < maxNumStreams; i++ {
|
||||||
|
_, err := m.GetOrOpenStream(protocol.StreamID(i))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
}
|
||||||
|
_, err := m.GetOrOpenStream(protocol.StreamID(maxNumStreams))
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does not error when many streams are opened and closed", func() {
|
||||||
|
for i := 2; i < 10*maxNumStreams; i++ {
|
||||||
|
_, err := m.GetOrOpenStream(protocol.StreamID(i))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
m.RemoveStream(protocol.StreamID(i))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
Context("deleting streams", func() {
|
Context("deleting streams", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
for i := 1; i <= 5; i++ {
|
for i := 1; i <= 5; i++ {
|
||||||
|
|||||||
Reference in New Issue
Block a user