package http3 import ( "bytes" "fmt" "log/slog" "net/http" "net/textproto" "strconv" "strings" "time" "git.geeks-team.ru/gr1ffon/quic-go/http3/qlog" "github.com/quic-go/qpack" "golang.org/x/net/http/httpguts" ) // The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by the http.ResponseWriter. // When a stream is taken over, it's the caller's responsibility to close the stream. type HTTPStreamer interface { HTTPStream() *Stream } // The maximum length of an encoded HTTP/3 frame header is 16: // The frame has a type and length field, both QUIC varints (maximum 8 bytes in length) const frameHeaderLen = 16 const maxSmallResponseSize = 4096 type ResponseWriter struct { str *Stream conn *Conn header http.Header trailers map[string]struct{} buf []byte status int // status code passed to WriteHeader // for responses smaller than maxSmallResponseSize, we buffer calls to Write, // and automatically add the Content-Length header smallResponseBuf []byte contentLen int64 // if handler set valid Content-Length header numWritten int64 // bytes written headerComplete bool // set once WriteHeader is called with a status code >= 200 headerWritten bool // set once the response header has been serialized to the stream isHead bool trailerWritten bool // set once the response trailers has been serialized to the stream hijacked bool // set on HTTPStream is called logger *slog.Logger } var ( _ http.ResponseWriter = &ResponseWriter{} _ http.Flusher = &ResponseWriter{} _ Hijacker = &ResponseWriter{} _ HTTPStreamer = &ResponseWriter{} // make sure that we implement (some of the) methods used by the http.ResponseController _ interface { SetReadDeadline(time.Time) error SetWriteDeadline(time.Time) error Flush() FlushError() error } = &ResponseWriter{} ) func newResponseWriter(str *Stream, conn *Conn, isHead bool, logger *slog.Logger) *ResponseWriter { return &ResponseWriter{ str: str, conn: conn, header: http.Header{}, buf: make([]byte, frameHeaderLen), isHead: isHead, logger: logger, } } func (w *ResponseWriter) Header() http.Header { return w.header } func (w *ResponseWriter) WriteHeader(status int) { if w.headerComplete { return } // http status must be 3 digits if status < 100 || status > 999 { panic(fmt.Sprintf("invalid WriteHeader code %v", status)) } w.status = status // immediately write 1xx headers if status < 200 { w.writeHeader(status) return } // We're done with headers once we write a status >= 200. w.headerComplete = true // Add Date header. // This is what the standard library does. // Can be disabled by setting the Date header to nil. if _, ok := w.header["Date"]; !ok { w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) } // Content-Length checking // use ParseUint instead of ParseInt, as negative values are invalid if clen := w.header.Get("Content-Length"); clen != "" { if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { w.contentLen = int64(cl) } else { // emit a warning for malformed Content-Length and remove it logger := w.logger if logger == nil { logger = slog.Default() } logger.Error("Malformed Content-Length", "value", clen) w.header.Del("Content-Length") } } } func (w *ResponseWriter) sniffContentType(p []byte) { // If no content type, apply sniffing algorithm to body. // We can't use `w.header.Get` here since if the Content-Type was set to nil, we shouldn't do sniffing. _, haveType := w.header["Content-Type"] // If the Content-Encoding was set and is non-blank, we shouldn't sniff the body. hasCE := w.header.Get("Content-Encoding") != "" if !hasCE && !haveType && len(p) > 0 { w.header.Set("Content-Type", http.DetectContentType(p)) } } func (w *ResponseWriter) Write(p []byte) (int, error) { bodyAllowed := bodyAllowedForStatus(w.status) if !w.headerComplete { w.sniffContentType(p) w.WriteHeader(http.StatusOK) bodyAllowed = true } if !bodyAllowed { return 0, http.ErrBodyNotAllowed } w.numWritten += int64(len(p)) if w.contentLen != 0 && w.numWritten > w.contentLen { return 0, http.ErrContentLength } if w.isHead { return len(p), nil } if !w.headerWritten { // Buffer small responses. // This allows us to automatically set the Content-Length field. if len(w.smallResponseBuf)+len(p) < maxSmallResponseSize { w.smallResponseBuf = append(w.smallResponseBuf, p...) return len(p), nil } } return w.doWrite(p) } func (w *ResponseWriter) doWrite(p []byte) (int, error) { if !w.headerWritten { w.sniffContentType(w.smallResponseBuf) if err := w.writeHeader(w.status); err != nil { return 0, maybeReplaceError(err) } w.headerWritten = true } l := uint64(len(w.smallResponseBuf) + len(p)) if l == 0 { return 0, nil } df := &dataFrame{Length: l} w.buf = w.buf[:0] w.buf = df.Append(w.buf) if w.str.qlogger != nil { w.str.qlogger.RecordEvent(qlog.FrameCreated{ StreamID: w.str.StreamID(), Raw: qlog.RawInfo{Length: len(w.buf) + int(l), PayloadLength: int(l)}, Frame: qlog.Frame{Frame: qlog.DataFrame{}}, }) } if _, err := w.str.writeUnframed(w.buf); err != nil { return 0, maybeReplaceError(err) } if len(w.smallResponseBuf) > 0 { if _, err := w.str.writeUnframed(w.smallResponseBuf); err != nil { return 0, maybeReplaceError(err) } w.smallResponseBuf = nil } var n int if len(p) > 0 { var err error n, err = w.str.writeUnframed(p) if err != nil { return n, maybeReplaceError(err) } } return n, nil } func (w *ResponseWriter) writeHeader(status int) error { var headerFields []qlog.HeaderField // only used for qlog var headers bytes.Buffer enc := qpack.NewEncoder(&headers) if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}); err != nil { return err } if w.str.qlogger != nil { headerFields = append(headerFields, qlog.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) } // Handle trailer fields if vals, ok := w.header["Trailer"]; ok { for _, val := range vals { for _, trailer := range strings.Split(val, ",") { // We need to convert to the canonical header key value here because this will be called when using // headers.Add or headers.Set. trailer = textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(trailer)) w.declareTrailer(trailer) } } } for k, v := range w.header { if _, excluded := w.trailers[k]; excluded { continue } // Ignore "Trailer:" prefixed headers if strings.HasPrefix(k, http.TrailerPrefix) { continue } for index := range v { name := strings.ToLower(k) value := v[index] if err := enc.WriteField(qpack.HeaderField{Name: name, Value: value}); err != nil { return err } if w.str.qlogger != nil { headerFields = append(headerFields, qlog.HeaderField{Name: name, Value: value}) } } } buf := make([]byte, 0, frameHeaderLen+headers.Len()) buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf) buf = append(buf, headers.Bytes()...) if w.str.qlogger != nil { qlogCreatedHeadersFrame(w.str.qlogger, w.str.StreamID(), len(buf), headers.Len(), headerFields) } _, err := w.str.writeUnframed(buf) return err } func (w *ResponseWriter) FlushError() error { if !w.headerComplete { w.WriteHeader(http.StatusOK) } _, err := w.doWrite(nil) return err } func (w *ResponseWriter) flushTrailers() { if w.trailerWritten { return } if err := w.writeTrailers(); err != nil { w.logger.Debug("could not write trailers", "error", err) } } func (w *ResponseWriter) Flush() { if err := w.FlushError(); err != nil { if w.logger != nil { w.logger.Debug("could not flush to stream", "error", err) } } } // declareTrailer adds a trailer to the trailer list, while also validating that the trailer has a // valid name. func (w *ResponseWriter) declareTrailer(k string) { if !httpguts.ValidTrailerHeader(k) { // Forbidden by RFC 9110, section 6.5.1. w.logger.Debug("ignoring invalid trailer", slog.String("header", k)) return } if w.trailers == nil { w.trailers = make(map[string]struct{}) } w.trailers[k] = struct{}{} } // hasNonEmptyTrailers checks to see if there are any trailers with an actual // value set. This is possible by adding trailers to the "Trailers" header // but never actually setting those names as trailers in the course of handling // the request. In that case, this check may save us some allocations. func (w *ResponseWriter) hasNonEmptyTrailers() bool { for trailer := range w.trailers { if _, ok := w.header[trailer]; ok { return true } } return false } // writeTrailers will write trailers to the stream if there are any. func (w *ResponseWriter) writeTrailers() error { // promote headers added via "Trailer:" convention as trailers, these can be added after // streaming the status/headers have been written. for k := range w.header { // Handle "Trailer:" prefix if strings.HasPrefix(k, http.TrailerPrefix) { w.declareTrailer(k) } } if !w.hasNonEmptyTrailers() { return nil } var b bytes.Buffer var headerFields []qlog.HeaderField enc := qpack.NewEncoder(&b) for trailer := range w.trailers { trailerName := strings.ToLower(strings.TrimPrefix(trailer, http.TrailerPrefix)) if vals, ok := w.header[trailer]; ok { for _, val := range vals { if err := enc.WriteField(qpack.HeaderField{Name: trailerName, Value: val}); err != nil { return err } if w.str.qlogger != nil { headerFields = append(headerFields, qlog.HeaderField{Name: trailerName, Value: val}) } } } } buf := make([]byte, 0, frameHeaderLen+b.Len()) buf = (&headersFrame{Length: uint64(b.Len())}).Append(buf) buf = append(buf, b.Bytes()...) if w.str.qlogger != nil { qlogCreatedHeadersFrame(w.str.qlogger, w.str.StreamID(), len(buf), b.Len(), headerFields) } _, err := w.str.writeUnframed(buf) w.trailerWritten = true return err } func (w *ResponseWriter) HTTPStream() *Stream { w.hijacked = true w.Flush() return w.str } func (w *ResponseWriter) wasStreamHijacked() bool { return w.hijacked } func (w *ResponseWriter) Connection() *Conn { return w.conn } func (w *ResponseWriter) SetReadDeadline(deadline time.Time) error { return w.str.SetReadDeadline(deadline) } func (w *ResponseWriter) SetWriteDeadline(deadline time.Time) error { return w.str.SetWriteDeadline(deadline) } // copied from http2/http2.go // bodyAllowedForStatus reports whether a given response status code // permits a body. See RFC 2616, section 4.4. func bodyAllowedForStatus(status int) bool { switch { case status >= 100 && status <= 199: return false case status == http.StatusNoContent: return false case status == http.StatusNotModified: return false } return true }