From 8ef69143bab525fed6d48dddc0b05569be785412 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 29 Jul 2017 09:18:33 +0700 Subject: [PATCH] implement a context for the stream The context is cancelled when the write-side of the stream is closed. --- h2quic/response_writer_test.go | 2 ++ interface.go | 4 ++++ stream.go | 18 +++++++++++++++++- stream_test.go | 31 +++++++++++++++++++++++++++++++ streams_map_test.go | 2 +- 5 files changed, 55 insertions(+), 2 deletions(-) diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index 93ebae16..735dbd03 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -2,6 +2,7 @@ package h2quic import ( "bytes" + "context" "io" "net/http" "sync" @@ -37,6 +38,7 @@ func (s *mockStream) Close() error { s.closed = true; r func (s *mockStream) Reset(error) { s.reset = true } func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true } func (s mockStream) StreamID() protocol.StreamID { return s.id } +func (s *mockStream) Context() context.Context { panic("not implemented") } func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") } func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") } func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") } diff --git a/interface.go b/interface.go index 69ace9aa..77cd91af 100644 --- a/interface.go +++ b/interface.go @@ -23,6 +23,10 @@ type Stream interface { StreamID() protocol.StreamID // Reset closes the stream with an error. Reset(error) + // The context is canceled as soon as the write-side of the stream is closed. + // This happens when Close() is called, or when the stream is reset (either locally or remotely). + // Warning: This API should not be considered stable and might change soon. + Context() context.Context // SetReadDeadline sets the deadline for future Read calls and // any currently-blocked Read call. // A zero value for t means Read will not time out. diff --git a/stream.go b/stream.go index 45cf01a1..ffadc183 100644 --- a/stream.go +++ b/stream.go @@ -1,6 +1,7 @@ package quic import ( + "context" "fmt" "io" "net" @@ -19,6 +20,9 @@ import ( type stream struct { mutex sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + streamID protocol.StreamID onData func() // onReset is a callback that should send a RST_STREAM @@ -55,6 +59,8 @@ type stream struct { flowControlManager flowcontrol.FlowControlManager } +var _ Stream = &stream{} + type deadlineError struct{} func (deadlineError) Error() string { return "deadline exceeded" } @@ -68,7 +74,7 @@ func newStream(StreamID protocol.StreamID, onData func(), onReset func(protocol.StreamID, protocol.ByteCount), flowControlManager flowcontrol.FlowControlManager) *stream { - return &stream{ + s := &stream{ onData: onData, onReset: onReset, streamID: StreamID, @@ -77,6 +83,8 @@ func newStream(StreamID protocol.StreamID, readChan: make(chan struct{}, 1), writeChan: make(chan struct{}, 1), } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + return s } // Read implements io.Reader. It is not thread safe! @@ -257,6 +265,7 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { // Close implements io.Closer func (s *stream) Close() error { s.finishedWriting.Set(true) + s.ctxCancel() s.onData() return nil } @@ -352,6 +361,7 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) { func (s *stream) Cancel(err error) { s.mutex.Lock() s.cancelled.Set(true) + s.ctxCancel() // errors must not be changed! if s.err == nil { s.err = err @@ -368,6 +378,7 @@ func (s *stream) Reset(err error) { } s.mutex.Lock() s.resetLocally.Set(true) + s.ctxCancel() // errors must not be changed! if s.err == nil { s.err = err @@ -388,6 +399,7 @@ func (s *stream) RegisterRemoteError(err error) { } s.mutex.Lock() s.resetRemotely.Set(true) + s.ctxCancel() // errors must not be changed! if s.err == nil { s.err = err @@ -412,6 +424,10 @@ func (s *stream) finished() bool { (s.finishedWriteAndSentFin() && s.resetRemotely.Get()) } +func (s *stream) Context() context.Context { + return s.ctx +} + func (s *stream) StreamID() protocol.StreamID { return s.streamID } diff --git a/stream_test.go b/stream_test.go index d14900b5..d347881d 100644 --- a/stream_test.go +++ b/stream_test.go @@ -435,6 +435,12 @@ var _ = Describe("Stream", func() { Expect(n).To(BeZero()) Expect(err).To(MatchError(io.EOF)) }) + + It("doesn't cancel the context", func() { + mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(0)) + str.CloseRemote(0) + Expect(str.Context().Done()).ToNot(BeClosed()) + }) }) }) @@ -463,6 +469,12 @@ var _ = Describe("Stream", func() { Expect(n).To(BeZero()) Expect(err).To(MatchError(testErr)) }) + + It("cancels the context", func() { + Expect(str.Context().Done()).ToNot(BeClosed()) + str.Cancel(testErr) + Expect(str.Context().Done()).To(BeClosed()) + }) }) }) @@ -713,6 +725,12 @@ var _ = Describe("Stream", func() { str.Reset(testErr) Expect(resetCalled).To(BeFalse()) }) + + It("cancels the context", func() { + Expect(str.Context().Done()).ToNot(BeClosed()) + str.Reset(testErr) + Expect(str.Context().Done()).To(BeClosed()) + }) }) }) @@ -958,6 +976,13 @@ var _ = Describe("Stream", func() { Expect(str.finished()).To(BeFalse()) }) + It("cancels the context after it is closed", func() { + Expect(str.Context().Done()).ToNot(BeClosed()) + str.Close() + str.sentFin() + Expect(str.Context().Done()).To(BeClosed()) + }) + It("is not finished if it is only closed for reading", func() { mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(0)) mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(0)) @@ -972,6 +997,12 @@ var _ = Describe("Stream", func() { Expect(str.finished()).To(BeTrue()) }) + It("cancels the context after receiving a RST", func() { + Expect(str.Context().Done()).ToNot(BeClosed()) + str.RegisterRemoteError(testErr) + Expect(str.Context().Done()).To(BeClosed()) + }) + It("is finished after being locally reset and receiving a RST in response", func() { str.Reset(testErr) Expect(str.finished()).To(BeFalse()) diff --git a/streams_map_test.go b/streams_map_test.go index b364f6da..db6ecede 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -29,7 +29,7 @@ var _ = Describe("Streams Map", func() { m = newStreamsMap(nil, p, mockCpm) m.newStream = func(id protocol.StreamID) *stream { - return &stream{streamID: id} + return newStream(id, nil, nil, nil) } }