diff --git a/h2quic/response_writer.go b/h2quic/response_writer.go index a1e608a02..f6a5961bb 100644 --- a/h2quic/response_writer.go +++ b/h2quic/response_writer.go @@ -38,6 +38,9 @@ func (w *responseWriter) Header() http.Header { } func (w *responseWriter) WriteHeader(status int) { + if w.headerWritten { + return + } w.headerWritten = true var headers bytes.Buffer @@ -68,9 +71,3 @@ func (w *responseWriter) Write(p []byte) (int, error) { } return w.dataStream.Write(p) } - -func (w *responseWriter) finish() { - if !w.headerWritten { - w.WriteHeader(200) - } -} diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index e93332d26..837fd19b3 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -79,14 +79,9 @@ var _ = Describe("Response Writer", func() { })) }) - It("writes a 200 in finish if nothing was called", func() { - w.finish() - Expect(headerStream.Bytes()).To(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x88})) // 0x88 is 200 - }) - - It("doesn't do anything in finish if data was written before", func() { + It("does not WriteHeader() twice", func() { w.WriteHeader(200) - w.finish() + w.WriteHeader(500) Expect(headerStream.Bytes()).To(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x88})) // 0x88 is 200 }) }) diff --git a/h2quic/server.go b/h2quic/server.go index 07757dbd0..ae1328ef4 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "net" "net/http" + "runtime" "sync" "sync/atomic" "time" @@ -166,8 +167,25 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, if handler == nil { handler = http.DefaultServeMux } - handler.ServeHTTP(responseWriter, req) - responseWriter.finish() + panicked := false + func() { + defer func() { + if p := recover(); p != nil { + // Copied from net/http/server.go + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + utils.Errorf("http: panic serving: %v\n%s", p, buf) + panicked = true + } + }() + handler.ServeHTTP(responseWriter, req) + }() + if panicked { + responseWriter.WriteHeader(500) + } else { + responseWriter.WriteHeader(200) + } if responseWriter.dataStream != nil { responseWriter.dataStream.Close() } diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 55e3695c4..f1fd2a58d 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -98,6 +98,22 @@ var _ = Describe("H2 server", func() { }).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x88})) // 0x88 is 200 }) + It("correctly handles a panicking handler", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("foobar") + }) + headerStream.Write([]byte{ + 0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5, + // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding + 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, + }) + err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) + Expect(err).NotTo(HaveOccurred()) + Eventually(func() []byte { + return headerStream.Buffer.Bytes() + }).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x8e})) // 0x82 is 500 + }) + It("does not close the dataStream when end of stream is not set", func() { var handlerCalled bool s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {