diff --git a/streams_map.go b/streams_map.go index 694b8609..6cfd45a8 100644 --- a/streams_map.go +++ b/streams_map.go @@ -2,6 +2,7 @@ package quic import ( "fmt" + "math" "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/handshake" @@ -65,9 +66,11 @@ func newStreamsMap( return newReceiveStream(id, m.sender, m.newFlowController(id), version) } m.outgoingBidiStreams = newOutgoingBidiStreamsMap(firstOutgoingBidiStream, newBidiStream) - m.incomingBidiStreams = newIncomingBidiStreamsMap(firstIncomingBidiStream, newBidiStream) + // TODO(#1150): use a reasonable stream limit + m.incomingBidiStreams = newIncomingBidiStreamsMap(firstIncomingBidiStream, protocol.StreamID(math.MaxUint32), newBidiStream) m.outgoingUniStreams = newOutgoingUniStreamsMap(firstOutgoingUniStream, newUniSendStream) - m.incomingUniStreams = newIncomingUniStreamsMap(firstIncomingUniStream, newUniReceiveStream) + // TODO(#1150): use a reasonable stream limit + m.incomingUniStreams = newIncomingUniStreamsMap(firstIncomingUniStream, protocol.StreamID(math.MaxUint32), newUniReceiveStream) return m } diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go index 774bf1a6..abd446d4 100644 --- a/streams_map_incoming_bidi.go +++ b/streams_map_incoming_bidi.go @@ -17,17 +17,23 @@ type incomingBidiStreamsMap struct { streams map[protocol.StreamID]streamI - nextStream protocol.StreamID - highestStream protocol.StreamID + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open newStream func(protocol.StreamID) streamI closeErr error } -func newIncomingBidiStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) streamI) *incomingBidiStreamsMap { +func newIncomingBidiStreamsMap( + nextStream protocol.StreamID, + maxStream protocol.StreamID, + newStream func(protocol.StreamID) streamI, +) *incomingBidiStreamsMap { m := &incomingBidiStreamsMap{ streams: make(map[protocol.StreamID]streamI), nextStream: nextStream, + maxStream: maxStream, newStream: newStream, } m.cond.L = &m.mutex @@ -55,6 +61,9 @@ func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) { } func (m *incomingBidiStreamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { + if id > m.maxStream { + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } // if the id is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go index e03311c6..44811f67 100644 --- a/streams_map_incoming_generic.go +++ b/streams_map_incoming_generic.go @@ -15,17 +15,23 @@ type incomingItemsMap struct { streams map[protocol.StreamID]item - nextStream protocol.StreamID - highestStream protocol.StreamID + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open newStream func(protocol.StreamID) item closeErr error } -func newIncomingItemsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) item) *incomingItemsMap { +func newIncomingItemsMap( + nextStream protocol.StreamID, + maxStream protocol.StreamID, + newStream func(protocol.StreamID) item, +) *incomingItemsMap { m := &incomingItemsMap{ streams: make(map[protocol.StreamID]item), nextStream: nextStream, + maxStream: maxStream, newStream: newStream, } m.cond.L = &m.mutex @@ -53,6 +59,9 @@ func (m *incomingItemsMap) AcceptStream() (item, error) { } func (m *incomingItemsMap) GetOrOpenStream(id protocol.StreamID) (item, error) { + if id > m.maxStream { + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } // if the id is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index 6b2b9b43..ad73df45 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -2,6 +2,7 @@ package quic import ( "errors" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -10,7 +11,11 @@ import ( ) var _ = Describe("Streams Map (outgoing)", func() { - const firstNewStream protocol.StreamID = 20 + const ( + firstNewStream protocol.StreamID = 20 + maxStream protocol.StreamID = firstNewStream + 4*100 + ) + var ( m *incomingItemsMap newItem func(id protocol.StreamID) item @@ -23,7 +28,7 @@ var _ = Describe("Streams Map (outgoing)", func() { newItemCounter++ return id } - m = newIncomingItemsMap(firstNewStream, newItem) + m = newIncomingItemsMap(firstNewStream, maxStream, newItem) }) It("opens all streams up to the id on GetOrOpenStream", func() { @@ -53,6 +58,17 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(str).To(Equal(firstNewStream + 4)) }) + It("allows opening the maximum stream ID", func() { + str, err := m.GetOrOpenStream(maxStream) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(maxStream)) + }) + + It("errors when trying to get a stream ID higher than the maximum", func() { + _, err := m.GetOrOpenStream(maxStream + 4) + Expect(err).To(MatchError(fmt.Errorf("peer tried to open stream %d (current limit: %d)", maxStream+4, maxStream))) + }) + It("blocks AcceptStream until a new stream is available", func() { strChan := make(chan item) go func() { diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go index 7cf57afb..2dd2deeb 100644 --- a/streams_map_incoming_uni.go +++ b/streams_map_incoming_uni.go @@ -17,17 +17,23 @@ type incomingUniStreamsMap struct { streams map[protocol.StreamID]receiveStreamI - nextStream protocol.StreamID - highestStream protocol.StreamID + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open newStream func(protocol.StreamID) receiveStreamI closeErr error } -func newIncomingUniStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) receiveStreamI) *incomingUniStreamsMap { +func newIncomingUniStreamsMap( + nextStream protocol.StreamID, + maxStream protocol.StreamID, + newStream func(protocol.StreamID) receiveStreamI, +) *incomingUniStreamsMap { m := &incomingUniStreamsMap{ streams: make(map[protocol.StreamID]receiveStreamI), nextStream: nextStream, + maxStream: maxStream, newStream: newStream, } m.cond.L = &m.mutex @@ -55,6 +61,9 @@ func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) { } func (m *incomingUniStreamsMap) GetOrOpenStream(id protocol.StreamID) (receiveStreamI, error) { + if id > m.maxStream { + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } // if the id is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil