diff --git a/session.go b/session.go index af61cd96..0f02e3ea 100644 --- a/session.go +++ b/session.go @@ -393,7 +393,6 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error return err } -// TODO: Handle frame.byteOffset func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if err != nil { @@ -402,6 +401,10 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { if str == nil { return errRstStreamOnInvalidStream } + err = s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) + if err != nil { + return err + } s.closeStreamWithError(str, fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) return nil } diff --git a/session_test.go b/session_test.go index ae41c2be..9749bf31 100644 --- a/session_test.go +++ b/session_test.go @@ -317,6 +317,30 @@ var _ = Describe("Session", func() { Expect(err).To(MatchError("RST_STREAM received with code 42")) }) + It("passes the byte offset to the flow controller", func() { + session.streamsMap.GetOrOpenStream(5) + session.flowControlManager = newMockFlowControlHandler() + err := session.handleRstStreamFrame(&frames.RstStreamFrame{ + StreamID: 5, + ByteOffset: 0x1337, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(session.flowControlManager.(*mockFlowControlHandler).highestReceivedForStream).To(Equal(protocol.StreamID(5))) + Expect(session.flowControlManager.(*mockFlowControlHandler).highestReceived).To(Equal(protocol.ByteCount(0x1337))) + }) + + It("returns errors from the flow controller", func() { + session.streamsMap.GetOrOpenStream(5) + session.flowControlManager = newMockFlowControlHandler() + testErr := errors.New("flow control violation") + session.flowControlManager.(*mockFlowControlHandler).flowControlViolation = testErr + err := session.handleRstStreamFrame(&frames.RstStreamFrame{ + StreamID: 5, + ByteOffset: 0x1337, + }) + Expect(err).To(MatchError(testErr)) + }) + It("ignores the error when the stream is not known", func() { err := session.handleFrames([]frames.Frame{&frames.RstStreamFrame{ StreamID: 5, diff --git a/stream_test.go b/stream_test.go index 9a7caf5d..386cf095 100644 --- a/stream_test.go +++ b/stream_test.go @@ -62,7 +62,7 @@ func (m *mockFlowControlHandler) AddBytesRead(streamID protocol.StreamID, n prot } func (m *mockFlowControlHandler) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - panic("not implemented") + return m.UpdateHighestReceived(streamID, byteOffset) } func (m *mockFlowControlHandler) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {