Files
quic-go/http3/body.go
Marten Seemann 34f4d1443f http3: implement on the HTTPStreamer on the ResponseWriter, flush header (#4469)
Currently the HTTPStreamer is implemented on the http.Request.Body. This
complicates usage, since it's not easily possible to flush the HTTP
header, requiring users to manually flash the header before taking over
the stream.

With this change, the HTTP header is now flushed automatically as soon
as HTTPStream is called.
2024-04-27 04:30:39 -07:00

135 lines
3.2 KiB
Go

package http3
import (
"context"
"errors"
"io"
"github.com/quic-go/quic-go"
)
// A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body.
// It is used by WebTransport to create WebTransport streams after a session has been established.
type Hijacker interface {
Connection() Connection
}
var errTooMuchData = errors.New("peer sent too much data")
// The body is used in the requestBody (for a http.Request) and the responseBody (for a http.Response).
type body struct {
str *stream
remainingContentLength int64
violatedContentLength bool
hasContentLength bool
}
func newBody(str *stream, contentLength int64) *body {
b := &body{str: str}
if contentLength >= 0 {
b.hasContentLength = true
b.remainingContentLength = contentLength
}
return b
}
func (r *body) StreamID() quic.StreamID { return r.str.StreamID() }
func (r *body) checkContentLengthViolation() error {
if !r.hasContentLength {
return nil
}
if r.remainingContentLength < 0 || r.remainingContentLength == 0 && r.str.hasMoreData() {
if !r.violatedContentLength {
r.str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
r.str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
r.violatedContentLength = true
}
return errTooMuchData
}
return nil
}
func (r *body) Read(b []byte) (int, error) {
if err := r.checkContentLengthViolation(); err != nil {
return 0, err
}
if r.hasContentLength {
b = b[:min(int64(len(b)), r.remainingContentLength)]
}
n, err := r.str.Read(b)
r.remainingContentLength -= int64(n)
if err := r.checkContentLengthViolation(); err != nil {
return n, err
}
return n, maybeReplaceError(err)
}
func (r *body) Close() error {
r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
return nil
}
type requestBody struct {
body
connCtx context.Context
rcvdSettings <-chan struct{}
getSettings func() *Settings
}
var _ io.ReadCloser = &requestBody{}
func newRequestBody(str *stream, contentLength int64, connCtx context.Context, rcvdSettings <-chan struct{}, getSettings func() *Settings) *requestBody {
return &requestBody{
body: *newBody(str, contentLength),
connCtx: connCtx,
rcvdSettings: rcvdSettings,
getSettings: getSettings,
}
}
type hijackableBody struct {
body body
// only set for the http.Response
// The channel is closed when the user is done with this response:
// either when Read() errors, or when Close() is called.
reqDone chan<- struct{}
reqDoneClosed bool
}
var _ io.ReadCloser = &hijackableBody{}
func newResponseBody(str *stream, contentLength int64, done chan<- struct{}) *hijackableBody {
return &hijackableBody{
body: *newBody(str, contentLength),
reqDone: done,
}
}
func (r *hijackableBody) Read(b []byte) (int, error) {
n, err := r.body.Read(b)
if err != nil {
r.requestDone()
}
return n, maybeReplaceError(err)
}
func (r *hijackableBody) requestDone() {
if r.reqDoneClosed || r.reqDone == nil {
return
}
if r.reqDone != nil {
close(r.reqDone)
}
r.reqDoneClosed = true
}
func (r *hijackableBody) Close() error {
r.requestDone()
// If the EOF was read, CancelRead() is a no-op.
r.body.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
return nil
}