diff --git a/streams_map.go b/streams_map.go index 05ce1a390..0bb2a200f 100644 --- a/streams_map.go +++ b/streams_map.go @@ -7,17 +7,15 @@ import ( "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" -) - -const ( - maxNumStreams = int(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier) + "github.com/lucas-clemente/quic-go/utils" ) type streamsMap struct { - streams map[protocol.StreamID]*stream - openStreams []protocol.StreamID - mutex sync.RWMutex - newStream newStreamLambda + streams map[protocol.StreamID]*stream + openStreams []protocol.StreamID + mutex sync.RWMutex + newStream newStreamLambda + maxNumStreams int roundRobinIndex int } @@ -30,10 +28,13 @@ var ( ) func newStreamsMap(newStream newStreamLambda) *streamsMap { + maxNumStreams := utils.Max(int(float32(protocol.MaxIncomingDynamicStreams)*protocol.MaxStreamsMultiplier), int(protocol.MaxIncomingDynamicStreams)) + return &streamsMap{ - streams: map[protocol.StreamID]*stream{}, - openStreams: make([]protocol.StreamID, 0, maxNumStreams), - newStream: newStream, + streams: map[protocol.StreamID]*stream{}, + openStreams: make([]protocol.StreamID, 0, maxNumStreams), + newStream: newStream, + maxNumStreams: maxNumStreams, } } @@ -54,7 +55,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { if ok { return s, nil } - if len(m.openStreams) == maxNumStreams { + if len(m.openStreams) == m.maxNumStreams { return nil, qerr.TooManyOpenStreams } if id%2 == 0 { diff --git a/streams_map_test.go b/streams_map_test.go index 4e04a559a..ddd772a7d 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -60,16 +60,16 @@ var _ = Describe("Streams Map", func() { Context("counting streams", func() { It("errors when too many streams are opened", func() { - for i := 0; i < maxNumStreams; i++ { + for i := 0; i < m.maxNumStreams; i++ { _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) } - _, err := m.GetOrOpenStream(protocol.StreamID(maxNumStreams)) + _, err := m.GetOrOpenStream(protocol.StreamID(m.maxNumStreams)) Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) It("does not error when many streams are opened and closed", func() { - for i := 2; i < 10*maxNumStreams; i++ { + for i := 2; i < 10*m.maxNumStreams; i++ { _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) m.RemoveStream(protocol.StreamID(i*2 + 1))