From a0c4e284852a9ced93333bada7889396cc239365 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 15 Dec 2017 18:39:28 +0700 Subject: [PATCH] send and handle STOP_SENDING frames (for IETF QUIC) --- internal/mocks/stream.go | 10 ++++++ session.go | 18 +++++++++++ session_test.go | 46 +++++++++++++++++++++++++++ stream.go | 31 ++++++++++++++---- stream_test.go | 68 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 167 insertions(+), 6 deletions(-) diff --git a/internal/mocks/stream.go b/internal/mocks/stream.go index 8da42d9a0..2c36be0b6 100644 --- a/internal/mocks/stream.go +++ b/internal/mocks/stream.go @@ -140,6 +140,16 @@ func (_mr *MockStreamIMockRecorder) HandleRstStreamFrame(arg0 interface{}) *gomo return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "HandleRstStreamFrame", reflect.TypeOf((*MockStreamI)(nil).HandleRstStreamFrame), arg0) } +// HandleStopSendingFrame mocks base method +func (_m *MockStreamI) HandleStopSendingFrame(_param0 *wire.StopSendingFrame) { + _m.ctrl.Call(_m, "HandleStopSendingFrame", _param0) +} + +// HandleStopSendingFrame indicates an expected call of HandleStopSendingFrame +func (_mr *MockStreamIMockRecorder) HandleStopSendingFrame(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "HandleStopSendingFrame", reflect.TypeOf((*MockStreamI)(nil).HandleStopSendingFrame), arg0) +} + // HandleStreamFrame mocks base method func (_m *MockStreamI) HandleStreamFrame(_param0 *wire.StreamFrame) error { ret := _m.ctrl.Call(_m, "HandleStreamFrame", _param0) diff --git a/session.go b/session.go index b880181f1..2caf73ae8 100644 --- a/session.go +++ b/session.go @@ -532,6 +532,8 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve err = s.handleMaxStreamDataFrame(frame) case *wire.BlockedFrame: case *wire.StreamBlockedFrame: + case *wire.StopSendingFrame: + err = s.handleStopSendingFrame(frame) case *wire.PingFrame: default: return errors.New("Session BUG: unexpected frame type") @@ -599,6 +601,22 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error return nil } +func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error { + if frame.StreamID == s.version.CryptoStreamID() { + return errors.New("Received a STOP_SENDING frame for the crypto stream") + } + str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // stream is closed and already garbage collected + return nil + } + str.HandleStopSendingFrame(frame) + return nil +} + func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { if frame.StreamID == s.version.CryptoStreamID() { return errors.New("Received RST_STREAM frame for the crypto stream") diff --git a/session_test.go b/session_test.go index 0001e7f27..96b26763f 100644 --- a/session_test.go +++ b/session_test.go @@ -425,6 +425,52 @@ var _ = Describe("Session", func() { }) }) + Context("handling STOP_SENDING frames", func() { + It("opens a new stream when receiving a STOP_SENDING frame for an unknown stream", func() { + f := &wire.StopSendingFrame{ + StreamID: 5, + ErrorCode: 10, + } + newStreamLambda := sess.streamsMap.newStream + sess.streamsMap.newStream = func(id protocol.StreamID) streamI { + str := newStreamLambda(id) + if id == 5 { + str.(*mocks.MockStreamI).EXPECT().HandleStopSendingFrame(f) + } + return str + } + err := sess.handleStopSendingFrame(f) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.streamsMap.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(str).ToNot(BeNil()) + }) + + It("errors when receiving a STOP_SENDING for the crypto stream", func() { + err := sess.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: sess.version.CryptoStreamID(), + ErrorCode: 10, + }) + Expect(err).To(MatchError("Received a STOP_SENDING frame for the crypto stream")) + }) + + It("ignores STOP_SENDING frames for a closed stream", func() { + str, err := sess.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + str.(*mocks.MockStreamI).EXPECT().Finished().Return(true) + err = sess.streamsMap.DeleteClosedStreams() + Expect(err).ToNot(HaveOccurred()) + str, err = sess.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + err = sess.handleFrames([]wire.Frame{&wire.StopSendingFrame{ + StreamID: 3, + ErrorCode: 1337, + }}, protocol.EncryptionUnspecified) + Expect(err).NotTo(HaveOccurred()) + }) + }) + It("handles PING frames", func() { err := sess.handleFrames([]wire.Frame{&wire.PingFrame{}}, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) diff --git a/stream.go b/stream.go index 89c41e52a..4330a161d 100644 --- a/stream.go +++ b/stream.go @@ -25,13 +25,17 @@ func (e streamCanceledError) ErrorCode() protocol.ApplicationErrorCode { return var _ StreamError = &streamCanceledError{} var _ error = &streamCanceledError{} -const errorCodeStoppingGQUIC protocol.ApplicationErrorCode = 7 +const ( + errorCodeStopping protocol.ApplicationErrorCode = 0 + errorCodeStoppingGQUIC protocol.ApplicationErrorCode = 7 +) type streamI interface { Stream HandleStreamFrame(*wire.StreamFrame) error HandleRstStreamFrame(*wire.RstStreamFrame) error + HandleStopSendingFrame(*wire.StopSendingFrame) PopStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFrame Finished() bool CloseForShutdown(error) @@ -69,7 +73,7 @@ type stream struct { closedForShutdown bool // set when CloseForShutdown() is called finRead bool // set once we read a frame with a FinBit finishedWriting bool // set once Close() is called - canceledWrite bool // set when CancelWrite() is called + canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received canceledRead bool // set when CancelRead() is called finSent bool // set when a STREAM_FRAME with FIN bit has b resetRemotely bool // set when HandleRstStreamFrame() is called @@ -459,7 +463,12 @@ func (s *stream) CancelRead(errorCode protocol.ApplicationErrorCode) error { s.canceledRead = true s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) s.signalRead() - // TODO(#1034): queue a STOP_SENDING (in IETF QUIC) + if s.version.UsesIETFFrameFormat() { + s.queueControlFrame(&wire.StopSendingFrame{ + StreamID: s.streamID, + ErrorCode: errorCode, + }) + } return nil } @@ -474,7 +483,7 @@ func (s *stream) HandleRstStreamFrame(frame *wire.RstStreamFrame) error { return err } if !s.version.UsesIETFFrameFormat() { - s.HandleStopSendingFrame(&wire.StopSendingFrame{ + s.handleStopSendingFrameImpl(&wire.StopSendingFrame{ StreamID: s.streamID, ErrorCode: frame.ErrorCode, }) @@ -500,12 +509,22 @@ func (s *stream) HandleRstStreamFrame(frame *wire.RstStreamFrame) error { } func (s *stream) HandleStopSendingFrame(frame *wire.StopSendingFrame) { - // send a RST_STREAM frame + s.mutex.Lock() + defer s.mutex.Unlock() + s.handleStopSendingFrameImpl(frame) +} + +// must be called after locking the mutex +func (s *stream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) { writeErr := streamCanceledError{ errorCode: frame.ErrorCode, error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode), } - s.cancelWriteImpl(errorCodeStoppingGQUIC, writeErr) + errorCode := errorCodeStopping + if !s.version.UsesIETFFrameFormat() { + errorCode = errorCodeStoppingGQUIC + } + s.cancelWriteImpl(errorCode, writeErr) } func (s *stream) Finished() bool { diff --git a/stream_test.go b/stream_test.go index 251a9d2be..aa94da6fd 100644 --- a/stream_test.go +++ b/stream_test.go @@ -886,6 +886,28 @@ var _ = Describe("Stream", func() { Expect(queuedControlFrames).To(BeEmpty()) // no RST_STREAM frame queued yet }) }) + + Context("for IETF QUIC", func() { + It("queues a STOP_SENDING frame", func() { + err := str.CancelRead(1234) + Expect(err).ToNot(HaveOccurred()) + Expect(queuedControlFrames).To(Equal([]wire.Frame{ + &wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 1234, + }, + })) + }) + + It("doesn't queue a RST_STREAM after closing the stream", func() { // this is what it does for gQUIC + err := str.CancelRead(1234) + Expect(err).ToNot(HaveOccurred()) + Expect(queuedControlFrames).To(HaveLen(1)) + Expect(queuedControlFrames[0]).To(BeAssignableToTypeOf(&wire.StopSendingFrame{})) + Expect(str.Close()).To(Succeed()) + Expect(queuedControlFrames).To(HaveLen(1)) + }) + }) }) Context("receiving RST_STREAM frames", func() { @@ -1061,6 +1083,52 @@ var _ = Describe("Stream", func() { }) }) }) + + Context("receiving STOP_SENDING frames", func() { + It("queues a RST_STREAM frames with error code Stopping", func() { + str.HandleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 101, + }) + Expect(queuedControlFrames).To(Equal([]wire.Frame{ + &wire.RstStreamFrame{ + StreamID: streamID, + ErrorCode: errorCodeStopping, + }, + })) + }) + + It("unblocks Write", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError("Stream 1337 was reset with error code 123")) + Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) + Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) + Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(123))) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + str.HandleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 123, + }) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't allow further calls to Write", func() { + str.HandleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 123, + }) + _, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError("Stream 1337 was reset with error code 123")) + Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) + Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) + Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(123))) + }) + }) }) Context("saying if it is finished", func() {