Files
quic-go/http3/client.go
Lorenzo Saino 8db2288382 Make http3.client.Close() succeed if session was not started
Invoking http3.client.Close() before client.dial() is invoked
causes a segmentation fault. That occurs because, in this
circumstance, invoking client.Close() results in invoking
client.session.CloseWithError(...) while client.session is nil.

This commit changes the behavior of
http3.client.Close() to return nil if client.session
is nil and adds an associated test case.
2020-02-23 00:21:19 +00:00

274 lines
6.9 KiB
Go

package http3
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"sync"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
)
// MethodGet0RTT allows a GET request to be sent using 0-RTT.
// Note that 0-RTT data doesn't provide replay protection.
const MethodGet0RTT = "GET_0RTT"
const defaultUserAgent = "quic-go HTTP/3"
const defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
var defaultQuicConfig = &quic.Config{
MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
KeepAlive: true,
}
var dialAddr = quic.DialAddrEarly
type roundTripperOpts struct {
DisableCompression bool
MaxHeaderBytes int64
}
// client is a HTTP3 client doing requests
type client struct {
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
handshakeErr error
requestWriter *requestWriter
decoder *qpack.Decoder
hostname string
session quic.EarlySession
logger utils.Logger
}
func newClient(
hostname string,
tlsConf *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
) *client {
if tlsConf == nil {
tlsConf = &tls.Config{}
} else {
tlsConf = tlsConf.Clone()
}
// Replace existing ALPNs by H3
tlsConf.NextProtos = []string{nextProtoH3}
if quicConfig == nil {
quicConfig = defaultQuicConfig
}
quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
logger := utils.DefaultLogger.WithPrefix("h3 client")
return &client{
hostname: authorityAddr("https", hostname),
tlsConf: tlsConf,
requestWriter: newRequestWriter(logger),
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
config: quicConfig,
opts: opts,
dialer: dialer,
logger: logger,
}
}
func (c *client) dial() error {
var err error
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}
// run the sesssion setup using 0-RTT data
go func() {
if err := c.setupSession(); err != nil {
c.logger.Debugf("Setting up session failed: %s", err)
c.session.CloseWithError(quic.ErrorCode(errorInternalError), "")
}
}()
return nil
}
func (c *client) setupSession() error {
// open the control stream
str, err := c.session.OpenUniStream()
if err != nil {
return err
}
buf := &bytes.Buffer{}
// write the type byte
buf.Write([]byte{0x0})
// send the SETTINGS frame
(&settingsFrame{}).Write(buf)
if _, err := str.Write(buf.Bytes()); err != nil {
return err
}
return nil
}
func (c *client) Close() error {
if c.session == nil {
return nil
}
return c.session.CloseWithError(quic.ErrorCode(errorNoError), "")
}
func (c *client) maxHeaderBytes() uint64 {
if c.opts.MaxHeaderBytes <= 0 {
return defaultMaxResponseHeaderBytes
}
return uint64(c.opts.MaxHeaderBytes)
}
// RoundTrip executes a request and returns a response
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "https" {
return nil, errors.New("http3: unsupported scheme")
}
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
return nil, fmt.Errorf("http3 client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
}
c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
})
if c.handshakeErr != nil {
return nil, c.handshakeErr
}
// Immediately send out this request, if this is a 0-RTT request.
if req.Method == MethodGet0RTT {
req.Method = http.MethodGet
} else {
// wait for the handshake to complete
select {
case <-c.session.HandshakeComplete().Done():
case <-req.Context().Done():
return nil, req.Context().Err()
}
}
str, err := c.session.OpenStreamSync(req.Context())
if err != nil {
return nil, err
}
// Request Cancellation:
// This go routine keeps running even after RoundTrip() returns.
// It is shut down when the application is done processing the body.
reqDone := make(chan struct{})
go func() {
select {
case <-req.Context().Done():
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
str.CancelRead(quic.ErrorCode(errorRequestCanceled))
case <-reqDone:
}
}()
rsp, rerr := c.doRequest(req, str, reqDone)
if rerr.err != nil { // if any error occurred
close(reqDone)
if rerr.streamErr != 0 { // if it was a stream error
str.CancelWrite(quic.ErrorCode(rerr.streamErr))
}
if rerr.connErr != 0 { // if it was a connection error
var reason string
if rerr.err != nil {
reason = rerr.err.Error()
}
c.session.CloseWithError(quic.ErrorCode(rerr.connErr), reason)
}
}
return rsp, rerr.err
}
func (c *client) doRequest(
req *http.Request,
str quic.Stream,
reqDone chan struct{},
) (*http.Response, requestError) {
var requestGzip bool
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
requestGzip = true
}
if err := c.requestWriter.WriteRequest(str, req, requestGzip); err != nil {
return nil, newStreamError(errorInternalError, err)
}
frame, err := parseNextFrame(str)
if err != nil {
return nil, newStreamError(errorFrameError, err)
}
hf, ok := frame.(*headersFrame)
if !ok {
return nil, newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
}
if hf.Length > c.maxHeaderBytes() {
return nil, newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
}
headerBlock := make([]byte, hf.Length)
if _, err := io.ReadFull(str, headerBlock); err != nil {
return nil, newStreamError(errorRequestIncomplete, err)
}
hfs, err := c.decoder.DecodeFull(headerBlock)
if err != nil {
// TODO: use the right error code
return nil, newConnError(errorGeneralProtocolError, err)
}
res := &http.Response{
Proto: "HTTP/3",
ProtoMajor: 3,
Header: http.Header{},
}
for _, hf := range hfs {
switch hf.Name {
case ":status":
status, err := strconv.Atoi(hf.Value)
if err != nil {
return nil, newStreamError(errorGeneralProtocolError, errors.New("malformed non-numeric status pseudo header"))
}
res.StatusCode = status
res.Status = hf.Value + " " + http.StatusText(status)
default:
res.Header.Add(hf.Name, hf.Value)
}
}
respBody := newResponseBody(str, reqDone, func() {
c.session.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "")
})
if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = newGzipReader(respBody)
res.Uncompressed = true
} else {
res.Body = respBody
}
return res, requestError{}
}