Merge pull request #305 from lucas-clemente/fix-256

implement a garbage collection for the streamsMap
This commit is contained in:
Lucas Clemente
2016-08-24 16:18:12 +02:00
committed by Lucas Clemente
3 changed files with 126 additions and 2 deletions

View File

@@ -40,6 +40,10 @@ const MaxStreamsMultiplier = 1.1
// MaxStreamsMinimumIncrement is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this absolute increment and the procentual increase specified by MaxStreamsMultiplier is used.
const MaxStreamsMinimumIncrement = 10
// MaxNewStreamIDDelta is the maximum difference between and a newly opened Stream and the highest StreamID that a client has ever opened
// note that the number of streams is half this value, since the client can only open streams with open StreamID
const MaxNewStreamIDDelta = 4 * MaxStreamsPerConnection
// MaxIdleConnectionStateLifetime is the maximum value accepted for the idle connection state lifetime
// TODO: set a reasonable value here
const MaxIdleConnectionStateLifetime = 60 * time.Second

View File

@@ -11,8 +11,12 @@ import (
)
type streamsMap struct {
streams map[protocol.StreamID]*stream
openStreams []protocol.StreamID
streams map[protocol.StreamID]*stream
openStreams []protocol.StreamID
highestStreamOpenedByClient protocol.StreamID
streamsOpenedAfterLastGarbageCollect int
mutex sync.RWMutex
newStream newStreamLambda
maxNumStreams int
@@ -47,6 +51,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
if ok {
return s, nil // s may be nil
}
// ... we don't have an existing stream, try opening a new one
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -61,10 +66,24 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
if id%2 == 0 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
}
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByClient {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByClient))
}
s, err := m.newStream(id)
if err != nil {
return nil, err
}
if id > m.highestStreamOpenedByClient {
m.highestStreamOpenedByClient = id
}
m.streamsOpenedAfterLastGarbageCollect++
if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 {
m.garbageCollectClosedStreams()
}
m.putStream(s)
return s, nil
}
@@ -168,3 +187,17 @@ func (m *streamsMap) NumberOfStreams() int {
m.mutex.RUnlock()
return n
}
// garbageCollectClosedStreams deletes nil values in the streams if they are smaller than protocol.MaxNewStreamIDDelta than the highest stream opened by the client
// note that this garbage collection is relatively expensive, since it iterates over the whole streams map. It should not be called every time a stream is openend or closed
func (m *streamsMap) garbageCollectClosedStreams() {
for id, str := range m.streams {
if str != nil {
continue
}
if id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByClient {
delete(m.streams, id)
}
}
m.streamsOpenedAfterLastGarbageCollect = 0
}

View File

@@ -76,6 +76,93 @@ var _ = Describe("Streams Map", func() {
}
})
})
Context("DoS mitigation", func() {
It("opens and closes a lot of streams", func() {
for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 {
streamID := protocol.StreamID(i)
_, err := m.GetOrOpenStream(streamID)
Expect(m.highestStreamOpenedByClient).To(Equal(streamID))
Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(streamID)
Expect(err).NotTo(HaveOccurred())
}
})
It("prevents opening of streams with very low StreamIDs, if higher streams have already been opened", func() {
for i := 1; i < protocol.MaxNewStreamIDDelta+14; i += 2 {
if i == 11 || i == 13 {
continue
}
streamID := protocol.StreamID(i)
_, err := m.GetOrOpenStream(streamID)
Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(streamID)
Expect(err).NotTo(HaveOccurred())
}
Expect(m.highestStreamOpenedByClient).To(Equal(protocol.StreamID(protocol.MaxNewStreamIDDelta + 13)))
_, err := m.GetOrOpenStream(11)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 11, which is a lot smaller than the highest opened stream, 413"))
_, err = m.GetOrOpenStream(13)
Expect(err).ToNot(HaveOccurred())
})
It("garbage-collects closed streams", func() {
for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 {
streamID := protocol.StreamID(i)
_, err := m.GetOrOpenStream(streamID)
Expect(m.highestStreamOpenedByClient).To(Equal(streamID))
Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(streamID)
Expect(err).NotTo(HaveOccurred())
}
m.garbageCollectClosedStreams()
for i := 1; i < 3*protocol.MaxNewStreamIDDelta; i += 2 {
Expect(m.streams).ToNot(HaveKey(protocol.StreamID(i)))
}
for i := 3*protocol.MaxNewStreamIDDelta + 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 {
Expect(m.streams).To(HaveKey(protocol.StreamID(i)))
}
})
It("does not garbage-collects open streams", func() {
for i := 1; i < 1002; i += 2 {
streamID := protocol.StreamID(i)
_, err := m.GetOrOpenStream(streamID)
Expect(m.highestStreamOpenedByClient).To(Equal(streamID))
Expect(err).NotTo(HaveOccurred())
if streamID != 23 {
err = m.RemoveStream(streamID)
Expect(err).NotTo(HaveOccurred())
}
}
lengthBefore := len(m.streams)
m.garbageCollectClosedStreams()
Expect(len(m.streams)).To(BeNumerically("<", lengthBefore))
Expect(m.streams).To(HaveKey(protocol.StreamID(23)))
Expect(m.streams[23]).ToNot(BeNil())
})
It("runs garbage-collection after a bunch of streams have been opened", func() {
numGarbageCollections := 0
numSavedStreams := 0
for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 {
streamID := protocol.StreamID(i)
_, err := m.GetOrOpenStream(streamID)
Expect(m.highestStreamOpenedByClient).To(Equal(streamID))
Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(streamID)
Expect(err).NotTo(HaveOccurred())
if len(m.streams) != numSavedStreams+1 {
numGarbageCollections++
}
numSavedStreams = len(m.streams)
}
Expect(numGarbageCollections).ToNot(BeZero())
Expect(numGarbageCollections).To(BeNumerically("<", 4))
Expect(len(m.streams)).To(BeNumerically("<", 2*protocol.MaxNewStreamIDDelta))
})
})
})
Context("deleting streams", func() {