diff --git a/flow_controller.go b/flow_controller.go index bbc18daff..d23b3ace3 100644 --- a/flow_controller.go +++ b/flow_controller.go @@ -15,6 +15,7 @@ type flowController struct { lastBlockedSentForOffset protocol.ByteCount bytesRead protocol.ByteCount + highestReceived protocol.ByteCount receiveWindowUpdateThreshold protocol.ByteCount receiveFlowControlWindow protocol.ByteCount receiveFlowControlWindowIncrement protocol.ByteCount @@ -61,6 +62,15 @@ func (c *flowController) SendWindowSize() protocol.ByteCount { return c.sendFlowControlWindow - c.bytesSent } +func (c *flowController) UpdateHighestReceived(n protocol.ByteCount) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if n > c.highestReceived { + c.highestReceived = n + } +} + func (c *flowController) AddBytesRead(n protocol.ByteCount) { c.mutex.Lock() defer c.mutex.Unlock() @@ -100,11 +110,11 @@ func (c *flowController) MaybeTriggerWindowUpdate() (bool, protocol.ByteCount) { return false, 0 } -func (c *flowController) CheckFlowControlViolation(highestByte protocol.ByteCount) bool { +func (c *flowController) CheckFlowControlViolation() bool { c.mutex.Lock() defer c.mutex.Unlock() - if highestByte > c.receiveFlowControlWindow { + if c.highestReceived > c.receiveFlowControlWindow { return true } return false diff --git a/flow_controller_test.go b/flow_controller_test.go index 5609f74c6..651d8b00f 100644 --- a/flow_controller_test.go +++ b/flow_controller_test.go @@ -103,12 +103,26 @@ var _ = Describe("Flow controller", func() { Expect(updateNecessary).To(BeFalse()) }) + It("updates the highestReceived", func() { + controller.highestReceived = 1337 + controller.UpdateHighestReceived(1338) + Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1338))) + }) + + It("does not decrease the highestReceived", func() { + controller.highestReceived = 1337 + controller.UpdateHighestReceived(1000) + Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337))) + }) + It("detects a flow control violation", func() { - Expect(controller.CheckFlowControlViolation(receiveFlowControlWindow + 1)).To(BeTrue()) + controller.UpdateHighestReceived(receiveFlowControlWindow + 1) + Expect(controller.CheckFlowControlViolation()).To(BeTrue()) }) It("does not give a flow control violation when using the window completely", func() { - Expect(controller.CheckFlowControlViolation(receiveFlowControlWindow)).To(BeFalse()) + controller.UpdateHighestReceived(receiveFlowControlWindow) + Expect(controller.CheckFlowControlViolation()).To(BeFalse()) }) }) }) diff --git a/stream.go b/stream.go index c2719349d..50745e36b 100644 --- a/stream.go +++ b/stream.go @@ -203,7 +203,8 @@ func (s *stream) Close() error { // AddStreamFrame adds a new stream frame func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { maxOffset := frame.Offset + protocol.ByteCount(len(frame.Data)) - if s.flowController.CheckFlowControlViolation(maxOffset) { + s.flowController.UpdateHighestReceived(maxOffset) + if s.flowController.CheckFlowControlViolation() { return errFlowControlViolation }