From 1ec720f2f2a486638bf05d0e1763aed9022c440d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 5 Feb 2018 16:03:02 +0800 Subject: [PATCH] implement sending of STREAM_ID_BLOCKED frames --- session.go | 1 + session_test.go | 10 ++++++++ streams_map.go | 12 ++++++++-- streams_map_outgoing_bidi.go | 27 ++++++++++++++++------ streams_map_outgoing_generic.go | 27 ++++++++++++++++------ streams_map_outgoing_generic_test.go | 34 +++++++++++++++++++++++++--- streams_map_outgoing_uni.go | 27 ++++++++++++++++------ streams_map_test.go | 8 +++++++ 8 files changed, 120 insertions(+), 26 deletions(-) diff --git a/session.go b/session.go index f93bf012..49c61651 100644 --- a/session.go +++ b/session.go @@ -568,6 +568,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve err = s.handleMaxStreamIDFrame(frame) case *wire.BlockedFrame: case *wire.StreamBlockedFrame: + case *wire.StreamIDBlockedFrame: case *wire.StopSendingFrame: err = s.handleStopSendingFrame(frame) case *wire.PingFrame: diff --git a/session_test.go b/session_test.go index 53c27ed0..45eca72b 100644 --- a/session_test.go +++ b/session_test.go @@ -415,6 +415,16 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) }) + It("handles STREAM_BLOCKED frames", func() { + err := sess.handleFrames([]wire.Frame{&wire.StreamBlockedFrame{}}, protocol.EncryptionUnspecified) + Expect(err).NotTo(HaveOccurred()) + }) + + It("handles STREAM_ID_BLOCKED frames", func() { + err := sess.handleFrames([]wire.Frame{&wire.StreamIDBlockedFrame{}}, protocol.EncryptionUnspecified) + Expect(err).NotTo(HaveOccurred()) + }) + It("errors on GOAWAY frames", func() { err := sess.handleFrames([]wire.Frame{&wire.GoawayFrame{}}, protocol.EncryptionUnspecified) Expect(err).To(MatchError("unimplemented: handling GOAWAY frames")) diff --git a/streams_map.go b/streams_map.go index 8287c83b..c3ce2ef4 100644 --- a/streams_map.go +++ b/streams_map.go @@ -64,7 +64,11 @@ func newStreamsMap( newUniReceiveStream := func(id protocol.StreamID) receiveStreamI { return newReceiveStream(id, m.sender, m.newFlowController(id), version) } - m.outgoingBidiStreams = newOutgoingBidiStreamsMap(firstOutgoingBidiStream, newBidiStream) + m.outgoingBidiStreams = newOutgoingBidiStreamsMap( + firstOutgoingBidiStream, + newBidiStream, + sender.queueControlFrame, + ) // TODO(#523): make these values configurable m.incomingBidiStreams = newIncomingBidiStreamsMap( firstIncomingBidiStream, @@ -73,7 +77,11 @@ func newStreamsMap( sender.queueControlFrame, newBidiStream, ) - m.outgoingUniStreams = newOutgoingUniStreamsMap(firstOutgoingUniStream, newUniSendStream) + m.outgoingUniStreams = newOutgoingUniStreamsMap( + firstOutgoingUniStream, + newUniSendStream, + sender.queueControlFrame, + ) // TODO(#523): make these values configurable m.incomingUniStreams = newIncomingUniStreamsMap( firstIncomingUniStream, diff --git a/streams_map_outgoing_bidi.go b/streams_map_outgoing_bidi.go index d6951a23..d2c92dec 100644 --- a/streams_map_outgoing_bidi.go +++ b/streams_map_outgoing_bidi.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) @@ -18,18 +19,26 @@ type outgoingBidiStreamsMap struct { streams map[protocol.StreamID]streamI - nextStream protocol.StreamID - maxStream protocol.StreamID - newStream func(protocol.StreamID) streamI + nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamID // the maximum stream ID we're allowed to open + highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for + + newStream func(protocol.StreamID) streamI + queueStreamIDBlocked func(*wire.StreamIDBlockedFrame) closeErr error } -func newOutgoingBidiStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) streamI) *outgoingBidiStreamsMap { +func newOutgoingBidiStreamsMap( + nextStream protocol.StreamID, + newStream func(protocol.StreamID) streamI, + queueControlFrame func(wire.Frame), +) *outgoingBidiStreamsMap { m := &outgoingBidiStreamsMap{ - streams: make(map[protocol.StreamID]streamI), - nextStream: nextStream, - newStream: newStream, + streams: make(map[protocol.StreamID]streamI), + nextStream: nextStream, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) }, } m.cond.L = &m.mutex return m @@ -63,6 +72,10 @@ func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) { return nil, m.closeErr } if m.nextStream > m.maxStream { + if m.maxStream == 0 || m.highestBlocked < m.maxStream { + m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream}) + m.highestBlocked = m.maxStream + } return nil, qerr.TooManyOpenStreams } s := m.newStream(m.nextStream) diff --git a/streams_map_outgoing_generic.go b/streams_map_outgoing_generic.go index 49ead874..5a283602 100644 --- a/streams_map_outgoing_generic.go +++ b/streams_map_outgoing_generic.go @@ -6,6 +6,7 @@ import ( "github.com/cheekybits/genny/generic" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) @@ -19,18 +20,26 @@ type outgoingItemsMap struct { streams map[protocol.StreamID]item - nextStream protocol.StreamID - maxStream protocol.StreamID - newStream func(protocol.StreamID) item + nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamID // the maximum stream ID we're allowed to open + highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for + + newStream func(protocol.StreamID) item + queueStreamIDBlocked func(*wire.StreamIDBlockedFrame) closeErr error } -func newOutgoingItemsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) item) *outgoingItemsMap { +func newOutgoingItemsMap( + nextStream protocol.StreamID, + newStream func(protocol.StreamID) item, + queueControlFrame func(wire.Frame), +) *outgoingItemsMap { m := &outgoingItemsMap{ - streams: make(map[protocol.StreamID]item), - nextStream: nextStream, - newStream: newStream, + streams: make(map[protocol.StreamID]item), + nextStream: nextStream, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) }, } m.cond.L = &m.mutex return m @@ -64,6 +73,10 @@ func (m *outgoingItemsMap) openStreamImpl() (item, error) { return nil, m.closeErr } if m.nextStream > m.maxStream { + if m.maxStream == 0 || m.highestBlocked < m.maxStream { + m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream}) + m.highestBlocked = m.maxStream + } return nil, qerr.TooManyOpenStreams } s := m.newStream(m.nextStream) diff --git a/streams_map_outgoing_generic_test.go b/streams_map_outgoing_generic_test.go index a4cf22f1..eb28784a 100644 --- a/streams_map_outgoing_generic_test.go +++ b/streams_map_outgoing_generic_test.go @@ -3,7 +3,9 @@ package quic import ( "errors" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -12,15 +14,17 @@ import ( var _ = Describe("Streams Map (outgoing)", func() { const firstNewStream protocol.StreamID = 10 var ( - m *outgoingItemsMap - newItem func(id protocol.StreamID) item + m *outgoingItemsMap + newItem func(id protocol.StreamID) item + mockSender *MockStreamSender ) BeforeEach(func() { newItem = func(id protocol.StreamID) item { return id } - m = newOutgoingItemsMap(firstNewStream, newItem) + mockSender = NewMockStreamSender(mockCtrl) + m = newOutgoingItemsMap(firstNewStream, newItem, mockSender.queueControlFrame) }) Context("no stream ID limit", func() { @@ -84,11 +88,13 @@ var _ = Describe("Streams Map (outgoing)", func() { Context("with stream ID limits", func() { It("errors when no stream can be opened immediately", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) _, err := m.OpenStream() Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) It("blocks until a stream can be opened synchronously", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -104,6 +110,7 @@ var _ = Describe("Streams Map (outgoing)", func() { }) It("stops opening synchronously when it is closed", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) testErr := errors.New("test error") done := make(chan struct{}) go func() { @@ -125,5 +132,26 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(firstNewStream)) }) + + It("queues a STREAM_ID_BLOCKED frame if no stream can be opened", func() { + m.SetMaxStream(firstNewStream) + mockSender.EXPECT().queueControlFrame(&wire.StreamIDBlockedFrame{StreamID: firstNewStream}) + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("only sends one STREAM_ID_BLOCKED frame for one stream ID", func() { + m.SetMaxStream(firstNewStream) + mockSender.EXPECT().queueControlFrame(&wire.StreamIDBlockedFrame{StreamID: firstNewStream}) + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + // try to open a stream twice, but expect only one STREAM_ID_BLOCKED to be sent + _, err = m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + _, err = m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) }) }) diff --git a/streams_map_outgoing_uni.go b/streams_map_outgoing_uni.go index 2f0e6bc0..77511b78 100644 --- a/streams_map_outgoing_uni.go +++ b/streams_map_outgoing_uni.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) @@ -18,18 +19,26 @@ type outgoingUniStreamsMap struct { streams map[protocol.StreamID]sendStreamI - nextStream protocol.StreamID - maxStream protocol.StreamID - newStream func(protocol.StreamID) sendStreamI + nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamID // the maximum stream ID we're allowed to open + highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for + + newStream func(protocol.StreamID) sendStreamI + queueStreamIDBlocked func(*wire.StreamIDBlockedFrame) closeErr error } -func newOutgoingUniStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) sendStreamI) *outgoingUniStreamsMap { +func newOutgoingUniStreamsMap( + nextStream protocol.StreamID, + newStream func(protocol.StreamID) sendStreamI, + queueControlFrame func(wire.Frame), +) *outgoingUniStreamsMap { m := &outgoingUniStreamsMap{ - streams: make(map[protocol.StreamID]sendStreamI), - nextStream: nextStream, - newStream: newStream, + streams: make(map[protocol.StreamID]sendStreamI), + nextStream: nextStream, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) }, } m.cond.L = &m.mutex return m @@ -63,6 +72,10 @@ func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) { return nil, m.closeErr } if m.nextStream > m.maxStream { + if m.maxStream == 0 || m.highestBlocked < m.maxStream { + m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream}) + m.highestBlocked = m.maxStream + } return nil, qerr.TooManyOpenStreams } s := m.newStream(m.nextStream) diff --git a/streams_map_test.go b/streams_map_test.go index a9901d8c..b53d7d65 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -257,6 +257,10 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() { }) Context("updating stream ID limits", func() { + BeforeEach(func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + }) + It("processes the parameter for outgoing bidirectional streams", func() { _, err := m.OpenStream() Expect(err).To(MatchError(qerr.TooManyOpenStreams)) @@ -281,6 +285,10 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() { }) Context("handling MAX_STREAM_ID frames", func() { + BeforeEach(func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() + }) + It("processes IDs for outgoing bidirectional streams", func() { _, err := m.OpenStream() Expect(err).To(MatchError(qerr.TooManyOpenStreams))