forked from quic-go/quic-go
http3: add HTTP Trailer support for servers (#4630)
This commit is contained in:
@@ -5,11 +5,13 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented the http.Response.Body.
|
||||
@@ -28,10 +30,11 @@ const maxSmallResponseSize = 4096
|
||||
type responseWriter struct {
|
||||
str *stream
|
||||
|
||||
conn Connection
|
||||
header http.Header
|
||||
buf []byte
|
||||
status int // status code passed to WriteHeader
|
||||
conn Connection
|
||||
header http.Header
|
||||
trailers map[string]struct{}
|
||||
buf []byte
|
||||
status int // status code passed to WriteHeader
|
||||
|
||||
// for responses smaller than maxSmallResponseSize, we buffer calls to Write,
|
||||
// and automatically add the Content-Length header
|
||||
@@ -42,6 +45,7 @@ type responseWriter struct {
|
||||
headerComplete bool // set once WriteHeader is called with a status code >= 200
|
||||
headerWritten bool // set once the response header has been serialized to the stream
|
||||
isHead bool
|
||||
trailerWritten bool // set once the response trailers has been serialized to the stream
|
||||
|
||||
hijacked bool // set on HTTPStream is called
|
||||
|
||||
@@ -198,7 +202,26 @@ func (w *responseWriter) writeHeader(status int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle trailer fields
|
||||
if vals, ok := w.header["Trailer"]; ok {
|
||||
for _, val := range vals {
|
||||
for _, trailer := range strings.Split(val, ",") {
|
||||
// We need to convert to the canonical header key value here because this will be called when using
|
||||
// headers.Add or headers.Set.
|
||||
trailer = textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(trailer))
|
||||
w.declareTrailer(trailer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range w.header {
|
||||
if _, excluded := w.trailers[k]; excluded {
|
||||
continue
|
||||
}
|
||||
// Ignore "Trailer:" prefixed headers
|
||||
if strings.HasPrefix(k, http.TrailerPrefix) {
|
||||
continue
|
||||
}
|
||||
for index := range v {
|
||||
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil {
|
||||
return err
|
||||
@@ -222,6 +245,15 @@ func (w *responseWriter) FlushError() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *responseWriter) flushTrailers() {
|
||||
if w.trailerWritten {
|
||||
return
|
||||
}
|
||||
if err := w.writeTrailers(); err != nil {
|
||||
w.logger.Debug("could not write trailers", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Flush() {
|
||||
if err := w.FlushError(); err != nil {
|
||||
if w.logger != nil {
|
||||
@@ -230,6 +262,69 @@ func (w *responseWriter) Flush() {
|
||||
}
|
||||
}
|
||||
|
||||
// declareTrailer adds a trailer to the trailer list, while also validating that the trailer has a
|
||||
// valid name.
|
||||
func (w *responseWriter) declareTrailer(k string) {
|
||||
if !httpguts.ValidTrailerHeader(k) {
|
||||
// Forbidden by RFC 9110, section 6.5.1.
|
||||
w.logger.Debug("ignoring invalid trailer", slog.String("header", k))
|
||||
return
|
||||
}
|
||||
if w.trailers == nil {
|
||||
w.trailers = make(map[string]struct{})
|
||||
}
|
||||
w.trailers[k] = struct{}{}
|
||||
}
|
||||
|
||||
// hasNonEmptyTrailers checks to see if there are any trailers with an actual
|
||||
// value set. This is possible by adding trailers to the "Trailers" header
|
||||
// but never actually setting those names as trailers in the course of handling
|
||||
// the request. In that case, this check may save us some allocations.
|
||||
func (w *responseWriter) hasNonEmptyTrailers() bool {
|
||||
for trailer := range w.trailers {
|
||||
if _, ok := w.header[trailer]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeTrailers will write trailers to the stream if there are any.
|
||||
func (w *responseWriter) writeTrailers() error {
|
||||
// promote headers added via "Trailer:" convention as trailers, these can be added after
|
||||
// streaming the status/headers have been written.
|
||||
for k := range w.header {
|
||||
// Handle "Trailer:" prefix
|
||||
if strings.HasPrefix(k, http.TrailerPrefix) {
|
||||
w.declareTrailer(k)
|
||||
}
|
||||
}
|
||||
|
||||
if !w.hasNonEmptyTrailers() {
|
||||
return nil
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
enc := qpack.NewEncoder(&b)
|
||||
for trailer := range w.trailers {
|
||||
if vals, ok := w.header[trailer]; ok {
|
||||
name := strings.TrimPrefix(trailer, http.TrailerPrefix)
|
||||
for _, val := range vals {
|
||||
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(name), Value: val}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
buf := make([]byte, 0, frameHeaderLen+b.Len())
|
||||
buf = (&headersFrame{Length: uint64(b.Len())}).Append(buf)
|
||||
buf = append(buf, b.Bytes()...)
|
||||
_, err := w.str.writeUnframed(buf)
|
||||
w.trailerWritten = true
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *responseWriter) HTTPStream() Stream {
|
||||
w.hijacked = true
|
||||
w.Flush()
|
||||
|
||||
@@ -31,6 +31,7 @@ var _ = Describe("Response Writer", func() {
|
||||
|
||||
decodeHeader := func(str io.Reader) map[string][]string {
|
||||
rw.Flush()
|
||||
rw.flushTrailers()
|
||||
fields := make(map[string][]string)
|
||||
decoder := qpack.NewDecoder(nil)
|
||||
|
||||
@@ -184,4 +185,65 @@ var _ = Describe("Response Writer", func() {
|
||||
Expect(func() { rw.WriteHeader(99) }).To(Panic())
|
||||
Expect(func() { rw.WriteHeader(1000) }).To(Panic())
|
||||
})
|
||||
|
||||
It("write announced trailer", func() {
|
||||
rw.Header().Add("Trailer", "Key")
|
||||
rw.WriteHeader(http.StatusTeapot)
|
||||
n, err := rw.Write([]byte("foobar"))
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rw.Header().Set("Key", "Value")
|
||||
|
||||
// writeTrailers needs to be called after writing the full body
|
||||
Expect(rw.writeTrailers()).ToNot(HaveOccurred())
|
||||
|
||||
fields := decodeHeader(strBuf)
|
||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
|
||||
Expect(fields).To(HaveKeyWithValue("trailer", []string{"Key"}))
|
||||
Expect(getData(strBuf)).To(Equal([]byte("foobar")))
|
||||
|
||||
fields = decodeHeader(strBuf)
|
||||
Expect(fields).To(HaveKeyWithValue("key", []string{"Value"}))
|
||||
})
|
||||
|
||||
It("ignore non-announced trailer (without trailer prefix)", func() {
|
||||
rw.Header().Set("Trailer", "Key")
|
||||
rw.WriteHeader(200)
|
||||
rw.Write([]byte("foobar"))
|
||||
rw.Header().Set("UnknownKey", "Value")
|
||||
rw.Header().Set("Key", "Value")
|
||||
|
||||
// Needs to call writeTrailers to simulate the end of the handler
|
||||
Expect(rw.writeTrailers()).ToNot(HaveOccurred())
|
||||
headers := decodeHeader(strBuf)
|
||||
Expect(headers).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||
Expect(headers).To(HaveKeyWithValue("trailer", []string{"Key"}))
|
||||
|
||||
Expect(getData(strBuf)).To(Equal([]byte("foobar")))
|
||||
|
||||
trailers := decodeHeader(strBuf)
|
||||
Expect(trailers).To(HaveKeyWithValue("key", []string{"Value"}))
|
||||
Expect(trailers).To(Not(HaveKeyWithValue("unknownkey", []string{"Value"})))
|
||||
})
|
||||
|
||||
It("write non-announced trailer (with trailer prefix)", func() {
|
||||
rw.Header().Set("Trailer", "Key")
|
||||
rw.WriteHeader(200)
|
||||
rw.Write([]byte("foobar"))
|
||||
rw.Header().Set("Key", "Value")
|
||||
rw.Header().Set(http.TrailerPrefix+"Key2", "Value")
|
||||
rw.Flush()
|
||||
|
||||
// Needs to call writeTrailers to simulate the end of the handler
|
||||
Expect(rw.writeTrailers()).ToNot(HaveOccurred())
|
||||
headers := decodeHeader(strBuf)
|
||||
Expect(headers).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||
Expect(headers).To(HaveKeyWithValue("trailer", []string{"Key"}))
|
||||
|
||||
Expect(getData(strBuf)).To(Equal([]byte("foobar")))
|
||||
|
||||
trailers := decodeHeader(strBuf)
|
||||
Expect(trailers).To(HaveKeyWithValue("key", []string{"Value"}))
|
||||
Expect(trailers).To(HaveKeyWithValue("key2", []string{"Value"}))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -625,6 +625,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
|
||||
}
|
||||
}
|
||||
r.Flush()
|
||||
r.flushTrailers()
|
||||
}
|
||||
|
||||
// abort the stream when there is a panic
|
||||
|
||||
@@ -1019,4 +1019,45 @@ var _ = Describe("HTTP tests", func() {
|
||||
Expect(num0RTTPackets.Load()).To(BeNumerically(">", 0))
|
||||
})
|
||||
})
|
||||
|
||||
It("sends and receives trailers", func() {
|
||||
mux.HandleFunc("/trailers", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
w.Header().Set("Trailer", "AtEnd1, AtEnd2")
|
||||
w.Header().Add("Trailer", "LAST")
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8") // normal header
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("AtEnd1", "value 1")
|
||||
io.WriteString(w, "This HTTP response has both headers before this text and trailers at the end.\n")
|
||||
w.(http.Flusher).Flush()
|
||||
w.Header().Set("AtEnd2", "value 2")
|
||||
io.WriteString(w, "More text\n")
|
||||
w.(http.Flusher).Flush()
|
||||
w.Header().Set("LAST", "value 3")
|
||||
w.Header().Set(http.TrailerPrefix+"Unannounced", "Surprise!")
|
||||
w.Header().Set("Late-Header", "No surprise!")
|
||||
})
|
||||
|
||||
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/trailers", port))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
Expect(resp.Header.Values("Trailer")).To(Equal([]string{"AtEnd1, AtEnd2", "LAST"}))
|
||||
Expect(resp.Header).To(Not(HaveKey("Atend1")))
|
||||
Expect(resp.Header).To(Not(HaveKey("Atend2")))
|
||||
Expect(resp.Header).To(Not(HaveKey("Last")))
|
||||
Expect(resp.Header).To(Not(HaveKey("Late-Header")))
|
||||
|
||||
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal("This HTTP response has both headers before this text and trailers at the end.\nMore text\n"))
|
||||
for k := range resp.Header {
|
||||
Expect(k).To(Not(HavePrefix(http.TrailerPrefix)))
|
||||
}
|
||||
Expect(resp.Trailer).To(Equal(http.Header(map[string][]string{
|
||||
"Atend1": {"value 1"},
|
||||
"Atend2": {"value 2"},
|
||||
"Last": {"value 3"},
|
||||
"Unannounced": {"Surprise!"},
|
||||
})))
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user