diff --git a/flowcontrol/flow_control_manager.go b/flowcontrol/flow_control_manager.go index 1ac9f13f..9b2616ff 100644 --- a/flowcontrol/flow_control_manager.go +++ b/flowcontrol/flow_control_manager.go @@ -63,32 +63,43 @@ func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) { // ResetStream should be called when receiving a RstStreamFrame // it updates the byte offset to the value in the RstStreamFrame // streamID must not be 0 here -func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) (protocol.ByteCount, error) { +func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { f.mutex.Lock() defer f.mutex.Unlock() streamFlowController, err := f.getFlowController(streamID) if err != nil { - return 0, err + return err } increment, err := streamFlowController.UpdateHighestReceived(byteOffset) if err != nil { - return 0, qerr.StreamDataAfterTermination + return qerr.StreamDataAfterTermination } if streamFlowController.CheckFlowControlViolation() { - return 0, qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow)) + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow)) } if f.contributesToConnectionFlowControl[streamID] { connectionFlowController := f.streamFlowController[0] connectionFlowController.IncrementHighestReceived(increment) if connectionFlowController.CheckFlowControlViolation() { - return 0, qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow)) + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow)) } } - return streamFlowController.GetBytesSent(), nil + return nil +} + +func (f *flowControlManager) GetBytesSent(streamID protocol.StreamID) (protocol.ByteCount, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + fc, err := f.getFlowController(streamID) + if err != nil { + return 0, err + } + return fc.GetBytesSent(), nil } // UpdateHighestReceived updates the highest received byte offset for a stream diff --git a/flowcontrol/flow_control_manager_test.go b/flowcontrol/flow_control_manager_test.go index b5f87685..d2555f1a 100644 --- a/flowcontrol/flow_control_manager_test.go +++ b/flowcontrol/flow_control_manager_test.go @@ -136,16 +136,14 @@ var _ = Describe("Flow Control Manager", func() { }) It("updates the connection level flow controller if the stream contributes", func() { - bytesSent, err := fcm.ResetStream(4, 0x100) - Expect(bytesSent).To(Equal(protocol.ByteCount(0x42))) + err := fcm.ResetStream(4, 0x100) Expect(err).ToNot(HaveOccurred()) Expect(fcm.streamFlowController[0].highestReceived).To(Equal(protocol.ByteCount(0x100))) Expect(fcm.streamFlowController[4].highestReceived).To(Equal(protocol.ByteCount(0x100))) }) It("does not update the connection level flow controller if the stream does not contribute", func() { - bytesSent, err := fcm.ResetStream(1, 0x100) - Expect(bytesSent).To(Equal(protocol.ByteCount(0x41))) + err := fcm.ResetStream(1, 0x100) Expect(err).ToNot(HaveOccurred()) Expect(fcm.streamFlowController[0].highestReceived).To(BeZero()) Expect(fcm.streamFlowController[1].highestReceived).To(Equal(protocol.ByteCount(0x100))) @@ -154,24 +152,24 @@ var _ = Describe("Flow Control Manager", func() { It("errors if the byteOffset is smaller than a byteOffset that set earlier", func() { err := fcm.UpdateHighestReceived(4, 0x100) Expect(err).ToNot(HaveOccurred()) - _, err = fcm.ResetStream(4, 0x50) + err = fcm.ResetStream(4, 0x50) Expect(err).To(MatchError(qerr.StreamDataAfterTermination)) }) It("returns an error when called with an unknown stream", func() { - _, err := fcm.ResetStream(1337, 0x1337) + err := fcm.ResetStream(1337, 0x1337) Expect(err).To(MatchError(errMapAccess)) }) Context("flow control violations", func() { It("errors when encountering a stream level flow control violation", func() { - _, err := fcm.ResetStream(4, 0x101) + err := fcm.ResetStream(4, 0x101) Expect(err).To(MatchError(qerr.Error(qerr.FlowControlReceivedTooMuchData, "Received 257 bytes on stream 4, allowed 256 bytes"))) // 0x100 = 256, 0x101 = 257 }) It("errors when encountering a connection-level flow control violation", func() { fcm.streamFlowController[4].receiveFlowControlWindow = 0x300 - _, err := fcm.ResetStream(4, 0x201) + err := fcm.ResetStream(4, 0x201) Expect(err).To(MatchError(qerr.Error(qerr.FlowControlReceivedTooMuchData, "Received 513 bytes for the connection, allowed 512 bytes"))) // 0x200 = 512, 0x201 = 513 }) }) diff --git a/flowcontrol/interface.go b/flowcontrol/interface.go index 22f76916..f98f0974 100644 --- a/flowcontrol/interface.go +++ b/flowcontrol/interface.go @@ -13,7 +13,7 @@ type FlowControlManager interface { NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) RemoveStream(streamID protocol.StreamID) // methods needed for receiving data - ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) (protocol.ByteCount, error) + ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error GetWindowUpdates() []WindowUpdate diff --git a/session.go b/session.go index a846ea11..9c39fbf9 100644 --- a/session.go +++ b/session.go @@ -402,20 +402,8 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { return errRstStreamOnInvalidStream } - shouldSendRst := !str.finishedWriteAndSentFin() str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) - bytesSent, err := s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) - if err != nil { - return err - } - - if shouldSendRst { - s.packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{ - StreamID: frame.StreamID, - ByteOffset: bytesSent, - }) - } - return nil + return s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) } func (s *Session) handleAckFrame(frame *frames.AckFrame) error { @@ -625,8 +613,16 @@ func (s *Session) OpenStream(id protocol.StreamID) (utils.Stream, error) { return s.streamsMap.OpenStream(id) } +func (s *Session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { + s.packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{ + StreamID: id, + ByteOffset: offset, + }) + s.scheduleSending() +} + func (s *Session) newStream(id protocol.StreamID) (*stream, error) { - stream, err := newStream(id, s.scheduleSending, s.flowControlManager) + stream, err := newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager) if err != nil { return nil, err } diff --git a/session_test.go b/session_test.go index 0cd123c9..0f876920 100644 --- a/session_test.go +++ b/session_test.go @@ -335,10 +335,9 @@ var _ = Describe("Session", func() { }) It("queues a RST_STERAM frame with the correct offset", func() { - _, err := session.GetOrOpenStream(5) + str, err := session.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - session.flowControlManager = newMockFlowControlHandler() - session.flowControlManager.(*mockFlowControlHandler).bytesSent = 0x1337 + str.(*stream).writeOffset = 0x1337 err = session.handleRstStreamFrame(&frames.RstStreamFrame{ StreamID: 5, }) @@ -392,6 +391,34 @@ var _ = Describe("Session", func() { }}) Expect(err).NotTo(HaveOccurred()) }) + + It("queues a RST_STREAM when a stream gets reset locally", func() { + testErr := errors.New("testErr") + str, err := session.streamsMap.GetOrOpenStream(5) + str.writeOffset = 0x1337 + Expect(err).ToNot(HaveOccurred()) + str.Reset(testErr) + Expect(session.packer.controlFrames).To(HaveLen(1)) + Expect(session.packer.controlFrames[0]).To(Equal(&frames.RstStreamFrame{ + StreamID: 5, + ByteOffset: 0x1337, + })) + }) + + It("doesn't queue another RST_STREAM, when it receives an RST_STREAM as a response for the first", func() { + testErr := errors.New("testErr") + str, err := session.streamsMap.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + str.Reset(testErr) + Expect(session.packer.controlFrames).To(HaveLen(1)) + err = session.handleRstStreamFrame(&frames.RstStreamFrame{ + StreamID: 5, + ByteOffset: 0x42, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(session.packer.controlFrames).To(HaveLen(1)) + }) + }) Context("handling WINDOW_UPDATE frames", func() { diff --git a/stream.go b/stream.go index ee992656..6bf95782 100644 --- a/stream.go +++ b/stream.go @@ -19,6 +19,8 @@ type stream struct { streamID protocol.StreamID onData func() + // onReset is a callback that should send a RST_STREAM + onReset func(protocol.StreamID, protocol.ByteCount) readPosInFrame int writeOffset protocol.ByteCount @@ -43,15 +45,17 @@ type stream struct { dataForWriting []byte finSent utils.AtomicBool + rstSent utils.AtomicBool doneWritingOrErrCond sync.Cond flowControlManager flowcontrol.FlowControlManager } // newStream creates a new Stream -func newStream(StreamID protocol.StreamID, onData func(), flowControlManager flowcontrol.FlowControlManager) (*stream, error) { +func newStream(StreamID protocol.StreamID, onData func(), onReset func(protocol.StreamID, protocol.ByteCount), flowControlManager flowcontrol.FlowControlManager) (*stream, error) { s := &stream{ onData: onData, + onReset: onReset, streamID: StreamID, flowControlManager: flowControlManager, frameQueue: newStreamFrameSorter(), @@ -207,6 +211,13 @@ func (s *stream) Close() error { return nil } +func (s *stream) shouldSendReset() bool { + if s.rstSent.Get() { + return false + } + return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin() +} + func (s *stream) shouldSendFin() bool { s.mutex.Lock() res := s.finishedWriting.Get() && !s.finSent.Get() && s.err == nil && s.dataForWriting == nil @@ -257,6 +268,9 @@ func (s *stream) Cancel(err error) { // resets the stream locally func (s *stream) Reset(err error) { + if s.resetLocally.Get() { + return + } s.mutex.Lock() s.resetLocally.Set(true) // errors must not be changed! @@ -265,11 +279,18 @@ func (s *stream) Reset(err error) { s.newFrameOrErrCond.Signal() s.doneWritingOrErrCond.Signal() } + if s.shouldSendReset() { + s.onReset(s.streamID, s.writeOffset) + s.rstSent.Set(true) + } s.mutex.Unlock() } // resets the stream remotely func (s *stream) RegisterRemoteError(err error) { + if s.resetRemotely.Get() { + return + } s.mutex.Lock() s.resetRemotely.Set(true) // errors must not be changed! @@ -277,6 +298,10 @@ func (s *stream) RegisterRemoteError(err error) { s.err = err s.doneWritingOrErrCond.Signal() } + if s.shouldSendReset() { + s.onReset(s.streamID, s.writeOffset) + s.rstSent.Set(true) + } s.mutex.Unlock() } diff --git a/stream_test.go b/stream_test.go index 422faeec..fac305ef 100644 --- a/stream_test.go +++ b/stream_test.go @@ -31,6 +31,8 @@ type mockFlowControlHandler struct { triggerConnectionWindowUpdate bool } +var _ flowcontrol.FlowControlManager = &mockFlowControlHandler{} + func newMockFlowControlHandler() *mockFlowControlHandler { return &mockFlowControlHandler{ sendWindowSizes: make(map[protocol.StreamID]protocol.ByteCount), @@ -61,8 +63,8 @@ func (m *mockFlowControlHandler) AddBytesRead(streamID protocol.StreamID, n prot return nil } -func (m *mockFlowControlHandler) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) (protocol.ByteCount, error) { - return m.bytesSent, m.UpdateHighestReceived(streamID, byteOffset) +func (m *mockFlowControlHandler) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { + return m.UpdateHighestReceived(streamID, byteOffset) } func (m *mockFlowControlHandler) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { @@ -108,19 +110,30 @@ var _ = Describe("Stream", func() { var ( str *stream onDataCalled bool + + resetCalled bool + resetCalledForStream protocol.StreamID + resetCalledAtOffset protocol.ByteCount ) onData := func() { onDataCalled = true } + onReset := func(id protocol.StreamID, offset protocol.ByteCount) { + resetCalled = true + resetCalledForStream = id + resetCalledAtOffset = offset + } + BeforeEach(func() { onDataCalled = false + resetCalled = false var streamID protocol.StreamID = 1337 cpm := &mockConnectionParametersManager{} flowControlManager := flowcontrol.NewFlowControlManager(cpm, &congestion.RTTStats{}) flowControlManager.NewStream(streamID, true) - str, _ = newStream(streamID, onData, flowControlManager) + str, _ = newStream(streamID, onData, onReset, flowControlManager) }) It("gets stream id", func() { @@ -432,6 +445,43 @@ var _ = Describe("Stream", func() { Expect(n).To(BeZero()) Expect(err).To(MatchError(testErr)) }) + + It("calls onReset when receiving a remote error", func() { + var writeReturned bool + str.writeOffset = 0x1000 + go func() { + str.Write([]byte("foobar")) + writeReturned = true + }() + str.RegisterRemoteError(testErr) + Expect(resetCalled).To(BeTrue()) + Expect(resetCalledForStream).To(Equal(protocol.StreamID(1337))) + Expect(resetCalledAtOffset).To(Equal(protocol.ByteCount(0x1000))) + Eventually(func() bool { return writeReturned }).Should(BeTrue()) + }) + + It("doesn't call onReset if it already sent a FIN", func() { + str.Close() + str.sentFin() + str.RegisterRemoteError(testErr) + Expect(resetCalled).To(BeFalse()) + }) + + It("doesn't call onReset if the stream was reset locally before", func() { + str.Reset(testErr) + Expect(resetCalled).To(BeTrue()) + resetCalled = false + str.RegisterRemoteError(testErr) + Expect(resetCalled).To(BeFalse()) + }) + + It("doesn't call onReset twice, when it gets two remote errors", func() { + str.RegisterRemoteError(testErr) + Expect(resetCalled).To(BeTrue()) + resetCalled = false + str.RegisterRemoteError(testErr) + Expect(resetCalled).To(BeFalse()) + }) }) Context("reset locally", func() { @@ -446,8 +496,7 @@ var _ = Describe("Stream", func() { }() Consistently(func() bool { return writeReturned }).Should(BeFalse()) str.Reset(testErr) - data := str.getDataForWriting(6) - Expect(data).To(BeNil()) + Expect(str.getDataForWriting(6)).To(BeNil()) Eventually(func() bool { return writeReturned }).Should(BeTrue()) Expect(n).To(BeZero()) Expect(err).To(MatchError(testErr)) @@ -458,6 +507,7 @@ var _ = Describe("Stream", func() { n, err := str.Write([]byte("foobar")) Expect(n).To(BeZero()) Expect(err).To(MatchError(testErr)) + Expect(str.getDataForWriting(6)).To(BeNil()) }) It("stops reading", func() { @@ -487,6 +537,37 @@ var _ = Describe("Stream", func() { Expect(n).To(BeZero()) Expect(err).To(MatchError(testErr)) }) + + It("calls onReset", func() { + str.writeOffset = 0x1000 + str.Reset(testErr) + Expect(resetCalled).To(BeTrue()) + Expect(resetCalledForStream).To(Equal(protocol.StreamID(1337))) + Expect(resetCalledAtOffset).To(Equal(protocol.ByteCount(0x1000))) + }) + + It("doesn't call onReset if it already sent a FIN", func() { + str.Close() + str.sentFin() + str.Reset(testErr) + Expect(resetCalled).To(BeFalse()) + }) + + It("doesn't call onReset if the stream was reset remotely before", func() { + str.RegisterRemoteError(testErr) + Expect(resetCalled).To(BeTrue()) + resetCalled = false + str.Reset(testErr) + Expect(resetCalled).To(BeFalse()) + }) + + It("doesn't call onReset twice", func() { + str.Reset(testErr) + Expect(resetCalled).To(BeTrue()) + resetCalled = false + str.Reset(testErr) + Expect(resetCalled).To(BeFalse()) + }) }) })