diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index 16b0d817..0a042585 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -12,11 +12,12 @@ import ( type mockStream struct { id protocol.StreamID bytes.Buffer + remoteClosed bool } -func (mockStream) Close() error { return nil } -func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") } -func (s mockStream) StreamID() protocol.StreamID { return s.id } +func (mockStream) Close() error { return nil } +func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true } +func (s mockStream) StreamID() protocol.StreamID { return s.id } var _ = Describe("Response Writer", func() { var ( diff --git a/h2quic/server.go b/h2quic/server.go index eb416bd8..23ead498 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -99,6 +99,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, return err } + if h2headersFrame.StreamEnded() { + dataStream.CloseRemote(0) + } + // stream's Close() closes the write side, not the read side req.Body = ioutil.NopCloser(dataStream) diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 1d71f159..85c3499b 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -15,19 +15,21 @@ import ( ) type mockSession struct { - closed bool + closed bool + dataStream *mockStream } func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { - return &mockStream{}, nil + return s.dataStream, nil } func (s *mockSession) Close(error) error { s.closed = true; return nil } var _ = Describe("H2 server", func() { var ( - s *Server - session *mockSession + s *Server + session *mockSession + dataStream *mockStream ) BeforeEach(func() { @@ -35,7 +37,8 @@ var _ = Describe("H2 server", func() { s, err = NewServer(testdata.GetTLSConfig()) Expect(err).NotTo(HaveOccurred()) Expect(s).NotTo(BeNil()) - session = &mockSession{} + dataStream = &mockStream{} + session = &mockSession{dataStream: dataStream} }) It("uses default handler", func() { @@ -66,7 +69,24 @@ var _ = Describe("H2 server", func() { h2framer = http2.NewFramer(nil, headerStream) }) - It("handles a sample request", func() { + It("handles a sample GET request", func() { + var handlerCalled bool + s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Expect(r.Host).To(Equal("www.example.com")) + handlerCalled = true + }) + headerStream.Write([]byte{ + 0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5, + // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding + 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, + }) + err := s.handleRequest(session, headerStream, hpackDecoder, h2framer) + Expect(err).NotTo(HaveOccurred()) + Eventually(func() bool { return handlerCalled }).Should(BeTrue()) + Expect(dataStream.remoteClosed).To(BeTrue()) + }) + + It("does not close the dataStream when end of stream is not set", func() { var handlerCalled bool s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { Expect(r.Host).To(Equal("www.example.com")) @@ -80,6 +100,7 @@ var _ = Describe("H2 server", func() { err := s.handleRequest(session, headerStream, hpackDecoder, h2framer) Expect(err).NotTo(HaveOccurred()) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) + Expect(dataStream.remoteClosed).To(BeFalse()) }) })