implement GetOrOpenStream in streamsMap

This commit is contained in:
Lucas Clemente
2016-08-08 14:31:33 +02:00
parent 77580dbf96
commit 65663c3314
5 changed files with 105 additions and 12 deletions

View File

@@ -23,7 +23,7 @@ var _ = Describe("Packet packer", func() {
fcm.sendWindowSizes[5] = protocol.MaxByteCount fcm.sendWindowSizes[5] = protocol.MaxByteCount
fcm.sendWindowSizes[7] = protocol.MaxByteCount fcm.sendWindowSizes[7] = protocol.MaxByteCount
streamFramer = newStreamFramer(newStreamsMap(), fcm) streamFramer = newStreamFramer(newStreamsMap(nil), fcm)
packer = &packetPacker{ packer = &packetPacker{
cryptoSetup: &handshake.CryptoSetup{}, cryptoSetup: &handshake.CryptoSetup{},

View File

@@ -108,13 +108,11 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
} }
session := &Session{ session := &Session{
connectionID: connectionID, connectionID: connectionID,
version: v, version: v,
conn: conn, conn: conn,
streamCallback: streamCallback, streamCallback: streamCallback,
closeCallback: closeCallback, closeCallback: closeCallback,
// streams: make(map[protocol.StreamID]*stream),
streamsMap: newStreamsMap(),
sentPacketHandler: sentPacketHandler, sentPacketHandler: sentPacketHandler,
receivedPacketHandler: receivedPacketHandler, receivedPacketHandler: receivedPacketHandler,
stopWaitingManager: stopWaitingManager, stopWaitingManager: stopWaitingManager,
@@ -129,6 +127,8 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
lastNetworkActivityTime: time.Now(), lastNetworkActivityTime: time.Now(),
} }
session.streamsMap = newStreamsMap(session.newStream)
cryptoStream, _ := session.OpenStream(1) cryptoStream, _ := session.OpenStream(1)
var err error var err error
session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.IP(), v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged) session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.IP(), v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged)
@@ -637,6 +637,10 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) {
return stream, nil return stream, nil
} }
func (s *Session) newStream(id protocol.StreamID) (*stream, error) {
return nil, errors.New("not implemented")
}
// garbageCollectStreams goes through all streams and removes EOF'ed streams // garbageCollectStreams goes through all streams and removes EOF'ed streams
// from the streams map. // from the streams map.
func (s *Session) garbageCollectStreams() { func (s *Session) garbageCollectStreams() {

View File

@@ -31,7 +31,7 @@ var _ = Describe("Stream Framer", func() {
stream1 = &stream{streamID: 10} stream1 = &stream{streamID: 10}
stream2 = &stream{streamID: 11} stream2 = &stream{streamID: 11}
streamsMap = newStreamsMap() streamsMap = newStreamsMap(nil)
streamsMap.PutStream(stream1) streamsMap.PutStream(stream1)
streamsMap.PutStream(stream2) streamsMap.PutStream(stream2)

View File

@@ -6,27 +6,34 @@ import (
"sync" "sync"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
const (
maxNumStreams = int(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
) )
type streamsMap struct { type streamsMap struct {
streams map[protocol.StreamID]*stream streams map[protocol.StreamID]*stream
openStreams []protocol.StreamID openStreams []protocol.StreamID
mutex sync.RWMutex mutex sync.RWMutex
newStream newStreamLambda
roundRobinIndex int roundRobinIndex int
} }
type streamLambda func(*stream) (bool, error) type streamLambda func(*stream) (bool, error)
type newStreamLambda func(protocol.StreamID) (*stream, error)
var ( var (
errMapAccess = errors.New("streamsMap: Error accessing the streams map") errMapAccess = errors.New("streamsMap: Error accessing the streams map")
) )
func newStreamsMap() *streamsMap { func newStreamsMap(newStream newStreamLambda) *streamsMap {
maxNumStreams := uint32(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
return &streamsMap{ return &streamsMap{
streams: map[protocol.StreamID]*stream{}, streams: map[protocol.StreamID]*stream{},
openStreams: make([]protocol.StreamID, 0, maxNumStreams), openStreams: make([]protocol.StreamID, 0, maxNumStreams),
newStream: newStream,
} }
} }
@@ -40,6 +47,33 @@ func (m *streamsMap) GetStream(id protocol.StreamID) (*stream, bool) {
return s, true return s, true
} }
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
m.mutex.RLock()
s, ok := m.streams[id]
m.mutex.RUnlock()
if ok {
return s, nil // s may be nil
}
// ... we don't have an existing stream, try opening a new one
m.mutex.Lock()
defer m.mutex.Unlock()
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
s, ok = m.streams[id]
if ok {
return s, nil
}
if len(m.openStreams) == maxNumStreams {
return nil, qerr.TooManyOpenStreams
}
s, err := m.newStream(id)
if err != nil {
return nil, err
}
m.putStreamImpl(s)
return s, nil
}
func (m *streamsMap) Iterate(fn streamLambda) error { func (m *streamsMap) Iterate(fn streamLambda) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -94,7 +128,10 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
func (m *streamsMap) PutStream(s *stream) error { func (m *streamsMap) PutStream(s *stream) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.putStreamImpl(s)
}
func (m *streamsMap) putStreamImpl(s *stream) error {
id := s.StreamID() id := s.StreamID()
if _, ok := m.streams[id]; ok { if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id) return fmt.Errorf("a stream with ID %d already exists", id)

View File

@@ -4,6 +4,7 @@ import (
"errors" "errors"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@@ -14,7 +15,7 @@ var _ = Describe("Streams Map", func() {
) )
BeforeEach(func() { BeforeEach(func() {
m = newStreamsMap() m = newStreamsMap(nil)
}) })
It("returns an error for non-existent streams", func() { It("returns an error for non-existent streams", func() {
@@ -52,6 +53,57 @@ var _ = Describe("Streams Map", func() {
}) })
}) })
Context("getting and creating streams", func() {
BeforeEach(func() {
m.newStream = func(id protocol.StreamID) (*stream, error) {
return &stream{streamID: id}, nil
}
})
It("gets new streams", func() {
s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
})
It("gets existing streams", func() {
s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
s, err = m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
})
It("returns nil for closed streams", func() {
s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(5)
Expect(err).NotTo(HaveOccurred())
s, err = m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s).To(BeNil())
})
Context("counting streams", func() {
It("errors when too many streams are opened", func() {
for i := 0; i < maxNumStreams; i++ {
_, err := m.GetOrOpenStream(protocol.StreamID(i))
Expect(err).NotTo(HaveOccurred())
}
_, err := m.GetOrOpenStream(protocol.StreamID(maxNumStreams))
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("does not error when many streams are opened and closed", func() {
for i := 2; i < 10*maxNumStreams; i++ {
_, err := m.GetOrOpenStream(protocol.StreamID(i))
Expect(err).NotTo(HaveOccurred())
m.RemoveStream(protocol.StreamID(i))
}
})
})
})
Context("deleting streams", func() { Context("deleting streams", func() {
BeforeEach(func() { BeforeEach(func() {
for i := 1; i <= 5; i++ { for i := 1; i <= 5; i++ {