From 09c2ac8aab55fb26170b4db8e5c4dbff76a36151 Mon Sep 17 00:00:00 2001 From: Kevin McDonald Date: Sat, 7 Sep 2024 06:51:54 +0200 Subject: [PATCH] http3: add HTTP Trailer support for servers (#4630) --- http3/response_writer.go | 103 +++++++++++++++++++++++++++-- http3/response_writer_test.go | 62 +++++++++++++++++ http3/server.go | 1 + integrationtests/self/http_test.go | 41 ++++++++++++ 4 files changed, 203 insertions(+), 4 deletions(-) diff --git a/http3/response_writer.go b/http3/response_writer.go index 71f6551ef..14ba1ba8f 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -5,11 +5,13 @@ import ( "fmt" "log/slog" "net/http" + "net/textproto" "strconv" "strings" "time" "github.com/quic-go/qpack" + "golang.org/x/net/http/httpguts" ) // The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented the http.Response.Body. @@ -28,10 +30,11 @@ const maxSmallResponseSize = 4096 type responseWriter struct { str *stream - conn Connection - header http.Header - buf []byte - status int // status code passed to WriteHeader + conn Connection + header http.Header + trailers map[string]struct{} + buf []byte + status int // status code passed to WriteHeader // for responses smaller than maxSmallResponseSize, we buffer calls to Write, // and automatically add the Content-Length header @@ -42,6 +45,7 @@ type responseWriter struct { headerComplete bool // set once WriteHeader is called with a status code >= 200 headerWritten bool // set once the response header has been serialized to the stream isHead bool + trailerWritten bool // set once the response trailers has been serialized to the stream hijacked bool // set on HTTPStream is called @@ -198,7 +202,26 @@ func (w *responseWriter) writeHeader(status int) error { return err } + // Handle trailer fields + if vals, ok := w.header["Trailer"]; ok { + for _, val := range vals { + for _, trailer := range strings.Split(val, ",") { + // We need to convert to the canonical header key value here because this will be called when using + // headers.Add or headers.Set. + trailer = textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(trailer)) + w.declareTrailer(trailer) + } + } + } + for k, v := range w.header { + if _, excluded := w.trailers[k]; excluded { + continue + } + // Ignore "Trailer:" prefixed headers + if strings.HasPrefix(k, http.TrailerPrefix) { + continue + } for index := range v { if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil { return err @@ -222,6 +245,15 @@ func (w *responseWriter) FlushError() error { return err } +func (w *responseWriter) flushTrailers() { + if w.trailerWritten { + return + } + if err := w.writeTrailers(); err != nil { + w.logger.Debug("could not write trailers", "error", err) + } +} + func (w *responseWriter) Flush() { if err := w.FlushError(); err != nil { if w.logger != nil { @@ -230,6 +262,69 @@ func (w *responseWriter) Flush() { } } +// declareTrailer adds a trailer to the trailer list, while also validating that the trailer has a +// valid name. +func (w *responseWriter) declareTrailer(k string) { + if !httpguts.ValidTrailerHeader(k) { + // Forbidden by RFC 9110, section 6.5.1. + w.logger.Debug("ignoring invalid trailer", slog.String("header", k)) + return + } + if w.trailers == nil { + w.trailers = make(map[string]struct{}) + } + w.trailers[k] = struct{}{} +} + +// hasNonEmptyTrailers checks to see if there are any trailers with an actual +// value set. This is possible by adding trailers to the "Trailers" header +// but never actually setting those names as trailers in the course of handling +// the request. In that case, this check may save us some allocations. +func (w *responseWriter) hasNonEmptyTrailers() bool { + for trailer := range w.trailers { + if _, ok := w.header[trailer]; ok { + return true + } + } + return false +} + +// writeTrailers will write trailers to the stream if there are any. +func (w *responseWriter) writeTrailers() error { + // promote headers added via "Trailer:" convention as trailers, these can be added after + // streaming the status/headers have been written. + for k := range w.header { + // Handle "Trailer:" prefix + if strings.HasPrefix(k, http.TrailerPrefix) { + w.declareTrailer(k) + } + } + + if !w.hasNonEmptyTrailers() { + return nil + } + + var b bytes.Buffer + enc := qpack.NewEncoder(&b) + for trailer := range w.trailers { + if vals, ok := w.header[trailer]; ok { + name := strings.TrimPrefix(trailer, http.TrailerPrefix) + for _, val := range vals { + if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(name), Value: val}); err != nil { + return err + } + } + } + } + + buf := make([]byte, 0, frameHeaderLen+b.Len()) + buf = (&headersFrame{Length: uint64(b.Len())}).Append(buf) + buf = append(buf, b.Bytes()...) + _, err := w.str.writeUnframed(buf) + w.trailerWritten = true + return err +} + func (w *responseWriter) HTTPStream() Stream { w.hijacked = true w.Flush() diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index 3394043a0..94437c779 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -31,6 +31,7 @@ var _ = Describe("Response Writer", func() { decodeHeader := func(str io.Reader) map[string][]string { rw.Flush() + rw.flushTrailers() fields := make(map[string][]string) decoder := qpack.NewDecoder(nil) @@ -184,4 +185,65 @@ var _ = Describe("Response Writer", func() { Expect(func() { rw.WriteHeader(99) }).To(Panic()) Expect(func() { rw.WriteHeader(1000) }).To(Panic()) }) + + It("write announced trailer", func() { + rw.Header().Add("Trailer", "Key") + rw.WriteHeader(http.StatusTeapot) + n, err := rw.Write([]byte("foobar")) + Expect(n).To(Equal(6)) + Expect(err).ToNot(HaveOccurred()) + rw.Header().Set("Key", "Value") + + // writeTrailers needs to be called after writing the full body + Expect(rw.writeTrailers()).ToNot(HaveOccurred()) + + fields := decodeHeader(strBuf) + Expect(fields).To(HaveKeyWithValue(":status", []string{"418"})) + Expect(fields).To(HaveKeyWithValue("trailer", []string{"Key"})) + Expect(getData(strBuf)).To(Equal([]byte("foobar"))) + + fields = decodeHeader(strBuf) + Expect(fields).To(HaveKeyWithValue("key", []string{"Value"})) + }) + + It("ignore non-announced trailer (without trailer prefix)", func() { + rw.Header().Set("Trailer", "Key") + rw.WriteHeader(200) + rw.Write([]byte("foobar")) + rw.Header().Set("UnknownKey", "Value") + rw.Header().Set("Key", "Value") + + // Needs to call writeTrailers to simulate the end of the handler + Expect(rw.writeTrailers()).ToNot(HaveOccurred()) + headers := decodeHeader(strBuf) + Expect(headers).To(HaveKeyWithValue(":status", []string{"200"})) + Expect(headers).To(HaveKeyWithValue("trailer", []string{"Key"})) + + Expect(getData(strBuf)).To(Equal([]byte("foobar"))) + + trailers := decodeHeader(strBuf) + Expect(trailers).To(HaveKeyWithValue("key", []string{"Value"})) + Expect(trailers).To(Not(HaveKeyWithValue("unknownkey", []string{"Value"}))) + }) + + It("write non-announced trailer (with trailer prefix)", func() { + rw.Header().Set("Trailer", "Key") + rw.WriteHeader(200) + rw.Write([]byte("foobar")) + rw.Header().Set("Key", "Value") + rw.Header().Set(http.TrailerPrefix+"Key2", "Value") + rw.Flush() + + // Needs to call writeTrailers to simulate the end of the handler + Expect(rw.writeTrailers()).ToNot(HaveOccurred()) + headers := decodeHeader(strBuf) + Expect(headers).To(HaveKeyWithValue(":status", []string{"200"})) + Expect(headers).To(HaveKeyWithValue("trailer", []string{"Key"})) + + Expect(getData(strBuf)).To(Equal([]byte("foobar"))) + + trailers := decodeHeader(strBuf) + Expect(trailers).To(HaveKeyWithValue("key", []string{"Value"})) + Expect(trailers).To(HaveKeyWithValue("key2", []string{"Value"})) + }) }) diff --git a/http3/server.go b/http3/server.go index d3e436735..9f285b6e7 100644 --- a/http3/server.go +++ b/http3/server.go @@ -625,6 +625,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat } } r.Flush() + r.flushTrailers() } // abort the stream when there is a panic diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index cef97ba2b..2242e0c4d 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -1019,4 +1019,45 @@ var _ = Describe("HTTP tests", func() { Expect(num0RTTPackets.Load()).To(BeNumerically(">", 0)) }) }) + + It("sends and receives trailers", func() { + mux.HandleFunc("/trailers", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + w.Header().Set("Trailer", "AtEnd1, AtEnd2") + w.Header().Add("Trailer", "LAST") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") // normal header + w.WriteHeader(http.StatusOK) + w.Header().Set("AtEnd1", "value 1") + io.WriteString(w, "This HTTP response has both headers before this text and trailers at the end.\n") + w.(http.Flusher).Flush() + w.Header().Set("AtEnd2", "value 2") + io.WriteString(w, "More text\n") + w.(http.Flusher).Flush() + w.Header().Set("LAST", "value 3") + w.Header().Set(http.TrailerPrefix+"Unannounced", "Surprise!") + w.Header().Set("Late-Header", "No surprise!") + }) + + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/trailers", port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Header.Values("Trailer")).To(Equal([]string{"AtEnd1, AtEnd2", "LAST"})) + Expect(resp.Header).To(Not(HaveKey("Atend1"))) + Expect(resp.Header).To(Not(HaveKey("Atend2"))) + Expect(resp.Header).To(Not(HaveKey("Last"))) + Expect(resp.Header).To(Not(HaveKey("Late-Header"))) + + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("This HTTP response has both headers before this text and trailers at the end.\nMore text\n")) + for k := range resp.Header { + Expect(k).To(Not(HavePrefix(http.TrailerPrefix))) + } + Expect(resp.Trailer).To(Equal(http.Header(map[string][]string{ + "Atend1": {"value 1"}, + "Atend2": {"value 2"}, + "Last": {"value 3"}, + "Unannounced": {"Surprise!"}, + }))) + }) })