forked from quic-go/quic-go
http3: expose an OpenStream method on the RoundTripper (#4409)
The stream exposes two methods required for doing an HTTP request: SendRequestHeader and ReadResponse. This can be used by applications that wish to use the stream for non-HTTP content afterwards. This will lead to a simplification in the API we need to expose for WebTransport, and will make it easier to send HTTP Datagrams associated with this stream.
This commit is contained in:
159
http3/client.go
159
http3/client.go
@@ -8,7 +8,6 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -114,7 +113,7 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con
|
||||
tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
|
||||
|
||||
return &client{
|
||||
hostname: authorityAddr("https", hostname),
|
||||
hostname: authorityAddr(hostname),
|
||||
tlsConf: tlsConf,
|
||||
requestWriter: newRequestWriter(logger),
|
||||
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
|
||||
@@ -222,16 +221,20 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
|
||||
return rsp, err
|
||||
}
|
||||
|
||||
func (c *client) dialConn(ctx context.Context) error {
|
||||
c.dialOnce.Do(func() {
|
||||
c.handshakeErr = c.dial(ctx)
|
||||
})
|
||||
return c.handshakeErr
|
||||
}
|
||||
|
||||
func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
||||
if authorityAddr(hostnameFromURL(req.URL)) != c.hostname {
|
||||
return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
||||
}
|
||||
|
||||
c.dialOnce.Do(func() {
|
||||
c.handshakeErr = c.dial(req.Context())
|
||||
})
|
||||
if c.handshakeErr != nil {
|
||||
return nil, c.handshakeErr
|
||||
if err := c.dialConn(req.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// At this point, c.conn is guaranteed to be set.
|
||||
@@ -290,31 +293,34 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
|
||||
}
|
||||
}()
|
||||
|
||||
doneChan := reqDone
|
||||
if opt.DontCloseRequestStream {
|
||||
doneChan = nil
|
||||
}
|
||||
rsp, rerr := c.doRequest(req, conn, str, opt, doneChan)
|
||||
if rerr.err != nil { // if any error occurred
|
||||
rsp, err := c.doRequest(req, conn, str, reqDone)
|
||||
if err != nil { // if any error occurred
|
||||
close(reqDone)
|
||||
<-done
|
||||
if rerr.streamErr != 0 { // if it was a stream error
|
||||
str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
|
||||
}
|
||||
if rerr.connErr != 0 { // if it was a connection error
|
||||
var reason string
|
||||
if rerr.err != nil {
|
||||
reason = rerr.err.Error()
|
||||
}
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
|
||||
}
|
||||
return nil, maybeReplaceError(rerr.err)
|
||||
return nil, maybeReplaceError(err)
|
||||
}
|
||||
if opt.DontCloseRequestStream {
|
||||
close(reqDone)
|
||||
<-done
|
||||
return rsp, maybeReplaceError(err)
|
||||
}
|
||||
|
||||
func (c *client) OpenStream(ctx context.Context) (RequestStream, error) {
|
||||
if err := c.dialConn(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rsp, maybeReplaceError(rerr.err)
|
||||
conn := *c.conn.Load()
|
||||
str, err := conn.OpenStreamSync(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newRequestStream(
|
||||
newStream(str, func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") }),
|
||||
c.hconn,
|
||||
c.requestWriter,
|
||||
nil,
|
||||
c.decoder,
|
||||
c.opts.DisableCompression,
|
||||
c.maxHeaderBytes(),
|
||||
func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") },
|
||||
), nil
|
||||
}
|
||||
|
||||
// cancelingReader reads from the io.Reader.
|
||||
@@ -356,21 +362,23 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength i
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
|
||||
var requestGzip bool
|
||||
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
|
||||
requestGzip = true
|
||||
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, reqDone chan<- struct{}) (*http.Response, error) {
|
||||
hstr := newRequestStream(
|
||||
newStream(str, func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") }),
|
||||
c.hconn,
|
||||
c.requestWriter,
|
||||
reqDone,
|
||||
c.decoder,
|
||||
c.opts.DisableCompression,
|
||||
c.maxHeaderBytes(),
|
||||
func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") },
|
||||
)
|
||||
if err := hstr.SendRequestHeader(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil {
|
||||
return nil, newStreamError(ErrCodeInternalError, err)
|
||||
}
|
||||
|
||||
if req.Body == nil && !opt.DontCloseRequestStream {
|
||||
str.Close()
|
||||
}
|
||||
|
||||
hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") })
|
||||
if req.Body != nil {
|
||||
if req.Body == nil {
|
||||
hstr.Close()
|
||||
} else {
|
||||
// send the request body asynchronously
|
||||
go func() {
|
||||
contentLength := int64(-1)
|
||||
@@ -382,75 +390,18 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui
|
||||
if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil {
|
||||
c.logger.Errorf("Error writing request: %s", err)
|
||||
}
|
||||
if !opt.DontCloseRequestStream {
|
||||
hstr.Close()
|
||||
}
|
||||
hstr.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
frame, err := parseNextFrame(str, nil)
|
||||
res, err := hstr.ReadResponse()
|
||||
if err != nil {
|
||||
return nil, newStreamError(ErrCodeFrameError, err)
|
||||
}
|
||||
hf, ok := frame.(*headersFrame)
|
||||
if !ok {
|
||||
return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
|
||||
}
|
||||
if hf.Length > c.maxHeaderBytes() {
|
||||
return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
|
||||
}
|
||||
headerBlock := make([]byte, hf.Length)
|
||||
if _, err := io.ReadFull(str, headerBlock); err != nil {
|
||||
return nil, newStreamError(ErrCodeRequestIncomplete, err)
|
||||
}
|
||||
hfs, err := c.decoder.DecodeFull(headerBlock)
|
||||
if err != nil {
|
||||
// TODO: use the right error code
|
||||
return nil, newConnError(ErrCodeGeneralProtocolError, err)
|
||||
}
|
||||
|
||||
res, err := responseFromHeaders(hfs)
|
||||
if err != nil {
|
||||
return nil, newStreamError(ErrCodeMessageError, err)
|
||||
return nil, err
|
||||
}
|
||||
connState := conn.ConnectionState().TLS
|
||||
res.TLS = &connState
|
||||
res.Request = req
|
||||
// Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
|
||||
// See section 4.1.2 of RFC 9114.
|
||||
var httpStr Stream
|
||||
if _, ok := res.Header["Content-Length"]; ok && res.ContentLength >= 0 {
|
||||
httpStr = newLengthLimitedStream(hstr, res.ContentLength)
|
||||
} else {
|
||||
httpStr = hstr
|
||||
}
|
||||
respBody := newResponseBody(httpStr, conn, reqDone)
|
||||
|
||||
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
|
||||
_, hasTransferEncoding := res.Header["Transfer-Encoding"]
|
||||
isInformational := res.StatusCode >= 100 && res.StatusCode < 200
|
||||
isNoContent := res.StatusCode == http.StatusNoContent
|
||||
isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300
|
||||
if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect {
|
||||
res.ContentLength = -1
|
||||
if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 {
|
||||
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
|
||||
res.ContentLength = clen64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
||||
res.Header.Del("Content-Encoding")
|
||||
res.Header.Del("Content-Length")
|
||||
res.ContentLength = -1
|
||||
res.Body = newGzipReader(respBody)
|
||||
res.Uncompressed = true
|
||||
} else {
|
||||
res.Body = respBody
|
||||
}
|
||||
|
||||
return res, requestError{}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *client) HandshakeComplete() bool {
|
||||
|
||||
Reference in New Issue
Block a user