From 543ce21a5a9ac477f37f240c07bbd86dc88c8de4 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 24 Aug 2016 00:59:11 +0700 Subject: [PATCH 1/3] prevent opening of new streams with very low StreamIDs ref #256 --- protocol/server_parameters.go | 4 ++++ streams_map.go | 16 ++++++++++++++-- streams_map_test.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/protocol/server_parameters.go b/protocol/server_parameters.go index ad25001e5..745465306 100644 --- a/protocol/server_parameters.go +++ b/protocol/server_parameters.go @@ -40,6 +40,10 @@ const MaxStreamsMultiplier = 1.1 // MaxStreamsMinimumIncrement is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this absolute increment and the procentual increase specified by MaxStreamsMultiplier is used. const MaxStreamsMinimumIncrement = 10 +// MaxNewStreamIDDelta is the maximum difference between and a newly opened Stream and the highest StreamID that a client has ever opened +// note that the number of streams is half this value, since the client can only open streams with open StreamID +const MaxNewStreamIDDelta = 4 * MaxStreamsPerConnection + // MaxIdleConnectionStateLifetime is the maximum value accepted for the idle connection state lifetime // TODO: set a reasonable value here const MaxIdleConnectionStateLifetime = 60 * time.Second diff --git a/streams_map.go b/streams_map.go index 0bb2a200f..eb7f309f6 100644 --- a/streams_map.go +++ b/streams_map.go @@ -11,8 +11,11 @@ import ( ) type streamsMap struct { - streams map[protocol.StreamID]*stream - openStreams []protocol.StreamID + streams map[protocol.StreamID]*stream + openStreams []protocol.StreamID + + highestStreamOpenedByClient protocol.StreamID + mutex sync.RWMutex newStream newStreamLambda maxNumStreams int @@ -61,10 +64,19 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { if id%2 == 0 { return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) } + if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByClient { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByClient)) + } + s, err := m.newStream(id) if err != nil { return nil, err } + + if id > m.highestStreamOpenedByClient { + m.highestStreamOpenedByClient = id + } + m.putStream(s) return s, nil } diff --git a/streams_map_test.go b/streams_map_test.go index ddd772a7d..e62596552 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -76,6 +76,37 @@ var _ = Describe("Streams Map", func() { } }) }) + + Context("DoS mitigation", func() { + It("opens and closes a lot of streams", func() { + for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 { + streamID := protocol.StreamID(i) + _, err := m.GetOrOpenStream(streamID) + Expect(m.highestStreamOpenedByClient).To(Equal(streamID)) + Expect(err).NotTo(HaveOccurred()) + err = m.RemoveStream(streamID) + Expect(err).NotTo(HaveOccurred()) + } + }) + + It("prevents opening of streams with very low StreamIDs, if higher streams have already been opened", func() { + for i := 1; i < protocol.MaxNewStreamIDDelta+14; i += 2 { + if i == 11 || i == 13 { + continue + } + streamID := protocol.StreamID(i) + _, err := m.GetOrOpenStream(streamID) + Expect(err).NotTo(HaveOccurred()) + err = m.RemoveStream(streamID) + Expect(err).NotTo(HaveOccurred()) + } + Expect(m.highestStreamOpenedByClient).To(Equal(protocol.StreamID(protocol.MaxNewStreamIDDelta + 13))) + _, err := m.GetOrOpenStream(11) + Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 11, which is a lot smaller than the highest opened stream, 413")) + _, err = m.GetOrOpenStream(13) + Expect(err).ToNot(HaveOccurred()) + }) + }) }) Context("deleting streams", func() { From 15352e9591e5a0abbb5691e5372d40526dfc4cd7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 24 Aug 2016 13:04:28 +0700 Subject: [PATCH 2/3] implement garbage collection of closed streams in streamsMap ref #256 --- streams_map.go | 14 ++++++++++++++ streams_map_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/streams_map.go b/streams_map.go index eb7f309f6..3fd553b43 100644 --- a/streams_map.go +++ b/streams_map.go @@ -180,3 +180,17 @@ func (m *streamsMap) NumberOfStreams() int { m.mutex.RUnlock() return n } + +// garbageCollectClosedStreams deletes nil values in the streams if they are smaller than protocol.MaxNewStreamIDDelta than the highest stream opened by the client +func (m *streamsMap) garbageCollectClosedStreams() { + m.mutex.Lock() + for id, str := range m.streams { + if str != nil { + continue + } + if id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByClient { + delete(m.streams, id) + } + } + m.mutex.Unlock() +} diff --git a/streams_map_test.go b/streams_map_test.go index e62596552..4c1b82609 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -106,6 +106,42 @@ var _ = Describe("Streams Map", func() { _, err = m.GetOrOpenStream(13) Expect(err).ToNot(HaveOccurred()) }) + + It("garbage-collects closed streams", func() { + for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 { + streamID := protocol.StreamID(i) + _, err := m.GetOrOpenStream(streamID) + Expect(m.highestStreamOpenedByClient).To(Equal(streamID)) + Expect(err).NotTo(HaveOccurred()) + err = m.RemoveStream(streamID) + Expect(err).NotTo(HaveOccurred()) + } + m.garbageCollectClosedStreams() + for i := 1; i < 3*protocol.MaxNewStreamIDDelta; i += 2 { + Expect(m.streams).ToNot(HaveKey(protocol.StreamID(i))) + } + for i := 3*protocol.MaxNewStreamIDDelta + 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 { + Expect(m.streams).To(HaveKey(protocol.StreamID(i))) + } + }) + + It("does not garbage-collects open streams", func() { + for i := 1; i < 1002; i += 2 { + streamID := protocol.StreamID(i) + _, err := m.GetOrOpenStream(streamID) + Expect(m.highestStreamOpenedByClient).To(Equal(streamID)) + Expect(err).NotTo(HaveOccurred()) + if streamID != 23 { + err = m.RemoveStream(streamID) + Expect(err).NotTo(HaveOccurred()) + } + } + lengthBefore := len(m.streams) + m.garbageCollectClosedStreams() + Expect(len(m.streams)).To(BeNumerically("<", lengthBefore)) + Expect(m.streams).To(HaveKey(protocol.StreamID(23))) + Expect(m.streams[23]).ToNot(BeNil()) + }) }) }) From 416e3f9e2eb8ffb8d30699bbaf2128f067108aa7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 24 Aug 2016 13:24:58 +0700 Subject: [PATCH 3/3] garbage collect streams map fixes #256 --- streams_map.go | 13 ++++++++++--- streams_map_test.go | 20 ++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/streams_map.go b/streams_map.go index 3fd553b43..14b74c696 100644 --- a/streams_map.go +++ b/streams_map.go @@ -14,7 +14,8 @@ type streamsMap struct { streams map[protocol.StreamID]*stream openStreams []protocol.StreamID - highestStreamOpenedByClient protocol.StreamID + highestStreamOpenedByClient protocol.StreamID + streamsOpenedAfterLastGarbageCollect int mutex sync.RWMutex newStream newStreamLambda @@ -50,6 +51,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { if ok { return s, nil // s may be nil } + // ... we don't have an existing stream, try opening a new one m.mutex.Lock() defer m.mutex.Unlock() @@ -77,6 +79,11 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { m.highestStreamOpenedByClient = id } + m.streamsOpenedAfterLastGarbageCollect++ + if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 { + m.garbageCollectClosedStreams() + } + m.putStream(s) return s, nil } @@ -182,8 +189,8 @@ func (m *streamsMap) NumberOfStreams() int { } // garbageCollectClosedStreams deletes nil values in the streams if they are smaller than protocol.MaxNewStreamIDDelta than the highest stream opened by the client +// note that this garbage collection is relatively expensive, since it iterates over the whole streams map. It should not be called every time a stream is openend or closed func (m *streamsMap) garbageCollectClosedStreams() { - m.mutex.Lock() for id, str := range m.streams { if str != nil { continue @@ -192,5 +199,5 @@ func (m *streamsMap) garbageCollectClosedStreams() { delete(m.streams, id) } } - m.mutex.Unlock() + m.streamsOpenedAfterLastGarbageCollect = 0 } diff --git a/streams_map_test.go b/streams_map_test.go index 4c1b82609..59268c744 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -142,6 +142,26 @@ var _ = Describe("Streams Map", func() { Expect(m.streams).To(HaveKey(protocol.StreamID(23))) Expect(m.streams[23]).ToNot(BeNil()) }) + + It("runs garbage-collection after a bunch of streams have been opened", func() { + numGarbageCollections := 0 + numSavedStreams := 0 + for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 { + streamID := protocol.StreamID(i) + _, err := m.GetOrOpenStream(streamID) + Expect(m.highestStreamOpenedByClient).To(Equal(streamID)) + Expect(err).NotTo(HaveOccurred()) + err = m.RemoveStream(streamID) + Expect(err).NotTo(HaveOccurred()) + if len(m.streams) != numSavedStreams+1 { + numGarbageCollections++ + } + numSavedStreams = len(m.streams) + } + Expect(numGarbageCollections).ToNot(BeZero()) + Expect(numGarbageCollections).To(BeNumerically("<", 4)) + Expect(len(m.streams)).To(BeNumerically("<", 2*protocol.MaxNewStreamIDDelta)) + }) }) })