diff --git a/internal/mocks/stream.go b/internal/mocks/stream.go index f50574d9..107d9177 100644 --- a/internal/mocks/stream.go +++ b/internal/mocks/stream.go @@ -104,6 +104,18 @@ func (_mr *MockStreamIMockRecorder) HandleMaxStreamDataFrame(arg0 interface{}) * return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "HandleMaxStreamDataFrame", reflect.TypeOf((*MockStreamI)(nil).HandleMaxStreamDataFrame), arg0) } +// HandleRstStreamFrame mocks base method +func (_m *MockStreamI) HandleRstStreamFrame(_param0 *wire.RstStreamFrame) error { + ret := _m.ctrl.Call(_m, "HandleRstStreamFrame", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleRstStreamFrame indicates an expected call of HandleRstStreamFrame +func (_mr *MockStreamIMockRecorder) HandleRstStreamFrame(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "HandleRstStreamFrame", reflect.TypeOf((*MockStreamI)(nil).HandleRstStreamFrame), arg0) +} + // HandleStreamFrame mocks base method func (_m *MockStreamI) HandleStreamFrame(_param0 *wire.StreamFrame) error { ret := _m.ctrl.Call(_m, "HandleStreamFrame", _param0) @@ -153,18 +165,6 @@ func (_mr *MockStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call { return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), arg0) } -// RegisterRemoteError mocks base method -func (_m *MockStreamI) RegisterRemoteError(_param0 error, _param1 protocol.ByteCount) error { - ret := _m.ctrl.Call(_m, "RegisterRemoteError", _param0, _param1) - ret0, _ := ret[0].(error) - return ret0 -} - -// RegisterRemoteError indicates an expected call of RegisterRemoteError -func (_mr *MockStreamIMockRecorder) RegisterRemoteError(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "RegisterRemoteError", reflect.TypeOf((*MockStreamI)(nil).RegisterRemoteError), arg0, arg1) -} - // Reset mocks base method func (_m *MockStreamI) Reset(_param0 error) { _m.ctrl.Call(_m, "Reset", _param0) diff --git a/session.go b/session.go index 62aafc07..8b689290 100644 --- a/session.go +++ b/session.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "errors" - "fmt" "net" "sync" "time" @@ -613,7 +612,7 @@ func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { // stream is closed and already garbage collected return nil } - return str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode), frame.ByteOffset) + return str.HandleRstStreamFrame(frame) } func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { diff --git a/session_test.go b/session_test.go index eabc6d2f..62a9e89a 100644 --- a/session_test.go +++ b/session_test.go @@ -343,29 +343,28 @@ var _ = Describe("Session", func() { Context("handling RST_STREAM frames", func() { It("closes the streams for writing", func() { - str, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - str.(*mocks.MockStreamI).EXPECT().RegisterRemoteError( - errors.New("RST_STREAM received with code 42"), - protocol.ByteCount(0x1337), - ) - err = sess.handleRstStreamFrame(&wire.RstStreamFrame{ + f := &wire.RstStreamFrame{ StreamID: 5, ErrorCode: 42, ByteOffset: 0x1337, - }) + } + str, err := sess.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + str.(*mocks.MockStreamI).EXPECT().HandleRstStreamFrame(f) + err = sess.handleRstStreamFrame(f) Expect(err).ToNot(HaveOccurred()) }) It("returns errors", func() { + f := &wire.RstStreamFrame{ + StreamID: 5, + ByteOffset: 0x1337, + } testErr := errors.New("flow control violation") str, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - str.(*mocks.MockStreamI).EXPECT().RegisterRemoteError(gomock.Any(), gomock.Any()).Return(testErr) - err = sess.handleRstStreamFrame(&wire.RstStreamFrame{ - StreamID: 5, - ByteOffset: 0x1337, - }) + str.(*mocks.MockStreamI).EXPECT().HandleRstStreamFrame(f).Return(testErr) + err = sess.handleRstStreamFrame(f) Expect(err).To(MatchError(testErr)) }) diff --git a/stream.go b/stream.go index 26ee53c9..6dacfd3f 100644 --- a/stream.go +++ b/stream.go @@ -18,7 +18,7 @@ type streamI interface { Stream HandleStreamFrame(*wire.StreamFrame) error - RegisterRemoteError(error, protocol.ByteCount) error + HandleRstStreamFrame(*wire.RstStreamFrame) error PopStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFrame Finished() bool Cancel(error) @@ -59,7 +59,7 @@ type stream struct { finishedWriting utils.AtomicBool // resetLocally is set if Reset() is called resetLocally utils.AtomicBool - // resetRemotely is set if RegisterRemoteError() is called + // resetRemotely is set if HandleRstStreamFrame() is called resetRemotely utils.AtomicBool frameQueue *streamFrameSorter @@ -433,8 +433,7 @@ func (s *stream) Reset(err error) { s.mutex.Unlock() } -// resets the stream remotely -func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error { +func (s *stream) HandleRstStreamFrame(frame *wire.RstStreamFrame) error { if s.resetRemotely.Get() { return nil } @@ -443,10 +442,10 @@ func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error s.ctxCancel() // errors must not be changed! if s.err == nil { - s.err = err + s.err = fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode) s.signalWrite() } - if err := s.flowController.UpdateHighestReceived(offset, true); err != nil { + if err := s.flowController.UpdateHighestReceived(frame.ByteOffset, true); err != nil { return err } if s.shouldSendReset() { diff --git a/stream_test.go b/stream_test.go index 47c1cd17..5bc197cb 100644 --- a/stream_test.go +++ b/stream_test.go @@ -478,8 +478,6 @@ var _ = Describe("Stream", func() { }) Context("resetting", func() { - testErr := errors.New("testErr") - Context("reset by the peer", func() { It("continues reading after receiving a remote error", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) @@ -489,22 +487,30 @@ var _ = Describe("Stream", func() { Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } str.HandleStreamFrame(&frame) - str.RegisterRemoteError(testErr, 10) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 10, + }) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 4) n, err := strWithTimeout.Read(b) Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(4)) }) - It("reads a delayed StreamFrame that arrives after receiving a remote error", func() { + It("reads a delayed STREAM frame that arrives after receiving a remote error", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - str.RegisterRemoteError(testErr, 4) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 4, + }) + Expect(err).ToNot(HaveOccurred()) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } - err := str.HandleStreamFrame(&frame) + err = str.HandleStreamFrame(&frame) Expect(err).ToNot(HaveOccurred()) b := make([]byte, 4) n, err := strWithTimeout.Read(b) @@ -520,15 +526,20 @@ var _ = Describe("Stream", func() { Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } str.HandleStreamFrame(&frame) - str.RegisterRemoteError(testErr, 8) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 8, + ErrorCode: 1337, + }) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 10) n, err := strWithTimeout.Read(b) Expect(b[0:4]).To(Equal(frame.Data)) - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError("RST_STREAM received with code 1337")) Expect(n).To(Equal(4)) }) - It("returns an EOF when reading past the offset, if the stream received a finbit", func() { + It("returns an EOF when reading past the offset, if the stream received a FIN bit", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), true) frame := wire.StreamFrame{ @@ -537,7 +548,11 @@ var _ = Describe("Stream", func() { FinBit: true, } str.HandleStreamFrame(&frame) - str.RegisterRemoteError(testErr, 8) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 8, + }) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 10) n, err := strWithTimeout.Read(b) Expect(b[:4]).To(Equal(frame.Data)) @@ -554,9 +569,12 @@ var _ = Describe("Stream", func() { FinBit: true, } str.HandleStreamFrame(&frame) - str.RegisterRemoteError(testErr, 4) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 4, + }) b := make([]byte, 3) - _, err := strWithTimeout.Read(b) + _, err = strWithTimeout.Read(b) Expect(err).ToNot(HaveOccurred()) Expect(b).To(Equal([]byte{0xde, 0xad, 0xbe})) b = make([]byte, 3) @@ -576,27 +594,36 @@ var _ = Describe("Stream", func() { Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } str.HandleStreamFrame(&frame) - str.RegisterRemoteError(testErr, 10) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 10, + }) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 3) - _, err := strWithTimeout.Read(b) + _, err = strWithTimeout.Read(b) Expect(err).ToNot(HaveOccurred()) }) It("stops writing after receiving a remote error", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), true) done := make(chan struct{}) go func() { defer GinkgoRecover() n, err := strWithTimeout.Write([]byte("foobar")) Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError("RST_STREAM received with code 1337")) close(done) }() - str.RegisterRemoteError(testErr, 10) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 8, + ErrorCode: 1337, + }) + Expect(err).ToNot(HaveOccurred()) Eventually(done).Should(BeClosed()) }) - It("returns how much was written when recieving a remote error", func() { + It("returns how much was written when receiving a remote error", func() { frameHeaderSize := protocol.ByteCount(4) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) @@ -605,7 +632,7 @@ var _ = Describe("Stream", func() { go func() { defer GinkgoRecover() n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError("RST_STREAM received with code 1337")) Expect(n).To(Equal(4)) close(done) }() @@ -614,22 +641,31 @@ var _ = Describe("Stream", func() { Eventually(func() *wire.StreamFrame { frame = str.PopStreamFrame(4 + frameHeaderSize); return frame }).ShouldNot(BeNil()) Expect(frame).ToNot(BeNil()) Expect(frame.DataLen()).To(BeEquivalentTo(4)) - str.RegisterRemoteError(testErr, 10) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 10, + ErrorCode: 1337, + }) + Expect(err).ToNot(HaveOccurred()) Eventually(done).Should(BeClosed()) }) - It("calls onReset when receiving a remote error", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) + It("calls queues a RST_STREAM frame when receiving a remote error", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) done := make(chan struct{}) str.writeOffset = 0x1000 go func() { _, _ = strWithTimeout.Write([]byte("foobar")) close(done) }() - str.RegisterRemoteError(testErr, 0) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 10, + }) + Expect(err).ToNot(HaveOccurred()) Expect(queuedControlFrames).To(Equal([]wire.Frame{ &wire.RstStreamFrame{ - StreamID: 1337, + StreamID: streamID, ByteOffset: 0x1000, }, })) @@ -637,32 +673,50 @@ var _ = Describe("Stream", func() { }) It("doesn't call onReset if it already sent a FIN", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) str.Close() f := str.PopStreamFrame(100) Expect(f.FinBit).To(BeTrue()) - str.RegisterRemoteError(testErr, 0) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 10, + }) + Expect(err).ToNot(HaveOccurred()) Expect(queuedControlFrames).To(BeEmpty()) }) - It("doesn't call queue a RST_STREAM if the stream was reset locally before", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - str.Reset(testErr) + It("doesn't queue a RST_STREAM if the stream was reset locally before", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) + str.Reset(errors.New("reset")) Expect(queuedControlFrames).To(HaveLen(1)) - str.RegisterRemoteError(testErr, 0) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 10, + }) + Expect(err).ToNot(HaveOccurred()) Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame }) It("doesn't queue two RST_STREAMs twice, when it gets two remote errors", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - str.RegisterRemoteError(testErr, 0) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), true) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 8, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(queuedControlFrames).To(HaveLen(1)) + err = str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 9, + }) + Expect(err).ToNot(HaveOccurred()) Expect(queuedControlFrames).To(HaveLen(1)) - str.RegisterRemoteError(testErr, 0) - Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame }) }) Context("reset locally", func() { + testErr := errors.New("test error") + It("stops writing", func() { done := make(chan struct{}) go func() { @@ -733,11 +787,14 @@ var _ = Describe("Stream", func() { }) It("doesn't queue a new RST_STREAM, if the stream was reset remotely before", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - str.RegisterRemoteError(testErr, 0) - Expect(queuedControlFrames).To(HaveLen(1)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) + err := str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 10, + }) + Expect(err).ToNot(HaveOccurred()) str.Reset(testErr) - Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame + Expect(queuedControlFrames).To(HaveLen(1)) }) It("doesn't call onReset twice", func() { @@ -1037,7 +1094,10 @@ var _ = Describe("Stream", func() { It("is finished after receiving a RST and sending one", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) // this directly sends a rst - str.RegisterRemoteError(testErr, 0) + str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 0, + }) Expect(str.rstSent.Get()).To(BeTrue()) Expect(str.Finished()).To(BeTrue()) }) @@ -1045,7 +1105,10 @@ var _ = Describe("Stream", func() { It("cancels the context after receiving a RST", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) Expect(str.Context().Done()).ToNot(BeClosed()) - str.RegisterRemoteError(testErr, 0) + str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 0, + }) Expect(str.Context().Done()).To(BeClosed()) }) @@ -1053,7 +1116,10 @@ var _ = Describe("Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(13), true) str.Reset(testErr) Expect(str.Finished()).To(BeFalse()) - str.RegisterRemoteError(testErr, 13) + str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 13, + }) Expect(str.Finished()).To(BeTrue()) }) @@ -1062,7 +1128,10 @@ var _ = Describe("Stream", func() { str.Close() f := str.PopStreamFrame(1000) Expect(f.FinBit).To(BeTrue()) - str.RegisterRemoteError(testErr, 13) + str.HandleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 13, + }) Expect(str.Finished()).To(BeTrue()) })