diff --git a/h2quic/client_test.go b/h2quic/client_test.go index aa61f36b..ad8df477 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -3,6 +3,7 @@ package h2quic import ( "bytes" "compress/gzip" + "context" "crypto/tls" "errors" "io" @@ -36,6 +37,7 @@ var _ = Describe("Client", func() { client = newClient(hostname, nil, &roundTripperOpts{}, nil) Expect(client.hostname).To(Equal(hostname)) session = &mockSession{} + session.ctx, session.ctxCancel = context.WithCancel(context.Background()) client.session = session headerStream = newMockStream(3) diff --git a/h2quic/response_writer.go b/h2quic/response_writer.go index e3eeb3b0..1dd4e928 100644 --- a/h2quic/response_writer.go +++ b/h2quic/response_writer.go @@ -83,7 +83,7 @@ func (w *responseWriter) Write(p []byte) (int, error) { func (w *responseWriter) Flush() {} -// TODO: Implement a functional CloseNotify method. +// This is a NOP. Use http.Request.Context func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) } // test that we implement http.Flusher diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index e8de1a4f..e3e3e27b 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -25,20 +25,24 @@ type mockStream struct { remoteClosed bool unblockRead chan struct{} + ctx context.Context + ctxCancel context.CancelFunc } func newMockStream(id protocol.StreamID) *mockStream { - return &mockStream{ + s := &mockStream{ id: id, unblockRead: make(chan struct{}), } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + return s } -func (s *mockStream) Close() error { s.closed = true; return nil } +func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil } func (s *mockStream) Reset(error) { s.reset = true } -func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true } +func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() } func (s mockStream) StreamID() protocol.StreamID { return s.id } -func (s *mockStream) Context() context.Context { panic("not implemented") } +func (s *mockStream) Context() context.Context { return s.ctx } func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") } func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") } func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") } diff --git a/h2quic/server.go b/h2quic/server.go index b01cd743..1c1d9bfe 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -194,6 +194,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, _, _ = dataStream.Read([]byte{0}) // read the eof } + req = req.WithContext(dataStream.Context()) reqBody := newRequestBody(dataStream) req.Body = reqBody diff --git a/h2quic/server_test.go b/h2quic/server_test.go index b576dbba..b555ddcf 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -34,6 +34,8 @@ type mockSession struct { streamsToOpen []quic.Stream blockOpenStreamSync bool streamOpenErr error + ctx context.Context + ctxCancel context.CancelFunc } func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) { @@ -57,6 +59,7 @@ func (s *mockSession) OpenStreamSync() (quic.Stream, error) { func (s *mockSession) Close(e error) error { s.closed = true s.closedWithError = e + s.ctxCancel() return nil } func (s *mockSession) LocalAddr() net.Addr { @@ -65,7 +68,9 @@ func (s *mockSession) LocalAddr() net.Addr { func (s *mockSession) RemoteAddr() net.Addr { return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42} } -func (s *mockSession) Context() context.Context { panic("not implemented") } +func (s *mockSession) Context() context.Context { + return s.ctx +} var _ = Describe("H2 server", func() { var ( @@ -84,6 +89,7 @@ var _ = Describe("H2 server", func() { dataStream = newMockStream(0) close(dataStream.unblockRead) session = &mockSession{dataStream: dataStream} + session.ctx, session.ctxCancel = context.WithCancel(context.Background()) origQuicListenAddr = quicListenAddr }) @@ -247,6 +253,29 @@ var _ = Describe("H2 server", func() { err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) Expect(err).To(MatchError("InvalidHeadersStreamData: expected a header frame")) }) + + It("Cancels the request context when the datstream is closed", func() { + var handlerCalled bool + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + err := r.Context().Err() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("context canceled")) + handlerCalled = true + }) + headerStream.dataToRead.Write([]byte{ + 0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5, + // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding + 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, + }) + dataStream.Close() + err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) + Expect(err).NotTo(HaveOccurred()) + Eventually(func() bool { return handlerCalled }).Should(BeTrue()) + Expect(dataStream.remoteClosed).To(BeTrue()) + Expect(dataStream.reset).To(BeFalse()) + }) + }) It("handles the header stream", func() {