forked from quic-go/quic-go
If Dial is set, it will be used for dialing new QUIC connections. If it is nil, quic.DialAddr will be used.
298 lines
7.0 KiB
Go
298 lines
7.0 KiB
Go
package h2quic
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
|
|
"golang.org/x/net/http2"
|
|
"golang.org/x/net/http2/hpack"
|
|
"golang.org/x/net/idna"
|
|
|
|
quic "github.com/lucas-clemente/quic-go"
|
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
"github.com/lucas-clemente/quic-go/qerr"
|
|
)
|
|
|
|
type roundTripperOpts struct {
|
|
DisableCompression bool
|
|
}
|
|
|
|
var dialAddr = quic.DialAddr
|
|
|
|
// client is a HTTP2 client doing QUIC requests
|
|
type client struct {
|
|
mutex sync.RWMutex
|
|
|
|
tlsConf *tls.Config
|
|
config *quic.Config
|
|
opts *roundTripperOpts
|
|
|
|
hostname string
|
|
handshakeErr error
|
|
dialOnce sync.Once
|
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
|
|
|
session quic.Session
|
|
headerStream quic.Stream
|
|
headerErr *qerr.QuicError
|
|
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
|
|
requestWriter *requestWriter
|
|
|
|
responses map[protocol.StreamID]chan *http.Response
|
|
}
|
|
|
|
var _ http.RoundTripper = &client{}
|
|
|
|
var defaultQuicConfig = &quic.Config{
|
|
RequestConnectionIDOmission: true,
|
|
KeepAlive: true,
|
|
}
|
|
|
|
// newClient creates a new client
|
|
func newClient(
|
|
hostname string,
|
|
tlsConfig *tls.Config,
|
|
opts *roundTripperOpts,
|
|
quicConfig *quic.Config,
|
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
|
) *client {
|
|
config := defaultQuicConfig
|
|
if quicConfig != nil {
|
|
config = quicConfig
|
|
}
|
|
return &client{
|
|
hostname: authorityAddr("https", hostname),
|
|
responses: make(map[protocol.StreamID]chan *http.Response),
|
|
tlsConf: tlsConfig,
|
|
config: config,
|
|
opts: opts,
|
|
headerErrored: make(chan struct{}),
|
|
dialer: dialer,
|
|
}
|
|
}
|
|
|
|
// dial dials the connection
|
|
func (c *client) dial() error {
|
|
var err error
|
|
if c.dialer != nil {
|
|
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
|
} else {
|
|
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// once the version has been negotiated, open the header stream
|
|
c.headerStream, err = c.session.OpenStream()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.requestWriter = newRequestWriter(c.headerStream)
|
|
go c.handleHeaderStream()
|
|
return nil
|
|
}
|
|
|
|
func (c *client) handleHeaderStream() {
|
|
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
|
h2framer := http2.NewFramer(nil, c.headerStream)
|
|
|
|
var err error
|
|
for err == nil {
|
|
err = c.readResponse(h2framer, decoder)
|
|
}
|
|
utils.Debugf("Error handling header stream: %s", err)
|
|
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
|
|
// stop all running request
|
|
close(c.headerErrored)
|
|
}
|
|
|
|
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
|
|
frame, err := h2framer.ReadFrame()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
hframe, ok := frame.(*http2.HeadersFrame)
|
|
if !ok {
|
|
return errors.New("not a headers frame")
|
|
}
|
|
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
|
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
|
if err != nil {
|
|
return fmt.Errorf("cannot read header fields: %s", err.Error())
|
|
}
|
|
|
|
c.mutex.RLock()
|
|
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
|
c.mutex.RUnlock()
|
|
if !ok {
|
|
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
|
|
}
|
|
|
|
rsp, err := responseFromHeaders(mhframe)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
responseChan <- rsp
|
|
return nil
|
|
}
|
|
|
|
// Roundtrip executes a request and returns a response
|
|
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
// TODO: add port to address, if it doesn't have one
|
|
if req.URL.Scheme != "https" {
|
|
return nil, errors.New("quic http2: unsupported scheme")
|
|
}
|
|
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
|
return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
|
}
|
|
|
|
c.dialOnce.Do(func() {
|
|
c.handshakeErr = c.dial()
|
|
})
|
|
|
|
if c.handshakeErr != nil {
|
|
return nil, c.handshakeErr
|
|
}
|
|
|
|
hasBody := (req.Body != nil)
|
|
|
|
responseChan := make(chan *http.Response)
|
|
dataStream, err := c.session.OpenStreamSync()
|
|
if err != nil {
|
|
_ = c.CloseWithError(err)
|
|
return nil, err
|
|
}
|
|
c.mutex.Lock()
|
|
c.responses[dataStream.StreamID()] = responseChan
|
|
c.mutex.Unlock()
|
|
|
|
var requestedGzip bool
|
|
if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
|
requestedGzip = true
|
|
}
|
|
// TODO: add support for trailers
|
|
endStream := !hasBody
|
|
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
|
if err != nil {
|
|
_ = c.CloseWithError(err)
|
|
return nil, err
|
|
}
|
|
|
|
resc := make(chan error, 1)
|
|
if hasBody {
|
|
go func() {
|
|
resc <- c.writeRequestBody(dataStream, req.Body)
|
|
}()
|
|
}
|
|
|
|
var res *http.Response
|
|
|
|
var receivedResponse bool
|
|
var bodySent bool
|
|
|
|
if !hasBody {
|
|
bodySent = true
|
|
}
|
|
|
|
for !(bodySent && receivedResponse) {
|
|
select {
|
|
case res = <-responseChan:
|
|
receivedResponse = true
|
|
c.mutex.Lock()
|
|
delete(c.responses, dataStream.StreamID())
|
|
c.mutex.Unlock()
|
|
case err := <-resc:
|
|
bodySent = true
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case <-c.headerErrored:
|
|
// an error occured on the header stream
|
|
_ = c.CloseWithError(c.headerErr)
|
|
return nil, c.headerErr
|
|
}
|
|
}
|
|
|
|
// TODO: correctly set this variable
|
|
var streamEnded bool
|
|
isHead := (req.Method == "HEAD")
|
|
|
|
res = setLength(res, isHead, streamEnded)
|
|
|
|
if streamEnded || isHead {
|
|
res.Body = noBody
|
|
} else {
|
|
res.Body = dataStream
|
|
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
|
res.Header.Del("Content-Encoding")
|
|
res.Header.Del("Content-Length")
|
|
res.ContentLength = -1
|
|
res.Body = &gzipReader{body: res.Body}
|
|
res.Uncompressed = true
|
|
}
|
|
}
|
|
|
|
res.Request = req
|
|
return res, nil
|
|
}
|
|
|
|
func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
|
|
defer func() {
|
|
cerr := body.Close()
|
|
if err == nil {
|
|
// TODO: what to do with dataStream here? Maybe reset it?
|
|
err = cerr
|
|
}
|
|
}()
|
|
|
|
_, err = io.Copy(dataStream, body)
|
|
if err != nil {
|
|
// TODO: what to do with dataStream here? Maybe reset it?
|
|
return err
|
|
}
|
|
return dataStream.Close()
|
|
}
|
|
|
|
// Close closes the client
|
|
func (c *client) CloseWithError(e error) error {
|
|
if c.session == nil {
|
|
return nil
|
|
}
|
|
return c.session.Close(e)
|
|
}
|
|
|
|
func (c *client) Close() error {
|
|
return c.CloseWithError(nil)
|
|
}
|
|
|
|
// copied from net/transport.go
|
|
|
|
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
|
// and returns a host:port. The port 443 is added if needed.
|
|
func authorityAddr(scheme string, authority string) (addr string) {
|
|
host, port, err := net.SplitHostPort(authority)
|
|
if err != nil { // authority didn't have a port
|
|
port = "443"
|
|
if scheme == "http" {
|
|
port = "80"
|
|
}
|
|
host = authority
|
|
}
|
|
if a, err := idna.ToASCII(host); err == nil {
|
|
host = a
|
|
}
|
|
// IPv6 address literal, without a port:
|
|
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
|
return host + ":" + port
|
|
}
|
|
return net.JoinHostPort(host, port)
|
|
}
|