diff --git a/internal/protocol/stream.go b/internal/protocol/stream.go index 1735a235..988dcc82 100644 --- a/internal/protocol/stream.go +++ b/internal/protocol/stream.go @@ -1,19 +1,5 @@ package protocol -// A StreamID in QUIC -type StreamID int64 - -// InvalidPacketNumber is a stream ID that is invalid. -// The first valid stream ID in QUIC is 0. -const InvalidStreamID StreamID = -1 - -// StreamNum is the stream number -type StreamNum int64 - -// MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames -// and as the stream count in the transport parameters -const MaxStreamCount StreamNum = 1 << 60 - // StreamType encodes if this is a unidirectional or bidirectional stream type StreamType uint8 @@ -24,6 +10,49 @@ const ( StreamTypeBidi ) +// InvalidPacketNumber is a stream ID that is invalid. +// The first valid stream ID in QUIC is 0. +const InvalidStreamID StreamID = -1 + +// StreamNum is the stream number +type StreamNum int64 + +const ( + // InvalidStreamNum is an invalid stream number. + InvalidStreamNum = -1 + // MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames + // and as the stream count in the transport parameters + MaxStreamCount StreamNum = 1 << 60 +) + +// StreamID calculates the stream ID. +func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID { + if s == 0 { + return InvalidStreamID + } + var first StreamID + switch stype { + case StreamTypeBidi: + switch pers { + case PerspectiveClient: + first = 0 + case PerspectiveServer: + first = 1 + } + case StreamTypeUni: + switch pers { + case PerspectiveClient: + first = 2 + case PerspectiveServer: + first = 3 + } + } + return first + 4*StreamID(s-1) +} + +// A StreamID in QUIC +type StreamID int64 + // InitiatedBy says if the stream was initiated by the client or by the server func (s StreamID) InitiatedBy() Perspective { if s%2 == 0 { @@ -45,34 +74,3 @@ func (s StreamID) Type() StreamType { func (s StreamID) StreamNum() StreamNum { return StreamNum(s/4) + 1 } - -// MaxStreamID is the highest stream ID that a peer is allowed to open, -// when it is allowed to open numStreams. -func MaxStreamID(stype StreamType, numStreams StreamNum, pers Perspective) StreamID { - if numStreams == 0 { - return InvalidStreamID - } - var first StreamID - switch stype { - case StreamTypeBidi: - switch pers { - case PerspectiveClient: - first = 0 - case PerspectiveServer: - first = 1 - } - case StreamTypeUni: - switch pers { - case PerspectiveClient: - first = 2 - case PerspectiveServer: - first = 3 - } - } - return first + 4*StreamID(numStreams-1) -} - -// FirstStream returns the first valid stream ID -func FirstStream(stype StreamType, pers Perspective) StreamID { - return MaxStreamID(stype, 1, pers) -} diff --git a/internal/protocol/stream_test.go b/internal/protocol/stream_test.go index eb9eea85..a8d4654c 100644 --- a/internal/protocol/stream_test.go +++ b/internal/protocol/stream_test.go @@ -24,13 +24,6 @@ var _ = Describe("Stream ID", func() { Expect(StreamID(7).Type()).To(Equal(StreamTypeUni)) }) - It("tells the first stream ID", func() { - Expect(FirstStream(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(0))) - Expect(FirstStream(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(1))) - Expect(FirstStream(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(2))) - Expect(FirstStream(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3))) - }) - It("tells the stream number", func() { Expect(StreamID(0).StreamNum()).To(BeEquivalentTo(1)) Expect(StreamID(1).StreamNum()).To(BeEquivalentTo(1)) @@ -42,26 +35,26 @@ var _ = Describe("Stream ID", func() { Expect(StreamID(11).StreamNum()).To(BeEquivalentTo(3)) }) - Context("maximum stream IDs", func() { - It("doesn't allow any", func() { - Expect(MaxStreamID(StreamTypeBidi, 0, PerspectiveClient)).To(Equal(InvalidStreamID)) - Expect(MaxStreamID(StreamTypeBidi, 0, PerspectiveServer)).To(Equal(InvalidStreamID)) - Expect(MaxStreamID(StreamTypeUni, 0, PerspectiveClient)).To(Equal(InvalidStreamID)) - Expect(MaxStreamID(StreamTypeUni, 0, PerspectiveServer)).To(Equal(InvalidStreamID)) + Context("converting stream nums to stream IDs", func() { + It("handles 0", func() { + Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(InvalidStreamID)) + Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(InvalidStreamID)) + Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(InvalidStreamID)) + Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(InvalidStreamID)) }) - It("allows one", func() { - Expect(MaxStreamID(StreamTypeBidi, 1, PerspectiveClient)).To(Equal(StreamID(0))) - Expect(MaxStreamID(StreamTypeBidi, 1, PerspectiveServer)).To(Equal(StreamID(1))) - Expect(MaxStreamID(StreamTypeUni, 1, PerspectiveClient)).To(Equal(StreamID(2))) - Expect(MaxStreamID(StreamTypeUni, 1, PerspectiveServer)).To(Equal(StreamID(3))) + It("handles the first", func() { + Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(0))) + Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(1))) + Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(2))) + Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3))) }) - It("allows many", func() { - Expect(MaxStreamID(StreamTypeBidi, 100, PerspectiveClient)).To(Equal(StreamID(396))) - Expect(MaxStreamID(StreamTypeBidi, 100, PerspectiveServer)).To(Equal(StreamID(397))) - Expect(MaxStreamID(StreamTypeUni, 100, PerspectiveClient)).To(Equal(StreamID(398))) - Expect(MaxStreamID(StreamTypeUni, 100, PerspectiveServer)).To(Equal(StreamID(399))) + It("handles others", func() { + Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(396))) + Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(397))) + Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(398))) + Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(399))) }) }) }) diff --git a/streams_map.go b/streams_map.go index f4b9d581..cc53efd1 100644 --- a/streams_map.go +++ b/streams_map.go @@ -12,6 +12,27 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) +type streamError struct { + message string + nums []protocol.StreamNum +} + +func (e streamError) Error() string { + return e.message +} + +func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error { + strError, ok := err.(streamError) + if !ok { + return err + } + ids := make([]interface{}, len(strError.nums)) + for i, num := range strError.nums { + ids[i] = num.StreamID(stype, pers) + } + return fmt.Errorf(strError.Error(), ids...) +} + type streamOpenErr struct{ error } var _ net.Error = &streamOpenErr{} @@ -49,112 +70,144 @@ func newStreamsMap( newFlowController: newFlowController, sender: sender, } - newBidiStream := func(id protocol.StreamID) streamI { - return newStream(id, m.sender, m.newFlowController(id), version) - } - newUniSendStream := func(id protocol.StreamID) sendStreamI { - return newSendStream(id, m.sender, m.newFlowController(id), version) - } - newUniReceiveStream := func(id protocol.StreamID) receiveStreamI { - return newReceiveStream(id, m.sender, m.newFlowController(id), version) - } m.outgoingBidiStreams = newOutgoingBidiStreamsMap( - protocol.FirstStream(protocol.StreamTypeBidi, perspective), - newBidiStream, + func(num protocol.StreamNum) streamI { + id := num.StreamID(protocol.StreamTypeBidi, perspective) + return newStream(id, m.sender, m.newFlowController(id), version) + }, sender.queueControlFrame, ) m.incomingBidiStreams = newIncomingBidiStreamsMap( - protocol.FirstStream(protocol.StreamTypeBidi, perspective.Opposite()), - protocol.MaxStreamID(protocol.StreamTypeBidi, protocol.StreamNum(maxIncomingBidiStreams), perspective.Opposite()), + func(num protocol.StreamNum) streamI { + id := num.StreamID(protocol.StreamTypeBidi, perspective.Opposite()) + return newStream(id, m.sender, m.newFlowController(id), version) + }, maxIncomingBidiStreams, sender.queueControlFrame, - newBidiStream, ) m.outgoingUniStreams = newOutgoingUniStreamsMap( - protocol.FirstStream(protocol.StreamTypeUni, perspective), - newUniSendStream, + func(num protocol.StreamNum) sendStreamI { + id := num.StreamID(protocol.StreamTypeUni, perspective) + return newSendStream(id, m.sender, m.newFlowController(id), version) + }, sender.queueControlFrame, ) m.incomingUniStreams = newIncomingUniStreamsMap( - protocol.FirstStream(protocol.StreamTypeUni, perspective.Opposite()), - protocol.MaxStreamID(protocol.StreamTypeUni, protocol.StreamNum(maxIncomingUniStreams), perspective.Opposite()), + func(num protocol.StreamNum) receiveStreamI { + id := num.StreamID(protocol.StreamTypeUni, perspective.Opposite()) + return newReceiveStream(id, m.sender, m.newFlowController(id), version) + }, maxIncomingUniStreams, sender.queueControlFrame, - newUniReceiveStream, ) return m } func (m *streamsMap) OpenStream() (Stream, error) { - return m.outgoingBidiStreams.OpenStream() + str, err := m.outgoingBidiStreams.OpenStream() + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) } func (m *streamsMap) OpenStreamSync() (Stream, error) { - return m.outgoingBidiStreams.OpenStreamSync() + str, err := m.outgoingBidiStreams.OpenStreamSync() + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) } func (m *streamsMap) OpenUniStream() (SendStream, error) { - return m.outgoingUniStreams.OpenStream() + str, err := m.outgoingUniStreams.OpenStream() + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) } func (m *streamsMap) OpenUniStreamSync() (SendStream, error) { - return m.outgoingUniStreams.OpenStreamSync() + str, err := m.outgoingUniStreams.OpenStreamSync() + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) } func (m *streamsMap) AcceptStream() (Stream, error) { - return m.incomingBidiStreams.AcceptStream() + str, err := m.incomingBidiStreams.AcceptStream() + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) } func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) { - return m.incomingUniStreams.AcceptStream() + str, err := m.incomingUniStreams.AcceptStream() + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) } func (m *streamsMap) DeleteStream(id protocol.StreamID) error { + num := id.StreamNum() switch id.Type() { case protocol.StreamTypeUni: if id.InitiatedBy() == m.perspective { - return m.outgoingUniStreams.DeleteStream(id) + return m.outgoingUniStreams.DeleteStream(num) } - return m.incomingUniStreams.DeleteStream(id) + return m.incomingUniStreams.DeleteStream(num) case protocol.StreamTypeBidi: if id.InitiatedBy() == m.perspective { - return m.outgoingBidiStreams.DeleteStream(id) + return m.outgoingBidiStreams.DeleteStream(num) } - return m.incomingBidiStreams.DeleteStream(id) + return m.incomingBidiStreams.DeleteStream(num) } panic("") } func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + str, err := m.getOrOpenReceiveStream(id) + if err != nil { + return nil, qerr.Error(qerr.StreamStateError, err.Error()) + } + return str, nil +} + +func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + num := id.StreamNum() switch id.Type() { case protocol.StreamTypeUni: if id.InitiatedBy() == m.perspective { // an outgoing unidirectional stream is a send stream, not a receive stream return nil, fmt.Errorf("peer attempted to open receive stream %d", id) } - return m.incomingUniStreams.GetOrOpenStream(id) + str, err := m.incomingUniStreams.GetOrOpenStream(num) + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) case protocol.StreamTypeBidi: + var str receiveStreamI + var err error if id.InitiatedBy() == m.perspective { - return m.outgoingBidiStreams.GetStream(id) + str, err = m.outgoingBidiStreams.GetStream(num) + } else { + str, err = m.incomingBidiStreams.GetOrOpenStream(num) } - return m.incomingBidiStreams.GetOrOpenStream(id) + return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) } panic("") } func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + str, err := m.getOrOpenSendStream(id) + if err != nil { + return nil, qerr.Error(qerr.StreamStateError, err.Error()) + } + return str, nil +} + +func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + num := id.StreamNum() switch id.Type() { case protocol.StreamTypeUni: if id.InitiatedBy() == m.perspective { - return m.outgoingUniStreams.GetStream(id) + str, err := m.outgoingUniStreams.GetStream(num) + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) } // an incoming unidirectional stream is a receive stream, not a send stream return nil, fmt.Errorf("peer attempted to open send stream %d", id) case protocol.StreamTypeBidi: + var str sendStreamI + var err error if id.InitiatedBy() == m.perspective { - return m.outgoingBidiStreams.GetStream(id) + str, err = m.outgoingBidiStreams.GetStream(num) + } else { + str, err = m.incomingBidiStreams.GetOrOpenStream(num) } - return m.incomingBidiStreams.GetOrOpenStream(id) + return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) } panic("") } @@ -163,12 +216,11 @@ func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error { if f.MaxStreamNum > protocol.MaxStreamCount { return qerr.StreamLimitError } - id := protocol.MaxStreamID(f.Type, f.MaxStreamNum, m.perspective) - switch id.Type() { + switch f.Type { case protocol.StreamTypeUni: - m.outgoingUniStreams.SetMaxStream(id) + m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum) case protocol.StreamTypeBidi: - m.outgoingBidiStreams.SetMaxStream(id) + m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum) } return nil } @@ -179,8 +231,8 @@ func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) error { return qerr.StreamLimitError } // Max{Uni,Bidi}StreamID returns the highest stream ID that the peer is allowed to open. - m.outgoingBidiStreams.SetMaxStream(protocol.MaxStreamID(protocol.StreamTypeBidi, p.MaxBidiStreamNum, m.perspective)) - m.outgoingUniStreams.SetMaxStream(protocol.MaxStreamID(protocol.StreamTypeUni, p.MaxUniStreamNum, m.perspective)) + m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum) + m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum) return nil } diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go index 76f8cb8f..f0aad6a2 100644 --- a/streams_map_incoming_bidi.go +++ b/streams_map_incoming_bidi.go @@ -5,7 +5,6 @@ package quic import ( - "fmt" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -16,38 +15,39 @@ type incomingBidiStreamsMap struct { mutex sync.RWMutex cond sync.Cond - streams map[protocol.StreamID]streamI + streams map[protocol.StreamNum]streamI // When a stream is deleted before it was accepted, we can't delete it immediately. // We need to wait until the application accepts it, and delete it immediately then. - streamsToDelete map[protocol.StreamID]struct{} // used as a set + streamsToDelete map[protocol.StreamNum]struct{} // used as a set - nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamID // the highest stream that the peer openend - maxStream protocol.StreamID // the highest stream that the peer is allowed to open - maxNumStreams uint64 // maximum number of streams + nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() + nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend + maxStream protocol.StreamNum // the highest stream that the peer is allowed to open + maxNumStreams uint64 // maximum number of streams - newStream func(protocol.StreamID) streamI + newStream func(protocol.StreamNum) streamI queueMaxStreamID func(*wire.MaxStreamsFrame) + // streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors closeErr error } func newIncomingBidiStreamsMap( - nextStreamToAccept protocol.StreamID, - initialMaxStreamID protocol.StreamID, - maxNumStreams uint64, + newStream func(protocol.StreamNum) streamI, + maxStreams uint64, queueControlFrame func(wire.Frame), - newStream func(protocol.StreamID) streamI, + // streamNumToID func(protocol.StreamNum) protocol.StreamID, ) *incomingBidiStreamsMap { m := &incomingBidiStreamsMap{ - streams: make(map[protocol.StreamID]streamI), - streamsToDelete: make(map[protocol.StreamID]struct{}), - nextStreamToAccept: nextStreamToAccept, - nextStreamToOpen: nextStreamToAccept, - maxStream: initialMaxStreamID, - maxNumStreams: maxNumStreams, + streams: make(map[protocol.StreamNum]streamI), + streamsToDelete: make(map[protocol.StreamNum]struct{}), + maxStream: protocol.StreamNum(maxStreams), + maxNumStreams: maxStreams, newStream: newStream, + nextStreamToOpen: 1, + nextStreamToAccept: 1, queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, + // streamNumToID: streamNumToID, } m.cond.L = &m.mutex return m @@ -57,45 +57,48 @@ func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) { m.mutex.Lock() defer m.mutex.Unlock() - var id protocol.StreamID + var num protocol.StreamNum var str streamI for { - id = m.nextStreamToAccept + num = m.nextStreamToAccept var ok bool if m.closeErr != nil { return nil, m.closeErr } - str, ok = m.streams[id] + str, ok = m.streams[num] if ok { break } m.cond.Wait() } - m.nextStreamToAccept += 4 + m.nextStreamToAccept++ // If this stream was completed before being accepted, we can delete it now. - if _, ok := m.streamsToDelete[id]; ok { - delete(m.streamsToDelete, id) - if err := m.deleteStream(id); err != nil { + if _, ok := m.streamsToDelete[num]; ok { + delete(m.streamsToDelete, num) + if err := m.deleteStream(num); err != nil { return nil, err } } return str, nil } -func (m *incomingBidiStreamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { +func (m *incomingBidiStreamsMap) GetOrOpenStream(num protocol.StreamNum) (streamI, error) { m.mutex.RLock() - if id > m.maxStream { + if num > m.maxStream { m.mutex.RUnlock() - return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + return nil, streamError{ + message: "peer tried to open stream %d (current limit: %d)", + nums: []protocol.StreamNum{num, m.maxStream}, + } } - // if the id is smaller than the highest we accepted + // if the num is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil - if id < m.nextStreamToOpen { + if num < m.nextStreamToOpen { var s streamI // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. - if _, ok := m.streamsToDelete[id]; !ok { - s = m.streams[id] + if _, ok := m.streamsToDelete[num]; !ok { + s = m.streams[num] } m.mutex.RUnlock() return s, nil @@ -106,46 +109,52 @@ func (m *incomingBidiStreamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, // no need to check the two error conditions from above again // * maxStream can only increase, so if the id was valid before, it definitely is valid now // * highestStream is only modified by this function - for newID := m.nextStreamToOpen; newID <= id; newID += 4 { - m.streams[newID] = m.newStream(newID) + for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { + m.streams[newNum] = m.newStream(newNum) m.cond.Signal() } - m.nextStreamToOpen = id + 4 - s := m.streams[id] + m.nextStreamToOpen = num + 1 + s := m.streams[num] m.mutex.Unlock() return s, nil } -func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { +func (m *incomingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.deleteStream(id) + return m.deleteStream(num) } -func (m *incomingBidiStreamsMap) deleteStream(id protocol.StreamID) error { - if _, ok := m.streams[id]; !ok { - return fmt.Errorf("Tried to delete unknown stream %d", id) +func (m *incomingBidiStreamsMap) deleteStream(num protocol.StreamNum) error { + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "Tried to delete unknown stream %d", + nums: []protocol.StreamNum{num}, + } } // Don't delete this stream yet, if it was not yet accepted. // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. - if id >= m.nextStreamToAccept { - if _, ok := m.streamsToDelete[id]; ok { - return fmt.Errorf("Tried to delete stream %d multiple times", id) + if num >= m.nextStreamToAccept { + if _, ok := m.streamsToDelete[num]; ok { + return streamError{ + message: "Tried to delete stream %d multiple times", + nums: []protocol.StreamNum{num}, + } } - m.streamsToDelete[id] = struct{}{} + m.streamsToDelete[num] = struct{}{} return nil } - delete(m.streams, id) + delete(m.streams, num) // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream if m.maxNumStreams > uint64(len(m.streams)) { numNewStreams := m.maxNumStreams - uint64(len(m.streams)) - m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4) + m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1 m.queueMaxStreamID(&wire.MaxStreamsFrame{ Type: protocol.StreamTypeBidi, - MaxStreamNum: m.maxStream.StreamNum(), + MaxStreamNum: m.maxStream, }) } return nil diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go index 2d873db2..c7485aca 100644 --- a/streams_map_incoming_generic.go +++ b/streams_map_incoming_generic.go @@ -1,7 +1,6 @@ package quic import ( - "fmt" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -14,38 +13,39 @@ type incomingItemsMap struct { mutex sync.RWMutex cond sync.Cond - streams map[protocol.StreamID]item + streams map[protocol.StreamNum]item // When a stream is deleted before it was accepted, we can't delete it immediately. // We need to wait until the application accepts it, and delete it immediately then. - streamsToDelete map[protocol.StreamID]struct{} // used as a set + streamsToDelete map[protocol.StreamNum]struct{} // used as a set - nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamID // the highest stream that the peer openend - maxStream protocol.StreamID // the highest stream that the peer is allowed to open - maxNumStreams uint64 // maximum number of streams + nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() + nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend + maxStream protocol.StreamNum // the highest stream that the peer is allowed to open + maxNumStreams uint64 // maximum number of streams - newStream func(protocol.StreamID) item + newStream func(protocol.StreamNum) item queueMaxStreamID func(*wire.MaxStreamsFrame) + // streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors closeErr error } func newIncomingItemsMap( - nextStreamToAccept protocol.StreamID, - initialMaxStreamID protocol.StreamID, - maxNumStreams uint64, + newStream func(protocol.StreamNum) item, + maxStreams uint64, queueControlFrame func(wire.Frame), - newStream func(protocol.StreamID) item, + // streamNumToID func(protocol.StreamNum) protocol.StreamID, ) *incomingItemsMap { m := &incomingItemsMap{ - streams: make(map[protocol.StreamID]item), - streamsToDelete: make(map[protocol.StreamID]struct{}), - nextStreamToAccept: nextStreamToAccept, - nextStreamToOpen: nextStreamToAccept, - maxStream: initialMaxStreamID, - maxNumStreams: maxNumStreams, + streams: make(map[protocol.StreamNum]item), + streamsToDelete: make(map[protocol.StreamNum]struct{}), + maxStream: protocol.StreamNum(maxStreams), + maxNumStreams: maxStreams, newStream: newStream, + nextStreamToOpen: 1, + nextStreamToAccept: 1, queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, + // streamNumToID: streamNumToID, } m.cond.L = &m.mutex return m @@ -55,45 +55,48 @@ func (m *incomingItemsMap) AcceptStream() (item, error) { m.mutex.Lock() defer m.mutex.Unlock() - var id protocol.StreamID + var num protocol.StreamNum var str item for { - id = m.nextStreamToAccept + num = m.nextStreamToAccept var ok bool if m.closeErr != nil { return nil, m.closeErr } - str, ok = m.streams[id] + str, ok = m.streams[num] if ok { break } m.cond.Wait() } - m.nextStreamToAccept += 4 + m.nextStreamToAccept++ // If this stream was completed before being accepted, we can delete it now. - if _, ok := m.streamsToDelete[id]; ok { - delete(m.streamsToDelete, id) - if err := m.deleteStream(id); err != nil { + if _, ok := m.streamsToDelete[num]; ok { + delete(m.streamsToDelete, num) + if err := m.deleteStream(num); err != nil { return nil, err } } return str, nil } -func (m *incomingItemsMap) GetOrOpenStream(id protocol.StreamID) (item, error) { +func (m *incomingItemsMap) GetOrOpenStream(num protocol.StreamNum) (item, error) { m.mutex.RLock() - if id > m.maxStream { + if num > m.maxStream { m.mutex.RUnlock() - return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + return nil, streamError{ + message: "peer tried to open stream %d (current limit: %d)", + nums: []protocol.StreamNum{num, m.maxStream}, + } } - // if the id is smaller than the highest we accepted + // if the num is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil - if id < m.nextStreamToOpen { + if num < m.nextStreamToOpen { var s item // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. - if _, ok := m.streamsToDelete[id]; !ok { - s = m.streams[id] + if _, ok := m.streamsToDelete[num]; !ok { + s = m.streams[num] } m.mutex.RUnlock() return s, nil @@ -104,46 +107,52 @@ func (m *incomingItemsMap) GetOrOpenStream(id protocol.StreamID) (item, error) { // no need to check the two error conditions from above again // * maxStream can only increase, so if the id was valid before, it definitely is valid now // * highestStream is only modified by this function - for newID := m.nextStreamToOpen; newID <= id; newID += 4 { - m.streams[newID] = m.newStream(newID) + for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { + m.streams[newNum] = m.newStream(newNum) m.cond.Signal() } - m.nextStreamToOpen = id + 4 - s := m.streams[id] + m.nextStreamToOpen = num + 1 + s := m.streams[num] m.mutex.Unlock() return s, nil } -func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error { +func (m *incomingItemsMap) DeleteStream(num protocol.StreamNum) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.deleteStream(id) + return m.deleteStream(num) } -func (m *incomingItemsMap) deleteStream(id protocol.StreamID) error { - if _, ok := m.streams[id]; !ok { - return fmt.Errorf("Tried to delete unknown stream %d", id) +func (m *incomingItemsMap) deleteStream(num protocol.StreamNum) error { + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "Tried to delete unknown stream %d", + nums: []protocol.StreamNum{num}, + } } // Don't delete this stream yet, if it was not yet accepted. // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. - if id >= m.nextStreamToAccept { - if _, ok := m.streamsToDelete[id]; ok { - return fmt.Errorf("Tried to delete stream %d multiple times", id) + if num >= m.nextStreamToAccept { + if _, ok := m.streamsToDelete[num]; ok { + return streamError{ + message: "Tried to delete stream %d multiple times", + nums: []protocol.StreamNum{num}, + } } - m.streamsToDelete[id] = struct{}{} + m.streamsToDelete[num] = struct{}{} return nil } - delete(m.streams, id) + delete(m.streams, num) // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream if m.maxNumStreams > uint64(len(m.streams)) { numNewStreams := m.maxNumStreams - uint64(len(m.streams)) - m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4) + m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1 m.queueMaxStreamID(&wire.MaxStreamsFrame{ Type: streamTypeGeneric, - MaxStreamNum: m.maxStream.StreamNum(), + MaxStreamNum: m.maxStream, }) } return nil diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index 6f503c54..0a59b7f5 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -2,7 +2,6 @@ package quic import ( "errors" - "fmt" "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -13,7 +12,7 @@ import ( ) type mockGenericStream struct { - id protocol.StreamID + num protocol.StreamNum closed bool closeErr error @@ -26,64 +25,65 @@ func (s *mockGenericStream) closeForShutdown(err error) { var _ = Describe("Streams Map (incoming)", func() { const ( - firstNewStream protocol.StreamID = 2 - maxNumStreams uint64 = 5 - initialMaxStream protocol.StreamID = firstNewStream + 4*protocol.StreamID(maxNumStreams-1) + maxNumStreams uint64 = 5 ) var ( m *incomingItemsMap - newItem func(id protocol.StreamID) item newItemCounter int mockSender *MockStreamSender ) BeforeEach(func() { newItemCounter = 0 - newItem = func(id protocol.StreamID) item { - newItemCounter++ - return &mockGenericStream{id: id} - } mockSender = NewMockStreamSender(mockCtrl) - m = newIncomingItemsMap(firstNewStream, initialMaxStream, maxNumStreams, mockSender.queueControlFrame, newItem) + m = newIncomingItemsMap( + func(num protocol.StreamNum) item { + newItemCounter++ + return &mockGenericStream{num: num} + }, + maxNumStreams, + mockSender.queueControlFrame, + ) }) It("opens all streams up to the id on GetOrOpenStream", func() { - _, err := m.GetOrOpenStream(firstNewStream + 4*4) + _, err := m.GetOrOpenStream(4) Expect(err).ToNot(HaveOccurred()) - Expect(newItemCounter).To(Equal(5)) + Expect(newItemCounter).To(Equal(4)) }) It("starts opening streams at the right position", func() { // like the test above, but with 2 calls to GetOrOpenStream - _, err := m.GetOrOpenStream(firstNewStream + 4) + _, err := m.GetOrOpenStream(2) Expect(err).ToNot(HaveOccurred()) Expect(newItemCounter).To(Equal(2)) - _, err = m.GetOrOpenStream(firstNewStream + 4*4) + _, err = m.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) Expect(newItemCounter).To(Equal(5)) }) It("accepts streams in the right order", func() { - _, err := m.GetOrOpenStream(firstNewStream + 4) // open stream 20 and 24 + _, err := m.GetOrOpenStream(2) // open streams 1 and 2 Expect(err).ToNot(HaveOccurred()) str, err := m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream + 4)) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) }) It("allows opening the maximum stream ID", func() { - str, err := m.GetOrOpenStream(initialMaxStream) + str, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(initialMaxStream)) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) }) It("errors when trying to get a stream ID higher than the maximum", func() { - _, err := m.GetOrOpenStream(initialMaxStream + 4) - Expect(err).To(MatchError(fmt.Errorf("peer tried to open stream %d (current limit: %d)", initialMaxStream+4, initialMaxStream))) + _, err := m.GetOrOpenStream(6) + Expect(err).To(HaveOccurred()) + Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)")) }) It("blocks AcceptStream until a new stream is available", func() { @@ -95,30 +95,12 @@ var _ = Describe("Streams Map (incoming)", func() { strChan <- str }() Consistently(strChan).ShouldNot(Receive()) - str, err := m.GetOrOpenStream(firstNewStream) + str, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) var acceptedStr item Eventually(strChan).Should(Receive(&acceptedStr)) - Expect(acceptedStr.(*mockGenericStream).id).To(Equal(firstNewStream)) - }) - - It("works with stream 0", func() { - m = newIncomingItemsMap(0, 1000, 1000, mockSender.queueControlFrame, newItem) - strChan := make(chan item) - go func() { - defer GinkgoRecover() - str, err := m.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - strChan <- str - }() - Consistently(strChan).ShouldNot(Receive()) - str, err := m.GetOrOpenStream(0) - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(BeZero()) - var acceptedStr item - Eventually(strChan).Should(Receive(&acceptedStr)) - Expect(acceptedStr.(*mockGenericStream).id).To(BeZero()) + Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) }) It("unblocks AcceptStream when it is closed", func() { @@ -143,9 +125,9 @@ var _ = Describe("Streams Map (incoming)", func() { }) It("closes all streams when CloseWithError is called", func() { - str1, err := m.GetOrOpenStream(firstNewStream) + str1, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) - str2, err := m.GetOrOpenStream(firstNewStream + 8) + str2, err := m.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) testErr := errors.New("test err") m.CloseWithError(testErr) @@ -157,37 +139,37 @@ var _ = Describe("Streams Map (incoming)", func() { It("deletes streams", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) - _, err := m.GetOrOpenStream(firstNewStream) + _, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) str, err := m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) - Expect(m.DeleteStream(firstNewStream)).To(Succeed()) - str, err = m.GetOrOpenStream(firstNewStream) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + Expect(m.DeleteStream(1)).To(Succeed()) + str, err = m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) }) It("waits until a stream is accepted before actually deleting it", func() { - _, err := m.GetOrOpenStream(firstNewStream + 4) + _, err := m.GetOrOpenStream(2) Expect(err).ToNot(HaveOccurred()) - Expect(m.DeleteStream(firstNewStream + 4)).To(Succeed()) + Expect(m.DeleteStream(2)).To(Succeed()) str, err := m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued mockSender.EXPECT().queueControlFrame(gomock.Any()) str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream + 4)) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) }) It("doesn't return a stream queued for deleting from GetOrOpenStream", func() { - str, err := m.GetOrOpenStream(firstNewStream) + str, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) - Expect(m.DeleteStream(firstNewStream)).To(Succeed()) - str, err = m.GetOrOpenStream(firstNewStream) + Expect(m.DeleteStream(1)).To(Succeed()) + str, err = m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued @@ -199,12 +181,13 @@ var _ = Describe("Streams Map (incoming)", func() { It("errors when deleting a non-existing stream", func() { err := m.DeleteStream(1337) - Expect(err).To(MatchError("Tried to delete unknown stream 1337")) + Expect(err).To(HaveOccurred()) + Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown stream 1337")) }) It("sends MAX_STREAMS frames when streams are deleted", func() { // open a bunch of streams - _, err := m.GetOrOpenStream(firstNewStream + 4*4) + _, err := m.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) // accept all streams for i := 0; i < 5; i++ { @@ -214,10 +197,10 @@ var _ = Describe("Streams Map (incoming)", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1))) }) - Expect(m.DeleteStream(firstNewStream + 2*4)).To(Succeed()) + Expect(m.DeleteStream(3)).To(Succeed()) mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2))) }) - Expect(m.DeleteStream(firstNewStream + 3*4)).To(Succeed()) + Expect(m.DeleteStream(4)).To(Succeed()) }) }) diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go index 39ae05a2..cc83eb66 100644 --- a/streams_map_incoming_uni.go +++ b/streams_map_incoming_uni.go @@ -5,7 +5,6 @@ package quic import ( - "fmt" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -16,38 +15,39 @@ type incomingUniStreamsMap struct { mutex sync.RWMutex cond sync.Cond - streams map[protocol.StreamID]receiveStreamI + streams map[protocol.StreamNum]receiveStreamI // When a stream is deleted before it was accepted, we can't delete it immediately. // We need to wait until the application accepts it, and delete it immediately then. - streamsToDelete map[protocol.StreamID]struct{} // used as a set + streamsToDelete map[protocol.StreamNum]struct{} // used as a set - nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamID // the highest stream that the peer openend - maxStream protocol.StreamID // the highest stream that the peer is allowed to open - maxNumStreams uint64 // maximum number of streams + nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() + nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend + maxStream protocol.StreamNum // the highest stream that the peer is allowed to open + maxNumStreams uint64 // maximum number of streams - newStream func(protocol.StreamID) receiveStreamI + newStream func(protocol.StreamNum) receiveStreamI queueMaxStreamID func(*wire.MaxStreamsFrame) + // streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors closeErr error } func newIncomingUniStreamsMap( - nextStreamToAccept protocol.StreamID, - initialMaxStreamID protocol.StreamID, - maxNumStreams uint64, + newStream func(protocol.StreamNum) receiveStreamI, + maxStreams uint64, queueControlFrame func(wire.Frame), - newStream func(protocol.StreamID) receiveStreamI, + // streamNumToID func(protocol.StreamNum) protocol.StreamID, ) *incomingUniStreamsMap { m := &incomingUniStreamsMap{ - streams: make(map[protocol.StreamID]receiveStreamI), - streamsToDelete: make(map[protocol.StreamID]struct{}), - nextStreamToAccept: nextStreamToAccept, - nextStreamToOpen: nextStreamToAccept, - maxStream: initialMaxStreamID, - maxNumStreams: maxNumStreams, + streams: make(map[protocol.StreamNum]receiveStreamI), + streamsToDelete: make(map[protocol.StreamNum]struct{}), + maxStream: protocol.StreamNum(maxStreams), + maxNumStreams: maxStreams, newStream: newStream, + nextStreamToOpen: 1, + nextStreamToAccept: 1, queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, + // streamNumToID: streamNumToID, } m.cond.L = &m.mutex return m @@ -57,45 +57,48 @@ func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) { m.mutex.Lock() defer m.mutex.Unlock() - var id protocol.StreamID + var num protocol.StreamNum var str receiveStreamI for { - id = m.nextStreamToAccept + num = m.nextStreamToAccept var ok bool if m.closeErr != nil { return nil, m.closeErr } - str, ok = m.streams[id] + str, ok = m.streams[num] if ok { break } m.cond.Wait() } - m.nextStreamToAccept += 4 + m.nextStreamToAccept++ // If this stream was completed before being accepted, we can delete it now. - if _, ok := m.streamsToDelete[id]; ok { - delete(m.streamsToDelete, id) - if err := m.deleteStream(id); err != nil { + if _, ok := m.streamsToDelete[num]; ok { + delete(m.streamsToDelete, num) + if err := m.deleteStream(num); err != nil { return nil, err } } return str, nil } -func (m *incomingUniStreamsMap) GetOrOpenStream(id protocol.StreamID) (receiveStreamI, error) { +func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receiveStreamI, error) { m.mutex.RLock() - if id > m.maxStream { + if num > m.maxStream { m.mutex.RUnlock() - return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + return nil, streamError{ + message: "peer tried to open stream %d (current limit: %d)", + nums: []protocol.StreamNum{num, m.maxStream}, + } } - // if the id is smaller than the highest we accepted + // if the num is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil - if id < m.nextStreamToOpen { + if num < m.nextStreamToOpen { var s receiveStreamI // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. - if _, ok := m.streamsToDelete[id]; !ok { - s = m.streams[id] + if _, ok := m.streamsToDelete[num]; !ok { + s = m.streams[num] } m.mutex.RUnlock() return s, nil @@ -106,46 +109,52 @@ func (m *incomingUniStreamsMap) GetOrOpenStream(id protocol.StreamID) (receiveSt // no need to check the two error conditions from above again // * maxStream can only increase, so if the id was valid before, it definitely is valid now // * highestStream is only modified by this function - for newID := m.nextStreamToOpen; newID <= id; newID += 4 { - m.streams[newID] = m.newStream(newID) + for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { + m.streams[newNum] = m.newStream(newNum) m.cond.Signal() } - m.nextStreamToOpen = id + 4 - s := m.streams[id] + m.nextStreamToOpen = num + 1 + s := m.streams[num] m.mutex.Unlock() return s, nil } -func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error { +func (m *incomingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.deleteStream(id) + return m.deleteStream(num) } -func (m *incomingUniStreamsMap) deleteStream(id protocol.StreamID) error { - if _, ok := m.streams[id]; !ok { - return fmt.Errorf("Tried to delete unknown stream %d", id) +func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error { + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "Tried to delete unknown stream %d", + nums: []protocol.StreamNum{num}, + } } // Don't delete this stream yet, if it was not yet accepted. // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. - if id >= m.nextStreamToAccept { - if _, ok := m.streamsToDelete[id]; ok { - return fmt.Errorf("Tried to delete stream %d multiple times", id) + if num >= m.nextStreamToAccept { + if _, ok := m.streamsToDelete[num]; ok { + return streamError{ + message: "Tried to delete stream %d multiple times", + nums: []protocol.StreamNum{num}, + } } - m.streamsToDelete[id] = struct{}{} + m.streamsToDelete[num] = struct{}{} return nil } - delete(m.streams, id) + delete(m.streams, num) // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream if m.maxNumStreams > uint64(len(m.streams)) { numNewStreams := m.maxNumStreams - uint64(len(m.streams)) - m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4) + m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1 m.queueMaxStreamID(&wire.MaxStreamsFrame{ Type: protocol.StreamTypeUni, - MaxStreamNum: m.maxStream.StreamNum(), + MaxStreamNum: m.maxStream, }) } return nil diff --git a/streams_map_outgoing_bidi.go b/streams_map_outgoing_bidi.go index ca8727e6..bce33b52 100644 --- a/streams_map_outgoing_bidi.go +++ b/streams_map_outgoing_bidi.go @@ -5,11 +5,9 @@ package quic import ( - "fmt" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -17,27 +15,26 @@ type outgoingBidiStreamsMap struct { mutex sync.RWMutex cond sync.Cond - streams map[protocol.StreamID]streamI + streams map[protocol.StreamNum]streamI - nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) - maxStream protocol.StreamID // the maximum stream ID we're allowed to open - blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamNum // the maximum stream ID we're allowed to open + blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - newStream func(protocol.StreamID) streamI + newStream func(protocol.StreamNum) streamI queueStreamIDBlocked func(*wire.StreamsBlockedFrame) closeErr error } func newOutgoingBidiStreamsMap( - nextStream protocol.StreamID, - newStream func(protocol.StreamID) streamI, + newStream func(protocol.StreamNum) streamI, queueControlFrame func(wire.Frame), ) *outgoingBidiStreamsMap { m := &outgoingBidiStreamsMap{ - streams: make(map[protocol.StreamID]streamI), - nextStream: nextStream, - maxStream: protocol.InvalidStreamID, + streams: make(map[protocol.StreamNum]streamI), + maxStream: protocol.InvalidStreamNum, + nextStream: 1, newStream: newStream, queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, } @@ -83,8 +80,8 @@ func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) { if m.nextStream > m.maxStream { if !m.blockedSent { var streamNum protocol.StreamNum - if m.maxStream != protocol.InvalidStreamID { - streamNum = m.maxStream.StreamNum() + if m.maxStream != protocol.InvalidStreamNum { + streamNum = m.maxStream } m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ Type: protocol.StreamTypeBidi, @@ -96,36 +93,42 @@ func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) { } s := m.newStream(m.nextStream) m.streams[m.nextStream] = s - m.nextStream += 4 + m.nextStream++ return s, nil } -func (m *outgoingBidiStreamsMap) GetStream(id protocol.StreamID) (streamI, error) { +func (m *outgoingBidiStreamsMap) GetStream(num protocol.StreamNum) (streamI, error) { m.mutex.RLock() - if id >= m.nextStream { + if num >= m.nextStream { m.mutex.RUnlock() - return nil, qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id)) + return nil, streamError{ + message: "peer attempted to open stream %d", + nums: []protocol.StreamNum{num}, + } } - s := m.streams[id] + s := m.streams[num] m.mutex.RUnlock() return s, nil } -func (m *outgoingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { +func (m *outgoingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { m.mutex.Lock() defer m.mutex.Unlock() - if _, ok := m.streams[id]; !ok { - return fmt.Errorf("Tried to delete unknown stream %d", id) + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "Tried to delete unknown stream %d", + nums: []protocol.StreamNum{num}, + } } - delete(m.streams, id) + delete(m.streams, num) return nil } -func (m *outgoingBidiStreamsMap) SetMaxStream(id protocol.StreamID) { +func (m *outgoingBidiStreamsMap) SetMaxStream(num protocol.StreamNum) { m.mutex.Lock() - if id > m.maxStream { - m.maxStream = id + if num > m.maxStream { + m.maxStream = num m.blockedSent = false m.cond.Broadcast() } diff --git a/streams_map_outgoing_generic.go b/streams_map_outgoing_generic.go index 4fcf0471..3fc04a4d 100644 --- a/streams_map_outgoing_generic.go +++ b/streams_map_outgoing_generic.go @@ -1,11 +1,9 @@ package quic import ( - "fmt" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -15,27 +13,26 @@ type outgoingItemsMap struct { mutex sync.RWMutex cond sync.Cond - streams map[protocol.StreamID]item + streams map[protocol.StreamNum]item - nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) - maxStream protocol.StreamID // the maximum stream ID we're allowed to open - blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamNum // the maximum stream ID we're allowed to open + blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - newStream func(protocol.StreamID) item + newStream func(protocol.StreamNum) item queueStreamIDBlocked func(*wire.StreamsBlockedFrame) closeErr error } func newOutgoingItemsMap( - nextStream protocol.StreamID, - newStream func(protocol.StreamID) item, + newStream func(protocol.StreamNum) item, queueControlFrame func(wire.Frame), ) *outgoingItemsMap { m := &outgoingItemsMap{ - streams: make(map[protocol.StreamID]item), - nextStream: nextStream, - maxStream: protocol.InvalidStreamID, + streams: make(map[protocol.StreamNum]item), + maxStream: protocol.InvalidStreamNum, + nextStream: 1, newStream: newStream, queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, } @@ -81,8 +78,8 @@ func (m *outgoingItemsMap) openStreamImpl() (item, error) { if m.nextStream > m.maxStream { if !m.blockedSent { var streamNum protocol.StreamNum - if m.maxStream != protocol.InvalidStreamID { - streamNum = m.maxStream.StreamNum() + if m.maxStream != protocol.InvalidStreamNum { + streamNum = m.maxStream } m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ Type: streamTypeGeneric, @@ -94,36 +91,42 @@ func (m *outgoingItemsMap) openStreamImpl() (item, error) { } s := m.newStream(m.nextStream) m.streams[m.nextStream] = s - m.nextStream += 4 + m.nextStream++ return s, nil } -func (m *outgoingItemsMap) GetStream(id protocol.StreamID) (item, error) { +func (m *outgoingItemsMap) GetStream(num protocol.StreamNum) (item, error) { m.mutex.RLock() - if id >= m.nextStream { + if num >= m.nextStream { m.mutex.RUnlock() - return nil, qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id)) + return nil, streamError{ + message: "peer attempted to open stream %d", + nums: []protocol.StreamNum{num}, + } } - s := m.streams[id] + s := m.streams[num] m.mutex.RUnlock() return s, nil } -func (m *outgoingItemsMap) DeleteStream(id protocol.StreamID) error { +func (m *outgoingItemsMap) DeleteStream(num protocol.StreamNum) error { m.mutex.Lock() defer m.mutex.Unlock() - if _, ok := m.streams[id]; !ok { - return fmt.Errorf("Tried to delete unknown stream %d", id) + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "Tried to delete unknown stream %d", + nums: []protocol.StreamNum{num}, + } } - delete(m.streams, id) + delete(m.streams, num) return nil } -func (m *outgoingItemsMap) SetMaxStream(id protocol.StreamID) { +func (m *outgoingItemsMap) SetMaxStream(num protocol.StreamNum) { m.mutex.Lock() - if id > m.maxStream { - m.maxStream = id + if num > m.maxStream { + m.maxStream = num m.blockedSent = false m.cond.Broadcast() } diff --git a/streams_map_outgoing_generic_test.go b/streams_map_outgoing_generic_test.go index 5ab66089..9779a901 100644 --- a/streams_map_outgoing_generic_test.go +++ b/streams_map_outgoing_generic_test.go @@ -5,27 +5,24 @@ import ( "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Streams Map (outgoing)", func() { - const firstNewStream protocol.StreamID = 3 - var ( m *outgoingItemsMap - newItem func(id protocol.StreamID) item + newItem func(num protocol.StreamNum) item mockSender *MockStreamSender ) BeforeEach(func() { - newItem = func(id protocol.StreamID) item { - return &mockGenericStream{id: id} + newItem = func(num protocol.StreamNum) item { + return &mockGenericStream{num: num} } mockSender = NewMockStreamSender(mockCtrl) - m = newOutgoingItemsMap(firstNewStream, newItem, mockSender.queueControlFrame) + m = newOutgoingItemsMap(newItem, mockSender.queueControlFrame) }) Context("no stream ID limit", func() { @@ -36,10 +33,10 @@ var _ = Describe("Streams Map (outgoing)", func() { It("opens streams", func() { str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) str, err = m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream + 4)) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) }) It("doesn't open streams after it has been closed", func() { @@ -52,38 +49,40 @@ var _ = Describe("Streams Map (outgoing)", func() { It("gets streams", func() { _, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - str, err := m.GetStream(firstNewStream) + str, err := m.GetStream(1) Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) }) It("errors when trying to get a stream that has not yet been opened", func() { - _, err := m.GetStream(firstNewStream) - Expect(err).To(MatchError(qerr.Error(qerr.StreamStateError, "peer attempted to open stream 3"))) + _, err := m.GetStream(1) + Expect(err).To(HaveOccurred()) + Expect(err.(streamError).TestError()).To(MatchError("peer attempted to open stream 1")) }) It("deletes streams", func() { - _, err := m.OpenStream() // opens firstNewStream + _, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - err = m.DeleteStream(firstNewStream) + Expect(m.DeleteStream(1)).To(Succeed()) Expect(err).ToNot(HaveOccurred()) - str, err := m.GetStream(firstNewStream) + str, err := m.GetStream(1) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) }) It("errors when deleting a non-existing stream", func() { err := m.DeleteStream(1337) - Expect(err).To(MatchError("Tried to delete unknown stream 1337")) + Expect(err).To(HaveOccurred()) + Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown stream 1337")) }) It("errors when deleting a stream twice", func() { _, err := m.OpenStream() // opens firstNewStream Expect(err).ToNot(HaveOccurred()) - err = m.DeleteStream(firstNewStream) - Expect(err).ToNot(HaveOccurred()) - err = m.DeleteStream(firstNewStream) - Expect(err).To(MatchError("Tried to delete unknown stream 3")) + Expect(m.DeleteStream(1)).To(Succeed()) + err = m.DeleteStream(1) + Expect(err).To(HaveOccurred()) + Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown stream 1")) }) It("closes all streams when CloseWithError is called", func() { @@ -114,31 +113,12 @@ var _ = Describe("Streams Map (outgoing)", func() { defer GinkgoRecover() str, err := m.OpenStreamSync() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) close(done) }() Consistently(done).ShouldNot(BeClosed()) - m.SetMaxStream(firstNewStream) - Eventually(done).Should(BeClosed()) - }) - - It("works with stream 0", func() { - m = newOutgoingItemsMap(0, newItem, mockSender.queueControlFrame) - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeZero()) - }) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - str, err := m.OpenStreamSync() - Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(BeZero()) - close(done) - }() - - Consistently(done).ShouldNot(BeClosed()) - m.SetMaxStream(0) + m.SetMaxStream(1) Eventually(done).Should(BeClosed()) }) @@ -159,17 +139,17 @@ var _ = Describe("Streams Map (outgoing)", func() { }) It("doesn't reduce the stream limit", func() { - m.SetMaxStream(firstNewStream + 4) - m.SetMaxStream(firstNewStream) + m.SetMaxStream(2) + m.SetMaxStream(1) _, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream + 4)) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) }) It("queues a STREAM_ID_BLOCKED frame if no stream can be opened", func() { - m.SetMaxStream(firstNewStream + 5*4) + m.SetMaxStream(6) // open the 6 allowed streams for i := 0; i < 6; i++ { _, err := m.OpenStream() @@ -185,7 +165,7 @@ var _ = Describe("Streams Map (outgoing)", func() { }) It("only sends one STREAM_ID_BLOCKED frame for one stream ID", func() { - m.SetMaxStream(firstNewStream) + m.SetMaxStream(1) mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1)) }) diff --git a/streams_map_outgoing_uni.go b/streams_map_outgoing_uni.go index 42befdaa..3ed98ff5 100644 --- a/streams_map_outgoing_uni.go +++ b/streams_map_outgoing_uni.go @@ -5,11 +5,9 @@ package quic import ( - "fmt" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -17,27 +15,26 @@ type outgoingUniStreamsMap struct { mutex sync.RWMutex cond sync.Cond - streams map[protocol.StreamID]sendStreamI + streams map[protocol.StreamNum]sendStreamI - nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) - maxStream protocol.StreamID // the maximum stream ID we're allowed to open - blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamNum // the maximum stream ID we're allowed to open + blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - newStream func(protocol.StreamID) sendStreamI + newStream func(protocol.StreamNum) sendStreamI queueStreamIDBlocked func(*wire.StreamsBlockedFrame) closeErr error } func newOutgoingUniStreamsMap( - nextStream protocol.StreamID, - newStream func(protocol.StreamID) sendStreamI, + newStream func(protocol.StreamNum) sendStreamI, queueControlFrame func(wire.Frame), ) *outgoingUniStreamsMap { m := &outgoingUniStreamsMap{ - streams: make(map[protocol.StreamID]sendStreamI), - nextStream: nextStream, - maxStream: protocol.InvalidStreamID, + streams: make(map[protocol.StreamNum]sendStreamI), + maxStream: protocol.InvalidStreamNum, + nextStream: 1, newStream: newStream, queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, } @@ -83,8 +80,8 @@ func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) { if m.nextStream > m.maxStream { if !m.blockedSent { var streamNum protocol.StreamNum - if m.maxStream != protocol.InvalidStreamID { - streamNum = m.maxStream.StreamNum() + if m.maxStream != protocol.InvalidStreamNum { + streamNum = m.maxStream } m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ Type: protocol.StreamTypeUni, @@ -96,36 +93,42 @@ func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) { } s := m.newStream(m.nextStream) m.streams[m.nextStream] = s - m.nextStream += 4 + m.nextStream++ return s, nil } -func (m *outgoingUniStreamsMap) GetStream(id protocol.StreamID) (sendStreamI, error) { +func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI, error) { m.mutex.RLock() - if id >= m.nextStream { + if num >= m.nextStream { m.mutex.RUnlock() - return nil, qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id)) + return nil, streamError{ + message: "peer attempted to open stream %d", + nums: []protocol.StreamNum{num}, + } } - s := m.streams[id] + s := m.streams[num] m.mutex.RUnlock() return s, nil } -func (m *outgoingUniStreamsMap) DeleteStream(id protocol.StreamID) error { +func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { m.mutex.Lock() defer m.mutex.Unlock() - if _, ok := m.streams[id]; !ok { - return fmt.Errorf("Tried to delete unknown stream %d", id) + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "Tried to delete unknown stream %d", + nums: []protocol.StreamNum{num}, + } } - delete(m.streams, id) + delete(m.streams, num) return nil } -func (m *outgoingUniStreamsMap) SetMaxStream(id protocol.StreamID) { +func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) { m.mutex.Lock() - if id > m.maxStream { - m.maxStream = id + if num > m.maxStream { + m.maxStream = num m.blockedSent = false m.cond.Broadcast() } diff --git a/streams_map_test.go b/streams_map_test.go index 5b233de6..cbace466 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -17,6 +17,14 @@ import ( . "github.com/onsi/gomega" ) +func (e streamError) TestError() error { + nums := make([]interface{}, len(e.nums)) + for i, num := range e.nums { + nums[i] = num + } + return fmt.Errorf(e.message, nums...) +} + type streamMapping struct { firstIncomingBidiStream protocol.StreamID firstIncomingUniStream protocol.StreamID @@ -221,7 +229,7 @@ var _ = Describe("Streams Map", func() { It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { id := ids.firstOutgoingBidiStream + 5*4 _, err := m.GetOrOpenSendStream(id) - Expect(err).To(MatchError(qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id)))) + Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open stream %d", id))) }) It("gets an outgoing unidirectional stream", func() { @@ -237,7 +245,7 @@ var _ = Describe("Streams Map", func() { It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { id := ids.firstOutgoingUniStream + 5*4 _, err := m.GetOrOpenSendStream(id) - Expect(err).To(MatchError(qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id)))) + Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open stream %d", id))) }) It("gets an incoming bidirectional stream", func() { @@ -250,7 +258,7 @@ var _ = Describe("Streams Map", func() { It("errors when trying to get an incoming unidirectional stream", func() { id := ids.firstIncomingUniStream _, err := m.GetOrOpenSendStream(id) - Expect(err).To(MatchError(fmt.Errorf("peer attempted to open send stream %d", id))) + Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open send stream %d", id))) }) }) @@ -268,7 +276,7 @@ var _ = Describe("Streams Map", func() { It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { id := ids.firstOutgoingBidiStream + 5*4 _, err := m.GetOrOpenReceiveStream(id) - Expect(err).To(MatchError(qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id)))) + Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open stream %d", id))) }) It("gets an incoming bidirectional stream", func() { @@ -288,37 +296,44 @@ var _ = Describe("Streams Map", func() { It("errors when trying to get an outgoing unidirectional stream", func() { id := ids.firstOutgoingUniStream _, err := m.GetOrOpenReceiveStream(id) - Expect(err).To(MatchError(fmt.Errorf("peer attempted to open receive stream %d", id))) + Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open receive stream %d", id))) }) }) }) Context("updating stream ID limits", func() { - It("processes the parameter for outgoing streams, as a server", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - m.perspective = protocol.PerspectiveServer - _, err := m.OpenStream() - expectTooManyStreamsError(err) - Expect(m.UpdateLimits(&handshake.TransportParameters{ - MaxBidiStreamNum: 5, - MaxUniStreamNum: 5, - })).To(Succeed()) - Expect(m.outgoingBidiStreams.maxStream).To(Equal(protocol.StreamID(17))) - Expect(m.outgoingUniStreams.maxStream).To(Equal(protocol.StreamID(19))) - }) + for _, p := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} { + pers := p - It("processes the parameter for outgoing streams, as a client", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - m.perspective = protocol.PerspectiveClient - _, err := m.OpenUniStream() - expectTooManyStreamsError(err) - Expect(m.UpdateLimits(&handshake.TransportParameters{ - MaxBidiStreamNum: 5, - MaxUniStreamNum: 5, - })).To(Succeed()) - Expect(m.outgoingBidiStreams.maxStream).To(Equal(protocol.StreamID(16))) - Expect(m.outgoingUniStreams.maxStream).To(Equal(protocol.StreamID(18))) - }) + It(fmt.Sprintf("processes the parameter for outgoing streams, as a %s", pers), func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + m.perspective = pers + _, err := m.OpenStream() + expectTooManyStreamsError(err) + Expect(m.UpdateLimits(&handshake.TransportParameters{ + MaxBidiStreamNum: 5, + MaxUniStreamNum: 8, + })).To(Succeed()) + + mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2) + // test we can only 5 bidirectional streams + for i := 0; i < 5; i++ { + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + protocol.StreamID(4*i))) + } + _, err = m.OpenStream() + expectTooManyStreamsError(err) + // test we can only 8 unidirectional streams + for i := 0; i < 8; i++ { + str, err := m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + protocol.StreamID(4*i))) + } + _, err = m.OpenUniStream() + expectTooManyStreamsError(err) + }) + } It("rejects parameters with too large unidirectional stream counts", func() { Expect(m.UpdateLimits(&handshake.TransportParameters{