diff --git a/flow_controller.go b/flow_controller.go new file mode 100644 index 000000000..ad4413aba --- /dev/null +++ b/flow_controller.go @@ -0,0 +1,70 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/handshake" + "github.com/lucas-clemente/quic-go/protocol" +) + +type flowController struct { + streamID protocol.StreamID + + bytesSent protocol.ByteCount + sendFlowControlWindow protocol.ByteCount + + bytesRead protocol.ByteCount + receiveWindowUpdateThreshold protocol.ByteCount + receiveFlowControlWindow protocol.ByteCount + receiveFlowControlWindowIncrement protocol.ByteCount +} + +func newFlowController(connectionParametersManager *handshake.ConnectionParametersManager) *flowController { + return &flowController{ + sendFlowControlWindow: connectionParametersManager.GetSendStreamFlowControlWindow(), + receiveFlowControlWindow: connectionParametersManager.GetReceiveStreamFlowControlWindow(), + receiveWindowUpdateThreshold: protocol.WindowUpdateThreshold, + receiveFlowControlWindowIncrement: protocol.ReceiveStreamFlowControlWindowIncrement, + } +} + +func (c *flowController) AddBytesSent(n protocol.ByteCount) { + c.bytesSent += n +} + +// UpdateSendWindow should be called after receiving a WindowUpdateFrame +// it returns true if the window was actually updated +func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool { + if newOffset > c.sendFlowControlWindow { + c.sendFlowControlWindow = newOffset + return true + } + return false +} + +func (c *flowController) SendWindowSize() protocol.ByteCount { + if c.bytesSent > c.sendFlowControlWindow { // should never happen, but make sure we don't do an underflow here + return 0 + } + return c.sendFlowControlWindow - c.bytesSent +} + +func (c *flowController) AddBytesRead(n protocol.ByteCount) { + c.bytesRead += n +} + +// MaybeTriggerWindowUpdate determines if it is necessary to send a WindowUpdate +// if so, it returns true and the offset of the window +func (c *flowController) MaybeTriggerWindowUpdate() (bool, protocol.ByteCount) { + diff := c.receiveFlowControlWindow - c.bytesRead + if diff < c.receiveWindowUpdateThreshold { + c.receiveFlowControlWindow += c.receiveFlowControlWindowIncrement + return true, c.bytesRead + c.receiveFlowControlWindowIncrement + } + return false, 0 +} + +func (c *flowController) CheckFlowControlViolation(highestByte protocol.ByteCount) bool { + if highestByte > c.receiveFlowControlWindow { + return true + } + return false +} diff --git a/flow_controller_test.go b/flow_controller_test.go new file mode 100644 index 000000000..88142bf67 --- /dev/null +++ b/flow_controller_test.go @@ -0,0 +1,86 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Flow controller", func() { + var controller *flowController + + BeforeEach(func() { + controller = &flowController{} + }) + + Context("send flow control", func() { + It("adds bytes sent", func() { + controller.bytesSent = 5 + controller.AddBytesSent(6) + Expect(controller.bytesSent).To(Equal(protocol.ByteCount(5 + 6))) + }) + + It("gets the size of the remaining flow control window", func() { + controller.bytesSent = 5 + controller.sendFlowControlWindow = 12 + Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(12 - 5))) + }) + + It("updates the size of the flow control window", func() { + controller.bytesSent = 5 + updateSuccessful := controller.UpdateSendWindow(15) + Expect(updateSuccessful).To(BeTrue()) + Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(15 - 5))) + }) + + It("does not decrease the flow control window", func() { + updateSuccessful := controller.UpdateSendWindow(20) + Expect(updateSuccessful).To(BeTrue()) + Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(20))) + updateSuccessful = controller.UpdateSendWindow(10) + Expect(updateSuccessful).To(BeFalse()) + Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(20))) + }) + }) + + Context("receive flow control", func() { + var receiveFlowControlWindow protocol.ByteCount = 1337 + var receiveWindowUpdateThreshold protocol.ByteCount = 500 + var receiveFlowControlWindowIncrement protocol.ByteCount = 600 + + BeforeEach(func() { + controller.receiveFlowControlWindow = receiveFlowControlWindow + controller.receiveWindowUpdateThreshold = receiveWindowUpdateThreshold + controller.receiveFlowControlWindowIncrement = receiveFlowControlWindowIncrement + }) + + It("adds bytes read", func() { + controller.bytesRead = 5 + controller.AddBytesRead(6) + Expect(controller.bytesRead).To(Equal(protocol.ByteCount(5 + 6))) + }) + + It("triggers a window update when necessary", func() { + readPosition := receiveFlowControlWindow - receiveWindowUpdateThreshold + 1 + controller.bytesRead = readPosition + updateNecessary, offset := controller.MaybeTriggerWindowUpdate() + Expect(updateNecessary).To(BeTrue()) + Expect(offset).To(Equal(readPosition + receiveFlowControlWindowIncrement)) + }) + + It("triggers a window update when not necessary", func() { + readPosition := receiveFlowControlWindow - receiveWindowUpdateThreshold - 1 + controller.bytesRead = readPosition + updateNecessary, _ := controller.MaybeTriggerWindowUpdate() + Expect(updateNecessary).To(BeFalse()) + }) + + It("detects a flow control violation", func() { + Expect(controller.CheckFlowControlViolation(receiveFlowControlWindow + 1)).To(BeTrue()) + }) + + It("does not give a flow control violation when using the window completely", func() { + Expect(controller.CheckFlowControlViolation(receiveFlowControlWindow)).To(BeFalse()) + }) + }) +}) diff --git a/stream.go b/stream.go index 975a133ec..7e14335e0 100644 --- a/stream.go +++ b/stream.go @@ -40,26 +40,22 @@ type stream struct { frameQueue streamFrameSorter newFrameOrErrCond sync.Cond - sendFlowControlWindow protocol.ByteCount - receiveFlowControlWindow protocol.ByteCount - receiveFlowControlWindowIncrement protocol.ByteCount - windowUpdateOrErrCond sync.Cond + flowController *flowController + + windowUpdateOrErrCond sync.Cond } // newStream creates a new Stream func newStream(session streamHandler, connectionParameterManager *handshake.ConnectionParametersManager, StreamID protocol.StreamID) (*stream, error) { s := &stream{ - session: session, - streamID: StreamID, + session: session, + streamID: StreamID, + flowController: newFlowController(connectionParameterManager), } s.newFrameOrErrCond.L = &s.mutex s.windowUpdateOrErrCond.L = &s.mutex - s.sendFlowControlWindow = connectionParameterManager.GetSendStreamFlowControlWindow() - s.receiveFlowControlWindow = connectionParameterManager.GetReceiveStreamFlowControlWindow() - s.receiveFlowControlWindowIncrement = protocol.ReceiveStreamFlowControlWindowIncrement - return s, nil } @@ -112,9 +108,11 @@ func (s *stream) Read(p []byte) (int, error) { m := utils.Min(len(p)-bytesRead, len(frame.Data)-s.readPosInFrame) copy(p[bytesRead:], frame.Data[s.readPosInFrame:]) + s.readPosInFrame += m bytesRead += m s.readOffset += protocol.ByteCount(m) + s.flowController.AddBytesRead(protocol.ByteCount(m)) s.maybeTriggerWindowUpdate() @@ -140,17 +138,11 @@ func (s *stream) ReadByte() (byte, error) { return p[0], err } -func (s *stream) updateReceiveFlowControlWindow() { - n := s.readOffset + s.receiveFlowControlWindowIncrement - s.receiveFlowControlWindow = n - s.session.updateReceiveFlowControlWindow(s.streamID, n) -} - func (s *stream) UpdateSendFlowControlWindow(n protocol.ByteCount) { s.mutex.Lock() defer s.mutex.Unlock() - if n > s.sendFlowControlWindow { - s.sendFlowControlWindow = n + + if s.flowController.UpdateSendWindow(n) { s.windowUpdateOrErrCond.Broadcast() } } @@ -168,10 +160,10 @@ func (s *stream) Write(p []byte) (int, error) { for dataWritten < len(p) { s.mutex.Lock() - remainingBytesInWindow := int64(s.sendFlowControlWindow) - int64(s.writeOffset) + remainingBytesInWindow := s.flowController.SendWindowSize() for remainingBytesInWindow == 0 && s.err == nil { s.windowUpdateOrErrCond.Wait() - remainingBytesInWindow = int64(s.sendFlowControlWindow) - int64(s.writeOffset) + remainingBytesInWindow = s.flowController.SendWindowSize() } s.mutex.Unlock() @@ -193,6 +185,7 @@ func (s *stream) Write(p []byte) (int, error) { } dataWritten += dataLen + s.flowController.AddBytesSent(protocol.ByteCount(dataLen)) s.writeOffset += protocol.ByteCount(dataLen) } @@ -212,7 +205,7 @@ func (s *stream) Close() error { // AddStreamFrame adds a new stream frame func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { maxOffset := frame.Offset + protocol.ByteCount(len(frame.Data)) - if maxOffset > s.receiveFlowControlWindow { + if s.flowController.CheckFlowControlViolation(maxOffset) { return errFlowControlViolation } @@ -224,9 +217,10 @@ func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { } func (s *stream) maybeTriggerWindowUpdate() { - diff := s.receiveFlowControlWindow - s.readOffset - if diff < protocol.WindowUpdateThreshold { - s.updateReceiveFlowControlWindow() + doUpdate, byteOffset := s.flowController.MaybeTriggerWindowUpdate() + + if doUpdate { + s.session.updateReceiveFlowControlWindow(s.streamID, byteOffset) } } diff --git a/stream_test.go b/stream_test.go index bcfdf818a..d34f58327 100644 --- a/stream_test.go +++ b/stream_test.go @@ -15,6 +15,9 @@ import ( type mockStreamHandler struct { frames []frames.Frame + + receiveFlowControlWindowCalled bool + receiveFlowControlWindowCalledForStream protocol.StreamID } func (m *mockStreamHandler) queueStreamFrame(f *frames.StreamFrame) error { @@ -23,6 +26,8 @@ func (m *mockStreamHandler) queueStreamFrame(f *frames.StreamFrame) error { } func (m *mockStreamHandler) updateReceiveFlowControlWindow(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { + m.receiveFlowControlWindowCalled = true + m.receiveFlowControlWindowCalledForStream = streamID return nil } @@ -277,7 +282,7 @@ var _ = Describe("Stream", func() { Context("flow control", func() { It("writes everything if the flow control window is big enough", func() { - str.sendFlowControlWindow = 4 + str.flowController.sendFlowControlWindow = 4 n, err := str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD}) Expect(n).To(Equal(4)) Expect(err).ToNot(HaveOccurred()) @@ -285,7 +290,7 @@ var _ = Describe("Stream", func() { It("waits for a flow control window update", func() { var b bool - str.sendFlowControlWindow = 1 + str.flowController.sendFlowControlWindow = 1 _, err := str.Write([]byte{0x42}) Expect(err).ToNot(HaveOccurred()) @@ -301,7 +306,7 @@ var _ = Describe("Stream", func() { }) It("splits writing of frames when given more data than the flow control windows size", func() { - str.sendFlowControlWindow = 2 + str.flowController.sendFlowControlWindow = 2 var b bool go func() { @@ -319,7 +324,7 @@ var _ = Describe("Stream", func() { It("writes after a flow control window update", func() { var b bool - str.sendFlowControlWindow = 1 + str.flowController.sendFlowControlWindow = 1 _, err := str.Write([]byte{0x42}) Expect(err).ToNot(HaveOccurred()) @@ -336,7 +341,7 @@ var _ = Describe("Stream", func() { It("immediately returns on remote errors", func() { var b bool - str.sendFlowControlWindow = 1 + str.flowController.sendFlowControlWindow = 1 testErr := errors.New("test error") @@ -353,23 +358,16 @@ var _ = Describe("Stream", func() { }) }) - Context("flow control window updating, for sending", func() { - It("updates the flow control window", func() { - str.sendFlowControlWindow = 3 - str.UpdateSendFlowControlWindow(4) - Expect(str.sendFlowControlWindow).To(Equal(protocol.ByteCount(4))) - }) - - It("never shrinks the flow control window", func() { - str.sendFlowControlWindow = 100 - str.UpdateSendFlowControlWindow(50) - Expect(str.sendFlowControlWindow).To(Equal(protocol.ByteCount(100))) - }) - }) - Context("flow control window updating, for receiving", func() { + var receiveFlowControlWindow protocol.ByteCount = 1337 + var receiveWindowUpdateThreshold protocol.ByteCount = 1000 + BeforeEach(func() { + str.flowController.receiveFlowControlWindow = receiveFlowControlWindow + str.flowController.receiveWindowUpdateThreshold = receiveWindowUpdateThreshold + }) + It("updates the flow control window", func() { - len := int(protocol.WindowUpdateThreshold) + 1 + len := int(receiveFlowControlWindow) - int(receiveWindowUpdateThreshold) + 1 frame := frames.StreamFrame{ Offset: 0, Data: bytes.Repeat([]byte{'f'}, len), @@ -380,12 +378,12 @@ var _ = Describe("Stream", func() { n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(len)) - Expect(str.receiveFlowControlWindow).To(Equal(protocol.ByteCount(len) + str.receiveFlowControlWindowIncrement)) + Expect(handler.receiveFlowControlWindowCalled).To(BeTrue()) + Expect(handler.receiveFlowControlWindowCalledForStream).To(Equal(str.streamID)) }) It("does not update the flow control window when not enough data was received", func() { - len := int(protocol.WindowUpdateThreshold) - 1 - receiveFlowControlWindow := str.receiveFlowControlWindow + len := int(receiveFlowControlWindow) - int(receiveWindowUpdateThreshold) - 1 frame := frames.StreamFrame{ Offset: 0, Data: bytes.Repeat([]byte{'f'}, len), @@ -396,11 +394,11 @@ var _ = Describe("Stream", func() { n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(len)) - Expect(str.receiveFlowControlWindow).To(Equal(receiveFlowControlWindow)) + Expect(handler.receiveFlowControlWindowCalled).To(BeFalse()) }) It("accepts frames that completely fill the flow control window", func() { - len := int(protocol.ReceiveStreamFlowControlWindow) + len := int(receiveFlowControlWindow) frame := frames.StreamFrame{ Offset: 0, Data: bytes.Repeat([]byte{'f'}, len), @@ -418,15 +416,6 @@ var _ = Describe("Stream", func() { err := str.AddStreamFrame(&frame) Expect(err).To(MatchError(errFlowControlViolation)) }) - - It("rejects a small frames that would violate the flow control window", func() { - frame := frames.StreamFrame{ - Offset: protocol.ReceiveStreamFlowControlWindow - 1, - Data: []byte{0x13, 0x37}, - } - err := str.AddStreamFrame(&frame) - Expect(err).To(MatchError(errFlowControlViolation)) - }) }) Context("closing", func() {