http3: add HTTP Trailer support for servers (#4630)

This commit is contained in:
Kevin McDonald
2024-09-07 06:51:54 +02:00
committed by GitHub
parent ad8f3c0ff8
commit 09c2ac8aab
4 changed files with 203 additions and 4 deletions

View File

@@ -5,11 +5,13 @@ import (
"fmt"
"log/slog"
"net/http"
"net/textproto"
"strconv"
"strings"
"time"
"github.com/quic-go/qpack"
"golang.org/x/net/http/httpguts"
)
// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented the http.Response.Body.
@@ -28,10 +30,11 @@ const maxSmallResponseSize = 4096
type responseWriter struct {
str *stream
conn Connection
header http.Header
buf []byte
status int // status code passed to WriteHeader
conn Connection
header http.Header
trailers map[string]struct{}
buf []byte
status int // status code passed to WriteHeader
// for responses smaller than maxSmallResponseSize, we buffer calls to Write,
// and automatically add the Content-Length header
@@ -42,6 +45,7 @@ type responseWriter struct {
headerComplete bool // set once WriteHeader is called with a status code >= 200
headerWritten bool // set once the response header has been serialized to the stream
isHead bool
trailerWritten bool // set once the response trailers has been serialized to the stream
hijacked bool // set on HTTPStream is called
@@ -198,7 +202,26 @@ func (w *responseWriter) writeHeader(status int) error {
return err
}
// Handle trailer fields
if vals, ok := w.header["Trailer"]; ok {
for _, val := range vals {
for _, trailer := range strings.Split(val, ",") {
// We need to convert to the canonical header key value here because this will be called when using
// headers.Add or headers.Set.
trailer = textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(trailer))
w.declareTrailer(trailer)
}
}
}
for k, v := range w.header {
if _, excluded := w.trailers[k]; excluded {
continue
}
// Ignore "Trailer:" prefixed headers
if strings.HasPrefix(k, http.TrailerPrefix) {
continue
}
for index := range v {
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil {
return err
@@ -222,6 +245,15 @@ func (w *responseWriter) FlushError() error {
return err
}
func (w *responseWriter) flushTrailers() {
if w.trailerWritten {
return
}
if err := w.writeTrailers(); err != nil {
w.logger.Debug("could not write trailers", "error", err)
}
}
func (w *responseWriter) Flush() {
if err := w.FlushError(); err != nil {
if w.logger != nil {
@@ -230,6 +262,69 @@ func (w *responseWriter) Flush() {
}
}
// declareTrailer adds a trailer to the trailer list, while also validating that the trailer has a
// valid name.
func (w *responseWriter) declareTrailer(k string) {
if !httpguts.ValidTrailerHeader(k) {
// Forbidden by RFC 9110, section 6.5.1.
w.logger.Debug("ignoring invalid trailer", slog.String("header", k))
return
}
if w.trailers == nil {
w.trailers = make(map[string]struct{})
}
w.trailers[k] = struct{}{}
}
// hasNonEmptyTrailers checks to see if there are any trailers with an actual
// value set. This is possible by adding trailers to the "Trailers" header
// but never actually setting those names as trailers in the course of handling
// the request. In that case, this check may save us some allocations.
func (w *responseWriter) hasNonEmptyTrailers() bool {
for trailer := range w.trailers {
if _, ok := w.header[trailer]; ok {
return true
}
}
return false
}
// writeTrailers will write trailers to the stream if there are any.
func (w *responseWriter) writeTrailers() error {
// promote headers added via "Trailer:" convention as trailers, these can be added after
// streaming the status/headers have been written.
for k := range w.header {
// Handle "Trailer:" prefix
if strings.HasPrefix(k, http.TrailerPrefix) {
w.declareTrailer(k)
}
}
if !w.hasNonEmptyTrailers() {
return nil
}
var b bytes.Buffer
enc := qpack.NewEncoder(&b)
for trailer := range w.trailers {
if vals, ok := w.header[trailer]; ok {
name := strings.TrimPrefix(trailer, http.TrailerPrefix)
for _, val := range vals {
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(name), Value: val}); err != nil {
return err
}
}
}
}
buf := make([]byte, 0, frameHeaderLen+b.Len())
buf = (&headersFrame{Length: uint64(b.Len())}).Append(buf)
buf = append(buf, b.Bytes()...)
_, err := w.str.writeUnframed(buf)
w.trailerWritten = true
return err
}
func (w *responseWriter) HTTPStream() Stream {
w.hijacked = true
w.Flush()

View File

@@ -31,6 +31,7 @@ var _ = Describe("Response Writer", func() {
decodeHeader := func(str io.Reader) map[string][]string {
rw.Flush()
rw.flushTrailers()
fields := make(map[string][]string)
decoder := qpack.NewDecoder(nil)
@@ -184,4 +185,65 @@ var _ = Describe("Response Writer", func() {
Expect(func() { rw.WriteHeader(99) }).To(Panic())
Expect(func() { rw.WriteHeader(1000) }).To(Panic())
})
It("write announced trailer", func() {
rw.Header().Add("Trailer", "Key")
rw.WriteHeader(http.StatusTeapot)
n, err := rw.Write([]byte("foobar"))
Expect(n).To(Equal(6))
Expect(err).ToNot(HaveOccurred())
rw.Header().Set("Key", "Value")
// writeTrailers needs to be called after writing the full body
Expect(rw.writeTrailers()).ToNot(HaveOccurred())
fields := decodeHeader(strBuf)
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
Expect(fields).To(HaveKeyWithValue("trailer", []string{"Key"}))
Expect(getData(strBuf)).To(Equal([]byte("foobar")))
fields = decodeHeader(strBuf)
Expect(fields).To(HaveKeyWithValue("key", []string{"Value"}))
})
It("ignore non-announced trailer (without trailer prefix)", func() {
rw.Header().Set("Trailer", "Key")
rw.WriteHeader(200)
rw.Write([]byte("foobar"))
rw.Header().Set("UnknownKey", "Value")
rw.Header().Set("Key", "Value")
// Needs to call writeTrailers to simulate the end of the handler
Expect(rw.writeTrailers()).ToNot(HaveOccurred())
headers := decodeHeader(strBuf)
Expect(headers).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(headers).To(HaveKeyWithValue("trailer", []string{"Key"}))
Expect(getData(strBuf)).To(Equal([]byte("foobar")))
trailers := decodeHeader(strBuf)
Expect(trailers).To(HaveKeyWithValue("key", []string{"Value"}))
Expect(trailers).To(Not(HaveKeyWithValue("unknownkey", []string{"Value"})))
})
It("write non-announced trailer (with trailer prefix)", func() {
rw.Header().Set("Trailer", "Key")
rw.WriteHeader(200)
rw.Write([]byte("foobar"))
rw.Header().Set("Key", "Value")
rw.Header().Set(http.TrailerPrefix+"Key2", "Value")
rw.Flush()
// Needs to call writeTrailers to simulate the end of the handler
Expect(rw.writeTrailers()).ToNot(HaveOccurred())
headers := decodeHeader(strBuf)
Expect(headers).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(headers).To(HaveKeyWithValue("trailer", []string{"Key"}))
Expect(getData(strBuf)).To(Equal([]byte("foobar")))
trailers := decodeHeader(strBuf)
Expect(trailers).To(HaveKeyWithValue("key", []string{"Value"}))
Expect(trailers).To(HaveKeyWithValue("key2", []string{"Value"}))
})
})

View File

@@ -625,6 +625,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
}
}
r.Flush()
r.flushTrailers()
}
// abort the stream when there is a panic

View File

@@ -1019,4 +1019,45 @@ var _ = Describe("HTTP tests", func() {
Expect(num0RTTPackets.Load()).To(BeNumerically(">", 0))
})
})
It("sends and receives trailers", func() {
mux.HandleFunc("/trailers", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
w.Header().Set("Trailer", "AtEnd1, AtEnd2")
w.Header().Add("Trailer", "LAST")
w.Header().Set("Content-Type", "text/plain; charset=utf-8") // normal header
w.WriteHeader(http.StatusOK)
w.Header().Set("AtEnd1", "value 1")
io.WriteString(w, "This HTTP response has both headers before this text and trailers at the end.\n")
w.(http.Flusher).Flush()
w.Header().Set("AtEnd2", "value 2")
io.WriteString(w, "More text\n")
w.(http.Flusher).Flush()
w.Header().Set("LAST", "value 3")
w.Header().Set(http.TrailerPrefix+"Unannounced", "Surprise!")
w.Header().Set("Late-Header", "No surprise!")
})
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/trailers", port))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
Expect(resp.Header.Values("Trailer")).To(Equal([]string{"AtEnd1, AtEnd2", "LAST"}))
Expect(resp.Header).To(Not(HaveKey("Atend1")))
Expect(resp.Header).To(Not(HaveKey("Atend2")))
Expect(resp.Header).To(Not(HaveKey("Last")))
Expect(resp.Header).To(Not(HaveKey("Late-Header")))
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal("This HTTP response has both headers before this text and trailers at the end.\nMore text\n"))
for k := range resp.Header {
Expect(k).To(Not(HavePrefix(http.TrailerPrefix)))
}
Expect(resp.Trailer).To(Equal(http.Header(map[string][]string{
"Atend1": {"value 1"},
"Atend2": {"value 2"},
"Last": {"value 3"},
"Unannounced": {"Surprise!"},
})))
})
})