Merge pull request #257 from lucas-clemente/streamsmap

implement a StreamsMap and use it in Session and StreamFramer
This commit is contained in:
Marten Seemann
2016-08-05 20:01:54 +07:00
committed by GitHub
7 changed files with 325 additions and 100 deletions

View File

@@ -2,7 +2,6 @@ package quic
import (
"bytes"
"sync"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/handshake"
@@ -24,7 +23,7 @@ var _ = Describe("Packet packer", func() {
fcm.sendWindowSizes[5] = protocol.MaxByteCount
fcm.sendWindowSizes[7] = protocol.MaxByteCount
streamFramer = newStreamFramer(&map[protocol.StreamID]*stream{}, &sync.RWMutex{}, fcm)
streamFramer = newStreamFramer(newStreamsMap(), fcm)
packer = &packetPacker{
cryptoSetup: &handshake.CryptoSetup{},

View File

@@ -48,7 +48,7 @@ type Session struct {
conn connection
streams map[protocol.StreamID]*stream
streamsMap *streamsMap
openStreamsCount uint32
streamsMutex sync.RWMutex
@@ -108,12 +108,13 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
}
session := &Session{
connectionID: connectionID,
version: v,
conn: conn,
streamCallback: streamCallback,
closeCallback: closeCallback,
streams: make(map[protocol.StreamID]*stream),
connectionID: connectionID,
version: v,
conn: conn,
streamCallback: streamCallback,
closeCallback: closeCallback,
// streams: make(map[protocol.StreamID]*stream),
streamsMap: newStreamsMap(),
sentPacketHandler: sentPacketHandler,
receivedPacketHandler: receivedPacketHandler,
stopWaitingManager: stopWaitingManager,
@@ -135,7 +136,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
return nil, err
}
session.streamFramer = newStreamFramer(&session.streams, &session.streamsMutex, flowControlManager)
session.streamFramer = newStreamFramer(session.streamsMap, flowControlManager)
session.packer = newPacketPacker(connectionID, session.cryptoSetup, session.connectionParametersManager, session.streamFramer, v)
session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: v}
@@ -331,10 +332,10 @@ func (s *Session) handlePacket(remoteAddr interface{}, hdr *PublicHeader, data [
func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
str, streamExists := s.streams[frame.StreamID]
str, strExists := s.streamsMap.GetStream(frame.StreamID)
var err error
if !streamExists {
if !strExists {
if !s.isValidStreamID(frame.StreamID) {
return qerr.InvalidStreamID
}
@@ -352,7 +353,7 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error {
if err != nil {
return err
}
if !streamExists {
if !strExists {
s.streamCallback(s, str)
}
return nil
@@ -366,14 +367,14 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error
s.streamsMutex.RLock()
defer s.streamsMutex.RUnlock()
if frame.StreamID != 0 {
stream, ok := s.streams[frame.StreamID]
if ok && stream == nil {
str, strExists := s.streamsMap.GetStream(frame.StreamID)
if strExists && str == nil {
return errWindowUpdateOnClosedStream
}
// open new stream when receiving a WindowUpdate for a non-existing stream
// this can occur if the client immediately sends a WindowUpdate for a newly opened stream, and packet reordering occurs such that the packet opening the new stream arrives after the WindowUpdate
if !ok {
if !strExists {
s.newStreamImpl(frame.StreamID)
}
}
@@ -384,7 +385,7 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error
// TODO: Handle frame.byteOffset
func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error {
s.streamsMutex.RLock()
str, streamExists := s.streams[frame.StreamID]
str, streamExists := s.streamsMap.GetStream(frame.StreamID)
s.streamsMutex.RUnlock()
if !streamExists || str == nil {
return errRstStreamOnInvalidStream
@@ -445,12 +446,16 @@ func (s *Session) closeImpl(e error, remoteClose bool) error {
func (s *Session) closeStreamsWithError(err error) {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
for _, str := range s.streams {
fn := func(str *stream) (bool, error) {
if str == nil {
continue
return true, nil
}
s.closeStreamWithError(str, err)
return true, nil
}
s.streamsMap.Iterate(fn)
}
func (s *Session) closeStreamWithError(str *stream, err error) {
@@ -595,7 +600,8 @@ func (s *Session) OpenStream(id protocol.StreamID) (utils.Stream, error) {
func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
if stream, ok := s.streams[id]; ok {
stream, strExists := s.streamsMap.GetStream(id)
if strExists {
return stream, nil
}
return s.newStreamImpl(id)
@@ -608,7 +614,8 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) {
go s.Close(qerr.TooManyOpenStreams)
return nil, qerr.TooManyOpenStreams
}
if _, ok := s.streams[id]; ok {
_, strExists := s.streamsMap.GetStream(id)
if strExists {
return nil, fmt.Errorf("Session: stream with ID %d already exists", id)
}
stream, err := newStream(s.scheduleSending, s.connectionParametersManager, s.flowControlManager, id)
@@ -624,26 +631,33 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) {
}
atomic.AddUint32(&s.openStreamsCount, 1)
s.streams[id] = stream
err = s.streamsMap.PutStream(stream)
if err != nil {
return nil, err
}
return stream, nil
}
// garbageCollectStreams goes through all streams and removes EOF'ed streams
// from the streams map.
func (s *Session) garbageCollectStreams() {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
for k, v := range s.streams {
if v == nil {
continue
fn := func(str *stream) (bool, error) {
if str == nil {
return true, nil
}
if v.finished() {
utils.Debugf("Garbage-collecting stream %d", k)
id := str.StreamID()
if str.finished() {
atomic.AddUint32(&s.openStreamsCount, ^uint32(0)) // decrement
s.streams[k] = nil
s.flowControlManager.RemoveStream(k)
err := s.streamsMap.RemoveStream(id)
if err != nil {
return false, err
}
s.flowControlManager.RemoveStream(id)
}
return true, nil
}
s.streamsMap.Iterate(fn)
}
func (s *Session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error {
@@ -680,20 +694,24 @@ func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) {
var res []*frames.WindowUpdateFrame
for id, str := range s.streams {
fn := func(str *stream) (bool, error) {
if str == nil {
continue
return true, nil
}
id := str.StreamID()
doUpdate, offset, err := s.flowControlManager.MaybeTriggerStreamWindowUpdate(id)
if err != nil {
return nil, err
return false, err
}
if doUpdate {
res = append(res, &frames.WindowUpdateFrame{StreamID: id, ByteOffset: offset})
}
return true, nil
}
s.streamsMap.Iterate(fn)
doUpdate, offset := s.flowControlManager.MaybeTriggerConnectionWindowUpdate()
if doUpdate {
res = append(res, &frames.WindowUpdateFrame{StreamID: 0, ByteOffset: offset})

View File

@@ -114,7 +114,7 @@ var _ = Describe("Session", func() {
)
Expect(err).NotTo(HaveOccurred())
session = pSession.(*Session)
Expect(session.streams).To(HaveLen(1)) // Crypto stream
Expect(session.streamsMap.NumberOfStreams()).To(Equal(1)) // Crypto stream
})
Context("when handling stream frames", func() {
@@ -123,10 +123,12 @@ var _ = Describe("Session", func() {
StreamID: 5,
Data: []byte{0xde, 0xca, 0xfb, 0xad},
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
_, err := str.Read(p)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
})
@@ -154,16 +156,18 @@ var _ = Describe("Session", func() {
StreamID: 5,
Data: []byte{0xde, 0xca},
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
Expect(streamCallbackCalled).To(BeTrue())
session.handleStreamFrame(&frames.StreamFrame{
StreamID: 5,
Offset: 2,
Data: []byte{0xfb, 0xad},
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
_, err := str.Read(p)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
})
@@ -173,8 +177,9 @@ var _ = Describe("Session", func() {
Expect(err).ToNot(HaveOccurred())
str.Close()
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ = session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
})
It("does not delete streams with FIN bit", func() {
@@ -183,16 +188,18 @@ var _ = Describe("Session", func() {
Data: []byte{0xde, 0xca, 0xfb, 0xad},
FinBit: true,
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
_, err := str.Read(p)
Expect(err).To(MatchError(io.EOF))
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ = session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
})
It("deletes streams with FIN bit & close", func() {
@@ -201,23 +208,27 @@ var _ = Describe("Session", func() {
Data: []byte{0xde, 0xca, 0xfb, 0xad},
FinBit: true,
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
_, err := str.Read(p)
Expect(err).To(MatchError(io.EOF))
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ = session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
// We still need to close the stream locally
session.streams[5].Close()
str.Close()
// ... and simulate that we actually the FIN
session.streams[5].sentFin()
str.sentFin()
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).To(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(1))
str, strExists := session.streamsMap.GetStream(5)
Expect(strExists).To(BeTrue())
Expect(str).To(BeNil())
// flow controller should have been notified
_, err = session.flowControlManager.SendWindowSize(5)
Expect(err).To(MatchError("Error accessing the flowController map."))
@@ -229,31 +240,35 @@ var _ = Describe("Session", func() {
StreamID: 5,
Data: []byte{0xde, 0xca, 0xfb, 0xad},
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
Expect(err).ToNot(HaveOccurred())
_, err := str.Read(p)
session.closeStreamsWithError(testErr)
_, err = session.streams[5].Read(p)
_, err = str.Read(p)
Expect(err).To(MatchError(testErr))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).To(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(1))
str, strExists := session.streamsMap.GetStream(5)
Expect(strExists).To(BeTrue())
Expect(str).To(BeNil())
})
It("closes empty streams with error", func() {
testErr := errors.New("test")
session.newStreamImpl(5)
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
session.closeStreamsWithError(testErr)
_, err := session.streams[5].Read([]byte{0})
_, err := str.Read([]byte{0})
Expect(err).To(MatchError(testErr))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).To(BeNil())
str, strExists := session.streamsMap.GetStream(5)
Expect(strExists).To(BeTrue())
Expect(str).To(BeNil())
})
It("informs the FlowControlManager about new streams", func() {
@@ -271,10 +286,12 @@ var _ = Describe("Session", func() {
Data: []byte{},
FinBit: true,
})
_, err := session.streams[5].Read([]byte{0})
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
_, err := str.Read([]byte{0})
Expect(err).To(MatchError(io.EOF))
session.streams[5].Close()
session.streams[5].sentFin()
str.Close()
str.sentFin()
session.garbageCollectStreams()
err = session.handleStreamFrame(&frames.StreamFrame{
StreamID: 5,
@@ -344,13 +361,17 @@ var _ = Describe("Session", func() {
ByteOffset: 1337,
})
Expect(err).ToNot(HaveOccurred())
Expect(session.streams).To(HaveKey(protocol.StreamID(5)))
Expect(session.streams[5]).ToNot(BeNil())
str, strExists := session.streamsMap.GetStream(5)
Expect(strExists).To(BeTrue())
Expect(str).ToNot(BeNil())
})
It("errors when receiving a WindowUpdateFrame for a closed stream", func() {
session.streams[5] = nil // this is what the garbageCollectStreams() does when a Stream is closed
err := session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{
session.handleStreamFrame(&frames.StreamFrame{StreamID: 5})
err := session.streamsMap.RemoveStream(5)
Expect(err).ToNot(HaveOccurred())
session.garbageCollectStreams()
err = session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{
StreamID: 5,
ByteOffset: 1337,
})
@@ -358,8 +379,11 @@ var _ = Describe("Session", func() {
})
It("ignores errors when receiving a WindowUpdateFrame for a closed stream", func() {
session.streams[5] = nil // this is what the garbageCollectStreams() does when a Stream is closed
err := session.handleFrames([]frames.Frame{&frames.WindowUpdateFrame{
session.handleStreamFrame(&frames.StreamFrame{StreamID: 5})
err := session.streamsMap.RemoveStream(5)
Expect(err).ToNot(HaveOccurred())
session.garbageCollectStreams()
err = session.handleFrames([]frames.Frame{&frames.WindowUpdateFrame{
StreamID: 5,
ByteOffset: 1337,
}})

View File

@@ -1,8 +1,6 @@
package quic
import (
"sync"
"github.com/lucas-clemente/quic-go/flowcontrol"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
@@ -11,8 +9,7 @@ import (
type streamFramer struct {
// TODO: Simplify by extracting the streams map into a separate object
streams *map[protocol.StreamID]*stream
streamsMutex *sync.RWMutex
streamsMap *streamsMap
flowControlManager flowcontrol.FlowControlManager
@@ -20,10 +17,9 @@ type streamFramer struct {
blockedFrameQueue []*frames.BlockedFrame
}
func newStreamFramer(streams *map[protocol.StreamID]*stream, streamsMutex *sync.RWMutex, flowControlManager flowcontrol.FlowControlManager) *streamFramer {
func newStreamFramer(streamsMap *streamsMap, flowControlManager flowcontrol.FlowControlManager) *streamFramer {
return &streamFramer{
streams: streams,
streamsMutex: streamsMutex,
streamsMap: streamsMap,
flowControlManager: flowControlManager,
}
}
@@ -73,15 +69,12 @@ func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount
}
func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []*frames.StreamFrame) {
f.streamsMutex.RLock()
defer f.streamsMutex.RUnlock()
frame := &frames.StreamFrame{DataLenPresent: true}
var currentLen protocol.ByteCount
for _, s := range *f.streams {
fn := func(s *stream) (bool, error) {
if s == nil {
continue
return true, nil
}
frame.StreamID = s.streamID
@@ -89,7 +82,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
frame.Offset = s.writeOffset
frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error
if currentLen+frameHeaderBytes > maxBytes {
return // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here
return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here
}
maxLen := maxBytes - currentLen - frameHeaderBytes
@@ -99,7 +92,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
}
if maxLen == 0 {
continue
return true, nil
}
data := s.getDataForWriting(maxLen)
@@ -111,7 +104,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
currentLen += frameHeaderBytes + frame.DataLen()
frame = &frames.StreamFrame{DataLenPresent: true}
}
continue
return true, nil
}
frame.Data = data
@@ -129,7 +122,11 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
res = append(res, frame)
currentLen += frameHeaderBytes + frame.DataLen()
frame = &frames.StreamFrame{DataLenPresent: true}
return true, nil
}
f.streamsMap.Iterate(fn)
return
}

View File

@@ -2,7 +2,6 @@ package quic
import (
"bytes"
"sync"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
@@ -14,7 +13,7 @@ var _ = Describe("Stream Framer", func() {
var (
retransmittedFrame1, retransmittedFrame2 *frames.StreamFrame
framer *streamFramer
streamsMap map[protocol.StreamID]*stream
streamsMap *streamsMap
stream1, stream2 *stream
fcm *mockFlowControlHandler
)
@@ -31,18 +30,17 @@ var _ = Describe("Stream Framer", func() {
stream1 = &stream{streamID: 10}
stream2 = &stream{streamID: 11}
streamsMap = map[protocol.StreamID]*stream{
1: nil, 2: nil, 3: nil, 4: nil, // we have to be able to deal with nil frames
10: stream1,
11: stream2,
}
streamsMap = newStreamsMap()
streamsMap.PutStream(stream1)
streamsMap.PutStream(stream2)
fcm = newMockFlowControlHandler()
fcm.sendWindowSizes[stream1.streamID] = protocol.MaxByteCount
fcm.sendWindowSizes[stream2.streamID] = protocol.MaxByteCount
fcm.sendWindowSizes[retransmittedFrame1.StreamID] = protocol.MaxByteCount
fcm.sendWindowSizes[retransmittedFrame2.StreamID] = protocol.MaxByteCount
framer = newStreamFramer(&streamsMap, &sync.RWMutex{}, fcm)
framer = newStreamFramer(streamsMap, fcm)
})
It("sets the DataLenPresent for dequeued retransmitted frames", func() {

76
streams_map.go Normal file
View File

@@ -0,0 +1,76 @@
package quic
import (
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/protocol"
)
type streamsMap struct {
streams map[protocol.StreamID]*stream
nStreams int
mutex sync.RWMutex
}
type streamLambda func(*stream) (bool, error)
func newStreamsMap() *streamsMap {
return &streamsMap{
streams: map[protocol.StreamID]*stream{},
}
}
func (m *streamsMap) GetStream(id protocol.StreamID) (*stream, bool) {
m.mutex.RLock()
s, ok := m.streams[id]
m.mutex.RUnlock()
if !ok {
return nil, false
}
return s, true
}
func (m *streamsMap) Iterate(fn streamLambda) error {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, str := range m.streams {
cont, err := fn(str)
if err != nil {
return err
}
if !cont {
break
}
}
return nil
}
func (m *streamsMap) PutStream(s *stream) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, ok := m.streams[s.StreamID()]; ok {
return fmt.Errorf("a stream with ID %d already exists", s.StreamID())
}
m.streams[s.StreamID()] = s
m.nStreams++
return nil
}
// Attention: this function must only be called if a mutex has been acquired previously
func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
s, ok := m.streams[id]
if !ok || s == nil {
return fmt.Errorf("attempted to remove non-existing stream: %d", id)
}
m.streams[id] = nil
m.nStreams--
return nil
}
func (m *streamsMap) NumberOfStreams() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.nStreams
}

113
streams_map_test.go Normal file
View File

@@ -0,0 +1,113 @@
package quic
import (
"errors"
"github.com/lucas-clemente/quic-go/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Streams Map", func() {
var (
m *streamsMap
)
BeforeEach(func() {
m = newStreamsMap()
})
It("returns an error for non-existant streams", func() {
_, exists := m.GetStream(1)
Expect(exists).To(BeFalse())
})
It("returns nil for previously existing streams", func() {
err := m.PutStream(&stream{streamID: 1})
Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(1)
Expect(err).NotTo(HaveOccurred())
s, exists := m.GetStream(1)
Expect(exists).To(BeTrue())
Expect(s).To(BeNil())
})
It("errors when removing non-existing stream", func() {
err := m.RemoveStream(1)
Expect(err).To(MatchError("attempted to remove non-existing stream: 1"))
})
It("stores streams", func() {
err := m.PutStream(&stream{streamID: 5})
Expect(err).NotTo(HaveOccurred())
s, exists := m.GetStream(5)
Expect(exists).To(BeTrue())
Expect(s.streamID).To(Equal(protocol.StreamID(5)))
})
It("does not store multiple streams with the same ID", func() {
err := m.PutStream(&stream{streamID: 5})
Expect(err).NotTo(HaveOccurred())
err = m.PutStream(&stream{streamID: 5})
Expect(err).To(MatchError("a stream with ID 5 already exists"))
})
It("gets the number of streams", func() {
Expect(m.NumberOfStreams()).To(Equal(0))
m.PutStream(&stream{streamID: 5})
Expect(m.NumberOfStreams()).To(Equal(1))
m.RemoveStream(5)
Expect(m.NumberOfStreams()).To(Equal(0))
})
Context("Lambda", func() {
// create 5 streams, ids 1 to 3
BeforeEach(func() {
for i := 1; i <= 3; i++ {
err := m.PutStream(&stream{streamID: protocol.StreamID(i)})
Expect(err).NotTo(HaveOccurred())
}
})
It("executes the lambda exactly once for every stream", func() {
var numIterations int
callbackCalled := make(map[protocol.StreamID]bool)
fn := func(str *stream) (bool, error) {
callbackCalled[str.StreamID()] = true
numIterations++
return true, nil
}
err := m.Iterate(fn)
Expect(err).ToNot(HaveOccurred())
Expect(callbackCalled).To(HaveKey(protocol.StreamID(1)))
Expect(callbackCalled).To(HaveKey(protocol.StreamID(2)))
Expect(callbackCalled).To(HaveKey(protocol.StreamID(3)))
Expect(numIterations).To(Equal(3))
})
It("stops iterating when the callback returns false", func() {
var numIterations int
fn := func(str *stream) (bool, error) {
numIterations++
return false, nil
}
err := m.Iterate(fn)
Expect(err).ToNot(HaveOccurred())
// due to map access randomization, we don't know for which stream the callback was executed
// but it must only be executed once
Expect(numIterations).To(Equal(1))
})
It("returns the error, if the lambda returns one", func() {
var numIterations int
expectedError := errors.New("test")
fn := func(str *stream) (bool, error) {
numIterations++
return true, expectedError
}
err := m.Iterate(fn)
Expect(err).To(MatchError(expectedError))
Expect(numIterations).To(Equal(1))
})
})
})