From c7b4ad6e80bcc617ad58db5486c279eafa74415a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 15 Feb 2017 22:29:08 +0700 Subject: [PATCH] return plain nil value for Session.GetOrOpenStream for closed streams fixes #418 --- session.go | 7 ++++++- session_test.go | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/session.go b/session.go index ee67fe85..551aec4a 100644 --- a/session.go +++ b/session.go @@ -658,7 +658,12 @@ func (s *Session) logPacket(packet *packedPacket) { // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { - return s.streamsMap.GetOrOpenStream(id) + str, err := s.streamsMap.GetOrOpenStream(id) + if str != nil { + return str, err + } + // make sure to return an actual nil value here, not an utils.Stream with value nil + return nil, err } // OpenStream opens a stream from the server's side diff --git a/session_test.go b/session_test.go index c5375c8b..4b1b1eeb 100644 --- a/session_test.go +++ b/session_test.go @@ -1155,6 +1155,29 @@ var _ = Describe("Session", func() { Expect(conn.written[0]).To(ContainSubstring("foobar")) }) + Context("getting streams", func() { + It("returns a new stream", func() { + str, err := session.GetOrOpenStream(11) + Expect(err).ToNot(HaveOccurred()) + Expect(str).ToNot(BeNil()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(11))) + }) + + It("returns a nil-value (not an interface with value nil) for closed streams", func() { + _, err := session.GetOrOpenStream(9) + Expect(err).ToNot(HaveOccurred()) + session.streamsMap.RemoveStream(9) + session.garbageCollectStreams() + Expect(session.streamsMap.GetOrOpenStream(9)).To(BeNil()) + str, err := session.GetOrOpenStream(9) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + // make sure that the returned value is a plain nil, not an utils.Stream with value nil + _, ok := str.(utils.Stream) + Expect(ok).To(BeFalse()) + }) + }) + Context("counting streams", func() { It("errors when too many streams are opened", func() { for i := 2; i <= 110; i++ {