http3: add HTTP Trailer support for servers (#4630)

This commit is contained in:
Kevin McDonald
2024-09-07 06:51:54 +02:00
committed by GitHub
parent ad8f3c0ff8
commit 09c2ac8aab
4 changed files with 203 additions and 4 deletions

View File

@@ -5,11 +5,13 @@ import (
"fmt"
"log/slog"
"net/http"
"net/textproto"
"strconv"
"strings"
"time"
"github.com/quic-go/qpack"
"golang.org/x/net/http/httpguts"
)
// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented the http.Response.Body.
@@ -28,10 +30,11 @@ const maxSmallResponseSize = 4096
type responseWriter struct {
str *stream
conn Connection
header http.Header
buf []byte
status int // status code passed to WriteHeader
conn Connection
header http.Header
trailers map[string]struct{}
buf []byte
status int // status code passed to WriteHeader
// for responses smaller than maxSmallResponseSize, we buffer calls to Write,
// and automatically add the Content-Length header
@@ -42,6 +45,7 @@ type responseWriter struct {
headerComplete bool // set once WriteHeader is called with a status code >= 200
headerWritten bool // set once the response header has been serialized to the stream
isHead bool
trailerWritten bool // set once the response trailers has been serialized to the stream
hijacked bool // set on HTTPStream is called
@@ -198,7 +202,26 @@ func (w *responseWriter) writeHeader(status int) error {
return err
}
// Handle trailer fields
if vals, ok := w.header["Trailer"]; ok {
for _, val := range vals {
for _, trailer := range strings.Split(val, ",") {
// We need to convert to the canonical header key value here because this will be called when using
// headers.Add or headers.Set.
trailer = textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(trailer))
w.declareTrailer(trailer)
}
}
}
for k, v := range w.header {
if _, excluded := w.trailers[k]; excluded {
continue
}
// Ignore "Trailer:" prefixed headers
if strings.HasPrefix(k, http.TrailerPrefix) {
continue
}
for index := range v {
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil {
return err
@@ -222,6 +245,15 @@ func (w *responseWriter) FlushError() error {
return err
}
func (w *responseWriter) flushTrailers() {
if w.trailerWritten {
return
}
if err := w.writeTrailers(); err != nil {
w.logger.Debug("could not write trailers", "error", err)
}
}
func (w *responseWriter) Flush() {
if err := w.FlushError(); err != nil {
if w.logger != nil {
@@ -230,6 +262,69 @@ func (w *responseWriter) Flush() {
}
}
// declareTrailer adds a trailer to the trailer list, while also validating that the trailer has a
// valid name.
func (w *responseWriter) declareTrailer(k string) {
if !httpguts.ValidTrailerHeader(k) {
// Forbidden by RFC 9110, section 6.5.1.
w.logger.Debug("ignoring invalid trailer", slog.String("header", k))
return
}
if w.trailers == nil {
w.trailers = make(map[string]struct{})
}
w.trailers[k] = struct{}{}
}
// hasNonEmptyTrailers checks to see if there are any trailers with an actual
// value set. This is possible by adding trailers to the "Trailers" header
// but never actually setting those names as trailers in the course of handling
// the request. In that case, this check may save us some allocations.
func (w *responseWriter) hasNonEmptyTrailers() bool {
for trailer := range w.trailers {
if _, ok := w.header[trailer]; ok {
return true
}
}
return false
}
// writeTrailers will write trailers to the stream if there are any.
func (w *responseWriter) writeTrailers() error {
// promote headers added via "Trailer:" convention as trailers, these can be added after
// streaming the status/headers have been written.
for k := range w.header {
// Handle "Trailer:" prefix
if strings.HasPrefix(k, http.TrailerPrefix) {
w.declareTrailer(k)
}
}
if !w.hasNonEmptyTrailers() {
return nil
}
var b bytes.Buffer
enc := qpack.NewEncoder(&b)
for trailer := range w.trailers {
if vals, ok := w.header[trailer]; ok {
name := strings.TrimPrefix(trailer, http.TrailerPrefix)
for _, val := range vals {
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(name), Value: val}); err != nil {
return err
}
}
}
}
buf := make([]byte, 0, frameHeaderLen+b.Len())
buf = (&headersFrame{Length: uint64(b.Len())}).Append(buf)
buf = append(buf, b.Bytes()...)
_, err := w.str.writeUnframed(buf)
w.trailerWritten = true
return err
}
func (w *responseWriter) HTTPStream() Stream {
w.hijacked = true
w.Flush()