diff --git a/h2quic/client.go b/h2quic/client.go index ae1ddc45..1ec9c4be 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -241,7 +241,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { if streamEnded || isHead { res.Body = noBody } else { - res.Body = &responseBody{dataStream: dataStream} + res.Body = &responseBody{dataStream} if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") diff --git a/h2quic/client_test.go b/h2quic/client_test.go index f78f799c..81006708 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -221,7 +221,7 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) Expect(rsp).To(Equal(teapot)) Expect(rsp.Body).To(BeAssignableToTypeOf(&responseBody{})) - Expect(rsp.Body.(*responseBody).dataStream).To(Equal(dataStream)) + Expect(rsp.Body.(*responseBody).Stream).To(Equal(dataStream)) Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) Expect(rsp.Request).To(Equal(request)) close(done) diff --git a/h2quic/response_body.go b/h2quic/response_body.go index f7b8d2c3..11d79e80 100644 --- a/h2quic/response_body.go +++ b/h2quic/response_body.go @@ -4,28 +4,15 @@ import ( "io" quic "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/internal/utils" ) type responseBody struct { - eofRead utils.AtomicBool - - dataStream quic.Stream + quic.Stream } var _ io.ReadCloser = &responseBody{} -func (rb *responseBody) Read(b []byte) (int, error) { - n, err := rb.dataStream.Read(b) - if err == io.EOF { - rb.eofRead.Set(true) - } - return n, err -} - func (rb *responseBody) Close() error { - if !rb.eofRead.Get() { - rb.dataStream.CancelRead(0) - } + rb.Stream.CancelRead(0) return nil } diff --git a/h2quic/response_body_test.go b/h2quic/response_body_test.go index 8e1192f2..de6cf561 100644 --- a/h2quic/response_body_test.go +++ b/h2quic/response_body_test.go @@ -2,7 +2,6 @@ package h2quic import ( "bytes" - "io" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -16,10 +15,10 @@ var _ = Describe("Response Body", func() { BeforeEach(func() { stream = newMockStream(42) - body = &responseBody{dataStream: stream} + body = &responseBody{stream} }) - It("calls CancelRead if the stream is closed before being completely read", func() { + It("calls CancelRead when closing", func() { stream.dataToRead = *bytes.NewBuffer([]byte("foobar")) n, err := body.Read(make([]byte, 3)) Expect(err).ToNot(HaveOccurred()) @@ -27,15 +26,4 @@ var _ = Describe("Response Body", func() { Expect(body.Close()).To(Succeed()) Expect(stream.canceledRead).To(BeTrue()) }) - - It("doesn't calls CancelRead if the stream was completely read", func() { - stream.dataToRead = *bytes.NewBuffer([]byte("foobar")) - close(stream.unblockRead) - n, _ := body.Read(make([]byte, 6)) - Expect(n).To(Equal(6)) - _, err := body.Read(make([]byte, 6)) - Expect(err).To(Equal(io.EOF)) - Expect(body.Close()).To(Succeed()) - Expect(stream.canceledRead).To(BeFalse()) - }) })