open implicitly opened streams in streamsMap

This commit is contained in:
Marten Seemann
2017-02-10 16:09:22 +07:00
parent f47142eaac
commit 6d3e94bf21
3 changed files with 98 additions and 36 deletions

View File

@@ -172,12 +172,12 @@ var _ = Describe("Session", func() {
StreamID: 5,
Data: []byte{0xde, 0xca, 0xfb, 0xad},
})
Expect(session.streamsMap.openStreams).To(HaveLen(2))
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
str, _ := session.streamsMap.GetOrOpenStream(5)
str, err := session.streamsMap.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
_, err := str.Read(p)
_, err = str.Read(p)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
})
@@ -197,14 +197,14 @@ var _ = Describe("Session", func() {
StreamID: 5,
Data: []byte{0xde, 0xca},
})
Expect(session.streamsMap.openStreams).To(HaveLen(2))
numOpenStreams := len(session.streamsMap.openStreams)
Expect(streamCallbackCalled).To(BeTrue())
session.handleStreamFrame(&frames.StreamFrame{
StreamID: 5,
Offset: 2,
Data: []byte{0xfb, 0xad},
})
Expect(session.streamsMap.openStreams).To(HaveLen(2))
Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams))
p := make([]byte, 4)
str, _ := session.streamsMap.GetOrOpenStream(5)
Expect(str).ToNot(BeNil())
@@ -218,8 +218,8 @@ var _ = Describe("Session", func() {
Expect(err).ToNot(HaveOccurred())
str.Close()
session.garbageCollectStreams()
Expect(session.streamsMap.openStreams).To(HaveLen(2))
str, _ = session.streamsMap.GetOrOpenStream(5)
str, err = session.streamsMap.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
})
@@ -229,7 +229,7 @@ var _ = Describe("Session", func() {
Data: []byte{0xde, 0xca, 0xfb, 0xad},
FinBit: true,
})
Expect(session.streamsMap.openStreams).To(HaveLen(2))
numOpenStreams := len(session.streamsMap.openStreams)
str, _ := session.streamsMap.GetOrOpenStream(5)
Expect(str).ToNot(BeNil())
Expect(streamCallbackCalled).To(BeTrue())
@@ -238,7 +238,7 @@ var _ = Describe("Session", func() {
Expect(err).To(MatchError(io.EOF))
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
session.garbageCollectStreams()
Expect(session.streamsMap.openStreams).To(HaveLen(2))
Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams))
str, _ = session.streamsMap.GetOrOpenStream(5)
Expect(str).ToNot(BeNil())
})
@@ -249,7 +249,7 @@ var _ = Describe("Session", func() {
Data: []byte{0xde, 0xca, 0xfb, 0xad},
FinBit: true,
})
Expect(session.streamsMap.openStreams).To(HaveLen(2))
numOpenStreams := len(session.streamsMap.openStreams)
str, _ := session.streamsMap.GetOrOpenStream(5)
Expect(str).ToNot(BeNil())
Expect(streamCallbackCalled).To(BeTrue())
@@ -258,7 +258,7 @@ var _ = Describe("Session", func() {
Expect(err).To(MatchError(io.EOF))
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
session.garbageCollectStreams()
Expect(session.streamsMap.openStreams).To(HaveLen(2))
Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams))
str, _ = session.streamsMap.GetOrOpenStream(5)
Expect(str).ToNot(BeNil())
// We still need to close the stream locally
@@ -266,7 +266,7 @@ var _ = Describe("Session", func() {
// ... and simulate that we actually the FIN
str.sentFin()
session.garbageCollectStreams()
Expect(session.streamsMap.openStreams).To(HaveLen(1))
Expect(len(session.streamsMap.openStreams)).To(BeNumerically("<", numOpenStreams))
str, err = session.streamsMap.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(str).To(BeNil())
@@ -276,23 +276,23 @@ var _ = Describe("Session", func() {
})
It("cancels streams with error", func() {
session.garbageCollectStreams()
testErr := errors.New("test")
session.handleStreamFrame(&frames.StreamFrame{
StreamID: 5,
Data: []byte{0xde, 0xca, 0xfb, 0xad},
})
Expect(session.streamsMap.openStreams).To(HaveLen(2))
str, _ := session.streamsMap.GetOrOpenStream(5)
str, err := session.streamsMap.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := str.Read(p)
_, err = str.Read(p)
Expect(err).ToNot(HaveOccurred())
session.closeStreamsWithError(testErr)
_, err = str.Read(p)
Expect(err).To(MatchError(testErr))
session.garbageCollectStreams()
Expect(session.streamsMap.openStreams).To(BeEmpty())
str, err = session.streamsMap.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(str).To(BeNil())
@@ -301,11 +301,11 @@ var _ = Describe("Session", func() {
It("cancels empty streams with error", func() {
testErr := errors.New("test")
session.GetOrOpenStream(5)
Expect(session.streamsMap.openStreams).To(HaveLen(2))
str, _ := session.streamsMap.GetOrOpenStream(5)
str, err := session.streamsMap.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
session.closeStreamsWithError(testErr)
_, err := str.Read([]byte{0})
_, err = str.Read([]byte{0})
Expect(err).To(MatchError(testErr))
session.garbageCollectStreams()
str, err = session.streamsMap.GetOrOpenStream(5)
@@ -1180,7 +1180,7 @@ var _ = Describe("Session", func() {
Context("counting streams", func() {
It("errors when too many streams are opened", func() {
for i := 2; i <= 110; i++ {
for i := 0; i < 110; i++ {
_, err := session.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred())
}

View File

@@ -17,11 +17,13 @@ type streamsMap struct {
connectionParameters handshake.ConnectionParametersManager
streams map[protocol.StreamID]*stream
// TODO: remove this
openStreams []protocol.StreamID
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
// TODO: remove this
streamsOpenedAfterLastGarbageCollect int
newStream newStreamLambda
@@ -69,7 +71,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
return s, nil // s may be nil
}
// ... we don't have an existing stream, try opening a new one
// ... we don't have an existing stream
m.mutex.Lock()
defer m.mutex.Unlock()
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
@@ -77,6 +79,35 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
if ok {
return s, nil
}
if id <= m.highestStreamOpenedByPeer {
return nil, nil
}
highestOpened := m.highestStreamOpenedByPeer
sid := id
// sid is always odd
for sid > highestOpened {
_, err := m.openRemoteStream(sid)
if err != nil {
return nil, err
}
if sid == 1 {
break
}
sid -= 2
}
// maybe trigger garbage collection of streams map
m.streamsOpenedAfterLastGarbageCollect++
if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 {
m.garbageCollectClosedStreams()
}
return m.streams[id], nil
}
func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() {
return nil, qerr.TooManyOpenStreams
}
@@ -105,12 +136,6 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
m.highestStreamOpenedByPeer = id
}
// maybe trigger garbage collection of streams map
m.streamsOpenedAfterLastGarbageCollect++
if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 {
m.garbageCollectClosedStreams()
}
m.putStream(s)
return s, nil
}
@@ -145,7 +170,12 @@ func (m *streamsMap) Iterate(fn streamLambda) error {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, streamID := range m.openStreams {
openStreams := make([]protocol.StreamID, len(m.openStreams), len(m.openStreams))
for i, streamID := range m.openStreams { // copy openStreams
openStreams[i] = streamID
}
for _, streamID := range openStreams {
cont, err := m.iterateFunc(streamID, fn)
if err != nil {
return err

View File

@@ -79,9 +79,9 @@ var _ = Describe("Streams Map", func() {
Context("client-side streams", func() {
It("gets new streams", func() {
s, err := m.GetOrOpenStream(5)
s, err := m.GetOrOpenStream(1)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
Expect(s.StreamID()).To(Equal(protocol.StreamID(1)))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
Expect(m.numOutgoingStreams).To(BeZero())
})
@@ -94,10 +94,11 @@ var _ = Describe("Streams Map", func() {
It("gets existing streams", func() {
s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
numStreams := m.numIncomingStreams
s, err = m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
Expect(m.numIncomingStreams).To(Equal(numStreams))
})
It("returns nil for closed streams", func() {
@@ -108,7 +109,24 @@ var _ = Describe("Streams Map", func() {
s, err = m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s).To(BeNil())
Expect(m.numIncomingStreams).To(BeZero())
})
It("opens skipped streams", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(m.streams).To(HaveKey(protocol.StreamID(1)))
Expect(m.streams).To(HaveKey(protocol.StreamID(3)))
Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
})
It("doesn't reopen an already closed stream", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
err = m.RemoveStream(5)
Expect(err).ToNot(HaveOccurred())
str, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
Context("counting streams", func() {
@@ -127,6 +145,11 @@ var _ = Describe("Streams Map", func() {
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("errors when too many streams are opened implicitely", func() {
_, err := m.GetOrOpenStream(protocol.StreamID(maxNumStreams*2 + 1))
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("does not error when many streams are opened and closed", func() {
for i := 2; i < 10*maxNumStreams; i++ {
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
@@ -197,12 +220,20 @@ var _ = Describe("Streams Map", func() {
})
It("gets new streams", func() {
s, err := m.GetOrOpenStream(6)
s, err := m.GetOrOpenStream(2)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(6)))
Expect(s.StreamID()).To(Equal(protocol.StreamID(2)))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
Expect(m.numIncomingStreams).To(BeZero())
})
It("opens skipped streams", func() {
_, err := m.GetOrOpenStream(6)
Expect(err).NotTo(HaveOccurred())
Expect(m.streams).To(HaveKey(protocol.StreamID(2)))
Expect(m.streams).To(HaveKey(protocol.StreamID(4)))
Expect(m.streams).To(HaveKey(protocol.StreamID(6)))
})
})
Context("server-side streams", func() {
@@ -231,6 +262,7 @@ var _ = Describe("Streams Map", func() {
setNewStreamsMap(protocol.PerspectiveServer)
})
// TODO: remove when removing the openStreams slice
Context("DoS mitigation", func() {
It("opens and closes a lot of streams", func() {
for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 {
@@ -243,7 +275,7 @@ var _ = Describe("Streams Map", func() {
}
})
It("prevents opening of streams with very low StreamIDs, if higher streams have already been opened", func() {
PIt("prevents opening of streams with very low StreamIDs, if higher streams have already been opened", func() {
for i := 1; i < protocol.MaxNewStreamIDDelta+14; i += 2 {
if i == 11 || i == 13 {
continue