diff --git a/session.go b/session.go index 78602f7aa..02fe321e0 100644 --- a/session.go +++ b/session.go @@ -574,7 +574,7 @@ func (s *Session) garbageCollectStreams() { if v == nil { continue } - if v.finishedReading() { + if v.finished() { s.streams[k] = nil } } diff --git a/session_test.go b/session_test.go index 1d2ecc132..5464a8bbf 100644 --- a/session_test.go +++ b/session_test.go @@ -151,7 +151,16 @@ var _ = Describe("Session", func() { Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) }) - It("closes streams with FIN bits", func() { + It("does not delete streams with Close()", func() { + str, err := session.NewStream(5) + Expect(err).ToNot(HaveOccurred()) + str.Close() + session.garbageCollectStreams() + Expect(session.streams).To(HaveLen(1)) + Expect(session.streams[5]).ToNot(BeNil()) + }) + + It("does not delete streams with FIN bit", func() { session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, @@ -166,6 +175,29 @@ var _ = Describe("Session", func() { Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) session.garbageCollectStreams() Expect(session.streams).To(HaveLen(1)) + Expect(session.streams[5]).ToNot(BeNil()) + }) + + It("closes streams with FIN bit & close", func() { + session.handleStreamFrame(&frames.StreamFrame{ + StreamID: 5, + Data: []byte{0xde, 0xca, 0xfb, 0xad}, + FinBit: true, + }) + Expect(session.streams).To(HaveLen(1)) + Expect(session.streams[5]).ToNot(BeNil()) + Expect(callbackCalled).To(BeTrue()) + p := make([]byte, 4) + _, err := session.streams[5].Read(p) + Expect(err).To(Equal(io.EOF)) + Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) + session.garbageCollectStreams() + Expect(session.streams).To(HaveLen(1)) + Expect(session.streams[5]).ToNot(BeNil()) + // We still need to close the stream locally + session.streams[5].Close() + session.garbageCollectStreams() + Expect(session.streams).To(HaveLen(1)) Expect(session.streams[5]).To(BeNil()) }) @@ -213,6 +245,7 @@ var _ = Describe("Session", func() { }) _, err := session.streams[5].Read([]byte{0}) Expect(err).To(Equal(io.EOF)) + session.streams[5].Close() session.garbageCollectStreams() err = session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, diff --git a/stream.go b/stream.go index 96be577b8..0235c2f17 100644 --- a/stream.go +++ b/stream.go @@ -29,7 +29,10 @@ type stream struct { err error mutex sync.Mutex + // eof is set if we are finished reading eof int32 // really a bool + // closed is set when we are finished writing + closed int32 // really a bool frameQueue streamFrameSorter newFrameOrErrCond sync.Cond @@ -195,6 +198,7 @@ func (s *stream) Write(p []byte) (int, error) { // Close implements io.Closer func (s *stream) Close() error { + atomic.StoreInt32(&s.closed, 1) return s.session.QueueStreamFrame(&frames.StreamFrame{ StreamID: s.streamID, Offset: s.writeOffset, @@ -222,6 +226,7 @@ func (s *stream) maybeTriggerWindowUpdate() { // RegisterError is called by session to indicate that an error occurred and the // stream should be closed. func (s *stream) RegisterError(err error) { + atomic.StoreInt32(&s.closed, 1) s.mutex.Lock() defer s.mutex.Unlock() if s.err != nil { // s.err must not be changed! @@ -236,6 +241,14 @@ func (s *stream) finishedReading() bool { return atomic.LoadInt32(&s.eof) != 0 } +func (s *stream) finishedWriting() bool { + return atomic.LoadInt32(&s.closed) != 0 +} + +func (s *stream) finished() bool { + return s.finishedReading() && s.finishedWriting() +} + func (s *stream) StreamID() protocol.StreamID { return s.streamID }