diff --git a/session_test.go b/session_test.go index 6c07b986d..ba0c69d5e 100644 --- a/session_test.go +++ b/session_test.go @@ -394,10 +394,9 @@ var _ = Describe("Session", func() { Expect(err).ToNot(HaveOccurred()) }) - It("ignores streams that existed previously", func() { + It("ignores STREAM frames for closed streams (client-side)", func() { sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, - Data: []byte{}, FinBit: true, }) str, _ := sess.streamsMap.GetOrOpenStream(5) @@ -407,11 +406,38 @@ var _ = Describe("Session", func() { str.Close() str.sentFin() sess.garbageCollectStreams() + str, _ = sess.streamsMap.GetOrOpenStream(5) + Expect(str).To(BeNil()) // make sure the stream is gone err = sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, - Data: []byte{}, + Data: []byte("foobar"), }) - Expect(err).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("ignores STREAM frames for closed streams (server-side)", func() { + ostr, err := sess.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(ostr.StreamID()).To(Equal(protocol.StreamID(2))) + err = sess.handleStreamFrame(&frames.StreamFrame{ + StreamID: 2, + FinBit: true, + }) + Expect(err).ToNot(HaveOccurred()) + str, _ := sess.streamsMap.GetOrOpenStream(2) + Expect(str).ToNot(BeNil()) + _, err = str.Read([]byte{0}) + Expect(err).To(MatchError(io.EOF)) + str.Close() + str.sentFin() + sess.garbageCollectStreams() + str, _ = sess.streamsMap.GetOrOpenStream(2) + Expect(str).To(BeNil()) // make sure the stream is gone + err = sess.handleStreamFrame(&frames.StreamFrame{ + StreamID: 2, + FinBit: true, + }) + Expect(err).ToNot(HaveOccurred()) }) }) diff --git a/streams_map.go b/streams_map.go index c1a8c2793..74be17e08 100644 --- a/streams_map.go +++ b/streams_map.go @@ -83,15 +83,27 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { return s, nil } - if id <= m.highestStreamOpenedByPeer { - return nil, nil + if m.perspective == protocol.PerspectiveServer { + if id%2 == 0 { + if id <= m.nextStream { // this is a server-side stream that we already opened. Must have been closed already + return nil, nil + } + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) + } + if id <= m.highestStreamOpenedByPeer { // this is a client-side stream that doesn't exist anymore. Must have been closed already + return nil, nil + } } - - if m.perspective == protocol.PerspectiveServer && id%2 == 0 { - return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) - } - if m.perspective == protocol.PerspectiveClient && id%2 == 1 { - return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) + if m.perspective == protocol.PerspectiveClient { + if id%2 == 1 { + if id <= m.nextStream { // this is a client-side stream that we already opened. + return nil, nil + } + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) + } + if id <= m.highestStreamOpenedByPeer { // this is a server-side stream that doesn't exist anymore. Must have been closed already + return nil, nil + } } // sid is the next stream that will be opened diff --git a/streams_map_test.go b/streams_map_test.go index f73a75fcc..b364f6da0 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -57,6 +57,13 @@ var _ = Describe("Streams Map", func() { Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side")) }) + It("rejects streams with even IDs, which are lower thatn the highest client-side stream", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + _, err = m.GetOrOpenStream(4) + Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 4 from client-side")) + }) + It("gets existing streams", func() { s, err := m.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) @@ -137,6 +144,17 @@ var _ = Describe("Streams Map", func() { Expect(err).To(MatchError(testErr)) }) + It("doesn't reopen an already closed stream", func() { + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) + err = m.RemoveStream(2) + Expect(err).ToNot(HaveOccurred()) + str, err = m.GetOrOpenStream(2) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + Context("counting streams", func() { It("errors when too many streams are opened", func() { for i := 1; i <= maxOutgoingStreams; i++ { @@ -353,12 +371,19 @@ var _ = Describe("Streams Map", func() { setNewStreamsMap(protocol.PerspectiveClient) }) - Context("client-side streams, as a client", func() { + Context("client-side streams", func() { It("rejects streams with odd IDs", func() { _, err := m.GetOrOpenStream(5) Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side")) }) + It("rejects streams with odds IDs, which are lower thatn the highest server-side stream", func() { + _, err := m.GetOrOpenStream(6) + Expect(err).NotTo(HaveOccurred()) + _, err = m.GetOrOpenStream(5) + Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side")) + }) + It("gets new streams", func() { s, err := m.GetOrOpenStream(2) Expect(err).NotTo(HaveOccurred()) @@ -374,6 +399,17 @@ var _ = Describe("Streams Map", func() { Expect(m.streams).To(HaveKey(protocol.StreamID(4))) Expect(m.streams).To(HaveKey(protocol.StreamID(6))) }) + + It("doesn't reopen an already closed stream", func() { + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + err = m.RemoveStream(1) + Expect(err).ToNot(HaveOccurred()) + str, err = m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) }) Context("server-side streams", func() { @@ -393,6 +429,16 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) Expect(s2.StreamID()).To(Equal(s1.StreamID() + 2)) }) + + It("doesn't reopen an already closed stream", func() { + _, err := m.GetOrOpenStream(4) + Expect(err).ToNot(HaveOccurred()) + err = m.RemoveStream(4) + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetOrOpenStream(4) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) }) Context("accepting streams", func() {