diff --git a/flowcontrol/flow_control_manager.go b/flowcontrol/flow_control_manager.go index fa2d46ae..1ac9f13f 100644 --- a/flowcontrol/flow_control_manager.go +++ b/flowcontrol/flow_control_manager.go @@ -63,32 +63,32 @@ func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) { // ResetStream should be called when receiving a RstStreamFrame // it updates the byte offset to the value in the RstStreamFrame // streamID must not be 0 here -func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { +func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) (protocol.ByteCount, error) { f.mutex.Lock() defer f.mutex.Unlock() streamFlowController, err := f.getFlowController(streamID) if err != nil { - return err + return 0, err } increment, err := streamFlowController.UpdateHighestReceived(byteOffset) if err != nil { - return qerr.StreamDataAfterTermination + return 0, qerr.StreamDataAfterTermination } if streamFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow)) + return 0, qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow)) } if f.contributesToConnectionFlowControl[streamID] { connectionFlowController := f.streamFlowController[0] connectionFlowController.IncrementHighestReceived(increment) if connectionFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow)) + return 0, qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow)) } } - return nil + return streamFlowController.GetBytesSent(), nil } // UpdateHighestReceived updates the highest received byte offset for a stream diff --git a/flowcontrol/flow_control_manager_test.go b/flowcontrol/flow_control_manager_test.go index 14cd3c39..b5f87685 100644 --- a/flowcontrol/flow_control_manager_test.go +++ b/flowcontrol/flow_control_manager_test.go @@ -131,17 +131,21 @@ var _ = Describe("Flow Control Manager", func() { fcm.NewStream(1, false) fcm.NewStream(4, true) fcm.NewStream(6, true) + fcm.streamFlowController[1].bytesSent = 0x41 + fcm.streamFlowController[4].bytesSent = 0x42 }) It("updates the connection level flow controller if the stream contributes", func() { - err := fcm.ResetStream(4, 0x100) + bytesSent, err := fcm.ResetStream(4, 0x100) + Expect(bytesSent).To(Equal(protocol.ByteCount(0x42))) Expect(err).ToNot(HaveOccurred()) Expect(fcm.streamFlowController[0].highestReceived).To(Equal(protocol.ByteCount(0x100))) Expect(fcm.streamFlowController[4].highestReceived).To(Equal(protocol.ByteCount(0x100))) }) It("does not update the connection level flow controller if the stream does not contribute", func() { - err := fcm.ResetStream(1, 0x100) + bytesSent, err := fcm.ResetStream(1, 0x100) + Expect(bytesSent).To(Equal(protocol.ByteCount(0x41))) Expect(err).ToNot(HaveOccurred()) Expect(fcm.streamFlowController[0].highestReceived).To(BeZero()) Expect(fcm.streamFlowController[1].highestReceived).To(Equal(protocol.ByteCount(0x100))) @@ -150,24 +154,24 @@ var _ = Describe("Flow Control Manager", func() { It("errors if the byteOffset is smaller than a byteOffset that set earlier", func() { err := fcm.UpdateHighestReceived(4, 0x100) Expect(err).ToNot(HaveOccurred()) - err = fcm.ResetStream(4, 0x50) + _, err = fcm.ResetStream(4, 0x50) Expect(err).To(MatchError(qerr.StreamDataAfterTermination)) }) It("returns an error when called with an unknown stream", func() { - err := fcm.ResetStream(1337, 0x1337) + _, err := fcm.ResetStream(1337, 0x1337) Expect(err).To(MatchError(errMapAccess)) }) Context("flow control violations", func() { It("errors when encountering a stream level flow control violation", func() { - err := fcm.ResetStream(4, 0x101) + _, err := fcm.ResetStream(4, 0x101) Expect(err).To(MatchError(qerr.Error(qerr.FlowControlReceivedTooMuchData, "Received 257 bytes on stream 4, allowed 256 bytes"))) // 0x100 = 256, 0x101 = 257 }) It("errors when encountering a connection-level flow control violation", func() { fcm.streamFlowController[4].receiveFlowControlWindow = 0x300 - err := fcm.ResetStream(4, 0x201) + _, err := fcm.ResetStream(4, 0x201) Expect(err).To(MatchError(qerr.Error(qerr.FlowControlReceivedTooMuchData, "Received 513 bytes for the connection, allowed 512 bytes"))) // 0x200 = 512, 0x201 = 513 }) }) diff --git a/flowcontrol/flow_controller.go b/flowcontrol/flow_controller.go index f5becf09..6acd8d35 100644 --- a/flowcontrol/flow_controller.go +++ b/flowcontrol/flow_controller.go @@ -66,6 +66,10 @@ func (c *flowController) AddBytesSent(n protocol.ByteCount) { c.bytesSent += n } +func (c *flowController) GetBytesSent() protocol.ByteCount { + return c.bytesSent +} + // UpdateSendWindow should be called after receiving a WindowUpdateFrame // it returns true if the window was actually updated func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool { diff --git a/flowcontrol/flow_controller_test.go b/flowcontrol/flow_controller_test.go index 4df40618..4473ec71 100644 --- a/flowcontrol/flow_controller_test.go +++ b/flowcontrol/flow_controller_test.go @@ -108,6 +108,11 @@ var _ = Describe("Flow controller", func() { Expect(controller.bytesSent).To(Equal(protocol.ByteCount(5 + 6))) }) + It("gets the bytesSent", func() { + controller.bytesSent = 8 + Expect(controller.GetBytesSent()).To(Equal(protocol.ByteCount(8))) + }) + It("gets the size of the remaining flow control window", func() { controller.bytesSent = 5 controller.sendFlowControlWindow = 12 diff --git a/flowcontrol/interface.go b/flowcontrol/interface.go index f98f0974..22f76916 100644 --- a/flowcontrol/interface.go +++ b/flowcontrol/interface.go @@ -13,7 +13,7 @@ type FlowControlManager interface { NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) RemoveStream(streamID protocol.StreamID) // methods needed for receiving data - ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error + ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) (protocol.ByteCount, error) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error GetWindowUpdates() []WindowUpdate diff --git a/session.go b/session.go index 0f02e3ea..2d1d5460 100644 --- a/session.go +++ b/session.go @@ -401,11 +401,11 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { if str == nil { return errRstStreamOnInvalidStream } - err = s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) + s.closeStreamWithError(str, fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) + _, err = s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) if err != nil { return err } - s.closeStreamWithError(str, fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) return nil } diff --git a/stream_test.go b/stream_test.go index c5ca23a6..5d279e1f 100644 --- a/stream_test.go +++ b/stream_test.go @@ -61,8 +61,8 @@ func (m *mockFlowControlHandler) AddBytesRead(streamID protocol.StreamID, n prot return nil } -func (m *mockFlowControlHandler) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - return m.UpdateHighestReceived(streamID, byteOffset) +func (m *mockFlowControlHandler) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) (protocol.ByteCount, error) { + return 0, m.UpdateHighestReceived(streamID, byteOffset) } func (m *mockFlowControlHandler) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {