From c546f5c9dca63f1735ea49ce1c0d0bba57d9eac5 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Fri, 27 May 2016 22:39:37 +0200 Subject: [PATCH] enforce max streams per connection limit fixes #13 --- handshake/connection_parameters_manager.go | 1 + protocol/server_parameters.go | 4 +++- session.go | 12 +++++++++-- session_test.go | 25 ++++++++++++++++++++++ 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/handshake/connection_parameters_manager.go b/handshake/connection_parameters_manager.go index 3a41e988..820c1269 100644 --- a/handshake/connection_parameters_manager.go +++ b/handshake/connection_parameters_manager.go @@ -46,6 +46,7 @@ func NewConnectionParamatersManager() *ConnectionParametersManager { sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + maxStreamsPerConnection: protocol.MaxStreamsPerConnection, } } diff --git a/protocol/server_parameters.go b/protocol/server_parameters.go index de12c5de..8fb250ae 100644 --- a/protocol/server_parameters.go +++ b/protocol/server_parameters.go @@ -33,9 +33,11 @@ const ReceiveStreamFlowControlWindow ByteCount = (1 << 20) // 1 MB const ReceiveConnectionFlowControlWindow ByteCount = (1 << 20) * 1.5 // 1.5 MB // MaxStreamsPerConnection is the maximum value accepted for the number of streams per connection -// TODO: set a reasonable value here const MaxStreamsPerConnection uint32 = 100 +// MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. +const MaxStreamsMultiplier = 1.1 + // MaxIdleConnectionStateLifetime is the maximum value accepted for the idle connection state lifetime // TODO: set a reasonable value here const MaxIdleConnectionStateLifetime = 60 * time.Second diff --git a/session.go b/session.go index 4ddf78ac..b2a46a23 100644 --- a/session.go +++ b/session.go @@ -44,8 +44,9 @@ type Session struct { conn connection - streams map[protocol.StreamID]*stream - streamsMutex sync.RWMutex + streams map[protocol.StreamID]*stream + openStreamsCount uint32 + streamsMutex sync.RWMutex sentPacketHandler ackhandler.SentPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler @@ -600,7 +601,12 @@ func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { return s.newStreamImpl(id) } +// The streamsMutex is locked by OpenStream or GetOrOpenStream before calling this function. func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) { + maxAllowedStreams := uint32(protocol.MaxStreamsMultiplier * float32(s.connectionParametersManager.GetMaxStreamsPerConnection())) + if s.openStreamsCount >= maxAllowedStreams { + return nil, qerr.TooManyOpenStreams + } stream, err := newStream(s, s.connectionParametersManager, s.flowController, id) if err != nil { return nil, err @@ -608,6 +614,7 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) { if s.streams[id] != nil { return nil, fmt.Errorf("Session: stream with ID %d already exists", id) } + s.openStreamsCount++ s.streams[id] = stream return stream, nil } @@ -628,6 +635,7 @@ func (s *Session) garbageCollectStreams() { s.windowUpdateManager.RemoveStream(k) } if v.finished() { + s.openStreamsCount-- s.streams[k] = nil } } diff --git a/session_test.go b/session_test.go index 35ec3917..2b0b25dc 100644 --- a/session_test.go +++ b/session_test.go @@ -634,4 +634,29 @@ var _ = Describe("Session", func() { session.scheduleSending() Eventually(func() bool { return len(conn.written) > 0 }).Should(BeTrue()) }) + + Context("counting streams", func() { + It("errors when too many streams are opened", func() { + // 1.1 * 100 + for i := 2; i <= 110; i++ { + _, err := session.OpenStream(protocol.StreamID(i)) + Expect(err).NotTo(HaveOccurred()) + } + _, err := session.OpenStream(protocol.StreamID(110)) + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("does not error when many streams are opened and closed", func() { + for i := 2; i <= 1000; i++ { + s, err := session.OpenStream(protocol.StreamID(i)) + Expect(err).NotTo(HaveOccurred()) + err = s.Close() + Expect(err).NotTo(HaveOccurred()) + s.CloseRemote(0) + _, err = s.Read([]byte("a")) + Expect(err).To(MatchError(io.EOF)) + session.garbageCollectStreams() + } + }) + }) })