From 65663c3314197100fbc300fdadae6e817445f5c5 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Mon, 8 Aug 2016 14:31:33 +0200 Subject: [PATCH] implement GetOrOpenStream in streamsMap --- packet_packer_test.go | 2 +- session.go | 18 +++++++++------ stream_framer_test.go | 2 +- streams_map.go | 41 ++++++++++++++++++++++++++++++-- streams_map_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 105 insertions(+), 12 deletions(-) diff --git a/packet_packer_test.go b/packet_packer_test.go index f56727ad..7a799f93 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -23,7 +23,7 @@ var _ = Describe("Packet packer", func() { fcm.sendWindowSizes[5] = protocol.MaxByteCount fcm.sendWindowSizes[7] = protocol.MaxByteCount - streamFramer = newStreamFramer(newStreamsMap(), fcm) + streamFramer = newStreamFramer(newStreamsMap(nil), fcm) packer = &packetPacker{ cryptoSetup: &handshake.CryptoSetup{}, diff --git a/session.go b/session.go index bc9ae48f..d6d7141c 100644 --- a/session.go +++ b/session.go @@ -108,13 +108,11 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol } session := &Session{ - connectionID: connectionID, - version: v, - conn: conn, - streamCallback: streamCallback, - closeCallback: closeCallback, - // streams: make(map[protocol.StreamID]*stream), - streamsMap: newStreamsMap(), + connectionID: connectionID, + version: v, + conn: conn, + streamCallback: streamCallback, + closeCallback: closeCallback, sentPacketHandler: sentPacketHandler, receivedPacketHandler: receivedPacketHandler, stopWaitingManager: stopWaitingManager, @@ -129,6 +127,8 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol lastNetworkActivityTime: time.Now(), } + session.streamsMap = newStreamsMap(session.newStream) + cryptoStream, _ := session.OpenStream(1) var err error session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.IP(), v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged) @@ -637,6 +637,10 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) { return stream, nil } +func (s *Session) newStream(id protocol.StreamID) (*stream, error) { + return nil, errors.New("not implemented") +} + // garbageCollectStreams goes through all streams and removes EOF'ed streams // from the streams map. func (s *Session) garbageCollectStreams() { diff --git a/stream_framer_test.go b/stream_framer_test.go index e8e742f0..ffddb1a4 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -31,7 +31,7 @@ var _ = Describe("Stream Framer", func() { stream1 = &stream{streamID: 10} stream2 = &stream{streamID: 11} - streamsMap = newStreamsMap() + streamsMap = newStreamsMap(nil) streamsMap.PutStream(stream1) streamsMap.PutStream(stream2) diff --git a/streams_map.go b/streams_map.go index fc271fb8..bc4ce3bb 100644 --- a/streams_map.go +++ b/streams_map.go @@ -6,27 +6,34 @@ import ( "sync" "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" +) + +const ( + maxNumStreams = int(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier) ) type streamsMap struct { streams map[protocol.StreamID]*stream openStreams []protocol.StreamID mutex sync.RWMutex + newStream newStreamLambda roundRobinIndex int } type streamLambda func(*stream) (bool, error) +type newStreamLambda func(protocol.StreamID) (*stream, error) var ( errMapAccess = errors.New("streamsMap: Error accessing the streams map") ) -func newStreamsMap() *streamsMap { - maxNumStreams := uint32(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier) +func newStreamsMap(newStream newStreamLambda) *streamsMap { return &streamsMap{ streams: map[protocol.StreamID]*stream{}, openStreams: make([]protocol.StreamID, 0, maxNumStreams), + newStream: newStream, } } @@ -40,6 +47,33 @@ func (m *streamsMap) GetStream(id protocol.StreamID) (*stream, bool) { return s, true } +// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. +func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { + m.mutex.RLock() + s, ok := m.streams[id] + m.mutex.RUnlock() + if ok { + return s, nil // s may be nil + } + // ... we don't have an existing stream, try opening a new one + m.mutex.Lock() + defer m.mutex.Unlock() + // We need to check whether another invocation has already created a stream (between RUnlock() and Lock()). + s, ok = m.streams[id] + if ok { + return s, nil + } + if len(m.openStreams) == maxNumStreams { + return nil, qerr.TooManyOpenStreams + } + s, err := m.newStream(id) + if err != nil { + return nil, err + } + m.putStreamImpl(s) + return s, nil +} + func (m *streamsMap) Iterate(fn streamLambda) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -94,7 +128,10 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { func (m *streamsMap) PutStream(s *stream) error { m.mutex.Lock() defer m.mutex.Unlock() + return m.putStreamImpl(s) +} +func (m *streamsMap) putStreamImpl(s *stream) error { id := s.StreamID() if _, ok := m.streams[id]; ok { return fmt.Errorf("a stream with ID %d already exists", id) diff --git a/streams_map_test.go b/streams_map_test.go index b949ec02..220a80c3 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -14,7 +15,7 @@ var _ = Describe("Streams Map", func() { ) BeforeEach(func() { - m = newStreamsMap() + m = newStreamsMap(nil) }) It("returns an error for non-existent streams", func() { @@ -52,6 +53,57 @@ var _ = Describe("Streams Map", func() { }) }) + Context("getting and creating streams", func() { + BeforeEach(func() { + m.newStream = func(id protocol.StreamID) (*stream, error) { + return &stream{streamID: id}, nil + } + }) + + It("gets new streams", func() { + s, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) + }) + + It("gets existing streams", func() { + s, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + s, err = m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) + }) + + It("returns nil for closed streams", func() { + s, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + err = m.RemoveStream(5) + Expect(err).NotTo(HaveOccurred()) + s, err = m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(s).To(BeNil()) + }) + + Context("counting streams", func() { + It("errors when too many streams are opened", func() { + for i := 0; i < maxNumStreams; i++ { + _, err := m.GetOrOpenStream(protocol.StreamID(i)) + Expect(err).NotTo(HaveOccurred()) + } + _, err := m.GetOrOpenStream(protocol.StreamID(maxNumStreams)) + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("does not error when many streams are opened and closed", func() { + for i := 2; i < 10*maxNumStreams; i++ { + _, err := m.GetOrOpenStream(protocol.StreamID(i)) + Expect(err).NotTo(HaveOccurred()) + m.RemoveStream(protocol.StreamID(i)) + } + }) + }) + }) + Context("deleting streams", func() { BeforeEach(func() { for i := 1; i <= 5; i++ {