From 2e8a5807ba89336ef1ef1364934e7beabc8f481c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 6 May 2018 09:46:21 +0900 Subject: [PATCH 1/2] queue stream-level window updates from the flow controller directly --- internal/flowcontrol/interface.go | 4 +- .../flowcontrol/stream_flow_controller.go | 10 +++- .../stream_flow_controller_test.go | 45 +++++++++++++----- internal/mocks/stream_flow_controller.go | 22 ++++----- mock_stream_sender_test.go | 10 ---- receive_stream.go | 6 +-- receive_stream_test.go | 46 ++++++------------- session.go | 2 + stream.go | 5 -- 9 files changed, 73 insertions(+), 77 deletions(-) diff --git a/internal/flowcontrol/interface.go b/internal/flowcontrol/interface.go index 61d57e31b..20297be74 100644 --- a/internal/flowcontrol/interface.go +++ b/internal/flowcontrol/interface.go @@ -21,8 +21,8 @@ type StreamFlowController interface { // UpdateHighestReceived should be called when a new highest offset is received // final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame UpdateHighestReceived(offset protocol.ByteCount, final bool) error - // HasWindowUpdate says if it is necessary to update the window - HasWindowUpdate() bool + // MaybeQueueWindowUpdate queues a window update, if necessary + MaybeQueueWindowUpdate() } // The ConnectionFlowController is the flow controller for the connection. diff --git a/internal/flowcontrol/stream_flow_controller.go b/internal/flowcontrol/stream_flow_controller.go index 6501278c2..aff47fc94 100644 --- a/internal/flowcontrol/stream_flow_controller.go +++ b/internal/flowcontrol/stream_flow_controller.go @@ -14,6 +14,8 @@ type streamFlowController struct { streamID protocol.StreamID + queueWindowUpdate func() + connection connectionFlowControllerI contributesToConnection bool // does the stream contribute to connection level flow control @@ -30,6 +32,7 @@ func NewStreamFlowController( receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, initialSendWindow protocol.ByteCount, + queueWindowUpdate func(protocol.StreamID), rttStats *congestion.RTTStats, logger utils.Logger, ) StreamFlowController { @@ -37,6 +40,7 @@ func NewStreamFlowController( streamID: streamID, contributesToConnection: contributesToConnection, connection: cfc.(connectionFlowControllerI), + queueWindowUpdate: func() { queueWindowUpdate(streamID) }, baseFlowController: baseFlowController{ rttStats: rttStats, receiveWindow: receiveWindow, @@ -120,11 +124,13 @@ func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) { return true, c.sendWindow } -func (c *streamFlowController) HasWindowUpdate() bool { +func (c *streamFlowController) MaybeQueueWindowUpdate() { c.mutex.Lock() hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate() c.mutex.Unlock() - return hasWindowUpdate + if hasWindowUpdate { + c.queueWindowUpdate() + } } func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { diff --git a/internal/flowcontrol/stream_flow_controller_test.go b/internal/flowcontrol/stream_flow_controller_test.go index dfca659a1..5707b331f 100644 --- a/internal/flowcontrol/stream_flow_controller_test.go +++ b/internal/flowcontrol/stream_flow_controller_test.go @@ -12,9 +12,13 @@ import ( ) var _ = Describe("Stream Flow controller", func() { - var controller *streamFlowController + var ( + controller *streamFlowController + queuedWindowUpdate bool + ) BeforeEach(func() { + queuedWindowUpdate = false rttStats := &congestion.RTTStats{} controller = &streamFlowController{ streamID: 10, @@ -23,24 +27,38 @@ var _ = Describe("Stream Flow controller", func() { controller.maxReceiveWindowSize = 10000 controller.rttStats = rttStats controller.logger = utils.DefaultLogger + controller.queueWindowUpdate = func() { queuedWindowUpdate = true } }) Context("Constructor", func() { rttStats := &congestion.RTTStats{} + receiveWindow := protocol.ByteCount(2000) + maxReceiveWindow := protocol.ByteCount(3000) + sendWindow := protocol.ByteCount(4000) It("sets the send and receive windows", func() { - receiveWindow := protocol.ByteCount(2000) - maxReceiveWindow := protocol.ByteCount(3000) - sendWindow := protocol.ByteCount(4000) - cc := NewConnectionFlowController(0, 0, nil, utils.DefaultLogger) - fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats, utils.DefaultLogger).(*streamFlowController) + fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController) Expect(fc.streamID).To(Equal(protocol.StreamID(5))) Expect(fc.receiveWindow).To(Equal(receiveWindow)) Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow)) Expect(fc.sendWindow).To(Equal(sendWindow)) Expect(fc.contributesToConnection).To(BeTrue()) }) + + It("queues window updates with the correction stream ID", func() { + var queued bool + queueWindowUpdate := func(id protocol.StreamID) { + Expect(id).To(Equal(protocol.StreamID(5))) + queued = true + } + + cc := NewConnectionFlowController(0, 0, nil, utils.DefaultLogger) + fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController) + fc.AddBytesRead(receiveWindow) + fc.MaybeQueueWindowUpdate() + Expect(queued).To(BeTrue()) + }) }) Context("receiving data", func() { @@ -175,12 +193,16 @@ var _ = Describe("Stream Flow controller", func() { oldWindowSize = controller.receiveWindowSize }) - It("tells if it has window updates", func() { - Expect(controller.HasWindowUpdate()).To(BeFalse()) + It("queues window updates", func() { + controller.MaybeQueueWindowUpdate() + Expect(queuedWindowUpdate).To(BeFalse()) controller.AddBytesRead(30) - Expect(controller.HasWindowUpdate()).To(BeTrue()) + controller.MaybeQueueWindowUpdate() + Expect(queuedWindowUpdate).To(BeTrue()) Expect(controller.GetWindowUpdate()).ToNot(BeZero()) - Expect(controller.HasWindowUpdate()).To(BeFalse()) + queuedWindowUpdate = false + controller.MaybeQueueWindowUpdate() + Expect(queuedWindowUpdate).To(BeFalse()) }) It("tells the connection flow controller when the window was autotuned", func() { @@ -213,7 +235,8 @@ var _ = Describe("Stream Flow controller", func() { controller.AddBytesRead(30) err := controller.UpdateHighestReceived(90, true) Expect(err).ToNot(HaveOccurred()) - Expect(controller.HasWindowUpdate()).To(BeFalse()) + controller.MaybeQueueWindowUpdate() + Expect(queuedWindowUpdate).To(BeFalse()) offset := controller.GetWindowUpdate() Expect(offset).To(BeZero()) }) diff --git a/internal/mocks/stream_flow_controller.go b/internal/mocks/stream_flow_controller.go index a69e73f19..0f355a817 100644 --- a/internal/mocks/stream_flow_controller.go +++ b/internal/mocks/stream_flow_controller.go @@ -66,18 +66,6 @@ func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate)) } -// HasWindowUpdate mocks base method -func (m *MockStreamFlowController) HasWindowUpdate() bool { - ret := m.ctrl.Call(m, "HasWindowUpdate") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasWindowUpdate indicates an expected call of HasWindowUpdate -func (mr *MockStreamFlowControllerMockRecorder) HasWindowUpdate() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).HasWindowUpdate)) -} - // IsBlocked mocks base method func (m *MockStreamFlowController) IsBlocked() (bool, protocol.ByteCount) { ret := m.ctrl.Call(m, "IsBlocked") @@ -91,6 +79,16 @@ func (mr *MockStreamFlowControllerMockRecorder) IsBlocked() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsBlocked)) } +// MaybeQueueWindowUpdate mocks base method +func (m *MockStreamFlowController) MaybeQueueWindowUpdate() { + m.ctrl.Call(m, "MaybeQueueWindowUpdate") +} + +// MaybeQueueWindowUpdate indicates an expected call of MaybeQueueWindowUpdate +func (mr *MockStreamFlowControllerMockRecorder) MaybeQueueWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeQueueWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).MaybeQueueWindowUpdate)) +} + // SendWindowSize mocks base method func (m *MockStreamFlowController) SendWindowSize() protocol.ByteCount { ret := m.ctrl.Call(m, "SendWindowSize") diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index da3ad8d06..d6f090a1f 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -45,16 +45,6 @@ func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0) } -// onHasWindowUpdate mocks base method -func (m *MockStreamSender) onHasWindowUpdate(arg0 protocol.StreamID) { - m.ctrl.Call(m, "onHasWindowUpdate", arg0) -} - -// onHasWindowUpdate indicates an expected call of onHasWindowUpdate -func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0) -} - // onStreamCompleted mocks base method func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) { m.ctrl.Call(m, "onStreamCompleted", arg0) diff --git a/receive_stream.go b/receive_stream.go index 9fc158f1a..cec69f1cd 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -151,10 +151,8 @@ func (s *receiveStream) Read(p []byte) (int, error) { if !s.resetRemotely { s.flowController.AddBytesRead(protocol.ByteCount(m)) } - // this call triggers the flow controller to increase the flow control window, if necessary - if s.flowController.HasWindowUpdate() { - s.sender.onHasWindowUpdate(s.streamID) - } + // increase the flow control window, if necessary + s.flowController.MaybeQueueWindowUpdate() if s.readPosInFrame >= int(frame.DataLen()) { s.frameQueue.Pop() diff --git a/receive_stream_test.go b/receive_stream_test.go index 4a3e04e4f..434437859 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -43,7 +43,7 @@ var _ = Describe("Receive Stream", func() { It("reads a single STREAM frame", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -61,7 +61,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -83,7 +83,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -107,7 +107,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -130,7 +130,7 @@ var _ = Describe("Receive Stream", func() { It("waits until data is available", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() go func() { defer GinkgoRecover() frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}} @@ -148,7 +148,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 2, Data: []byte{0xBE, 0xEF}, @@ -173,7 +173,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -204,7 +204,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte("foob"), @@ -230,22 +230,6 @@ var _ = Describe("Receive Stream", func() { Expect(err).To(MatchError(errEmptyStreamData)) }) - It("calls the onHasWindowUpdate callback, when the a MAX_STREAM_DATA should be sent", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) - mockFC.EXPECT().HasWindowUpdate().Return(true) - mockSender.EXPECT().onHasWindowUpdate(streamID) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte("foobar"), - } - err := str.handleStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - _, err = strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - }) - Context("deadlines", func() { It("the deadline error has the right net.Error properties", func() { Expect(errDeadline.Temporary()).To(BeTrue()) @@ -318,7 +302,7 @@ var _ = Describe("Receive Stream", func() { It("returns EOFs", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() str.handleStreamFrame(&wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -339,7 +323,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 2, Data: []byte{0xBE, 0xEF}, @@ -367,7 +351,7 @@ var _ = Describe("Receive Stream", func() { It("returns EOFs with partial read", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() err := str.handleStreamFrame(&wire.StreamFrame{ Offset: 0, Data: []byte{0xde, 0xad}, @@ -385,7 +369,7 @@ var _ = Describe("Receive Stream", func() { It("handles immediate FINs", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() err := str.handleStreamFrame(&wire.StreamFrame{ Offset: 0, FinBit: true, @@ -402,7 +386,7 @@ var _ = Describe("Receive Stream", func() { It("closes when CloseRemote is called", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() str.CloseRemote(0) mockSender.EXPECT().onStreamCompleted(streamID) b := make([]byte, 8) @@ -478,7 +462,7 @@ var _ = Describe("Receive Stream", func() { It("doesn't send a RST_STREAM frame, if the FIN was already read", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() // no calls to mockSender.queueControlFrame err := str.handleStreamFrame(&wire.StreamFrame{ StreamID: streamID, @@ -601,7 +585,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)), mockSender.EXPECT().onStreamCompleted(streamID), ) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) readReturned := make(chan struct{}) go func() { defer GinkgoRecover() diff --git a/session.go b/session.go index 8194dfe08..d144e89b4 100644 --- a/session.go +++ b/session.go @@ -1137,6 +1137,7 @@ func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlow protocol.ReceiveStreamFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), initialSendWindow, + s.onHasWindowUpdate, s.rttStats, s.logger, ) @@ -1151,6 +1152,7 @@ func (s *session) newCryptoStream() cryptoStreamI { protocol.ReceiveStreamFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), 0, + s.onHasWindowUpdate, s.rttStats, s.logger, ) diff --git a/stream.go b/stream.go index 831234934..f8d851bed 100644 --- a/stream.go +++ b/stream.go @@ -18,7 +18,6 @@ const ( // The streamSender is notified by the stream about various events. type streamSender interface { queueControlFrame(wire.Frame) - onHasWindowUpdate(protocol.StreamID) onHasStreamData(protocol.StreamID) onStreamCompleted(protocol.StreamID) } @@ -34,10 +33,6 @@ func (s *uniStreamSender) queueControlFrame(f wire.Frame) { s.streamSender.queueControlFrame(f) } -func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) { - s.streamSender.onHasWindowUpdate(id) -} - func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) { s.streamSender.onHasStreamData(id) } From 08160ab18f3b8cd8fa94b4d4d947f585e2140c74 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 6 May 2018 10:51:51 +0900 Subject: [PATCH 2/2] queue connection-level window updates from the flow controller directly It is not sufficient to check for connection-level window updates every time a packet is sent. When a connection-level window update needs to be sent, we need to make sure that it gets sent immediately (i.e. call scheduleSending() in the session). --- .../flowcontrol/connection_flow_controller.go | 13 ++++++ .../connection_flow_controller_test.go | 20 ++++++++- internal/flowcontrol/interface.go | 3 +- .../flowcontrol/stream_flow_controller.go | 3 ++ .../stream_flow_controller_test.go | 22 +++++++--- internal/mocks/connection_flow_controller.go | 10 +++++ session.go | 19 +++++---- session_test.go | 19 --------- window_update_queue.go | 42 ++++++++++++++----- window_update_queue_test.go | 30 +++++++++---- 10 files changed, 126 insertions(+), 55 deletions(-) diff --git a/internal/flowcontrol/connection_flow_controller.go b/internal/flowcontrol/connection_flow_controller.go index 9dd34e3e1..ab565d28a 100644 --- a/internal/flowcontrol/connection_flow_controller.go +++ b/internal/flowcontrol/connection_flow_controller.go @@ -12,6 +12,8 @@ import ( type connectionFlowController struct { lastBlockedAt protocol.ByteCount baseFlowController + + queueWindowUpdate func() } var _ ConnectionFlowController = &connectionFlowController{} @@ -21,6 +23,7 @@ var _ ConnectionFlowController = &connectionFlowController{} func NewConnectionFlowController( receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, + queueWindowUpdate func(), rttStats *congestion.RTTStats, logger utils.Logger, ) ConnectionFlowController { @@ -32,6 +35,7 @@ func NewConnectionFlowController( maxReceiveWindowSize: maxReceiveWindow, logger: logger, }, + queueWindowUpdate: queueWindowUpdate, } } @@ -62,6 +66,15 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B return nil } +func (c *connectionFlowController) MaybeQueueWindowUpdate() { + c.mutex.Lock() + hasWindowUpdate := c.hasWindowUpdate() + c.mutex.Unlock() + if hasWindowUpdate { + c.queueWindowUpdate() + } +} + func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { c.mutex.Lock() oldWindowSize := c.receiveWindowSize diff --git a/internal/flowcontrol/connection_flow_controller_test.go b/internal/flowcontrol/connection_flow_controller_test.go index cba41eb28..2ff7dd760 100644 --- a/internal/flowcontrol/connection_flow_controller_test.go +++ b/internal/flowcontrol/connection_flow_controller_test.go @@ -11,7 +11,10 @@ import ( ) var _ = Describe("Connection Flow controller", func() { - var controller *connectionFlowController + var ( + controller *connectionFlowController + queuedWindowUpdate bool + ) // update the congestion such that it returns a given value for the smoothed RTT setRtt := func(t time.Duration) { @@ -23,6 +26,7 @@ var _ = Describe("Connection Flow controller", func() { controller = &connectionFlowController{} controller.rttStats = &congestion.RTTStats{} controller.logger = utils.DefaultLogger + controller.queueWindowUpdate = func() { queuedWindowUpdate = true } }) Context("Constructor", func() { @@ -32,7 +36,7 @@ var _ = Describe("Connection Flow controller", func() { receiveWindow := protocol.ByteCount(2000) maxReceiveWindow := protocol.ByteCount(3000) - fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, rttStats, utils.DefaultLogger).(*connectionFlowController) + fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, nil, rttStats, utils.DefaultLogger).(*connectionFlowController) Expect(fc.receiveWindow).To(Equal(receiveWindow)) Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow)) }) @@ -53,6 +57,18 @@ var _ = Describe("Connection Flow controller", func() { controller.bytesRead = 100 - 60 }) + It("queues window updates", func() { + controller.MaybeQueueWindowUpdate() + Expect(queuedWindowUpdate).To(BeFalse()) + controller.AddBytesRead(30) + controller.MaybeQueueWindowUpdate() + Expect(queuedWindowUpdate).To(BeTrue()) + Expect(controller.GetWindowUpdate()).ToNot(BeZero()) + queuedWindowUpdate = false + controller.MaybeQueueWindowUpdate() + Expect(queuedWindowUpdate).To(BeFalse()) + }) + It("gets a window update", func() { windowSize := controller.receiveWindowSize oldOffset := controller.bytesRead diff --git a/internal/flowcontrol/interface.go b/internal/flowcontrol/interface.go index 20297be74..450d06abf 100644 --- a/internal/flowcontrol/interface.go +++ b/internal/flowcontrol/interface.go @@ -10,6 +10,7 @@ type flowController interface { // for receiving AddBytesRead(protocol.ByteCount) GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary + MaybeQueueWindowUpdate() // queues a window update, if necessary } // A StreamFlowController is a flow controller for a QUIC stream. @@ -21,8 +22,6 @@ type StreamFlowController interface { // UpdateHighestReceived should be called when a new highest offset is received // final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame UpdateHighestReceived(offset protocol.ByteCount, final bool) error - // MaybeQueueWindowUpdate queues a window update, if necessary - MaybeQueueWindowUpdate() } // The ConnectionFlowController is the flow controller for the connection. diff --git a/internal/flowcontrol/stream_flow_controller.go b/internal/flowcontrol/stream_flow_controller.go index aff47fc94..a394de0cc 100644 --- a/internal/flowcontrol/stream_flow_controller.go +++ b/internal/flowcontrol/stream_flow_controller.go @@ -131,6 +131,9 @@ func (c *streamFlowController) MaybeQueueWindowUpdate() { if hasWindowUpdate { c.queueWindowUpdate() } + if c.contributesToConnection { + c.connection.MaybeQueueWindowUpdate() + } } func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { diff --git a/internal/flowcontrol/stream_flow_controller_test.go b/internal/flowcontrol/stream_flow_controller_test.go index 5707b331f..ab4bd67d2 100644 --- a/internal/flowcontrol/stream_flow_controller_test.go +++ b/internal/flowcontrol/stream_flow_controller_test.go @@ -13,16 +13,18 @@ import ( var _ = Describe("Stream Flow controller", func() { var ( - controller *streamFlowController - queuedWindowUpdate bool + controller *streamFlowController + queuedWindowUpdate bool + queuedConnWindowUpdate bool ) BeforeEach(func() { queuedWindowUpdate = false + queuedConnWindowUpdate = false rttStats := &congestion.RTTStats{} controller = &streamFlowController{ streamID: 10, - connection: NewConnectionFlowController(1000, 1000, rttStats, utils.DefaultLogger).(*connectionFlowController), + connection: NewConnectionFlowController(1000, 1000, func() { queuedConnWindowUpdate = true }, rttStats, utils.DefaultLogger).(*connectionFlowController), } controller.maxReceiveWindowSize = 10000 controller.rttStats = rttStats @@ -37,7 +39,7 @@ var _ = Describe("Stream Flow controller", func() { sendWindow := protocol.ByteCount(4000) It("sets the send and receive windows", func() { - cc := NewConnectionFlowController(0, 0, nil, utils.DefaultLogger) + cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger) fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController) Expect(fc.streamID).To(Equal(protocol.StreamID(5))) Expect(fc.receiveWindow).To(Equal(receiveWindow)) @@ -53,7 +55,7 @@ var _ = Describe("Stream Flow controller", func() { queued = true } - cc := NewConnectionFlowController(0, 0, nil, utils.DefaultLogger) + cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger) fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController) fc.AddBytesRead(receiveWindow) fc.MaybeQueueWindowUpdate() @@ -189,6 +191,7 @@ var _ = Describe("Stream Flow controller", func() { controller.receiveWindow = 100 controller.receiveWindowSize = 60 controller.bytesRead = 100 - 60 + controller.connection.(*connectionFlowController).receiveWindow = 100 controller.connection.(*connectionFlowController).receiveWindowSize = 120 oldWindowSize = controller.receiveWindowSize }) @@ -205,6 +208,15 @@ var _ = Describe("Stream Flow controller", func() { Expect(queuedWindowUpdate).To(BeFalse()) }) + It("queues connection-level window updates", func() { + controller.contributesToConnection = true + controller.MaybeQueueWindowUpdate() + Expect(queuedConnWindowUpdate).To(BeFalse()) + controller.AddBytesRead(60) + controller.MaybeQueueWindowUpdate() + Expect(queuedConnWindowUpdate).To(BeTrue()) + }) + It("tells the connection flow controller when the window was autotuned", func() { oldOffset := controller.bytesRead controller.contributesToConnection = true diff --git a/internal/mocks/connection_flow_controller.go b/internal/mocks/connection_flow_controller.go index ae10e785f..1a47362b9 100644 --- a/internal/mocks/connection_flow_controller.go +++ b/internal/mocks/connection_flow_controller.go @@ -79,6 +79,16 @@ func (mr *MockConnectionFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsNewlyBlocked)) } +// MaybeQueueWindowUpdate mocks base method +func (m *MockConnectionFlowController) MaybeQueueWindowUpdate() { + m.ctrl.Call(m, "MaybeQueueWindowUpdate") +} + +// MaybeQueueWindowUpdate indicates an expected call of MaybeQueueWindowUpdate +func (mr *MockConnectionFlowControllerMockRecorder) MaybeQueueWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeQueueWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).MaybeQueueWindowUpdate)) +} + // SendWindowSize mocks base method func (m *MockConnectionFlowController) SendWindowSize() protocol.ByteCount { ret := m.ctrl.Call(m, "SendWindowSize") diff --git a/session.go b/session.go index d144e89b4..a34f2edaf 100644 --- a/session.go +++ b/session.go @@ -405,6 +405,7 @@ func (s *session) preSetup() { s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ReceiveConnectionFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), + s.onHasConnectionWindowUpdate, s.rttStats, s.logger, ) @@ -425,7 +426,7 @@ func (s *session) postSetup() error { s.sessionCreationTime = now s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.version) - s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.packer.QueueControlFrame) + s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.connFlowController, s.packer.QueueControlFrame) return nil } @@ -1021,9 +1022,6 @@ func (s *session) maybeSendRetransmission() (bool, error) { } func (s *session) sendPacket() (bool, error) { - if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { - s.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: offset}) - } if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { s.packer.QueueControlFrame(&wire.BlockedFrame{Offset: offset}) } @@ -1137,7 +1135,7 @@ func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlow protocol.ReceiveStreamFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), initialSendWindow, - s.onHasWindowUpdate, + s.onHasStreamWindowUpdate, s.rttStats, s.logger, ) @@ -1152,7 +1150,7 @@ func (s *session) newCryptoStream() cryptoStreamI { protocol.ReceiveStreamFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), 0, - s.onHasWindowUpdate, + s.onHasStreamWindowUpdate, s.rttStats, s.logger, ) @@ -1202,8 +1200,13 @@ func (s *session) queueControlFrame(f wire.Frame) { s.scheduleSending() } -func (s *session) onHasWindowUpdate(id protocol.StreamID) { - s.windowUpdateQueue.Add(id) +func (s *session) onHasStreamWindowUpdate(id protocol.StreamID) { + s.windowUpdateQueue.AddStream(id) + s.scheduleSending() +} + +func (s *session) onHasConnectionWindowUpdate() { + s.windowUpdateQueue.AddConnection() s.scheduleSending() } diff --git a/session_test.go b/session_test.go index 4c6623fb1..571421df6 100644 --- a/session_test.go +++ b/session_test.go @@ -691,24 +691,6 @@ var _ = Describe("Session", func() { Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e})))) }) - It("adds a MAX_DATA frames", func() { - fc := mocks.NewMockConnectionFlowController(mockCtrl) - fc.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1337)) - fc.EXPECT().IsNewlyBlocked() - sess.connFlowController = fc - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(Equal([]wire.Frame{ - &wire.MaxDataFrame{ByteOffset: 0x1337}, - })) - Expect(p.SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) - sess.sentPacketHandler = sph - sent, err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - }) - It("adds MAX_STREAM_DATA frames", func() { sess.windowUpdateQueue.callback(&wire.MaxStreamDataFrame{ StreamID: 2, @@ -726,7 +708,6 @@ var _ = Describe("Session", func() { It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { fc := mocks.NewMockConnectionFlowController(mockCtrl) - fc.EXPECT().GetWindowUpdate() fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) sess.connFlowController = fc sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) diff --git a/window_update_queue.go b/window_update_queue.go index ed006aa2b..dfbc45ab4 100644 --- a/window_update_queue.go +++ b/window_update_queue.go @@ -3,6 +3,7 @@ package quic import ( "sync" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -10,29 +11,50 @@ import ( type windowUpdateQueue struct { mutex sync.Mutex - queue map[protocol.StreamID]bool // used as a set - callback func(wire.Frame) - cryptoStream cryptoStreamI - streamGetter streamGetter + queue map[protocol.StreamID]bool // used as a set + queuedConn bool // connection-level window update + + cryptoStream cryptoStreamI + streamGetter streamGetter + connFlowController flowcontrol.ConnectionFlowController + callback func(wire.Frame) } -func newWindowUpdateQueue(streamGetter streamGetter, cryptoStream cryptoStreamI, cb func(wire.Frame)) *windowUpdateQueue { +func newWindowUpdateQueue( + streamGetter streamGetter, + cryptoStream cryptoStreamI, + connFC flowcontrol.ConnectionFlowController, + cb func(wire.Frame), +) *windowUpdateQueue { return &windowUpdateQueue{ - queue: make(map[protocol.StreamID]bool), - streamGetter: streamGetter, - cryptoStream: cryptoStream, - callback: cb, + queue: make(map[protocol.StreamID]bool), + streamGetter: streamGetter, + cryptoStream: cryptoStream, + connFlowController: connFC, + callback: cb, } } -func (q *windowUpdateQueue) Add(id protocol.StreamID) { +func (q *windowUpdateQueue) AddStream(id protocol.StreamID) { q.mutex.Lock() q.queue[id] = true q.mutex.Unlock() } +func (q *windowUpdateQueue) AddConnection() { + q.mutex.Lock() + q.queuedConn = true + q.mutex.Unlock() +} + func (q *windowUpdateQueue) QueueAll() { q.mutex.Lock() + // queue a connection-level window update + if q.queuedConn { + q.callback(&wire.MaxDataFrame{ByteOffset: q.connFlowController.GetWindowUpdate()}) + q.queuedConn = false + } + // queue all stream-level window updates var offset protocol.ByteCount for id := range q.queue { if id == q.cryptoStream.StreamID() { diff --git a/window_update_queue_test.go b/window_update_queue_test.go index cf0511f12..317b03502 100644 --- a/window_update_queue_test.go +++ b/window_update_queue_test.go @@ -1,6 +1,7 @@ package quic import ( + "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -12,6 +13,7 @@ var _ = Describe("Window Update Queue", func() { var ( q *windowUpdateQueue streamGetter *MockStreamGetter + connFC *mocks.MockConnectionFlowController queuedFrames []wire.Frame cryptoStream *MockCryptoStream ) @@ -19,9 +21,10 @@ var _ = Describe("Window Update Queue", func() { BeforeEach(func() { streamGetter = NewMockStreamGetter(mockCtrl) cryptoStream = NewMockCryptoStream(mockCtrl) + connFC = mocks.NewMockConnectionFlowController(mockCtrl) cryptoStream.EXPECT().StreamID().Return(protocol.StreamID(0)).AnyTimes() queuedFrames = queuedFrames[:0] - q = newWindowUpdateQueue(streamGetter, cryptoStream, func(f wire.Frame) { + q = newWindowUpdateQueue(streamGetter, cryptoStream, connFC, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }) }) @@ -33,8 +36,8 @@ var _ = Describe("Window Update Queue", func() { stream3.EXPECT().getWindowUpdate().Return(protocol.ByteCount(30)) streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(stream3, nil) streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(1)).Return(stream1, nil) - q.Add(3) - q.Add(1) + q.AddStream(3) + q.AddStream(1) q.QueueAll() Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 1, ByteOffset: 10})) Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 3, ByteOffset: 30})) @@ -44,7 +47,7 @@ var _ = Describe("Window Update Queue", func() { stream10 := NewMockStreamI(mockCtrl) stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(100)) streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil) - q.Add(10) + q.AddStream(10) q.QueueAll() Expect(queuedFrames).To(HaveLen(1)) q.QueueAll() @@ -53,7 +56,7 @@ var _ = Describe("Window Update Queue", func() { It("doesn't queue a MAX_STREAM_DATA for a closed stream", func() { streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(12)).Return(nil, nil) - q.Add(12) + q.AddStream(12) q.QueueAll() Expect(queuedFrames).To(BeEmpty()) }) @@ -62,26 +65,35 @@ var _ = Describe("Window Update Queue", func() { stream5 := NewMockStreamI(mockCtrl) stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0)) streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(stream5, nil) - q.Add(5) + q.AddStream(5) q.QueueAll() Expect(queuedFrames).To(BeEmpty()) }) It("adds MAX_STREAM_DATA frames for the crypto stream", func() { cryptoStream.EXPECT().getWindowUpdate().Return(protocol.ByteCount(42)) - q.Add(0) + q.AddStream(0) q.QueueAll() Expect(queuedFrames).To(Equal([]wire.Frame{ &wire.MaxStreamDataFrame{StreamID: 0, ByteOffset: 42}, })) }) + It("queues MAX_DATA frames", func() { + connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1337)) + q.AddConnection() + q.QueueAll() + Expect(queuedFrames).To(Equal([]wire.Frame{ + &wire.MaxDataFrame{ByteOffset: 0x1337}, + })) + }) + It("deduplicates", func() { stream10 := NewMockStreamI(mockCtrl) stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(200)) streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil) - q.Add(10) - q.Add(10) + q.AddStream(10) + q.AddStream(10) q.QueueAll() Expect(queuedFrames).To(Equal([]wire.Frame{ &wire.MaxStreamDataFrame{StreamID: 10, ByteOffset: 200},