diff --git a/connection.go b/connection.go index 583e9a90b..3371b76ab 100644 --- a/connection.go +++ b/connection.go @@ -2282,6 +2282,8 @@ func (s *connection) queueControlFrame(f wire.Frame) { s.scheduleSending() } +func (s *connection) onHasConnectionData() { s.scheduleSending() } + func (s *connection) onHasStreamData(id protocol.StreamID, str sendStreamI) { s.framer.AddActiveStream(id, str) s.scheduleSending() diff --git a/internal/flowcontrol/connection_flow_controller.go b/internal/flowcontrol/connection_flow_controller.go index 362a9a956..bbeb78892 100644 --- a/internal/flowcontrol/connection_flow_controller.go +++ b/internal/flowcontrol/connection_flow_controller.go @@ -57,10 +57,12 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B return nil } -func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) { +func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) (hasWindowUpdate bool) { c.mutex.Lock() + defer c.mutex.Unlock() + c.baseFlowController.addBytesRead(n) - c.mutex.Unlock() + return c.baseFlowController.hasWindowUpdate() } func (c *connectionFlowController) GetWindowUpdate(now time.Time) protocol.ByteCount { diff --git a/internal/flowcontrol/connection_flow_controller_test.go b/internal/flowcontrol/connection_flow_controller_test.go index de54ca282..74b4a5f77 100644 --- a/internal/flowcontrol/connection_flow_controller_test.go +++ b/internal/flowcontrol/connection_flow_controller_test.go @@ -19,8 +19,9 @@ func TestConnectionFlowControlWindowUpdate(t *testing.T) { &utils.RTTStats{}, utils.DefaultLogger, ) + require.False(t, fc.AddBytesRead(1)) require.Zero(t, fc.GetWindowUpdate(time.Now())) - fc.AddBytesRead(100) + require.True(t, fc.AddBytesRead(99)) require.Equal(t, protocol.ByteCount(200), fc.GetWindowUpdate(time.Now())) } diff --git a/internal/flowcontrol/interface.go b/internal/flowcontrol/interface.go index be641de1a..23cf30c58 100644 --- a/internal/flowcontrol/interface.go +++ b/internal/flowcontrol/interface.go @@ -18,7 +18,7 @@ type flowController interface { // A StreamFlowController is a flow controller for a QUIC stream. type StreamFlowController interface { flowController - AddBytesRead(protocol.ByteCount) (shouldQueueWindowUpdate bool) + AddBytesRead(protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool) // UpdateHighestReceived is called when a new highest offset is received // final has to be to true if this is the final offset of the stream, // as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame @@ -32,7 +32,7 @@ type StreamFlowController interface { // The ConnectionFlowController is the flow controller for the connection. type ConnectionFlowController interface { flowController - AddBytesRead(protocol.ByteCount) + AddBytesRead(protocol.ByteCount) (hasWindowUpdate bool) Reset() error IsNewlyBlocked() (bool, protocol.ByteCount) } diff --git a/internal/flowcontrol/stream_flow_controller.go b/internal/flowcontrol/stream_flow_controller.go index e12f1f467..ba0051224 100644 --- a/internal/flowcontrol/stream_flow_controller.go +++ b/internal/flowcontrol/stream_flow_controller.go @@ -98,12 +98,12 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, return c.connection.IncrementHighestReceived(increment, now) } -func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (shouldQueueWindowUpdate bool) { +func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool) { c.mutex.Lock() c.baseFlowController.addBytesRead(n) - shouldQueueWindowUpdate = c.shouldQueueWindowUpdate() + hasStreamWindowUpdate = c.shouldQueueWindowUpdate() c.mutex.Unlock() - c.connection.AddBytesRead(n) + hasConnWindowUpdate = c.connection.AddBytesRead(n) return } diff --git a/internal/flowcontrol/stream_flow_controller_test.go b/internal/flowcontrol/stream_flow_controller_test.go index 34bc80842..92d572669 100644 --- a/internal/flowcontrol/stream_flow_controller_test.go +++ b/internal/flowcontrol/stream_flow_controller_test.go @@ -192,16 +192,20 @@ func TestStreamWindowUpdate(t *testing.T) { utils.DefaultLogger, ) require.Zero(t, fc.GetWindowUpdate(time.Now())) - fc.AddBytesRead(24) + hasStreamWindowUpdate, _ := fc.AddBytesRead(24) + require.False(t, hasStreamWindowUpdate) require.Zero(t, fc.GetWindowUpdate(time.Now())) // the window is updated when it's 25% filled - fc.AddBytesRead(1) + hasStreamWindowUpdate, _ = fc.AddBytesRead(1) + require.True(t, hasStreamWindowUpdate) require.Equal(t, protocol.ByteCount(125), fc.GetWindowUpdate(time.Now())) - fc.AddBytesRead(24) + hasStreamWindowUpdate, _ = fc.AddBytesRead(24) + require.False(t, hasStreamWindowUpdate) require.Zero(t, fc.GetWindowUpdate(time.Now())) // the window is updated when it's 25% filled - fc.AddBytesRead(1) + hasStreamWindowUpdate, _ = fc.AddBytesRead(1) + require.True(t, hasStreamWindowUpdate) require.Equal(t, protocol.ByteCount(150), fc.GetWindowUpdate(time.Now())) // Receive the final offset. @@ -211,6 +215,31 @@ func TestStreamWindowUpdate(t *testing.T) { require.Zero(t, fc.GetWindowUpdate(time.Now())) } +func TestStreamConnectionWindowUpdate(t *testing.T) { + connFC := NewConnectionFlowController( + 100, + protocol.MaxByteCount, + nil, + &utils.RTTStats{}, + utils.DefaultLogger, + ) + fc := NewStreamFlowController( + 42, + connFC, + 1000, + protocol.MaxByteCount, + protocol.MaxByteCount, + &utils.RTTStats{}, + utils.DefaultLogger, + ) + + hasStreamWindowUpdate, hasConnWindowUpdate := fc.AddBytesRead(50) + require.False(t, hasStreamWindowUpdate) + require.Zero(t, fc.GetWindowUpdate(time.Now())) + require.True(t, hasConnWindowUpdate) + require.NotZero(t, connFC.GetWindowUpdate(time.Now())) +} + func TestStreamWindowAutoTuning(t *testing.T) { // the RTT is 1 second rttStats := &utils.RTTStats{} diff --git a/internal/mocks/connection_flow_controller.go b/internal/mocks/connection_flow_controller.go index face2b214..19a71d130 100644 --- a/internal/mocks/connection_flow_controller.go +++ b/internal/mocks/connection_flow_controller.go @@ -42,9 +42,11 @@ func (m *MockConnectionFlowController) EXPECT() *MockConnectionFlowControllerMoc } // AddBytesRead mocks base method. -func (m *MockConnectionFlowController) AddBytesRead(arg0 protocol.ByteCount) { +func (m *MockConnectionFlowController) AddBytesRead(arg0 protocol.ByteCount) bool { m.ctrl.T.Helper() - m.ctrl.Call(m, "AddBytesRead", arg0) + ret := m.ctrl.Call(m, "AddBytesRead", arg0) + ret0, _ := ret[0].(bool) + return ret0 } // AddBytesRead indicates an expected call of AddBytesRead. @@ -60,19 +62,19 @@ type MockConnectionFlowControllerAddBytesReadCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockConnectionFlowControllerAddBytesReadCall) Return() *MockConnectionFlowControllerAddBytesReadCall { - c.Call = c.Call.Return() +func (c *MockConnectionFlowControllerAddBytesReadCall) Return(hasWindowUpdate bool) *MockConnectionFlowControllerAddBytesReadCall { + c.Call = c.Call.Return(hasWindowUpdate) return c } // Do rewrite *gomock.Call.Do -func (c *MockConnectionFlowControllerAddBytesReadCall) Do(f func(protocol.ByteCount)) *MockConnectionFlowControllerAddBytesReadCall { +func (c *MockConnectionFlowControllerAddBytesReadCall) Do(f func(protocol.ByteCount) bool) *MockConnectionFlowControllerAddBytesReadCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockConnectionFlowControllerAddBytesReadCall) DoAndReturn(f func(protocol.ByteCount)) *MockConnectionFlowControllerAddBytesReadCall { +func (c *MockConnectionFlowControllerAddBytesReadCall) DoAndReturn(f func(protocol.ByteCount) bool) *MockConnectionFlowControllerAddBytesReadCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/internal/mocks/stream_flow_controller.go b/internal/mocks/stream_flow_controller.go index 8a1167870..38216131e 100644 --- a/internal/mocks/stream_flow_controller.go +++ b/internal/mocks/stream_flow_controller.go @@ -78,11 +78,12 @@ func (c *MockStreamFlowControllerAbandonCall) DoAndReturn(f func()) *MockStreamF } // AddBytesRead mocks base method. -func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) bool { +func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) (bool, bool) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddBytesRead", arg0) ret0, _ := ret[0].(bool) - return ret0 + ret1, _ := ret[1].(bool) + return ret0, ret1 } // AddBytesRead indicates an expected call of AddBytesRead. @@ -98,19 +99,19 @@ type MockStreamFlowControllerAddBytesReadCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockStreamFlowControllerAddBytesReadCall) Return(shouldQueueWindowUpdate bool) *MockStreamFlowControllerAddBytesReadCall { - c.Call = c.Call.Return(shouldQueueWindowUpdate) +func (c *MockStreamFlowControllerAddBytesReadCall) Return(hasStreamWindowUpdate, hasConnWindowUpdate bool) *MockStreamFlowControllerAddBytesReadCall { + c.Call = c.Call.Return(hasStreamWindowUpdate, hasConnWindowUpdate) return c } // Do rewrite *gomock.Call.Do -func (c *MockStreamFlowControllerAddBytesReadCall) Do(f func(protocol.ByteCount) bool) *MockStreamFlowControllerAddBytesReadCall { +func (c *MockStreamFlowControllerAddBytesReadCall) Do(f func(protocol.ByteCount) (bool, bool)) *MockStreamFlowControllerAddBytesReadCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamFlowControllerAddBytesReadCall) DoAndReturn(f func(protocol.ByteCount) bool) *MockStreamFlowControllerAddBytesReadCall { +func (c *MockStreamFlowControllerAddBytesReadCall) DoAndReturn(f func(protocol.ByteCount) (bool, bool)) *MockStreamFlowControllerAddBytesReadCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index d3418ee83..7d3a76e57 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -40,6 +40,42 @@ func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder { return m.recorder } +// onHasConnectionData mocks base method. +func (m *MockStreamSender) onHasConnectionData() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "onHasConnectionData") +} + +// onHasConnectionData indicates an expected call of onHasConnectionData. +func (mr *MockStreamSenderMockRecorder) onHasConnectionData() *MockStreamSenderonHasConnectionDataCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasConnectionData", reflect.TypeOf((*MockStreamSender)(nil).onHasConnectionData)) + return &MockStreamSenderonHasConnectionDataCall{Call: call} +} + +// MockStreamSenderonHasConnectionDataCall wrap *gomock.Call +type MockStreamSenderonHasConnectionDataCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamSenderonHasConnectionDataCall) Return() *MockStreamSenderonHasConnectionDataCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamSenderonHasConnectionDataCall) Do(f func()) *MockStreamSenderonHasConnectionDataCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamSenderonHasConnectionDataCall) DoAndReturn(f func()) *MockStreamSenderonHasConnectionDataCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // onHasStreamControlFrame mocks base method. func (m *MockStreamSender) onHasStreamControlFrame(arg0 protocol.StreamID, arg1 streamControlFrameGetter) { m.ctrl.T.Helper() diff --git a/receive_stream.go b/receive_stream.go index 9b1b71cb6..192b92f7c 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -91,16 +91,19 @@ func (s *receiveStream) Read(p []byte) (int, error) { defer func() { <-s.readOnce }() s.mutex.Lock() - queuedNewControlFrame, n, err := s.readImpl(p) + queuedStreamWindowUpdate, queuedConnWindowUpdate, n, err := s.readImpl(p) completed := s.isNewlyCompleted() s.mutex.Unlock() if completed { s.sender.onStreamCompleted(s.streamID) } - if queuedNewControlFrame { + if queuedStreamWindowUpdate { s.sender.onHasStreamControlFrame(s.streamID, s) } + if queuedConnWindowUpdate { + s.sender.onHasConnectionData() + } return n, err } @@ -125,20 +128,19 @@ func (s *receiveStream) isNewlyCompleted() bool { return false } -func (s *receiveStream) readImpl(p []byte) (bool, int, error) { +func (s *receiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnWindowUpdate bool, _ int, _ error) { if s.currentFrameIsLast && s.currentFrame == nil { s.errorRead = true - return false, 0, io.EOF + return false, false, 0, io.EOF } if s.cancelledRemotely || s.cancelledLocally { s.errorRead = true - return false, 0, s.cancelErr + return false, false, 0, s.cancelErr } if s.closeForShutdownErr != nil { - return false, 0, s.closeForShutdownErr + return false, false, 0, s.closeForShutdownErr } - var queuedNewControlFrame bool var bytesRead int var deadlineTimer *utils.Timer for bytesRead < len(p) { @@ -146,23 +148,23 @@ func (s *receiveStream) readImpl(p []byte) (bool, int, error) { s.dequeueNextFrame() } if s.currentFrame == nil && bytesRead > 0 { - return queuedNewControlFrame, bytesRead, s.closeForShutdownErr + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr } for { // Stop waiting on errors if s.closeForShutdownErr != nil { - return queuedNewControlFrame, bytesRead, s.closeForShutdownErr + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr } if s.cancelledRemotely || s.cancelledLocally { s.errorRead = true - return queuedNewControlFrame, 0, s.cancelErr + return hasStreamWindowUpdate, hasConnWindowUpdate, 0, s.cancelErr } deadline := s.deadline if !deadline.IsZero() { if !time.Now().Before(deadline) { - return queuedNewControlFrame, bytesRead, errDeadline + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, errDeadline } if deadlineTimer == nil { deadlineTimer = utils.NewTimer() @@ -192,10 +194,10 @@ func (s *receiveStream) readImpl(p []byte) (bool, int, error) { } if bytesRead > len(p) { - return queuedNewControlFrame, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) } if s.readPosInFrame > len(s.currentFrame) { - return queuedNewControlFrame, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) } m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:]) @@ -205,9 +207,13 @@ func (s *receiveStream) readImpl(p []byte) (bool, int, error) { // when a RESET_STREAM was received, the flow controller was already // informed about the final byteOffset for this stream if !s.cancelledRemotely { - if queueMaxStreamData := s.flowController.AddBytesRead(protocol.ByteCount(m)); queueMaxStreamData { + hasStream, hasConn := s.flowController.AddBytesRead(protocol.ByteCount(m)) + if hasStream { s.queuedMaxStreamData = true - queuedNewControlFrame = true + hasStreamWindowUpdate = true + } + if hasConn { + hasConnWindowUpdate = true } } @@ -217,10 +223,10 @@ func (s *receiveStream) readImpl(p []byte) (bool, int, error) { s.currentFrameDone() } s.errorRead = true - return queuedNewControlFrame, bytesRead, io.EOF + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, io.EOF } } - return queuedNewControlFrame, bytesRead, nil + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, nil } func (s *receiveStream) dequeueNextFrame() { diff --git a/receive_stream_test.go b/receive_stream_test.go index b4ce411dc..a2ef35ece 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -152,7 +152,17 @@ func TestReceiveStreamReadOverlappingData(t *testing.T) { require.Equal(t, []byte{'f', 'o', 'o', 'b', 'a', 'r'}, b) } -func TestReceiveStreamMaxStreamData(t *testing.T) { +func TestReceiveStreamFlowControlUpdates(t *testing.T) { + t.Run("stream", func(t *testing.T) { + testReceiveStreamFlowControlUpdates(t, true, false) + }) + + t.Run("connection", func(t *testing.T) { + testReceiveStreamFlowControlUpdates(t, false, true) + }) +} + +func testReceiveStreamFlowControlUpdates(t *testing.T, hasStreamWindowUpdate, hasConnWindowUpdate bool) { const streamID protocol.StreamID = 42 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) @@ -163,19 +173,31 @@ func TestReceiveStreamMaxStreamData(t *testing.T) { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, now) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte{0xde, 0xad, 0xbe, 0xef}}, now)) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)).Return(true) - mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)).Return(hasStreamWindowUpdate, hasConnWindowUpdate) + if hasStreamWindowUpdate { + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + } + if hasConnWindowUpdate { + mockSender.EXPECT().onHasConnectionData() + } n, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(make([]byte, 3)) require.NoError(t, err) require.Equal(t, 3, n) require.True(t, mockCtrl.Satisfied()) - now = now.Add(time.Second) - mockFC.EXPECT().GetWindowUpdate(now).Return(protocol.ByteCount(1337)) - f, ok, hasMore := str.getControlFrame(now) - require.True(t, ok) - require.Equal(t, &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337}, f.Frame) - require.False(t, hasMore) + if hasStreamWindowUpdate { + now = now.Add(time.Second) + mockFC.EXPECT().GetWindowUpdate(now).Return(protocol.ByteCount(1337)) + f, ok, hasMore := str.getControlFrame(now) + require.True(t, ok) + require.Equal(t, &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337}, f.Frame) + require.False(t, hasMore) + } + if hasConnWindowUpdate { + _, ok, hasMore := str.getControlFrame(now) + require.False(t, ok) + require.False(t, hasMore) + } } func TestReceiveStreamDeadlineInThePast(t *testing.T) { @@ -583,9 +605,9 @@ func TestReceiveStreamConcurrentReads(t *testing.T) { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any(), gomock.Any()).AnyTimes() var bytesRead protocol.ByteCount - mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) bool { + mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) (bool, bool) { bytesRead += n - return false + return false, false }).AnyTimes() var numCompleted atomic.Int32 diff --git a/stream.go b/stream.go index b43a97024..9cd2695dd 100644 --- a/stream.go +++ b/stream.go @@ -24,6 +24,7 @@ var errDeadline net.Error = &deadlineError{} // The streamSender is notified by the stream about various events. type streamSender interface { + onHasConnectionData() onHasStreamData(protocol.StreamID, sendStreamI) onHasStreamControlFrame(protocol.StreamID, streamControlFrameGetter) // must be called without holding the mutex that is acquired by closeForShutdown