http3: simplify response header writing (#4441)

This commit is contained in:
Marten Seemann
2024-04-17 09:44:36 -07:00
committed by GitHub
parent 788b6f4409
commit e48e1d465d
5 changed files with 54 additions and 59 deletions

View File

@@ -26,7 +26,7 @@ func encodeResponse(status int) []byte {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger)
rw := newResponseWriter(newStream(rstr, nil), nil, false, utils.DefaultLogger)
rw.WriteHeader(status)
rw.Flush()
return buf.Bytes()
@@ -738,7 +738,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger)
rw := newResponseWriter(newStream(rstr, nil), nil, false, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response"))
@@ -764,7 +764,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger)
rw := newResponseWriter(newStream(rstr, nil), nil, false, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })

View File

@@ -101,6 +101,10 @@ func (s *stream) Write(b []byte) (int, error) {
return s.Stream.Write(b)
}
func (s *stream) writeUnframed(b []byte) (int, error) {
return s.Stream.Write(b)
}
func (s *stream) StreamID() protocol.StreamID {
return s.Stream.StreamID()
}

View File

@@ -8,7 +8,6 @@ import (
"strings"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/qpack"
@@ -18,46 +17,15 @@ import (
// The frame has a type and length field, both QUIC varints (maximum 8 bytes in length)
const frameHeaderLen = 16
// headerWriter wraps the stream, so that the first Write call flushes the header to the stream
type headerWriter struct {
str quic.Stream
header http.Header
status int // status code passed to WriteHeader
logger utils.Logger
}
// writeHeader encodes and flush header to the stream
func (hw *headerWriter) writeHeader() error {
var headers bytes.Buffer
enc := qpack.NewEncoder(&headers)
if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)}); err != nil {
return err
}
for k, v := range hw.header {
for index := range v {
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil {
return err
}
}
}
buf := make([]byte, 0, frameHeaderLen+headers.Len())
buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
hw.logger.Infof("Responding with %d", hw.status)
buf = append(buf, headers.Bytes()...)
_, err := hw.str.Write(buf)
return err
}
const maxSmallResponseSize = 4096
type responseWriter struct {
*headerWriter
conn Connection
buf []byte
str *stream
conn Connection
header http.Header
buf []byte
status int // status code passed to WriteHeader
// for responses smaller than maxSmallResponseSize, we buffer calls to Write,
// and automatically add the Content-Length header
@@ -68,6 +36,8 @@ type responseWriter struct {
headerComplete bool // set once WriteHeader is called with a status code >= 200
headerWritten bool // set once the response header has been serialized to the stream
isHead bool
logger utils.Logger
}
var (
@@ -76,17 +46,14 @@ var (
_ Hijacker = &responseWriter{}
)
func newResponseWriter(str quic.Stream, conn Connection, isHead bool, logger utils.Logger) *responseWriter {
hw := &headerWriter{
str: str,
header: http.Header{},
logger: logger,
}
func newResponseWriter(str *stream, conn Connection, isHead bool, logger utils.Logger) *responseWriter {
return &responseWriter{
headerWriter: hw,
buf: make([]byte, frameHeaderLen),
conn: conn,
isHead: isHead,
str: str,
conn: conn,
header: http.Header{},
buf: make([]byte, frameHeaderLen),
isHead: isHead,
logger: logger,
}
}
@@ -107,7 +74,7 @@ func (w *responseWriter) WriteHeader(status int) {
// immediately write 1xx headers
if status < 200 {
w.writeHeader()
w.writeHeader(status)
return
}
@@ -178,7 +145,7 @@ func (w *responseWriter) Write(p []byte) (int, error) {
func (w *responseWriter) doWrite(p []byte) (int, error) {
if !w.headerWritten {
if err := w.writeHeader(); err != nil {
if err := w.writeHeader(w.status); err != nil {
return 0, maybeReplaceError(err)
}
w.headerWritten = true
@@ -191,11 +158,11 @@ func (w *responseWriter) doWrite(p []byte) (int, error) {
df := &dataFrame{Length: l}
w.buf = w.buf[:0]
w.buf = df.Append(w.buf)
if _, err := w.str.Write(w.buf); err != nil {
if _, err := w.str.writeUnframed(w.buf); err != nil {
return 0, maybeReplaceError(err)
}
if len(w.smallResponseBuf) > 0 {
if _, err := w.str.Write(w.smallResponseBuf); err != nil {
if _, err := w.str.writeUnframed(w.smallResponseBuf); err != nil {
return 0, maybeReplaceError(err)
}
w.smallResponseBuf = nil
@@ -203,7 +170,7 @@ func (w *responseWriter) doWrite(p []byte) (int, error) {
var n int
if len(p) > 0 {
var err error
n, err = w.str.Write(p)
n, err = w.str.writeUnframed(p)
if err != nil {
return n, maybeReplaceError(err)
}
@@ -211,6 +178,29 @@ func (w *responseWriter) doWrite(p []byte) (int, error) {
return n, nil
}
func (w *responseWriter) writeHeader(status int) error {
var headers bytes.Buffer
enc := qpack.NewEncoder(&headers)
if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}); err != nil {
return err
}
for k, v := range w.header {
for index := range v {
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil {
return err
}
}
}
buf := make([]byte, 0, frameHeaderLen+headers.Len())
buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
buf = append(buf, headers.Bytes()...)
_, err := w.str.writeUnframed(buf)
return err
}
func (w *responseWriter) FlushError() error {
if !w.headerComplete {
w.WriteHeader(http.StatusOK)

View File

@@ -28,7 +28,7 @@ var _ = Describe("Response Writer", func() {
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
str.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes()
str.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes()
rw = newResponseWriter(str, nil, false, utils.DefaultLogger)
rw = newResponseWriter(newStream(str, nil), nil, false, utils.DefaultLogger)
})
decodeHeader := func(str io.Reader) map[string][]string {

View File

@@ -531,7 +531,8 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
contentLength = req.ContentLength
}
body := newRequestBody(newStream(str, conn), contentLength, conn.Context(), conn.ReceivedSettings(), conn.Settings)
hstr := newStream(str, conn)
body := newRequestBody(hstr, contentLength, conn.Context(), conn.ReceivedSettings(), conn.Settings)
req.Body = body
if s.logger.Debug() {
@@ -551,7 +552,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
}
}
req = req.WithContext(ctx)
r := newResponseWriter(str, conn, req.Method == http.MethodHead, s.logger)
r := newResponseWriter(hstr, conn, req.Method == http.MethodHead, s.logger)
handler := s.Handler
if handler == nil {
handler = http.DefaultServeMux