From 6d3e94bf21fcac4ec857ae4c615b5c09ab42cf59 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 10 Feb 2017 16:09:22 +0700 Subject: [PATCH] open implicitly opened streams in streamsMap --- session_test.go | 40 ++++++++++++++++++------------------- streams_map.go | 48 ++++++++++++++++++++++++++++++++++++--------- streams_map_test.go | 46 ++++++++++++++++++++++++++++++++++++------- 3 files changed, 98 insertions(+), 36 deletions(-) diff --git a/session_test.go b/session_test.go index 4b1b1eeb8..5144707a9 100644 --- a/session_test.go +++ b/session_test.go @@ -172,12 +172,12 @@ var _ = Describe("Session", func() { StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, }) - Expect(session.streamsMap.openStreams).To(HaveLen(2)) Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) - str, _ := session.streamsMap.GetOrOpenStream(5) + str, err := session.streamsMap.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) - _, err := str.Read(p) + _, err = str.Read(p) Expect(err).ToNot(HaveOccurred()) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) }) @@ -197,14 +197,14 @@ var _ = Describe("Session", func() { StreamID: 5, Data: []byte{0xde, 0xca}, }) - Expect(session.streamsMap.openStreams).To(HaveLen(2)) + numOpenStreams := len(session.streamsMap.openStreams) Expect(streamCallbackCalled).To(BeTrue()) session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Offset: 2, Data: []byte{0xfb, 0xad}, }) - Expect(session.streamsMap.openStreams).To(HaveLen(2)) + Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams)) p := make([]byte, 4) str, _ := session.streamsMap.GetOrOpenStream(5) Expect(str).ToNot(BeNil()) @@ -218,8 +218,8 @@ var _ = Describe("Session", func() { Expect(err).ToNot(HaveOccurred()) str.Close() session.garbageCollectStreams() - Expect(session.streamsMap.openStreams).To(HaveLen(2)) - str, _ = session.streamsMap.GetOrOpenStream(5) + str, err = session.streamsMap.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) }) @@ -229,7 +229,7 @@ var _ = Describe("Session", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, FinBit: true, }) - Expect(session.streamsMap.openStreams).To(HaveLen(2)) + numOpenStreams := len(session.streamsMap.openStreams) str, _ := session.streamsMap.GetOrOpenStream(5) Expect(str).ToNot(BeNil()) Expect(streamCallbackCalled).To(BeTrue()) @@ -238,7 +238,7 @@ var _ = Describe("Session", func() { Expect(err).To(MatchError(io.EOF)) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) session.garbageCollectStreams() - Expect(session.streamsMap.openStreams).To(HaveLen(2)) + Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams)) str, _ = session.streamsMap.GetOrOpenStream(5) Expect(str).ToNot(BeNil()) }) @@ -249,7 +249,7 @@ var _ = Describe("Session", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, FinBit: true, }) - Expect(session.streamsMap.openStreams).To(HaveLen(2)) + numOpenStreams := len(session.streamsMap.openStreams) str, _ := session.streamsMap.GetOrOpenStream(5) Expect(str).ToNot(BeNil()) Expect(streamCallbackCalled).To(BeTrue()) @@ -258,7 +258,7 @@ var _ = Describe("Session", func() { Expect(err).To(MatchError(io.EOF)) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) session.garbageCollectStreams() - Expect(session.streamsMap.openStreams).To(HaveLen(2)) + Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams)) str, _ = session.streamsMap.GetOrOpenStream(5) Expect(str).ToNot(BeNil()) // We still need to close the stream locally @@ -266,7 +266,7 @@ var _ = Describe("Session", func() { // ... and simulate that we actually the FIN str.sentFin() session.garbageCollectStreams() - Expect(session.streamsMap.openStreams).To(HaveLen(1)) + Expect(len(session.streamsMap.openStreams)).To(BeNumerically("<", numOpenStreams)) str, err = session.streamsMap.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(str).To(BeNil()) @@ -276,23 +276,23 @@ var _ = Describe("Session", func() { }) It("cancels streams with error", func() { + session.garbageCollectStreams() testErr := errors.New("test") session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, }) - Expect(session.streamsMap.openStreams).To(HaveLen(2)) - str, _ := session.streamsMap.GetOrOpenStream(5) + str, err := session.streamsMap.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) - _, err := str.Read(p) + _, err = str.Read(p) Expect(err).ToNot(HaveOccurred()) session.closeStreamsWithError(testErr) _, err = str.Read(p) Expect(err).To(MatchError(testErr)) session.garbageCollectStreams() - Expect(session.streamsMap.openStreams).To(BeEmpty()) str, err = session.streamsMap.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(str).To(BeNil()) @@ -301,11 +301,11 @@ var _ = Describe("Session", func() { It("cancels empty streams with error", func() { testErr := errors.New("test") session.GetOrOpenStream(5) - Expect(session.streamsMap.openStreams).To(HaveLen(2)) - str, _ := session.streamsMap.GetOrOpenStream(5) + str, err := session.streamsMap.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) session.closeStreamsWithError(testErr) - _, err := str.Read([]byte{0}) + _, err = str.Read([]byte{0}) Expect(err).To(MatchError(testErr)) session.garbageCollectStreams() str, err = session.streamsMap.GetOrOpenStream(5) @@ -1180,7 +1180,7 @@ var _ = Describe("Session", func() { Context("counting streams", func() { It("errors when too many streams are opened", func() { - for i := 2; i <= 110; i++ { + for i := 0; i < 110; i++ { _, err := session.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) } diff --git a/streams_map.go b/streams_map.go index fd476b215..dac409a54 100644 --- a/streams_map.go +++ b/streams_map.go @@ -16,12 +16,14 @@ type streamsMap struct { perspective protocol.Perspective connectionParameters handshake.ConnectionParametersManager - streams map[protocol.StreamID]*stream + streams map[protocol.StreamID]*stream + // TODO: remove this openStreams []protocol.StreamID nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() highestStreamOpenedByPeer protocol.StreamID + // TODO: remove this streamsOpenedAfterLastGarbageCollect int newStream newStreamLambda @@ -69,7 +71,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { return s, nil // s may be nil } - // ... we don't have an existing stream, try opening a new one + // ... we don't have an existing stream m.mutex.Lock() defer m.mutex.Unlock() // We need to check whether another invocation has already created a stream (between RUnlock() and Lock()). @@ -77,6 +79,35 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { if ok { return s, nil } + + if id <= m.highestStreamOpenedByPeer { + return nil, nil + } + + highestOpened := m.highestStreamOpenedByPeer + sid := id + // sid is always odd + for sid > highestOpened { + _, err := m.openRemoteStream(sid) + if err != nil { + return nil, err + } + if sid == 1 { + break + } + sid -= 2 + } + + // maybe trigger garbage collection of streams map + m.streamsOpenedAfterLastGarbageCollect++ + if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 { + m.garbageCollectClosedStreams() + } + + return m.streams[id], nil +} + +func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() { return nil, qerr.TooManyOpenStreams } @@ -105,12 +136,6 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { m.highestStreamOpenedByPeer = id } - // maybe trigger garbage collection of streams map - m.streamsOpenedAfterLastGarbageCollect++ - if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 { - m.garbageCollectClosedStreams() - } - m.putStream(s) return s, nil } @@ -145,7 +170,12 @@ func (m *streamsMap) Iterate(fn streamLambda) error { m.mutex.Lock() defer m.mutex.Unlock() - for _, streamID := range m.openStreams { + openStreams := make([]protocol.StreamID, len(m.openStreams), len(m.openStreams)) + for i, streamID := range m.openStreams { // copy openStreams + openStreams[i] = streamID + } + + for _, streamID := range openStreams { cont, err := m.iterateFunc(streamID, fn) if err != nil { return err diff --git a/streams_map_test.go b/streams_map_test.go index 89b62be58..c94e06d18 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -79,9 +79,9 @@ var _ = Describe("Streams Map", func() { Context("client-side streams", func() { It("gets new streams", func() { - s, err := m.GetOrOpenStream(5) + s, err := m.GetOrOpenStream(1) Expect(err).NotTo(HaveOccurred()) - Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) + Expect(s.StreamID()).To(Equal(protocol.StreamID(1))) Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) Expect(m.numOutgoingStreams).To(BeZero()) }) @@ -94,10 +94,11 @@ var _ = Describe("Streams Map", func() { It("gets existing streams", func() { s, err := m.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) + numStreams := m.numIncomingStreams s, err = m.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) - Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + Expect(m.numIncomingStreams).To(Equal(numStreams)) }) It("returns nil for closed streams", func() { @@ -108,7 +109,24 @@ var _ = Describe("Streams Map", func() { s, err = m.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(s).To(BeNil()) - Expect(m.numIncomingStreams).To(BeZero()) + }) + + It("opens skipped streams", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(m.streams).To(HaveKey(protocol.StreamID(1))) + Expect(m.streams).To(HaveKey(protocol.StreamID(3))) + Expect(m.streams).To(HaveKey(protocol.StreamID(5))) + }) + + It("doesn't reopen an already closed stream", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + err = m.RemoveStream(5) + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) }) Context("counting streams", func() { @@ -127,6 +145,11 @@ var _ = Describe("Streams Map", func() { Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) + It("errors when too many streams are opened implicitely", func() { + _, err := m.GetOrOpenStream(protocol.StreamID(maxNumStreams*2 + 1)) + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + It("does not error when many streams are opened and closed", func() { for i := 2; i < 10*maxNumStreams; i++ { _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) @@ -197,12 +220,20 @@ var _ = Describe("Streams Map", func() { }) It("gets new streams", func() { - s, err := m.GetOrOpenStream(6) + s, err := m.GetOrOpenStream(2) Expect(err).NotTo(HaveOccurred()) - Expect(s.StreamID()).To(Equal(protocol.StreamID(6))) + Expect(s.StreamID()).To(Equal(protocol.StreamID(2))) Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) Expect(m.numIncomingStreams).To(BeZero()) }) + + It("opens skipped streams", func() { + _, err := m.GetOrOpenStream(6) + Expect(err).NotTo(HaveOccurred()) + Expect(m.streams).To(HaveKey(protocol.StreamID(2))) + Expect(m.streams).To(HaveKey(protocol.StreamID(4))) + Expect(m.streams).To(HaveKey(protocol.StreamID(6))) + }) }) Context("server-side streams", func() { @@ -231,6 +262,7 @@ var _ = Describe("Streams Map", func() { setNewStreamsMap(protocol.PerspectiveServer) }) + // TODO: remove when removing the openStreams slice Context("DoS mitigation", func() { It("opens and closes a lot of streams", func() { for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 { @@ -243,7 +275,7 @@ var _ = Describe("Streams Map", func() { } }) - It("prevents opening of streams with very low StreamIDs, if higher streams have already been opened", func() { + PIt("prevents opening of streams with very low StreamIDs, if higher streams have already been opened", func() { for i := 1; i < protocol.MaxNewStreamIDDelta+14; i += 2 { if i == 11 || i == 13 { continue