diff --git a/streams_map.go b/streams_map.go index c1650f18c..41d24615c 100644 --- a/streams_map.go +++ b/streams_map.go @@ -8,16 +8,18 @@ import ( ) type streamsMap struct { - streams map[protocol.StreamID]*stream - nStreams int - mutex sync.RWMutex + streams map[protocol.StreamID]*stream + openStreams []protocol.StreamID + mutex sync.RWMutex } type streamLambda func(*stream) (bool, error) func newStreamsMap() *streamsMap { + maxNumStreams := uint32(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier) return &streamsMap{ - streams: map[protocol.StreamID]*stream{}, + streams: map[protocol.StreamID]*stream{}, + openStreams: make([]protocol.StreamID, 0, maxNumStreams), } } @@ -50,11 +52,15 @@ func (m *streamsMap) Iterate(fn streamLambda) error { func (m *streamsMap) PutStream(s *stream) error { m.mutex.Lock() defer m.mutex.Unlock() - if _, ok := m.streams[s.StreamID()]; ok { - return fmt.Errorf("a stream with ID %d already exists", s.StreamID()) + + id := s.StreamID() + if _, ok := m.streams[id]; ok { + return fmt.Errorf("a stream with ID %d already exists", id) } - m.streams[s.StreamID()] = s - m.nStreams++ + + m.streams[id] = s + m.openStreams = append(m.openStreams, id) + return nil } @@ -64,14 +70,23 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error { if !ok || s == nil { return fmt.Errorf("attempted to remove non-existing stream: %d", id) } + m.streams[id] = nil - m.nStreams-- + + for i, s := range m.openStreams { + if s == id { + // delete the streamID from the openStreams slice + m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])] + } + } + return nil } // NumberOfStreams gets the number of open streams func (m *streamsMap) NumberOfStreams() int { m.mutex.RLock() - defer m.mutex.RUnlock() - return m.nStreams + n := len(m.openStreams) + m.mutex.RUnlock() + return n } diff --git a/streams_map_test.go b/streams_map_test.go index c286a7760..4c752d193 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -32,24 +32,68 @@ var _ = Describe("Streams Map", func() { Expect(s).To(BeNil()) }) - It("errors when removing non-existing stream", func() { - err := m.RemoveStream(1) - Expect(err).To(MatchError("attempted to remove non-existing stream: 1")) + Context("putting streams", func() { + It("stores streams", func() { + err := m.PutStream(&stream{streamID: 5}) + Expect(err).NotTo(HaveOccurred()) + s, exists := m.GetStream(5) + Expect(exists).To(BeTrue()) + Expect(s.streamID).To(Equal(protocol.StreamID(5))) + Expect(m.openStreams).To(HaveLen(1)) + Expect(m.openStreams[0]).To(Equal(protocol.StreamID(5))) + }) + + It("does not store multiple streams with the same ID", func() { + err := m.PutStream(&stream{streamID: 5}) + Expect(err).NotTo(HaveOccurred()) + err = m.PutStream(&stream{streamID: 5}) + Expect(err).To(MatchError("a stream with ID 5 already exists")) + Expect(m.openStreams).To(HaveLen(1)) + }) }) - It("stores streams", func() { - err := m.PutStream(&stream{streamID: 5}) - Expect(err).NotTo(HaveOccurred()) - s, exists := m.GetStream(5) - Expect(exists).To(BeTrue()) - Expect(s.streamID).To(Equal(protocol.StreamID(5))) - }) + Context("deleting streams", func() { + BeforeEach(func() { + for i := 1; i <= 5; i++ { + err := m.PutStream(&stream{streamID: protocol.StreamID(i)}) + Expect(err).ToNot(HaveOccurred()) + } + Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5})) + }) - It("does not store multiple streams with the same ID", func() { - err := m.PutStream(&stream{streamID: 5}) - Expect(err).NotTo(HaveOccurred()) - err = m.PutStream(&stream{streamID: 5}) - Expect(err).To(MatchError("a stream with ID 5 already exists")) + It("errors when removing non-existing stream", func() { + err := m.RemoveStream(1337) + Expect(err).To(MatchError("attempted to remove non-existing stream: 1337")) + }) + + It("removes the first stream", func() { + err := m.RemoveStream(1) + Expect(err).ToNot(HaveOccurred()) + Expect(m.openStreams).To(HaveLen(4)) + Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5})) + }) + + It("removes a stream in the middle", func() { + err := m.RemoveStream(3) + Expect(err).ToNot(HaveOccurred()) + Expect(m.openStreams).To(HaveLen(4)) + Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5})) + }) + + It("removes a stream at the end", func() { + err := m.RemoveStream(5) + Expect(err).ToNot(HaveOccurred()) + Expect(m.openStreams).To(HaveLen(4)) + Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4})) + }) + + It("removes all streams", func() { + for i := 1; i <= 5; i++ { + err := m.RemoveStream(protocol.StreamID(i)) + Expect(err).ToNot(HaveOccurred()) + } + Expect(m.openStreams).To(BeEmpty()) + }) }) Context("number of streams", func() { @@ -68,7 +112,7 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) err = m.RemoveStream(5) Expect(err).ToNot(HaveOccurred()) - Expect(m.NumberOfStreams()).To(Equal(0)) + Expect(m.NumberOfStreams()).To(BeZero()) }) })