From 8ce763682a53f47205d4c903e08a96f7d12fddda Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 3 Jan 2017 17:51:56 +0700 Subject: [PATCH] send a RstStreamFrame when receiving a RstStreamFrame on an open stream fixes #378 --- session.go | 15 ++++++++++++--- session_test.go | 27 +++++++++++++++++++++++++++ stream_test.go | 2 +- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/session.go b/session.go index 2d1d5460..a8ff25dc 100644 --- a/session.go +++ b/session.go @@ -317,7 +317,6 @@ func (s *Session) handleFrames(fs []frames.Frame) error { switch frame := ff.(type) { case *frames.StreamFrame: err = s.handleStreamFrame(frame) - // TODO: send RstStreamFrame case *frames.AckFrame: err = s.handleAckFrame(frame) case *frames.ConnectionCloseFrame: @@ -369,7 +368,8 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { return err } if str == nil { - // Stream is closed, ignore + // Stream is closed and already garbage collected + // ignore this StreamFrame return nil } err = str.AddStreamFrame(frame) @@ -401,11 +401,20 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { if str == nil { return errRstStreamOnInvalidStream } + + shouldSendRst := !str.finishedWriting() s.closeStreamWithError(str, fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) - _, err = s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) + bytesSent, err := s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) if err != nil { return err } + + if shouldSendRst { + s.packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{ + StreamID: frame.StreamID, + ByteOffset: bytesSent, + }) + } return nil } diff --git a/session_test.go b/session_test.go index 9749bf31..820a1a7a 100644 --- a/session_test.go +++ b/session_test.go @@ -317,6 +317,33 @@ var _ = Describe("Session", func() { Expect(err).To(MatchError("RST_STREAM received with code 42")) }) + It("queues a RST_STERAM frame with the correct offset", func() { + _, err := session.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + session.flowControlManager = newMockFlowControlHandler() + session.flowControlManager.(*mockFlowControlHandler).bytesSent = 0x1337 + err = session.handleRstStreamFrame(&frames.RstStreamFrame{ + StreamID: 5, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(session.packer.controlFrames).To(HaveLen(1)) + Expect(session.packer.controlFrames[0].(*frames.RstStreamFrame)).To(Equal(&frames.RstStreamFrame{ + StreamID: 5, + ByteOffset: 0x1337, + })) + }) + + It("doesn't queue a RST_STREAM for a stream that it already sent a FIN on", func() { + str, err := session.GetOrOpenStream(5) + str.(*stream).sentFin() + str.Close() + err = session.handleRstStreamFrame(&frames.RstStreamFrame{ + StreamID: 5, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(session.packer.controlFrames).To(BeEmpty()) + }) + It("passes the byte offset to the flow controller", func() { session.streamsMap.GetOrOpenStream(5) session.flowControlManager = newMockFlowControlHandler() diff --git a/stream_test.go b/stream_test.go index 5d279e1f..2b59dc32 100644 --- a/stream_test.go +++ b/stream_test.go @@ -62,7 +62,7 @@ func (m *mockFlowControlHandler) AddBytesRead(streamID protocol.StreamID, n prot } func (m *mockFlowControlHandler) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) (protocol.ByteCount, error) { - return 0, m.UpdateHighestReceived(streamID, byteOffset) + return m.bytesSent, m.UpdateHighestReceived(streamID, byteOffset) } func (m *mockFlowControlHandler) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {