From 05be874b11c767bcc944709cce5e88100d409f5b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 28 Dec 2018 20:02:51 +0700 Subject: [PATCH] cancel reading from the response stream when the response body is closed --- h2quic/client.go | 2 +- h2quic/client_test.go | 9 ++-- h2quic/response_body.go | 31 ++++++++++++++ h2quic/response_body_test.go | 41 +++++++++++++++++++ h2quic/response_writer_test.go | 4 +- h2quic/server_test.go | 14 +++---- .../self/{client_test.go => http_test.go} | 16 +++++++- integrationtests/tools/testserver/server.go | 15 +++---- 8 files changed, 106 insertions(+), 26 deletions(-) create mode 100644 h2quic/response_body.go create mode 100644 h2quic/response_body_test.go rename integrationtests/self/{client_test.go => http_test.go} (84%) diff --git a/h2quic/client.go b/h2quic/client.go index 62a81248..ae1ddc45 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -241,7 +241,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { if streamEnded || isHead { res.Body = noBody } else { - res.Body = dataStream + res.Body = &responseBody{dataStream: dataStream} if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") diff --git a/h2quic/client_test.go b/h2quic/client_test.go index 21c47768..f78f799c 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -220,7 +220,8 @@ var _ = Describe("Client", func() { rsp, err := client.RoundTrip(request) Expect(err).ToNot(HaveOccurred()) Expect(rsp).To(Equal(teapot)) - Expect(rsp.Body).To(Equal(dataStream)) + Expect(rsp.Body).To(BeAssignableToTypeOf(&responseBody{})) + Expect(rsp.Body.(*responseBody).dataStream).To(Equal(dataStream)) Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) Expect(rsp.Request).To(Equal(request)) close(done) @@ -246,7 +247,7 @@ var _ = Describe("Client", func() { cancel() Eventually(done).Should(BeClosed()) - Expect(dataStream.reset).To(BeTrue()) + Expect(dataStream.canceledRead).To(BeTrue()) Expect(dataStream.canceledWrite).To(BeTrue()) Expect(client.headerErrored).ToNot(BeClosed()) }) @@ -267,7 +268,7 @@ var _ = Describe("Client", func() { time.Sleep(10 * time.Millisecond) cancel() Eventually(done).Should(BeClosed()) - Expect(dataStream.reset).To(BeTrue()) + Expect(dataStream.canceledRead).To(BeTrue()) Expect(dataStream.canceledWrite).To(BeTrue()) Expect(client.headerErrored).ToNot(BeClosed()) }) @@ -288,7 +289,7 @@ var _ = Describe("Client", func() { }() Eventually(done).Should(BeClosed()) - Expect(dataStream.reset).To(BeTrue()) + Expect(dataStream.canceledRead).To(BeTrue()) Expect(dataStream.canceledWrite).To(BeTrue()) Expect(client.headerErrored).ToNot(BeClosed()) }) diff --git a/h2quic/response_body.go b/h2quic/response_body.go new file mode 100644 index 00000000..f7b8d2c3 --- /dev/null +++ b/h2quic/response_body.go @@ -0,0 +1,31 @@ +package h2quic + +import ( + "io" + + quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type responseBody struct { + eofRead utils.AtomicBool + + dataStream quic.Stream +} + +var _ io.ReadCloser = &responseBody{} + +func (rb *responseBody) Read(b []byte) (int, error) { + n, err := rb.dataStream.Read(b) + if err == io.EOF { + rb.eofRead.Set(true) + } + return n, err +} + +func (rb *responseBody) Close() error { + if !rb.eofRead.Get() { + rb.dataStream.CancelRead(0) + } + return nil +} diff --git a/h2quic/response_body_test.go b/h2quic/response_body_test.go new file mode 100644 index 00000000..8e1192f2 --- /dev/null +++ b/h2quic/response_body_test.go @@ -0,0 +1,41 @@ +package h2quic + +import ( + "bytes" + "io" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Response Body", func() { + var ( + stream *mockStream + body *responseBody + ) + + BeforeEach(func() { + stream = newMockStream(42) + body = &responseBody{dataStream: stream} + }) + + It("calls CancelRead if the stream is closed before being completely read", func() { + stream.dataToRead = *bytes.NewBuffer([]byte("foobar")) + n, err := body.Read(make([]byte, 3)) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(body.Close()).To(Succeed()) + Expect(stream.canceledRead).To(BeTrue()) + }) + + It("doesn't calls CancelRead if the stream was completely read", func() { + stream.dataToRead = *bytes.NewBuffer([]byte("foobar")) + close(stream.unblockRead) + n, _ := body.Read(make([]byte, 6)) + Expect(n).To(Equal(6)) + _, err := body.Read(make([]byte, 6)) + Expect(err).To(Equal(io.EOF)) + Expect(body.Close()).To(Succeed()) + Expect(stream.canceledRead).To(BeFalse()) + }) +}) diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index 77c67a93..bb005be5 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -22,7 +22,7 @@ type mockStream struct { id protocol.StreamID dataToRead bytes.Buffer dataWritten bytes.Buffer - reset bool + canceledRead bool canceledWrite bool closed bool remoteClosed bool @@ -44,7 +44,7 @@ func newMockStream(id protocol.StreamID) *mockStream { } func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil } -func (s *mockStream) CancelRead(quic.ErrorCode) error { s.reset = true; return nil } +func (s *mockStream) CancelRead(quic.ErrorCode) error { s.canceledRead = true; return nil } func (s *mockStream) CancelWrite(quic.ErrorCode) error { s.canceledWrite = true; return nil } func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() } func (s mockStream) StreamID() protocol.StreamID { return s.id } diff --git a/h2quic/server_test.go b/h2quic/server_test.go index e37ef97c..572f5985 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -144,7 +144,7 @@ var _ = Describe("H2 server", func() { Expect(err).NotTo(HaveOccurred()) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) Expect(dataStream.remoteClosed).To(BeTrue()) - Expect(dataStream.reset).To(BeFalse()) + Expect(dataStream.canceledRead).To(BeFalse()) }) It("returns 200 with an empty handler", func() { @@ -191,7 +191,7 @@ var _ = Describe("H2 server", func() { err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) Expect(err).NotTo(HaveOccurred()) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) - Eventually(func() bool { return dataStream.reset }).Should(BeTrue()) + Eventually(func() bool { return dataStream.canceledRead }).Should(BeTrue()) Expect(dataStream.remoteClosed).To(BeFalse()) }) @@ -205,7 +205,7 @@ var _ = Describe("H2 server", func() { headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7}) err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) Expect(err).NotTo(HaveOccurred()) - Eventually(func() bool { return dataStream.reset }).Should(BeTrue()) + Eventually(func() bool { return dataStream.canceledRead }).Should(BeTrue()) Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse()) Expect(handlerCalled).To(BeTrue()) }) @@ -238,7 +238,7 @@ var _ = Describe("H2 server", func() { headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7}) err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) Expect(err).NotTo(HaveOccurred()) - Eventually(func() bool { return dataStream.reset }).Should(BeTrue()) + Eventually(func() bool { return dataStream.canceledRead }).Should(BeTrue()) Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse()) Expect(handlerCalled).To(BeTrue()) }) @@ -259,7 +259,7 @@ var _ = Describe("H2 server", func() { err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) Expect(err).NotTo(HaveOccurred()) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) - Expect(dataStream.reset).To(BeFalse()) + Expect(dataStream.canceledRead).To(BeFalse()) }) It("ignores PRIORITY frames", func() { @@ -276,7 +276,7 @@ var _ = Describe("H2 server", func() { err = s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) Expect(err).ToNot(HaveOccurred()) Consistently(handlerCalled).ShouldNot(BeClosed()) - Expect(dataStream.reset).To(BeFalse()) + Expect(dataStream.canceledRead).To(BeFalse()) Expect(dataStream.closed).To(BeFalse()) }) @@ -308,7 +308,7 @@ var _ = Describe("H2 server", func() { Expect(err).NotTo(HaveOccurred()) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) Expect(dataStream.remoteClosed).To(BeTrue()) - Expect(dataStream.reset).To(BeFalse()) + Expect(dataStream.canceledRead).To(BeFalse()) }) }) diff --git a/integrationtests/self/client_test.go b/integrationtests/self/http_test.go similarity index 84% rename from integrationtests/self/client_test.go rename to integrationtests/self/http_test.go index 82459a9c..d176224b 100644 --- a/integrationtests/self/client_test.go +++ b/integrationtests/self/http_test.go @@ -19,7 +19,7 @@ import ( "github.com/onsi/gomega/gbytes" ) -var _ = Describe("Client tests", func() { +var _ = Describe("HTTP tests", func() { var client *http.Client versions := protocol.SupportedVersions @@ -43,7 +43,8 @@ var _ = Describe("Client tests", func() { RootCAs: testdata.GetRootCA(), }, QuicConfig: &quic.Config{ - Versions: []protocol.VersionNumber{version}, + Versions: []protocol.VersionNumber{version}, + IdleTimeout: 10 * time.Second, }, }, } @@ -76,6 +77,17 @@ var _ = Describe("Client tests", func() { Expect(body).To(Equal(testserver.PRDataLong)) }) + It("downloads many files, if the response is not read", func() { + const num = 150 + + for i := 0; i < num; i++ { + resp, err := client.Get("https://localhost:" + testserver.Port() + "/prdata") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Body.Close()).To(Succeed()) + } + }) + It("uploads a file", func() { resp, err := client.Post( "https://localhost:"+testserver.Port()+"/echo", diff --git a/integrationtests/tools/testserver/server.go b/integrationtests/tools/testserver/server.go index 70ba2dda..9517636e 100644 --- a/integrationtests/tools/testserver/server.go +++ b/integrationtests/tools/testserver/server.go @@ -40,32 +40,27 @@ func init() { var err error l, err := strconv.Atoi(sl) Expect(err).NotTo(HaveOccurred()) - _, err = w.Write(GeneratePRData(l)) - Expect(err).NotTo(HaveOccurred()) + w.Write(GeneratePRData(l)) // don't check the error here. Stream may be reset. } else { - _, err := w.Write(PRData) - Expect(err).NotTo(HaveOccurred()) + w.Write(PRData) // don't check the error here. Stream may be reset. } }) http.HandleFunc("/prdatalong", func(w http.ResponseWriter, r *http.Request) { defer GinkgoRecover() - _, err := w.Write(PRDataLong) - Expect(err).NotTo(HaveOccurred()) + w.Write(PRDataLong) // don't check the error here. Stream may be reset. }) http.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { defer GinkgoRecover() - _, err := io.WriteString(w, "Hello, World!\n") - Expect(err).NotTo(HaveOccurred()) + io.WriteString(w, "Hello, World!\n") // don't check the error here. Stream may be reset. }) http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) { defer GinkgoRecover() body, err := ioutil.ReadAll(r.Body) Expect(err).NotTo(HaveOccurred()) - _, err = w.Write(body) - Expect(err).NotTo(HaveOccurred()) + w.Write(body) // don't check the error here. Stream may be reset. }) }