implement GetOrOpenStream in streamsMap

This commit is contained in:
Lucas Clemente
2016-08-08 14:31:33 +02:00
parent 77580dbf96
commit 65663c3314
5 changed files with 105 additions and 12 deletions

View File

@@ -23,7 +23,7 @@ var _ = Describe("Packet packer", func() {
fcm.sendWindowSizes[5] = protocol.MaxByteCount
fcm.sendWindowSizes[7] = protocol.MaxByteCount
streamFramer = newStreamFramer(newStreamsMap(), fcm)
streamFramer = newStreamFramer(newStreamsMap(nil), fcm)
packer = &packetPacker{
cryptoSetup: &handshake.CryptoSetup{},

View File

@@ -108,13 +108,11 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
}
session := &Session{
connectionID: connectionID,
version: v,
conn: conn,
streamCallback: streamCallback,
closeCallback: closeCallback,
// streams: make(map[protocol.StreamID]*stream),
streamsMap: newStreamsMap(),
connectionID: connectionID,
version: v,
conn: conn,
streamCallback: streamCallback,
closeCallback: closeCallback,
sentPacketHandler: sentPacketHandler,
receivedPacketHandler: receivedPacketHandler,
stopWaitingManager: stopWaitingManager,
@@ -129,6 +127,8 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
lastNetworkActivityTime: time.Now(),
}
session.streamsMap = newStreamsMap(session.newStream)
cryptoStream, _ := session.OpenStream(1)
var err error
session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.IP(), v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged)
@@ -637,6 +637,10 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) {
return stream, nil
}
func (s *Session) newStream(id protocol.StreamID) (*stream, error) {
return nil, errors.New("not implemented")
}
// garbageCollectStreams goes through all streams and removes EOF'ed streams
// from the streams map.
func (s *Session) garbageCollectStreams() {

View File

@@ -31,7 +31,7 @@ var _ = Describe("Stream Framer", func() {
stream1 = &stream{streamID: 10}
stream2 = &stream{streamID: 11}
streamsMap = newStreamsMap()
streamsMap = newStreamsMap(nil)
streamsMap.PutStream(stream1)
streamsMap.PutStream(stream2)

View File

@@ -6,27 +6,34 @@ import (
"sync"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
const (
maxNumStreams = int(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
)
type streamsMap struct {
streams map[protocol.StreamID]*stream
openStreams []protocol.StreamID
mutex sync.RWMutex
newStream newStreamLambda
roundRobinIndex int
}
type streamLambda func(*stream) (bool, error)
type newStreamLambda func(protocol.StreamID) (*stream, error)
var (
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
)
func newStreamsMap() *streamsMap {
maxNumStreams := uint32(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
func newStreamsMap(newStream newStreamLambda) *streamsMap {
return &streamsMap{
streams: map[protocol.StreamID]*stream{},
openStreams: make([]protocol.StreamID, 0, maxNumStreams),
newStream: newStream,
}
}
@@ -40,6 +47,33 @@ func (m *streamsMap) GetStream(id protocol.StreamID) (*stream, bool) {
return s, true
}
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
m.mutex.RLock()
s, ok := m.streams[id]
m.mutex.RUnlock()
if ok {
return s, nil // s may be nil
}
// ... we don't have an existing stream, try opening a new one
m.mutex.Lock()
defer m.mutex.Unlock()
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
s, ok = m.streams[id]
if ok {
return s, nil
}
if len(m.openStreams) == maxNumStreams {
return nil, qerr.TooManyOpenStreams
}
s, err := m.newStream(id)
if err != nil {
return nil, err
}
m.putStreamImpl(s)
return s, nil
}
func (m *streamsMap) Iterate(fn streamLambda) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -94,7 +128,10 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
func (m *streamsMap) PutStream(s *stream) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.putStreamImpl(s)
}
func (m *streamsMap) putStreamImpl(s *stream) error {
id := s.StreamID()
if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id)

View File

@@ -4,6 +4,7 @@ import (
"errors"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
@@ -14,7 +15,7 @@ var _ = Describe("Streams Map", func() {
)
BeforeEach(func() {
m = newStreamsMap()
m = newStreamsMap(nil)
})
It("returns an error for non-existent streams", func() {
@@ -52,6 +53,57 @@ var _ = Describe("Streams Map", func() {
})
})
Context("getting and creating streams", func() {
BeforeEach(func() {
m.newStream = func(id protocol.StreamID) (*stream, error) {
return &stream{streamID: id}, nil
}
})
It("gets new streams", func() {
s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
})
It("gets existing streams", func() {
s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
s, err = m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
})
It("returns nil for closed streams", func() {
s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(5)
Expect(err).NotTo(HaveOccurred())
s, err = m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s).To(BeNil())
})
Context("counting streams", func() {
It("errors when too many streams are opened", func() {
for i := 0; i < maxNumStreams; i++ {
_, err := m.GetOrOpenStream(protocol.StreamID(i))
Expect(err).NotTo(HaveOccurred())
}
_, err := m.GetOrOpenStream(protocol.StreamID(maxNumStreams))
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("does not error when many streams are opened and closed", func() {
for i := 2; i < 10*maxNumStreams; i++ {
_, err := m.GetOrOpenStream(protocol.StreamID(i))
Expect(err).NotTo(HaveOccurred())
m.RemoveStream(protocol.StreamID(i))
}
})
})
})
Context("deleting streams", func() {
BeforeEach(func() {
for i := 1; i <= 5; i++ {