diff --git a/session.go b/session.go index 37036b2dd..ec0e9c2f9 100644 --- a/session.go +++ b/session.go @@ -610,6 +610,7 @@ func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) { maxAllowedStreams := uint32(protocol.MaxStreamsMultiplier * float32(s.connectionParametersManager.GetMaxStreamsPerConnection())) if atomic.LoadUint32(&s.openStreamsCount) >= maxAllowedStreams { + go s.Close(qerr.TooManyOpenStreams) return nil, qerr.TooManyOpenStreams } if _, ok := s.streams[id]; ok { diff --git a/session_test.go b/session_test.go index 217ba7344..88c95ef89 100644 --- a/session_test.go +++ b/session_test.go @@ -682,14 +682,16 @@ var _ = Describe("Session", func() { }) Context("counting streams", func() { - It("errors when too many streams are opened", func() { + It("errors when too many streams are opened", func(done Done) { // 1.1 * 100 for i := 2; i <= 110; i++ { _, err := session.OpenStream(protocol.StreamID(i)) Expect(err).NotTo(HaveOccurred()) } - _, err := session.OpenStream(protocol.StreamID(110)) + _, err := session.OpenStream(protocol.StreamID(111)) Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + Eventually(session.closeChan).Should(Receive()) + close(done) }) It("does not error when many streams are opened and closed", func() {