From 64bc80339e1199520c2bbd277a3b9d524d3cc86a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 9 Jan 2017 23:47:01 +0700 Subject: [PATCH] reset streams that the request body is not read from fixes #384 --- h2quic/request_body.go | 29 +++++++++++++++++++++++++ h2quic/request_body_test.go | 39 ++++++++++++++++++++++++++++++++++ h2quic/response_writer_test.go | 6 ++++-- h2quic/server.go | 9 +++++--- h2quic/server_test.go | 38 ++++++++++++++++++++++++++++++++- 5 files changed, 115 insertions(+), 6 deletions(-) create mode 100644 h2quic/request_body.go create mode 100644 h2quic/request_body_test.go diff --git a/h2quic/request_body.go b/h2quic/request_body.go new file mode 100644 index 00000000..41ff5c69 --- /dev/null +++ b/h2quic/request_body.go @@ -0,0 +1,29 @@ +package h2quic + +import ( + "io" + + "github.com/lucas-clemente/quic-go/utils" +) + +type requestBody struct { + requestRead bool + dataStream utils.Stream +} + +// make sure the requestBody can be used as a http.Request.Body +var _ io.ReadCloser = &requestBody{} + +func newRequestBody(stream utils.Stream) *requestBody { + return &requestBody{dataStream: stream} +} + +func (b *requestBody) Read(p []byte) (int, error) { + b.requestRead = true + return b.dataStream.Read(p) +} + +func (b *requestBody) Close() error { + // stream's Close() closes the write side, not the read side + return nil +} diff --git a/h2quic/request_body_test.go b/h2quic/request_body_test.go new file mode 100644 index 00000000..98206666 --- /dev/null +++ b/h2quic/request_body_test.go @@ -0,0 +1,39 @@ +package h2quic + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Request body", func() { + var ( + stream *mockStream + rb *requestBody + ) + + BeforeEach(func() { + stream = &mockStream{} + stream.Write([]byte("foobar")) // provides data to be read + rb = newRequestBody(stream) + }) + + It("reads from the stream", func() { + b := make([]byte, 10) + n, _ := stream.Read(b) + Expect(n).To(Equal(6)) + Expect(b[0:6]).To(Equal([]byte("foobar"))) + }) + + It("saves if the stream was read from", func() { + Expect(rb.requestRead).To(BeFalse()) + rb.Read(make([]byte, 1)) + Expect(rb.requestRead).To(BeTrue()) + }) + + It("doesn't close the stream when closing the request body", func() { + Expect(stream.closed).To(BeFalse()) + err := rb.Close() + Expect(err).ToNot(HaveOccurred()) + Expect(stream.closed).To(BeFalse()) + }) +}) diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index aa46c032..a896980b 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -14,10 +14,12 @@ type mockStream struct { id protocol.StreamID bytes.Buffer remoteClosed bool + reset bool + closed bool } -func (mockStream) Close() error { return nil } -func (mockStream) Reset(error) { panic("not implemented") } +func (s *mockStream) Close() error { s.closed = true; return nil } +func (s *mockStream) Reset(error) { s.reset = true } func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true } func (s mockStream) StreamID() protocol.StreamID { return s.id } diff --git a/h2quic/server.go b/h2quic/server.go index 59890359..fbc83278 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "errors" "fmt" - "io/ioutil" "net" "net/http" "runtime" @@ -155,13 +154,14 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, return err } + var streamEnded bool if h2headersFrame.StreamEnded() { dataStream.CloseRemote(0) + streamEnded = true _, _ = dataStream.Read([]byte{0}) // read the eof } - // stream's Close() closes the write side, not the read side - req.Body = ioutil.NopCloser(dataStream) + req.Body = newRequestBody(dataStream) responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID)) @@ -190,6 +190,9 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, responseWriter.WriteHeader(200) } if responseWriter.dataStream != nil { + if !streamEnded && !req.Body.(*requestBody).requestRead { + responseWriter.dataStream.Reset(nil) + } responseWriter.dataStream.Close() } if s.CloseAfterFirstRequest { diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 3c1eabcb..068d1c32 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -79,6 +79,7 @@ var _ = Describe("H2 server", func() { Expect(err).NotTo(HaveOccurred()) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) Expect(dataStream.remoteClosed).To(BeTrue()) + Expect(dataStream.reset).To(BeFalse()) }) It("returns 200 with an empty handler", func() { @@ -111,7 +112,7 @@ var _ = Describe("H2 server", func() { }).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x8e})) // 0x82 is 500 }) - It("does not close the dataStream when end of stream is not set", func() { + It("resets the dataStream when client sends a body in GET request", func() { var handlerCalled bool s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { Expect(r.Host).To(Equal("www.example.com")) @@ -126,6 +127,41 @@ var _ = Describe("H2 server", func() { Expect(err).NotTo(HaveOccurred()) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) Expect(dataStream.remoteClosed).To(BeFalse()) + Expect(dataStream.reset).To(BeTrue()) + }) + + It("resets the dataStream when the body of POST request is not read", func() { + var handlerCalled bool + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Expect(r.Host).To(Equal("www.example.com")) + Expect(r.Method).To(Equal("POST")) + handlerCalled = true + }) + headerStream.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7}) + err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) + Expect(err).NotTo(HaveOccurred()) + Eventually(func() bool { return handlerCalled }).Should(BeTrue()) + Expect(dataStream.remoteClosed).To(BeFalse()) + Expect(dataStream.reset).To(BeTrue()) + }) + + It("closes the dataStream if the body of POST request was read", func() { + var handlerCalled bool + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Expect(r.Host).To(Equal("www.example.com")) + Expect(r.Method).To(Equal("POST")) + handlerCalled = true + // read the request body + b := make([]byte, 1000) + n, _ := r.Body.Read(b) + Expect(n).ToNot(BeZero()) + }) + headerStream.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7}) + dataStream.Write([]byte("foo=bar")) + err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) + Expect(err).NotTo(HaveOccurred()) + Eventually(func() bool { return handlerCalled }).Should(BeTrue()) + Expect(dataStream.reset).To(BeFalse()) }) It("errors when non-header frames are received", func() {