From e0bf13be01aaa715329d32da90d526671b5245ce Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Sat, 16 Dec 2023 11:39:49 +0800 Subject: [PATCH] http3: reset stream when a handler panics (#4181) * interrupt the stream when a panick happened * move the declaration of errPanicked * check what's read is a prefix of what's written * check errPanicked * use MatchError instead of Equal * use channel to notify the response has been received --- http3/server.go | 6 ++++++ http3/server_test.go | 4 ++-- integrationtests/self/http_test.go | 20 ++++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/http3/server.go b/http3/server.go index 5c70271d4..a66a2711c 100644 --- a/http3/server.go +++ b/http3/server.go @@ -30,6 +30,7 @@ var ( quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { return quic.ListenAddrEarly(addr, tlsConf, config) } + errPanicked = errors.New("panicked") ) // NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2. @@ -652,6 +653,11 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q } // If the EOF was read by the handler, CancelRead() is a no-op. str.CancelRead(quic.StreamErrorCode(ErrCodeNoError)) + + // abort the stream when there is a panic + if panicked { + return newStreamError(ErrCodeInternalError, errPanicked) + } return requestError{} } diff --git a/http3/server_test.go b/http3/server_test.go index 5ec58668f..8da806111 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -272,7 +272,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(gomock.Any()) serr := s.handleRequest(conn, str, qpackDecoder, nil) - Expect(serr.err).ToNot(HaveOccurred()) + Expect(serr.err).To(MatchError(errPanicked)) Expect(responseBuf.Bytes()).To(HaveLen(0)) }) @@ -288,7 +288,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(gomock.Any()) serr := s.handleRequest(conn, str, qpackDecoder, nil) - Expect(serr.err).ToNot(HaveOccurred()) + Expect(serr.err).To(MatchError(errPanicked)) Expect(responseBuf.Bytes()).To(HaveLen(0)) }) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index b746816df..5ec17e2d9 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -140,6 +140,26 @@ var _ = Describe("HTTP tests", func() { Expect(resp.Header.Get("Content-Length")).To(Equal(strconv.Itoa(len("foobar")))) }) + It("detects stream errors when server panics when writing response", func() { + respChan := make(chan struct{}) + mux.HandleFunc("/writing_and_panicking", func(w http.ResponseWriter, r *http.Request) { + // no recover here as it will interfere with the handler + w.Write([]byte("foobar")) + w.(http.Flusher).Flush() + // wait for the client to receive the response + <-respChan + panic(http.ErrAbortHandler) + }) + + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/writing_and_panicking", port)) + close(respChan) + Expect(err).ToNot(HaveOccurred()) + body, err := io.ReadAll(resp.Body) + Expect(err).To(HaveOccurred()) + // the body will be a prefix of what's written + Expect(bytes.HasPrefix([]byte("foobar"), body)).To(BeTrue()) + }) + It("requests to different servers with the same udpconn", func() { resp, err := client.Get(fmt.Sprintf("https://localhost:%d/remoteAddr", port)) Expect(err).ToNot(HaveOccurred())