diff --git a/session.go b/session.go index 8de0372b..b2f473c7 100644 --- a/session.go +++ b/session.go @@ -565,10 +565,16 @@ func (s *Session) logPacket(packet *packedPacket) { } // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. +// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { return s.streamsMap.GetOrOpenStream(id) } +// OpenStream opens a stream from the server's side +func (s *Session) OpenStream(id protocol.StreamID) (utils.Stream, error) { + return s.streamsMap.OpenStream(id) +} + func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) { return s.streamsMap.GetOrOpenStream(id) } diff --git a/session_test.go b/session_test.go index 899e4d1d..9c56f996 100644 --- a/session_test.go +++ b/session_test.go @@ -142,10 +142,10 @@ var _ = Describe("Session", func() { }) It("does not reject existing streams with even StreamIDs", func() { - _, err := session.GetOrOpenStream(4) + _, err := session.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) err = session.handleStreamFrame(&frames.StreamFrame{ - StreamID: 4, + StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, }) Expect(err).ToNot(HaveOccurred()) @@ -820,16 +820,16 @@ var _ = Describe("Session", func() { Context("counting streams", func() { It("errors when too many streams are opened", func() { for i := 2; i <= 110; i++ { - _, err := session.GetOrOpenStream(protocol.StreamID(i)) + _, err := session.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) } - _, err := session.GetOrOpenStream(protocol.StreamID(111)) + _, err := session.GetOrOpenStream(protocol.StreamID(301)) Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) It("does not error when many streams are opened and closed", func() { for i := 2; i <= 1000; i++ { - s, err := session.GetOrOpenStream(protocol.StreamID(i)) + s, err := session.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) err = s.Close() Expect(err).NotTo(HaveOccurred()) diff --git a/streams_map.go b/streams_map.go index 2e757477..05ce1a39 100644 --- a/streams_map.go +++ b/streams_map.go @@ -38,6 +38,7 @@ func newStreamsMap(newStream newStreamLambda) *streamsMap { } // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. +// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { m.mutex.RLock() s, ok := m.streams[id] @@ -56,6 +57,9 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { if len(m.openStreams) == maxNumStreams { return nil, qerr.TooManyOpenStreams } + if id%2 == 0 { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) + } s, err := m.newStream(id) if err != nil { return nil, err @@ -64,6 +68,11 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { return s, nil } +// OpenStream opens a stream from the server's side +func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) { + panic("OpenStream: not implemented") +} + func (m *streamsMap) Iterate(fn streamLambda) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/streams_map_test.go b/streams_map_test.go index 6c2f7f8a..4e04a559 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -31,6 +31,11 @@ var _ = Describe("Streams Map", func() { Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) }) + It("rejects streams with even IDs", func() { + _, err := m.GetOrOpenStream(6) + Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side")) + }) + It("gets existing streams", func() { s, err := m.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) @@ -49,10 +54,14 @@ var _ = Describe("Streams Map", func() { Expect(s).To(BeNil()) }) + It("panics on OpenStream", func() { + Expect(func() { m.OpenStream(0) }).To(Panic()) + }) + Context("counting streams", func() { It("errors when too many streams are opened", func() { for i := 0; i < maxNumStreams; i++ { - _, err := m.GetOrOpenStream(protocol.StreamID(i)) + _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) } _, err := m.GetOrOpenStream(protocol.StreamID(maxNumStreams)) @@ -61,9 +70,9 @@ var _ = Describe("Streams Map", func() { It("does not error when many streams are opened and closed", func() { for i := 2; i < 10*maxNumStreams; i++ { - _, err := m.GetOrOpenStream(protocol.StreamID(i)) + _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) - m.RemoveStream(protocol.StreamID(i)) + m.RemoveStream(protocol.StreamID(i*2 + 1)) } }) })