forked from quic-go/quic-go
open implicitly opened streams in streamsMap
This commit is contained in:
@@ -172,12 +172,12 @@ var _ = Describe("Session", func() {
|
||||
StreamID: 5,
|
||||
Data: []byte{0xde, 0xca, 0xfb, 0xad},
|
||||
})
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(2))
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
p := make([]byte, 4)
|
||||
str, _ := session.streamsMap.GetOrOpenStream(5)
|
||||
str, err := session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
_, err := str.Read(p)
|
||||
_, err = str.Read(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
|
||||
})
|
||||
@@ -197,14 +197,14 @@ var _ = Describe("Session", func() {
|
||||
StreamID: 5,
|
||||
Data: []byte{0xde, 0xca},
|
||||
})
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(2))
|
||||
numOpenStreams := len(session.streamsMap.openStreams)
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
session.handleStreamFrame(&frames.StreamFrame{
|
||||
StreamID: 5,
|
||||
Offset: 2,
|
||||
Data: []byte{0xfb, 0xad},
|
||||
})
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(2))
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams))
|
||||
p := make([]byte, 4)
|
||||
str, _ := session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(str).ToNot(BeNil())
|
||||
@@ -218,8 +218,8 @@ var _ = Describe("Session", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.Close()
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(2))
|
||||
str, _ = session.streamsMap.GetOrOpenStream(5)
|
||||
str, err = session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
})
|
||||
|
||||
@@ -229,7 +229,7 @@ var _ = Describe("Session", func() {
|
||||
Data: []byte{0xde, 0xca, 0xfb, 0xad},
|
||||
FinBit: true,
|
||||
})
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(2))
|
||||
numOpenStreams := len(session.streamsMap.openStreams)
|
||||
str, _ := session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(str).ToNot(BeNil())
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
@@ -238,7 +238,7 @@ var _ = Describe("Session", func() {
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(2))
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams))
|
||||
str, _ = session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(str).ToNot(BeNil())
|
||||
})
|
||||
@@ -249,7 +249,7 @@ var _ = Describe("Session", func() {
|
||||
Data: []byte{0xde, 0xca, 0xfb, 0xad},
|
||||
FinBit: true,
|
||||
})
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(2))
|
||||
numOpenStreams := len(session.streamsMap.openStreams)
|
||||
str, _ := session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(str).ToNot(BeNil())
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
@@ -258,7 +258,7 @@ var _ = Describe("Session", func() {
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(2))
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams))
|
||||
str, _ = session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(str).ToNot(BeNil())
|
||||
// We still need to close the stream locally
|
||||
@@ -266,7 +266,7 @@ var _ = Describe("Session", func() {
|
||||
// ... and simulate that we actually the FIN
|
||||
str.sentFin()
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(1))
|
||||
Expect(len(session.streamsMap.openStreams)).To(BeNumerically("<", numOpenStreams))
|
||||
str, err = session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
@@ -276,23 +276,23 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
It("cancels streams with error", func() {
|
||||
session.garbageCollectStreams()
|
||||
testErr := errors.New("test")
|
||||
session.handleStreamFrame(&frames.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte{0xde, 0xca, 0xfb, 0xad},
|
||||
})
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(2))
|
||||
str, _ := session.streamsMap.GetOrOpenStream(5)
|
||||
str, err := session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
p := make([]byte, 4)
|
||||
_, err := str.Read(p)
|
||||
_, err = str.Read(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
session.closeStreamsWithError(testErr)
|
||||
_, err = str.Read(p)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streamsMap.openStreams).To(BeEmpty())
|
||||
str, err = session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
@@ -301,11 +301,11 @@ var _ = Describe("Session", func() {
|
||||
It("cancels empty streams with error", func() {
|
||||
testErr := errors.New("test")
|
||||
session.GetOrOpenStream(5)
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(2))
|
||||
str, _ := session.streamsMap.GetOrOpenStream(5)
|
||||
str, err := session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
session.closeStreamsWithError(testErr)
|
||||
_, err := str.Read([]byte{0})
|
||||
_, err = str.Read([]byte{0})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
session.garbageCollectStreams()
|
||||
str, err = session.streamsMap.GetOrOpenStream(5)
|
||||
@@ -1180,7 +1180,7 @@ var _ = Describe("Session", func() {
|
||||
|
||||
Context("counting streams", func() {
|
||||
It("errors when too many streams are opened", func() {
|
||||
for i := 2; i <= 110; i++ {
|
||||
for i := 0; i < 110; i++ {
|
||||
_, err := session.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
|
||||
@@ -17,11 +17,13 @@ type streamsMap struct {
|
||||
connectionParameters handshake.ConnectionParametersManager
|
||||
|
||||
streams map[protocol.StreamID]*stream
|
||||
// TODO: remove this
|
||||
openStreams []protocol.StreamID
|
||||
|
||||
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
||||
highestStreamOpenedByPeer protocol.StreamID
|
||||
|
||||
// TODO: remove this
|
||||
streamsOpenedAfterLastGarbageCollect int
|
||||
|
||||
newStream newStreamLambda
|
||||
@@ -69,7 +71,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||
return s, nil // s may be nil
|
||||
}
|
||||
|
||||
// ... we don't have an existing stream, try opening a new one
|
||||
// ... we don't have an existing stream
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
|
||||
@@ -77,6 +79,35 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||
if ok {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
if id <= m.highestStreamOpenedByPeer {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
highestOpened := m.highestStreamOpenedByPeer
|
||||
sid := id
|
||||
// sid is always odd
|
||||
for sid > highestOpened {
|
||||
_, err := m.openRemoteStream(sid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sid == 1 {
|
||||
break
|
||||
}
|
||||
sid -= 2
|
||||
}
|
||||
|
||||
// maybe trigger garbage collection of streams map
|
||||
m.streamsOpenedAfterLastGarbageCollect++
|
||||
if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 {
|
||||
m.garbageCollectClosedStreams()
|
||||
}
|
||||
|
||||
return m.streams[id], nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
|
||||
if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() {
|
||||
return nil, qerr.TooManyOpenStreams
|
||||
}
|
||||
@@ -105,12 +136,6 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||
m.highestStreamOpenedByPeer = id
|
||||
}
|
||||
|
||||
// maybe trigger garbage collection of streams map
|
||||
m.streamsOpenedAfterLastGarbageCollect++
|
||||
if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 {
|
||||
m.garbageCollectClosedStreams()
|
||||
}
|
||||
|
||||
m.putStream(s)
|
||||
return s, nil
|
||||
}
|
||||
@@ -145,7 +170,12 @@ func (m *streamsMap) Iterate(fn streamLambda) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
for _, streamID := range m.openStreams {
|
||||
openStreams := make([]protocol.StreamID, len(m.openStreams), len(m.openStreams))
|
||||
for i, streamID := range m.openStreams { // copy openStreams
|
||||
openStreams[i] = streamID
|
||||
}
|
||||
|
||||
for _, streamID := range openStreams {
|
||||
cont, err := m.iterateFunc(streamID, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -79,9 +79,9 @@ var _ = Describe("Streams Map", func() {
|
||||
|
||||
Context("client-side streams", func() {
|
||||
It("gets new streams", func() {
|
||||
s, err := m.GetOrOpenStream(5)
|
||||
s, err := m.GetOrOpenStream(1)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(1)))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
|
||||
Expect(m.numOutgoingStreams).To(BeZero())
|
||||
})
|
||||
@@ -94,10 +94,11 @@ var _ = Describe("Streams Map", func() {
|
||||
It("gets existing streams", func() {
|
||||
s, err := m.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
numStreams := m.numIncomingStreams
|
||||
s, err = m.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
|
||||
Expect(m.numIncomingStreams).To(Equal(numStreams))
|
||||
})
|
||||
|
||||
It("returns nil for closed streams", func() {
|
||||
@@ -108,7 +109,24 @@ var _ = Describe("Streams Map", func() {
|
||||
s, err = m.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s).To(BeNil())
|
||||
Expect(m.numIncomingStreams).To(BeZero())
|
||||
})
|
||||
|
||||
It("opens skipped streams", func() {
|
||||
_, err := m.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(m.streams).To(HaveKey(protocol.StreamID(1)))
|
||||
Expect(m.streams).To(HaveKey(protocol.StreamID(3)))
|
||||
Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
|
||||
})
|
||||
|
||||
It("doesn't reopen an already closed stream", func() {
|
||||
_, err := m.GetOrOpenStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = m.RemoveStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.GetOrOpenStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
})
|
||||
|
||||
Context("counting streams", func() {
|
||||
@@ -127,6 +145,11 @@ var _ = Describe("Streams Map", func() {
|
||||
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||
})
|
||||
|
||||
It("errors when too many streams are opened implicitely", func() {
|
||||
_, err := m.GetOrOpenStream(protocol.StreamID(maxNumStreams*2 + 1))
|
||||
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||
})
|
||||
|
||||
It("does not error when many streams are opened and closed", func() {
|
||||
for i := 2; i < 10*maxNumStreams; i++ {
|
||||
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
||||
@@ -197,12 +220,20 @@ var _ = Describe("Streams Map", func() {
|
||||
})
|
||||
|
||||
It("gets new streams", func() {
|
||||
s, err := m.GetOrOpenStream(6)
|
||||
s, err := m.GetOrOpenStream(2)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(6)))
|
||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(2)))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
|
||||
Expect(m.numIncomingStreams).To(BeZero())
|
||||
})
|
||||
|
||||
It("opens skipped streams", func() {
|
||||
_, err := m.GetOrOpenStream(6)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(m.streams).To(HaveKey(protocol.StreamID(2)))
|
||||
Expect(m.streams).To(HaveKey(protocol.StreamID(4)))
|
||||
Expect(m.streams).To(HaveKey(protocol.StreamID(6)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("server-side streams", func() {
|
||||
@@ -231,6 +262,7 @@ var _ = Describe("Streams Map", func() {
|
||||
setNewStreamsMap(protocol.PerspectiveServer)
|
||||
})
|
||||
|
||||
// TODO: remove when removing the openStreams slice
|
||||
Context("DoS mitigation", func() {
|
||||
It("opens and closes a lot of streams", func() {
|
||||
for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 {
|
||||
@@ -243,7 +275,7 @@ var _ = Describe("Streams Map", func() {
|
||||
}
|
||||
})
|
||||
|
||||
It("prevents opening of streams with very low StreamIDs, if higher streams have already been opened", func() {
|
||||
PIt("prevents opening of streams with very low StreamIDs, if higher streams have already been opened", func() {
|
||||
for i := 1; i < protocol.MaxNewStreamIDDelta+14; i += 2 {
|
||||
if i == 11 || i == 13 {
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user