forked from quic-go/quic-go
http3: add HTTP Trailer support for servers (#4630)
This commit is contained in:
@@ -5,11 +5,13 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented the http.Response.Body.
|
||||
@@ -28,10 +30,11 @@ const maxSmallResponseSize = 4096
|
||||
type responseWriter struct {
|
||||
str *stream
|
||||
|
||||
conn Connection
|
||||
header http.Header
|
||||
buf []byte
|
||||
status int // status code passed to WriteHeader
|
||||
conn Connection
|
||||
header http.Header
|
||||
trailers map[string]struct{}
|
||||
buf []byte
|
||||
status int // status code passed to WriteHeader
|
||||
|
||||
// for responses smaller than maxSmallResponseSize, we buffer calls to Write,
|
||||
// and automatically add the Content-Length header
|
||||
@@ -42,6 +45,7 @@ type responseWriter struct {
|
||||
headerComplete bool // set once WriteHeader is called with a status code >= 200
|
||||
headerWritten bool // set once the response header has been serialized to the stream
|
||||
isHead bool
|
||||
trailerWritten bool // set once the response trailers has been serialized to the stream
|
||||
|
||||
hijacked bool // set on HTTPStream is called
|
||||
|
||||
@@ -198,7 +202,26 @@ func (w *responseWriter) writeHeader(status int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle trailer fields
|
||||
if vals, ok := w.header["Trailer"]; ok {
|
||||
for _, val := range vals {
|
||||
for _, trailer := range strings.Split(val, ",") {
|
||||
// We need to convert to the canonical header key value here because this will be called when using
|
||||
// headers.Add or headers.Set.
|
||||
trailer = textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(trailer))
|
||||
w.declareTrailer(trailer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range w.header {
|
||||
if _, excluded := w.trailers[k]; excluded {
|
||||
continue
|
||||
}
|
||||
// Ignore "Trailer:" prefixed headers
|
||||
if strings.HasPrefix(k, http.TrailerPrefix) {
|
||||
continue
|
||||
}
|
||||
for index := range v {
|
||||
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil {
|
||||
return err
|
||||
@@ -222,6 +245,15 @@ func (w *responseWriter) FlushError() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *responseWriter) flushTrailers() {
|
||||
if w.trailerWritten {
|
||||
return
|
||||
}
|
||||
if err := w.writeTrailers(); err != nil {
|
||||
w.logger.Debug("could not write trailers", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Flush() {
|
||||
if err := w.FlushError(); err != nil {
|
||||
if w.logger != nil {
|
||||
@@ -230,6 +262,69 @@ func (w *responseWriter) Flush() {
|
||||
}
|
||||
}
|
||||
|
||||
// declareTrailer adds a trailer to the trailer list, while also validating that the trailer has a
|
||||
// valid name.
|
||||
func (w *responseWriter) declareTrailer(k string) {
|
||||
if !httpguts.ValidTrailerHeader(k) {
|
||||
// Forbidden by RFC 9110, section 6.5.1.
|
||||
w.logger.Debug("ignoring invalid trailer", slog.String("header", k))
|
||||
return
|
||||
}
|
||||
if w.trailers == nil {
|
||||
w.trailers = make(map[string]struct{})
|
||||
}
|
||||
w.trailers[k] = struct{}{}
|
||||
}
|
||||
|
||||
// hasNonEmptyTrailers checks to see if there are any trailers with an actual
|
||||
// value set. This is possible by adding trailers to the "Trailers" header
|
||||
// but never actually setting those names as trailers in the course of handling
|
||||
// the request. In that case, this check may save us some allocations.
|
||||
func (w *responseWriter) hasNonEmptyTrailers() bool {
|
||||
for trailer := range w.trailers {
|
||||
if _, ok := w.header[trailer]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeTrailers will write trailers to the stream if there are any.
|
||||
func (w *responseWriter) writeTrailers() error {
|
||||
// promote headers added via "Trailer:" convention as trailers, these can be added after
|
||||
// streaming the status/headers have been written.
|
||||
for k := range w.header {
|
||||
// Handle "Trailer:" prefix
|
||||
if strings.HasPrefix(k, http.TrailerPrefix) {
|
||||
w.declareTrailer(k)
|
||||
}
|
||||
}
|
||||
|
||||
if !w.hasNonEmptyTrailers() {
|
||||
return nil
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
enc := qpack.NewEncoder(&b)
|
||||
for trailer := range w.trailers {
|
||||
if vals, ok := w.header[trailer]; ok {
|
||||
name := strings.TrimPrefix(trailer, http.TrailerPrefix)
|
||||
for _, val := range vals {
|
||||
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(name), Value: val}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
buf := make([]byte, 0, frameHeaderLen+b.Len())
|
||||
buf = (&headersFrame{Length: uint64(b.Len())}).Append(buf)
|
||||
buf = append(buf, b.Bytes()...)
|
||||
_, err := w.str.writeUnframed(buf)
|
||||
w.trailerWritten = true
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *responseWriter) HTTPStream() Stream {
|
||||
w.hijacked = true
|
||||
w.Flush()
|
||||
|
||||
Reference in New Issue
Block a user