diff --git a/streams_map.go b/streams_map.go index 3fd553b4..14b74c69 100644 --- a/streams_map.go +++ b/streams_map.go @@ -14,7 +14,8 @@ type streamsMap struct { streams map[protocol.StreamID]*stream openStreams []protocol.StreamID - highestStreamOpenedByClient protocol.StreamID + highestStreamOpenedByClient protocol.StreamID + streamsOpenedAfterLastGarbageCollect int mutex sync.RWMutex newStream newStreamLambda @@ -50,6 +51,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { if ok { return s, nil // s may be nil } + // ... we don't have an existing stream, try opening a new one m.mutex.Lock() defer m.mutex.Unlock() @@ -77,6 +79,11 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { m.highestStreamOpenedByClient = id } + m.streamsOpenedAfterLastGarbageCollect++ + if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 { + m.garbageCollectClosedStreams() + } + m.putStream(s) return s, nil } @@ -182,8 +189,8 @@ func (m *streamsMap) NumberOfStreams() int { } // garbageCollectClosedStreams deletes nil values in the streams if they are smaller than protocol.MaxNewStreamIDDelta than the highest stream opened by the client +// note that this garbage collection is relatively expensive, since it iterates over the whole streams map. It should not be called every time a stream is openend or closed func (m *streamsMap) garbageCollectClosedStreams() { - m.mutex.Lock() for id, str := range m.streams { if str != nil { continue @@ -192,5 +199,5 @@ func (m *streamsMap) garbageCollectClosedStreams() { delete(m.streams, id) } } - m.mutex.Unlock() + m.streamsOpenedAfterLastGarbageCollect = 0 } diff --git a/streams_map_test.go b/streams_map_test.go index 4c1b8260..59268c74 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -142,6 +142,26 @@ var _ = Describe("Streams Map", func() { Expect(m.streams).To(HaveKey(protocol.StreamID(23))) Expect(m.streams[23]).ToNot(BeNil()) }) + + It("runs garbage-collection after a bunch of streams have been opened", func() { + numGarbageCollections := 0 + numSavedStreams := 0 + for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 { + streamID := protocol.StreamID(i) + _, err := m.GetOrOpenStream(streamID) + Expect(m.highestStreamOpenedByClient).To(Equal(streamID)) + Expect(err).NotTo(HaveOccurred()) + err = m.RemoveStream(streamID) + Expect(err).NotTo(HaveOccurred()) + if len(m.streams) != numSavedStreams+1 { + numGarbageCollections++ + } + numSavedStreams = len(m.streams) + } + Expect(numGarbageCollections).ToNot(BeZero()) + Expect(numGarbageCollections).To(BeNumerically("<", 4)) + Expect(len(m.streams)).To(BeNumerically("<", 2*protocol.MaxNewStreamIDDelta)) + }) }) })