diff --git a/http3/client.go b/http3/client.go index 05d769ac..e61cf6e6 100644 --- a/http3/client.go +++ b/http3/client.go @@ -2,14 +2,10 @@ package http3 import ( "context" - "crypto/tls" - "errors" "fmt" "io" - "net" "net/http" "sync" - "sync/atomic" "time" "github.com/quic-go/quic-go" @@ -39,127 +35,63 @@ var defaultQuicConfig = &quic.Config{ KeepAlivePeriod: 10 * time.Second, } -type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) +// SingleDestinationRoundTripper is an HTTP/3 client doing requests to a single remote server. +type SingleDestinationRoundTripper struct { + Connection quic.Connection -var dialAddr dialFunc = quic.DialAddrEarly + // Enable support for HTTP/3 datagrams (RFC 9297). + // If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams. + EnableDatagrams bool -type roundTripperOpts struct { - DisableCompression bool - EnableDatagram bool - MaxHeaderBytes int64 + // Additional HTTP/3 settings. + // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). AdditionalSettings map[uint64]uint64 StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error) UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool) -} -// client is a HTTP3 client doing requests -type client struct { - tlsConf *tls.Config - config *quic.Config - opts *roundTripperOpts + // MaxResponseHeaderBytes specifies a limit on how many response bytes are + // allowed in the server's response header. + // Zero means to use a default limit. + MaxResponseHeaderBytes int64 - dialOnce sync.Once - dialer dialFunc - handshakeErr error - - hconn *connection + // DisableCompression, if true, prevents the Transport from requesting compression with an + // "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value. + // If the Transport requests gzip on its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. + // However, if the user explicitly requested gzip it is not automatically uncompressed. + DisableCompression bool + initOnce sync.Once + hconn *connection requestWriter *requestWriter - - decoder *qpack.Decoder - - hostname string - conn atomic.Pointer[quic.EarlyConnection] - - logger utils.Logger + decoder *qpack.Decoder + logger utils.Logger } -var _ roundTripCloser = &client{} - -func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { - if conf == nil { - conf = defaultQuicConfig.Clone() - conf.EnableDatagrams = opts.EnableDatagram - } - if opts.EnableDatagram && !conf.EnableDatagrams { - return nil, errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled") - } - if len(conf.Versions) == 0 { - conf = conf.Clone() - conf.Versions = []quic.Version{protocol.SupportedVersions[0]} - } - if len(conf.Versions) != 1 { - return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") - } - if conf.MaxIncomingStreams == 0 { - conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams - } - logger := utils.DefaultLogger.WithPrefix("h3 client") - - if tlsConf == nil { - tlsConf = &tls.Config{} - } else { - tlsConf = tlsConf.Clone() - } - if tlsConf.ServerName == "" { - sni, _, err := net.SplitHostPort(hostname) - if err != nil { - // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. - sni = hostname - } - tlsConf.ServerName = sni - } - // Replace existing ALPNs by H3 - tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} - - return &client{ - hostname: authorityAddr(hostname), - tlsConf: tlsConf, - requestWriter: newRequestWriter(logger), - decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), - config: conf, - opts: opts, - dialer: dialer, - logger: logger, - }, nil +func (c *SingleDestinationRoundTripper) Start() Connection { + c.initOnce.Do(func() { c.init() }) + return c.hconn } -func (c *client) dial(ctx context.Context) error { - var err error - var conn quic.EarlyConnection - if c.dialer != nil { - conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) - } else { - conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) - } - if err != nil { - return err - } - c.conn.Store(&conn) - +func (c *SingleDestinationRoundTripper) init() { + c.logger = utils.DefaultLogger.WithPrefix("h3 client") + c.requestWriter = newRequestWriter(c.logger) + c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {}) + c.hconn = newConnection(c.Connection, c.EnableDatagrams, c.UniStreamHijacker, protocol.PerspectiveClient, c.logger) // send the SETTINGs frame, using 0-RTT data, if possible go func() { - if err := c.setupConn(conn); err != nil { + if err := c.setupConn(c.Connection); err != nil { c.logger.Debugf("Setting up connection failed: %s", err) - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") } }() - - if c.opts.StreamHijacker != nil { - go c.handleBidirectionalStreams(conn) + if c.StreamHijacker != nil { + go c.handleBidirectionalStreams() } - c.hconn = newConnection( - conn, - c.opts.EnableDatagram, - c.opts.UniStreamHijacker, - protocol.PerspectiveClient, - c.logger, - ) go c.hconn.HandleUnidirectionalStreams() - return nil } -func (c *client) setupConn(conn quic.EarlyConnection) error { +func (c *SingleDestinationRoundTripper) setupConn(conn quic.Connection) error { // open the control stream str, err := conn.OpenUniStream() if err != nil { @@ -168,22 +100,22 @@ func (c *client) setupConn(conn quic.EarlyConnection) error { b := make([]byte, 0, 64) b = quicvarint.Append(b, streamTypeControlStream) // send the SETTINGS frame - b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b) + b = (&settingsFrame{Datagram: c.EnableDatagrams, Other: c.AdditionalSettings}).Append(b) _, err = str.Write(b) return err } -func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) { +func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() { for { - str, err := conn.AcceptStream(context.Background()) + str, err := c.hconn.AcceptStream(context.Background()) if err != nil { c.logger.Debugf("accepting bidirectional stream failed: %s", err) return } go func(str quic.Stream) { _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) { - id := conn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) - return c.opts.StreamHijacker(ft, id, str, e) + id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) + return c.StreamHijacker(ft, id, str, e) }) if err == errHijacked { return @@ -191,28 +123,22 @@ func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) { if err != nil { c.logger.Debugf("error handling stream: %s", err) } - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") + c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") }(str) } } -func (c *client) Close() error { - conn := c.conn.Load() - if conn == nil { - return nil - } - return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") -} - -func (c *client) maxHeaderBytes() uint64 { - if c.opts.MaxHeaderBytes <= 0 { +func (c *SingleDestinationRoundTripper) maxHeaderBytes() uint64 { + if c.MaxResponseHeaderBytes <= 0 { return defaultMaxResponseHeaderBytes } - return uint64(c.opts.MaxHeaderBytes) + return uint64(c.MaxResponseHeaderBytes) } // RoundTripOpt executes a request and returns a response -func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { +func (c *SingleDestinationRoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + c.initOnce.Do(func() { c.init() }) + rsp, err := c.roundTripOpt(req, opt) if err != nil && req.Context().Err() != nil { // if the context was canceled, return the context cancellation error @@ -221,25 +147,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon return rsp, err } -func (c *client) dialConn(ctx context.Context) error { - c.dialOnce.Do(func() { - c.handshakeErr = c.dial(ctx) - }) - return c.handshakeErr -} - -func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { - if authorityAddr(hostnameFromURL(req.URL)) != c.hostname { - return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) - } - - if err := c.dialConn(req.Context()); err != nil { - return nil, err - } - - // At this point, c.conn is guaranteed to be set. - conn := *c.conn.Load() - +func (c *SingleDestinationRoundTripper) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { // Immediately send out this request, if this is a 0-RTT request. switch req.Method { case MethodGet0RTT: @@ -254,26 +162,30 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon req.Method = http.MethodHead default: // wait for the handshake to complete - select { - case <-conn.HandshakeComplete(): - case <-req.Context().Done(): - return nil, req.Context().Err() + earlyConn, ok := c.Connection.(quic.EarlyConnection) + if ok { + select { + case <-earlyConn.HandshakeComplete(): + case <-req.Context().Done(): + return nil, req.Context().Err() + } } } if opt.CheckSettings != nil { + connCtx := c.Connection.Context() // wait for the server's SETTINGS frame to arrive select { case <-c.hconn.ReceivedSettings(): - case <-conn.Context().Done(): - return nil, context.Cause(conn.Context()) + case <-connCtx.Done(): + return nil, context.Cause(connCtx) } if err := opt.CheckSettings(*c.hconn.Settings()); err != nil { return nil, err } } - str, err := conn.OpenStreamSync(req.Context()) + str, err := c.Connection.OpenStreamSync(req.Context()) if err != nil { return nil, err } @@ -293,7 +205,7 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } }() - rsp, err := c.doRequest(req, conn, str, reqDone) + rsp, err := c.doRequest(req, str, reqDone) if err != nil { // if any error occurred close(reqDone) <-done @@ -302,24 +214,22 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon return rsp, maybeReplaceError(err) } -func (c *client) OpenStream(ctx context.Context) (RequestStream, error) { - if err := c.dialConn(ctx); err != nil { - return nil, err - } - conn := *c.conn.Load() - str, err := conn.OpenStreamSync(ctx) +func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (RequestStream, error) { + c.initOnce.Do(func() { c.init() }) + + str, err := c.Connection.OpenStreamSync(ctx) if err != nil { return nil, err } return newRequestStream( - newStream(str, func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") }), + newStream(str, func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }), c.hconn, c.requestWriter, nil, c.decoder, - c.opts.DisableCompression, + c.DisableCompression, c.maxHeaderBytes(), - func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") }, + func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }, ), nil } @@ -338,7 +248,7 @@ func (r *cancelingReader) Read(b []byte) (int, error) { return n, err } -func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error { +func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error { defer body.Close() buf := make([]byte, bodyCopyBufferSize) sr := &cancelingReader{str: str, r: body} @@ -362,16 +272,16 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength i return err } -func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, reqDone chan<- struct{}) (*http.Response, error) { +func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str quic.Stream, reqDone chan<- struct{}) (*http.Response, error) { hstr := newRequestStream( - newStream(str, func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") }), + newStream(str, func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }), c.hconn, c.requestWriter, reqDone, c.decoder, - c.opts.DisableCompression, + c.DisableCompression, c.maxHeaderBytes(), - func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") }, + func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }, ) if err := hstr.SendRequestHeader(req); err != nil { return nil, err @@ -398,21 +308,8 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui if err != nil { return nil, err } - connState := conn.ConnectionState().TLS + connState := c.Connection.ConnectionState().TLS res.TLS = &connState res.Request = req return res, nil } - -func (c *client) HandshakeComplete() bool { - conn := c.conn.Load() - if conn == nil { - return false - } - select { - case <-(*conn).HandshakeComplete(): - return true - default: - return false - } -} diff --git a/http3/client_test.go b/http3/client_test.go index 259567af..3cd09f94 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -4,7 +4,6 @@ import ( "bytes" "compress/gzip" "context" - "crypto/tls" "errors" "io" "net/http" @@ -13,7 +12,6 @@ import ( "github.com/quic-go/quic-go" mockquic "github.com/quic-go/quic-go/internal/mocks/quic" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/quicvarint" @@ -35,181 +33,14 @@ func encodeResponse(status int) []byte { } var _ = Describe("Client", func() { - var ( - cl *client - req *http.Request - origDialAddr = dialAddr - handshakeChan <-chan struct{} // a closed chan - ) + var handshakeChan <-chan struct{} // a closed chan BeforeEach(func() { - origDialAddr = dialAddr - hostname := "quic.clemente.io:1337" - c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - cl = c.(*client) - Expect(cl.hostname).To(Equal(hostname)) - - req, err = http.NewRequest("GET", "https://localhost:1337", nil) - Expect(err).ToNot(HaveOccurred()) - ch := make(chan struct{}) close(ch) handshakeChan = ch }) - AfterEach(func() { - dialAddr = origDialAddr - }) - - It("rejects quic.Configs that allow multiple QUIC versions", func() { - qconf := &quic.Config{ - Versions: []quic.Version{protocol.Version2, protocol.Version1}, - } - _, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil) - Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) - }) - - It("uses the default QUIC and TLS config if none is give", func() { - client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - var dialAddrCalled bool - dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { - Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams)) - Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3})) - Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1})) - dialAddrCalled = true - return nil, errors.New("test done") - } - client.RoundTripOpt(req, RoundTripOpt{}) - Expect(dialAddrCalled).To(BeTrue()) - }) - - It("adds the port to the hostname, if none is given", func() { - client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - var dialAddrCalled bool - dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { - Expect(hostname).To(Equal("quic.clemente.io:443")) - dialAddrCalled = true - return nil, errors.New("test done") - } - req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil) - Expect(err).ToNot(HaveOccurred()) - client.RoundTripOpt(req, RoundTripOpt{}) - Expect(dialAddrCalled).To(BeTrue()) - }) - - It("sets the ServerName in the tls.Config, if not set", func() { - const host = "foo.bar" - dialCalled := false - dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - Expect(tlsCfg.ServerName).To(Equal(host)) - dialCalled = true - return nil, errors.New("test done") - } - client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc) - Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequest("GET", "https://foo.bar", nil) - Expect(err).ToNot(HaveOccurred()) - client.RoundTripOpt(req, RoundTripOpt{}) - Expect(dialCalled).To(BeTrue()) - }) - - It("uses the TLS config and QUIC config", func() { - tlsConf := &tls.Config{ - ServerName: "foo.bar", - NextProtos: []string{"proto foo", "proto bar"}, - } - quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond} - client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) - Expect(err).ToNot(HaveOccurred()) - var dialAddrCalled bool - dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { - Expect(host).To(Equal("localhost:1337")) - Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) - Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3})) - Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) - dialAddrCalled = true - return nil, errors.New("test done") - } - client.RoundTripOpt(req, RoundTripOpt{}) - Expect(dialAddrCalled).To(BeTrue()) - // make sure the original tls.Config was not modified - Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) - }) - - It("uses the custom dialer, if provided", func() { - testErr := errors.New("test done") - tlsConf := &tls.Config{ServerName: "foo.bar"} - quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} - ctx, cancel := context.WithTimeout(context.Background(), time.Hour) - defer cancel() - var dialerCalled bool - dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { - Expect(ctxP).To(Equal(ctx)) - Expect(address).To(Equal("localhost:1337")) - Expect(tlsConfP.ServerName).To(Equal("foo.bar")) - Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) - dialerCalled = true - return nil, testErr - } - client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) - Expect(err).ToNot(HaveOccurred()) - _, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{}) - Expect(err).To(MatchError(testErr)) - Expect(dialerCalled).To(BeTrue()) - }) - - It("enables HTTP/3 Datagrams", func() { - testErr := errors.New("handshake error") - client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { - Expect(quicConf.EnableDatagrams).To(BeTrue()) - return nil, testErr - } - _, err = client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError(testErr)) - }) - - It("errors when dialing fails", func() { - testErr := errors.New("handshake error") - client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return nil, testErr - } - _, err = client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError(testErr)) - }) - - It("closes correctly if connection was not created", func() { - client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(client.Close()).To(Succeed()) - }) - - Context("validating the address", func() { - It("refuses to do requests for the wrong host", func() { - req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) - Expect(err).ToNot(HaveOccurred()) - _, err = cl.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) - }) - - It("allows requests using a different scheme", func() { - testErr := errors.New("handshake error") - req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil) - Expect(err).ToNot(HaveOccurred()) - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return nil, testErr - } - _, err = cl.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError(testErr)) - }) - }) - Context("hijacking bidirectional streams", func() { var ( request *http.Request @@ -232,9 +63,6 @@ var _ = Describe("Client", func() { conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return conn, nil - } var err error request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) @@ -248,11 +76,14 @@ var _ = Describe("Client", func() { It("hijacks a bidirectional stream of unknown frame type", func() { id := quic.ConnectionTracingID(1234) frameTypeChan := make(chan FrameType, 1) - cl.opts.StreamHijacker = func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - Expect(connTracingID).To(Equal(id)) - frameTypeChan <- ft - return true, nil + rt := &SingleDestinationRoundTripper{ + Connection: conn, + StreamHijacker: func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + Expect(connTracingID).To(Equal(id)) + frameTypeChan <- ft + return true, nil + }, } buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) @@ -265,7 +96,7 @@ var _ = Describe("Client", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) conn.EXPECT().Context().Return(ctx).AnyTimes() - _, err := cl.RoundTripOpt(request, RoundTripOpt{}) + _, err := rt.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -273,10 +104,13 @@ var _ = Describe("Client", func() { It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { frameTypeChan := make(chan FrameType, 1) - cl.opts.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - frameTypeChan <- ft - return false, nil + rt := &SingleDestinationRoundTripper{ + Connection: conn, + StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + frameTypeChan <- ft + return false, nil + }, } buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) @@ -290,17 +124,20 @@ var _ = Describe("Client", func() { ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := cl.RoundTripOpt(request, RoundTripOpt{}) + _, err := rt.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) }) It("closes the connection when hijacker returned error", func() { frameTypeChan := make(chan FrameType, 1) - cl.opts.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - frameTypeChan <- ft - return false, errors.New("error in hijacker") + rt := &SingleDestinationRoundTripper{ + Connection: conn, + StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + frameTypeChan <- ft + return false, errors.New("error in hijacker") + }, } buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) @@ -314,7 +151,7 @@ var _ = Describe("Client", func() { ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := cl.RoundTripOpt(request, RoundTripOpt{}) + _, err := rt.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) }) @@ -323,12 +160,15 @@ var _ = Describe("Client", func() { testErr := errors.New("test error") unknownStr := mockquic.NewMockStream(mockCtrl) done := make(chan struct{}) - cl.opts.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, e error) (hijacked bool, err error) { - defer close(done) - Expect(e).To(MatchError(testErr)) - Expect(ft).To(BeZero()) - Expect(str).To(Equal(unknownStr)) - return false, nil + rt := &SingleDestinationRoundTripper{ + Connection: conn, + StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, e error) (hijacked bool, err error) { + defer close(done) + Expect(e).To(MatchError(testErr)) + Expect(ft).To(BeZero()) + Expect(str).To(Equal(unknownStr)) + return false, nil + }, } unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() @@ -340,7 +180,7 @@ var _ = Describe("Client", func() { ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := cl.RoundTripOpt(request, RoundTripOpt{}) + _, err := rt.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -368,9 +208,6 @@ var _ = Describe("Client", func() { conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return conn, nil - } var err error req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) @@ -384,11 +221,14 @@ var _ = Describe("Client", func() { It("hijacks an unidirectional stream of unknown stream type", func() { id := quic.ConnectionTracingID(100) streamTypeChan := make(chan StreamType, 1) - cl.opts.UniStreamHijacker = func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { - Expect(connTracingID).To(Equal(id)) - Expect(err).ToNot(HaveOccurred()) - streamTypeChan <- st - return true + rt := &SingleDestinationRoundTripper{ + Connection: conn, + UniStreamHijacker: func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { + Expect(connTracingID).To(Equal(id)) + Expect(err).ToNot(HaveOccurred()) + streamTypeChan <- st + return true + }, } buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) @@ -403,7 +243,7 @@ var _ = Describe("Client", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) conn.EXPECT().Context().Return(ctx).AnyTimes() - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := rt.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -413,12 +253,15 @@ var _ = Describe("Client", func() { testErr := errors.New("test error") done := make(chan struct{}) unknownStr := mockquic.NewMockStream(mockCtrl) - cl.opts.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool { - defer close(done) - Expect(st).To(BeZero()) - Expect(str).To(Equal(unknownStr)) - Expect(err).To(MatchError(testErr)) - return true + rt := &SingleDestinationRoundTripper{ + Connection: conn, + UniStreamHijacker: func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool { + defer close(done) + Expect(st).To(BeZero()) + Expect(str).To(Equal(unknownStr)) + Expect(err).To(MatchError(testErr)) + return true + }, } unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr) @@ -429,7 +272,7 @@ var _ = Describe("Client", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := rt.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -437,10 +280,13 @@ var _ = Describe("Client", func() { It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { streamTypeChan := make(chan StreamType, 1) - cl.opts.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { - Expect(err).ToNot(HaveOccurred()) - streamTypeChan <- st - return false + rt := &SingleDestinationRoundTripper{ + Connection: conn, + UniStreamHijacker: func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { + Expect(err).ToNot(HaveOccurred()) + streamTypeChan <- st + return false + }, } buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) @@ -456,7 +302,7 @@ var _ = Describe("Client", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := rt.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -467,6 +313,7 @@ var _ = Describe("Client", func() { var ( req *http.Request conn *mockquic.MockEarlyConnection + rt *SingleDestinationRoundTripper settingsFrameWritten chan struct{} ) testDone := make(chan struct{}, 1) @@ -482,9 +329,7 @@ var _ = Describe("Client", func() { conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().HandshakeComplete().Return(handshakeChan) - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return conn, nil - } + rt = &SingleDestinationRoundTripper{Connection: conn} var err error req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) @@ -511,9 +356,10 @@ var _ = Describe("Client", func() { return nil, errors.New("test done") }) conn.EXPECT().Context().Return(context.Background()) - _, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error { + _, err := rt.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error { return errors.New("wrong settings") }}) + rt.Connection = conn Expect(err).To(MatchError("wrong settings")) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -524,6 +370,7 @@ var _ = Describe("Client", func() { req *http.Request str *mockquic.MockStream conn *mockquic.MockEarlyConnection + cl *SingleDestinationRoundTripper settingsFrameWritten chan struct{} ) testDone := make(chan struct{}) @@ -566,9 +413,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - return conn, nil - } + cl = &SingleDestinationRoundTripper{Connection: conn} var err error req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) @@ -579,7 +424,7 @@ var _ = Describe("Client", func() { Eventually(settingsFrameWritten).Should(BeClosed()) }) - It("errors if it can't open a stream", func() { + It("errors if it can't open a request stream", func() { testErr := errors.New("stream open error") conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) @@ -749,6 +594,7 @@ var _ = Describe("Client", func() { }) It("cancels the stream when the HEADERS frame is too large", func() { + cl.MaxResponseHeaderBytes = 1337 b := (&headersFrame{Length: 1338}).Append(nil) r := bytes.NewReader(b) str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) @@ -760,6 +606,18 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError("http3: HEADERS frame too large: 1338 bytes (max: 1337)")) Eventually(closed).Should(BeClosed()) }) + + It("opens a request stream", func() { + cl.Connection.(quic.EarlyConnection).HandshakeComplete() + str, err := cl.OpenRequestStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(str.SendRequestHeader(req)).To(Succeed()) + str.Write([]byte("foobar")) + d := dataFrame{Length: 6} + data := d.Append([]byte{}) + data = append(data, []byte("foobar")...) + Expect(bytes.Contains(strBuf.Bytes(), data)).To(BeTrue()) + }) }) Context("request cancellations", func() { @@ -770,7 +628,7 @@ var _ = Describe("Client", func() { errChan := make(chan error) go func() { - _, err := cl.roundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) errChan <- err }() Consistently(errChan).ShouldNot(Receive()) @@ -853,8 +711,10 @@ var _ = Describe("Client", func() { }) It("doesn't add gzip if the header disable it", func() { - client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) - Expect(err).ToNot(HaveOccurred()) + client := &SingleDestinationRoundTripper{ + Connection: conn, + DisableCompression: true, + } conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) @@ -866,7 +726,7 @@ var _ = Describe("Client", func() { ) testErr := errors.New("test done") str.EXPECT().Read(gomock.Any()).Return(0, testErr) - _, err = client.RoundTripOpt(req, RoundTripOpt{}) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) hfs := decodeHeader(buf) Expect(hfs).ToNot(HaveKey("accept-encoding")) diff --git a/http3/mock_roundtripcloser_test.go b/http3/mock_roundtripcloser_test.go deleted file mode 100644 index 26c14899..00000000 --- a/http3/mock_roundtripcloser_test.go +++ /dev/null @@ -1,195 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go/http3 (interfaces: RoundTripCloser) -// -// Generated by this command: -// -// mockgen -typed -build_flags=-tags=gomock -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser -// - -// Package http3 is a generated GoMock package. -package http3 - -import ( - context "context" - http "net/http" - reflect "reflect" - - gomock "go.uber.org/mock/gomock" -) - -// MockRoundTripCloser is a mock of RoundTripCloser interface. -type MockRoundTripCloser struct { - ctrl *gomock.Controller - recorder *MockRoundTripCloserMockRecorder -} - -// MockRoundTripCloserMockRecorder is the mock recorder for MockRoundTripCloser. -type MockRoundTripCloserMockRecorder struct { - mock *MockRoundTripCloser -} - -// NewMockRoundTripCloser creates a new mock instance. -func NewMockRoundTripCloser(ctrl *gomock.Controller) *MockRoundTripCloser { - mock := &MockRoundTripCloser{ctrl: ctrl} - mock.recorder = &MockRoundTripCloserMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockRoundTripCloser) EXPECT() *MockRoundTripCloserMockRecorder { - return m.recorder -} - -// Close mocks base method. -func (m *MockRoundTripCloser) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockRoundTripCloserMockRecorder) Close() *MockRoundTripCloserCloseCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRoundTripCloser)(nil).Close)) - return &MockRoundTripCloserCloseCall{Call: call} -} - -// MockRoundTripCloserCloseCall wrap *gomock.Call -type MockRoundTripCloserCloseCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockRoundTripCloserCloseCall) Return(arg0 error) *MockRoundTripCloserCloseCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockRoundTripCloserCloseCall) Do(f func() error) *MockRoundTripCloserCloseCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockRoundTripCloserCloseCall) DoAndReturn(f func() error) *MockRoundTripCloserCloseCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// HandshakeComplete mocks base method. -func (m *MockRoundTripCloser) HandshakeComplete() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HandshakeComplete") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HandshakeComplete indicates an expected call of HandshakeComplete. -func (mr *MockRoundTripCloserMockRecorder) HandshakeComplete() *MockRoundTripCloserHandshakeCompleteCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockRoundTripCloser)(nil).HandshakeComplete)) - return &MockRoundTripCloserHandshakeCompleteCall{Call: call} -} - -// MockRoundTripCloserHandshakeCompleteCall wrap *gomock.Call -type MockRoundTripCloserHandshakeCompleteCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockRoundTripCloserHandshakeCompleteCall) Return(arg0 bool) *MockRoundTripCloserHandshakeCompleteCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockRoundTripCloserHandshakeCompleteCall) Do(f func() bool) *MockRoundTripCloserHandshakeCompleteCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockRoundTripCloserHandshakeCompleteCall) DoAndReturn(f func() bool) *MockRoundTripCloserHandshakeCompleteCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// OpenStream mocks base method. -func (m *MockRoundTripCloser) OpenStream(arg0 context.Context) (RequestStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStream", arg0) - ret0, _ := ret[0].(RequestStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenStream indicates an expected call of OpenStream. -func (mr *MockRoundTripCloserMockRecorder) OpenStream(arg0 any) *MockRoundTripCloserOpenStreamCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockRoundTripCloser)(nil).OpenStream), arg0) - return &MockRoundTripCloserOpenStreamCall{Call: call} -} - -// MockRoundTripCloserOpenStreamCall wrap *gomock.Call -type MockRoundTripCloserOpenStreamCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockRoundTripCloserOpenStreamCall) Return(arg0 RequestStream, arg1 error) *MockRoundTripCloserOpenStreamCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockRoundTripCloserOpenStreamCall) Do(f func(context.Context) (RequestStream, error)) *MockRoundTripCloserOpenStreamCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockRoundTripCloserOpenStreamCall) DoAndReturn(f func(context.Context) (RequestStream, error)) *MockRoundTripCloserOpenStreamCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// RoundTripOpt mocks base method. -func (m *MockRoundTripCloser) RoundTripOpt(arg0 *http.Request, arg1 RoundTripOpt) (*http.Response, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RoundTripOpt", arg0, arg1) - ret0, _ := ret[0].(*http.Response) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// RoundTripOpt indicates an expected call of RoundTripOpt. -func (mr *MockRoundTripCloserMockRecorder) RoundTripOpt(arg0, arg1 any) *MockRoundTripCloserRoundTripOptCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTripOpt", reflect.TypeOf((*MockRoundTripCloser)(nil).RoundTripOpt), arg0, arg1) - return &MockRoundTripCloserRoundTripOptCall{Call: call} -} - -// MockRoundTripCloserRoundTripOptCall wrap *gomock.Call -type MockRoundTripCloserRoundTripOptCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockRoundTripCloserRoundTripOptCall) Return(arg0 *http.Response, arg1 error) *MockRoundTripCloserRoundTripOptCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockRoundTripCloserRoundTripOptCall) Do(f func(*http.Request, RoundTripOpt) (*http.Response, error)) *MockRoundTripCloserRoundTripOptCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockRoundTripCloserRoundTripOptCall) DoAndReturn(f func(*http.Request, RoundTripOpt) (*http.Response, error)) *MockRoundTripCloserRoundTripOptCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/http3/mock_singleroundtripper_test.go b/http3/mock_singleroundtripper_test.go new file mode 100644 index 00000000..1549a1be --- /dev/null +++ b/http3/mock_singleroundtripper_test.go @@ -0,0 +1,119 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go/http3 (interfaces: SingleRoundTripper) +// +// Generated by this command: +// +// mockgen -typed -build_flags=-tags=gomock -package http3 -destination mock_singleroundtripper_test.go github.com/quic-go/quic-go/http3 SingleRoundTripper +// + +// Package http3 is a generated GoMock package. +package http3 + +import ( + context "context" + http "net/http" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockSingleRoundTripper is a mock of SingleRoundTripper interface. +type MockSingleRoundTripper struct { + ctrl *gomock.Controller + recorder *MockSingleRoundTripperMockRecorder +} + +// MockSingleRoundTripperMockRecorder is the mock recorder for MockSingleRoundTripper. +type MockSingleRoundTripperMockRecorder struct { + mock *MockSingleRoundTripper +} + +// NewMockSingleRoundTripper creates a new mock instance. +func NewMockSingleRoundTripper(ctrl *gomock.Controller) *MockSingleRoundTripper { + mock := &MockSingleRoundTripper{ctrl: ctrl} + mock.recorder = &MockSingleRoundTripperMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSingleRoundTripper) EXPECT() *MockSingleRoundTripperMockRecorder { + return m.recorder +} + +// OpenRequestStream mocks base method. +func (m *MockSingleRoundTripper) OpenRequestStream(arg0 context.Context) (RequestStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenRequestStream", arg0) + ret0, _ := ret[0].(RequestStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenRequestStream indicates an expected call of OpenRequestStream. +func (mr *MockSingleRoundTripperMockRecorder) OpenRequestStream(arg0 any) *MockSingleRoundTripperOpenRequestStreamCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenRequestStream", reflect.TypeOf((*MockSingleRoundTripper)(nil).OpenRequestStream), arg0) + return &MockSingleRoundTripperOpenRequestStreamCall{Call: call} +} + +// MockSingleRoundTripperOpenRequestStreamCall wrap *gomock.Call +type MockSingleRoundTripperOpenRequestStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSingleRoundTripperOpenRequestStreamCall) Return(arg0 RequestStream, arg1 error) *MockSingleRoundTripperOpenRequestStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSingleRoundTripperOpenRequestStreamCall) Do(f func(context.Context) (RequestStream, error)) *MockSingleRoundTripperOpenRequestStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSingleRoundTripperOpenRequestStreamCall) DoAndReturn(f func(context.Context) (RequestStream, error)) *MockSingleRoundTripperOpenRequestStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// RoundTripOpt mocks base method. +func (m *MockSingleRoundTripper) RoundTripOpt(arg0 *http.Request, arg1 RoundTripOpt) (*http.Response, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RoundTripOpt", arg0, arg1) + ret0, _ := ret[0].(*http.Response) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RoundTripOpt indicates an expected call of RoundTripOpt. +func (mr *MockSingleRoundTripperMockRecorder) RoundTripOpt(arg0, arg1 any) *MockSingleRoundTripperRoundTripOptCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTripOpt", reflect.TypeOf((*MockSingleRoundTripper)(nil).RoundTripOpt), arg0, arg1) + return &MockSingleRoundTripperRoundTripOptCall{Call: call} +} + +// MockSingleRoundTripperRoundTripOptCall wrap *gomock.Call +type MockSingleRoundTripperRoundTripOptCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSingleRoundTripperRoundTripOptCall) Return(arg0 *http.Response, arg1 error) *MockSingleRoundTripperRoundTripOptCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSingleRoundTripperRoundTripOptCall) Do(f func(*http.Request, RoundTripOpt) (*http.Response, error)) *MockSingleRoundTripperRoundTripOptCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSingleRoundTripperRoundTripOptCall) DoAndReturn(f func(*http.Request, RoundTripOpt) (*http.Response, error)) *MockSingleRoundTripperRoundTripOptCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/http3/mockgen.go b/http3/mockgen.go index 57af7972..83a3974f 100644 --- a/http3/mockgen.go +++ b/http3/mockgen.go @@ -2,7 +2,7 @@ package http3 -//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser" -type RoundTripCloser = roundTripCloser +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package http3 -destination mock_singleroundtripper_test.go github.com/quic-go/quic-go/http3 SingleRoundTripper" +type SingleRoundTripper = singleRoundTripper //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener" diff --git a/http3/roundtrip.go b/http3/roundtrip.go index d3b4c360..02014c47 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "net/url" "strings" "sync" "sync/atomic" @@ -16,6 +15,7 @@ import ( "golang.org/x/net/http/httpguts" "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/internal/protocol" ) // Settings are HTTP/3 settings that apply to the underlying connection. @@ -39,15 +39,17 @@ type RoundTripOpt struct { CheckSettings func(Settings) error } -type roundTripCloser interface { +type singleRoundTripper interface { RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) - HandshakeComplete() bool - OpenStream(context.Context) (RequestStream, error) - io.Closer + OpenRequestStream(context.Context) (RequestStream, error) } -type roundTripCloserWithCount struct { - roundTripCloser +type roundTripperWithCount struct { + dialing chan struct{} // closed as soon as quic.Dial(Early) returned + dialErr error + conn quic.EarlyConnection + rt singleRoundTripper + useCount atomic.Int64 } @@ -55,16 +57,6 @@ type roundTripCloserWithCount struct { type RoundTripper struct { mutex sync.Mutex - // DisableCompression, if true, prevents the Transport from - // requesting compression with an "Accept-Encoding: gzip" - // request header when the Request contains no existing - // Accept-Encoding value. If the Transport requests gzip on - // its own and gets a gzipped response, it's transparently - // decoded in the Response.Body. However, if the user - // explicitly requested gzip it is not automatically - // uncompressed. - DisableCompression bool - // TLSClientConfig specifies the TLS configuration to use with // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config @@ -73,6 +65,12 @@ type RoundTripper struct { // If nil, reasonable default values will be used. QUICConfig *quic.Config + // Dial specifies an optional dial function for creating QUIC + // connections for requests. + // If Dial is nil, a UDPConn will be created at the first request + // and will be reused for subsequent connections to other servers. + Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) + // Enable support for HTTP/3 datagrams (RFC 9297). // If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams. EnableDatagrams bool @@ -81,33 +79,24 @@ type RoundTripper struct { // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). AdditionalSettings map[uint64]uint64 - // When set, this callback is called for the first unknown frame parsed on a bidirectional stream. - // It is called right after parsing the frame type. - // If parsing the frame type fails, the error is passed to the callback. - // In that case, the frame type will not be set. - // Callers can either ignore the frame and return control of the stream back to HTTP/3 - // (by returning hijacked false). - // Alternatively, callers can take over the QUIC stream (by returning hijacked true). - StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error) - - // When set, this callback is called for unknown unidirectional stream of unknown stream type. - // If parsing the stream type fails, the error is passed to the callback. - // In that case, the stream type will not be set. - UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool) - - // Dial specifies an optional dial function for creating QUIC - // connections for requests. - // If Dial is nil, a UDPConn will be created at the first request - // and will be reused for subsequent connections to other servers. - Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) - // MaxResponseHeaderBytes specifies a limit on how many response bytes are // allowed in the server's response header. // Zero means to use a default limit. MaxResponseHeaderBytes int64 - newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests - clients map[string]*roundTripCloserWithCount + // DisableCompression, if true, prevents the Transport from requesting compression with an + // "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value. + // If the Transport requests gzip on its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. + // However, if the user explicitly requested gzip it is not automatically uncompressed. + DisableCompression bool + + initOnce sync.Once + initErr error + + newClient func(quic.EarlyConnection) singleRoundTripper + + clients map[string]*roundTripperWithCount transport *quic.Transport } @@ -121,6 +110,11 @@ var ErrNoCachedConn = errors.New("http3: no cached connection was available") // RoundTripOpt is like RoundTrip, but takes options. func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + r.initOnce.Do(func() { r.initErr = r.init() }) + if r.initErr != nil { + return nil, r.initErr + } + if req.URL == nil { closeRequestBody(req) return nil, errors.New("http3: nil Request.URL") @@ -154,12 +148,22 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. } hostname := authorityAddr(hostnameFromURL(req.URL)) - cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn) + cl, isReused, err := r.getClient(req.Context(), hostname, opt.OnlyCachedConn) if err != nil { return nil, err } + + select { + case <-cl.dialing: + case <-req.Context().Done(): + return nil, context.Cause(req.Context()) + } + + if cl.dialErr != nil { + return nil, cl.dialErr + } defer cl.useCount.Add(-1) - rsp, err := cl.RoundTripOpt(req, opt) + rsp, err := cl.rt.RoundTripOpt(req, opt) if err != nil { r.removeClient(hostname) if isReused { @@ -176,68 +180,123 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r.RoundTripOpt(req, RoundTripOpt{}) } -func (r *RoundTripper) OpenStream(ctx context.Context, url *url.URL) (RequestStream, error) { - hostname := authorityAddr(hostnameFromURL(url)) - cl, _, err := r.getClient(hostname, false) - if err != nil { - return nil, err +func (r *RoundTripper) init() error { + if r.newClient == nil { + r.newClient = func(conn quic.EarlyConnection) singleRoundTripper { + return &SingleDestinationRoundTripper{ + Connection: conn, + EnableDatagrams: r.EnableDatagrams, + DisableCompression: r.DisableCompression, + AdditionalSettings: r.AdditionalSettings, + MaxResponseHeaderBytes: r.MaxResponseHeaderBytes, + } + } } - return cl.OpenStream(ctx) + if r.QUICConfig == nil { + r.QUICConfig = defaultQuicConfig.Clone() + r.QUICConfig.EnableDatagrams = r.EnableDatagrams + } + if r.EnableDatagrams && !r.QUICConfig.EnableDatagrams { + return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled") + } + if len(r.QUICConfig.Versions) == 0 { + r.QUICConfig = r.QUICConfig.Clone() + r.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]} + } + if len(r.QUICConfig.Versions) != 1 { + return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") + } + if r.QUICConfig.MaxIncomingStreams == 0 { + r.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams + } + return nil } -func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) { +func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) { r.mutex.Lock() defer r.mutex.Unlock() if r.clients == nil { - r.clients = make(map[string]*roundTripCloserWithCount) + r.clients = make(map[string]*roundTripperWithCount) } - client, ok := r.clients[hostname] + cl, ok := r.clients[hostname] if !ok { if onlyCached { return nil, false, ErrNoCachedConn } - var err error - newCl := newClient - if r.newClient != nil { - newCl = r.newClient + cl = &roundTripperWithCount{ + dialing: make(chan struct{}), } - dial := r.Dial - if dial == nil { - if r.transport == nil { - udpConn, err := net.ListenUDP("udp", nil) - if err != nil { - return nil, false, err - } - r.transport = &quic.Transport{Conn: udpConn} + go func() { + defer close(cl.dialing) + conn, rt, err := r.dial(ctx, hostname) + if err != nil { + cl.dialErr = err + return } - dial = r.makeDialer() - } - c, err := newCl( - hostname, - r.TLSClientConfig, - &roundTripperOpts{ - EnableDatagram: r.EnableDatagrams, - DisableCompression: r.DisableCompression, - MaxHeaderBytes: r.MaxResponseHeaderBytes, - StreamHijacker: r.StreamHijacker, - UniStreamHijacker: r.UniStreamHijacker, - AdditionalSettings: r.AdditionalSettings, - }, - r.QUICConfig, - dial, - ) - if err != nil { - return nil, false, err - } - client = &roundTripCloserWithCount{roundTripCloser: c} - r.clients[hostname] = client - } else if client.HandshakeComplete() { - isReused = true + cl.conn = conn + cl.rt = rt + }() + r.clients[hostname] = cl } - client.useCount.Add(1) - return client, isReused, nil + select { + case <-cl.dialing: + if cl.dialErr != nil { + return nil, false, cl.dialErr + } + select { + case <-cl.conn.HandshakeComplete(): + isReused = true + default: + } + default: + } + cl.useCount.Add(1) + return cl, isReused, nil +} + +func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) { + var tlsConf *tls.Config + if r.TLSClientConfig == nil { + tlsConf = &tls.Config{} + } else { + tlsConf = r.TLSClientConfig.Clone() + } + if tlsConf.ServerName == "" { + sni, _, err := net.SplitHostPort(hostname) + if err != nil { + // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. + sni = hostname + } + tlsConf.ServerName = sni + } + // Replace existing ALPNs by H3 + tlsConf.NextProtos = []string{versionToALPN(r.QUICConfig.Versions[0])} + + dial := r.Dial + if dial == nil { + if r.transport == nil { + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, nil, err + } + r.transport = &quic.Transport{Conn: udpConn} + } + dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) + } + } + + conn, err := dial(ctx, hostname, tlsConf, r.QUICConfig) + if err != nil { + return nil, nil, err + } + return conn, r.newClient(conn), nil } func (r *RoundTripper) removeClient(hostname string) { @@ -255,7 +314,7 @@ func (r *RoundTripper) Close() error { r.mutex.Lock() defer r.mutex.Unlock() for _, client := range r.clients { - if err := client.Close(); err != nil { + if err := client.conn.CloseWithError(0, ""); err != nil { return err } } @@ -300,23 +359,12 @@ func isNotToken(r rune) bool { return !httpguts.IsTokenRune(r) } -// makeDialer makes a QUIC dialer using r.udpConn. -func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) - } -} - func (r *RoundTripper) CloseIdleConnections() { r.mutex.Lock() defer r.mutex.Unlock() for hostname, client := range r.clients { if client.useCount.Load() == 0 { - client.Close() + client.conn.CloseWithError(0, "") delete(r.clients, hostname) } } diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index 2df4294e..f3ff1e55 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -7,10 +7,11 @@ import ( "errors" "io" "net/http" - "sync/atomic" "time" "github.com/quic-go/quic-go" + mockquic "github.com/quic-go/quic-go/internal/mocks/quic" + "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" . "github.com/onsi/ginkgo/v2" @@ -45,138 +46,258 @@ func (m *mockBody) Close() error { } var _ = Describe("RoundTripper", func() { - var ( - rt *RoundTripper - req *http.Request - ) + var req *http.Request BeforeEach(func() { - rt = &RoundTripper{} var err error req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) Expect(err).ToNot(HaveOccurred()) }) - Context("dialing hosts", func() { - It("creates new clients", func() { - testErr := errors.New("test err") - req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) - Expect(err).ToNot(HaveOccurred()) - rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { - cl := NewMockRoundTripCloser(mockCtrl) - cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr) - return cl, nil - } - _, err = rt.RoundTrip(req) - Expect(err).To(MatchError(testErr)) - }) + It("rejects quic.Configs that allow multiple QUIC versions", func() { + qconf := &quic.Config{ + Versions: []quic.Version{protocol.Version2, protocol.Version1}, + } + rt := &RoundTripper{QUICConfig: qconf} + _, err := rt.RoundTrip(req) + Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) + }) - It("creates new clients with additional settings", func() { - testErr := errors.New("test err") - req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) - Expect(err).ToNot(HaveOccurred()) - rt.AdditionalSettings = map[uint64]uint64{1337: 42} - rt.newClient = func(_ string, _ *tls.Config, opts *roundTripperOpts, conf *quic.Config, _ dialFunc) (roundTripCloser, error) { - cl := NewMockRoundTripCloser(mockCtrl) - cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr) - Expect(opts.AdditionalSettings).To(HaveKeyWithValue(uint64(1337), uint64(42))) - return cl, nil - } - _, err = rt.RoundTrip(req) - Expect(err).To(MatchError(testErr)) - }) + It("uses the default QUIC and TLS config if none is give", func() { + var dialAddrCalled bool + rt := &RoundTripper{ + Dial: func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { + defer GinkgoRecover() + Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams)) + Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3})) + Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1})) + dialAddrCalled = true + return nil, errors.New("test done") + }, + } + _, err := rt.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("test done")) + Expect(dialAddrCalled).To(BeTrue()) + }) - It("uses the quic.Config, if provided", func() { - config := &quic.Config{HandshakeIdleTimeout: time.Millisecond} - var receivedConfig *quic.Config - rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { - receivedConfig = config + It("adds the port to the hostname, if none is given", func() { + var dialAddrCalled bool + rt := &RoundTripper{ + Dial: func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { + defer GinkgoRecover() + Expect(hostname).To(Equal("quic.clemente.io:443")) + dialAddrCalled = true + return nil, errors.New("test done") + }, + } + req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("test done")) + Expect(dialAddrCalled).To(BeTrue()) + }) + + It("sets the ServerName in the tls.Config, if not set", func() { + const host = "foo.bar" + var dialCalled bool + rt := &RoundTripper{ + Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + defer GinkgoRecover() + Expect(tlsCfg.ServerName).To(Equal(host)) + dialCalled = true + return nil, errors.New("test done") + }, + } + req, err := http.NewRequest("GET", "https://foo.bar", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("test done")) + Expect(dialCalled).To(BeTrue()) + }) + + It("uses the TLS config and QUIC config", func() { + tlsConf := &tls.Config{ + ServerName: "foo.bar", + NextProtos: []string{"proto foo", "proto bar"}, + } + quicConf := &quic.Config{MaxIdleTimeout: 3 * time.Nanosecond} + var dialAddrCalled bool + rt := &RoundTripper{ + Dial: func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { + defer GinkgoRecover() + Expect(host).To(Equal("www.example.org:443")) + Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) + Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3})) + Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) + dialAddrCalled = true + return nil, errors.New("test done") + }, + QUICConfig: quicConf, + TLSClientConfig: tlsConf, + } + _, err := rt.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("test done")) + Expect(dialAddrCalled).To(BeTrue()) + // make sure the original tls.Config was not modified + Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) + }) + + It("uses the custom dialer, if provided", func() { + testErr := errors.New("test done") + tlsConf := &tls.Config{ServerName: "foo.bar"} + quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) + defer cancel() + var dialerCalled bool + rt := &RoundTripper{ + Dial: func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { + defer GinkgoRecover() + Expect(ctxP).To(Equal(ctx)) + Expect(address).To(Equal("www.example.org:443")) + Expect(tlsConfP.ServerName).To(Equal("foo.bar")) + Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) + dialerCalled = true + return nil, testErr + }, + TLSClientConfig: tlsConf, + QUICConfig: quicConf, + } + _, err := rt.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{}) + Expect(err).To(MatchError(testErr)) + Expect(dialerCalled).To(BeTrue()) + }) + + It("enables HTTP/3 Datagrams", func() { + testErr := errors.New("handshake error") + rt := &RoundTripper{ + EnableDatagrams: true, + Dial: func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { + defer GinkgoRecover() + Expect(quicConf.EnableDatagrams).To(BeTrue()) + return nil, testErr + }, + } + _, err := rt.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError(testErr)) + }) + + It("requires quic.Config.EnableDatagram if HTTP/3 datagrams are enabled", func() { + rt := &RoundTripper{ + QUICConfig: &quic.Config{EnableDatagrams: false}, + EnableDatagrams: true, + Dial: func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { return nil, errors.New("handshake error") - } - rt.QUICConfig = config - _, err := rt.RoundTrip(req) - Expect(err).To(MatchError("handshake error")) - Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout)) - }) + }, + } + _, err := rt.RoundTrip(req) + Expect(err).To(MatchError("HTTP Datagrams enabled, but QUIC Datagrams disabled")) + }) - It("requires quic.Config.EnableDatagram if HTTP datagrams are enabled", func() { - rt.QUICConfig = &quic.Config{EnableDatagrams: false} - rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { - return nil, errors.New("handshake error") - } - rt.EnableDatagrams = true - _, err := rt.RoundTrip(req) - Expect(err).To(MatchError("HTTP Datagrams enabled, but QUIC Datagrams disabled")) - rt.QUICConfig.EnableDatagrams = true - _, err = rt.RoundTrip(req) - Expect(err).To(MatchError("handshake error")) - }) - - It("uses the custom dialer, if provided", func() { - var dialed bool - dialer := func(_ context.Context, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - dialed = true - return nil, errors.New("handshake error") - } - rt.Dial = dialer - _, err := rt.RoundTrip(req) - Expect(err).To(MatchError("handshake error")) - Expect(dialed).To(BeTrue()) - }) + It("creates new clients", func() { + testErr := errors.New("test err") + req1, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + req2, err := http.NewRequest("GET", "https://example.com/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + var hostsDialed []string + rt := &RoundTripper{ + Dial: func(_ context.Context, host string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { + hostsDialed = append(hostsDialed, host) + return nil, testErr + }, + } + _, err = rt.RoundTrip(req1) + Expect(err).To(MatchError(testErr)) + _, err = rt.RoundTrip(req2) + Expect(err).To(MatchError(testErr)) + Expect(hostsDialed).To(Equal([]string{"quic-go.net:443", "example.com:443"})) }) Context("reusing clients", func() { - var req1, req2 *http.Request + var ( + rt *RoundTripper + req1, req2 *http.Request + clientChan chan *MockSingleRoundTripper + ) BeforeEach(func() { + clientChan = make(chan *MockSingleRoundTripper, 16) + rt = &RoundTripper{ + newClient: func(quic.EarlyConnection) singleRoundTripper { + select { + case c := <-clientChan: + return c + default: + Fail("no client") + return nil + } + }, + } var err error - req1, err = http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) + req1, err = http.NewRequest("GET", "https://quic-go.net/file1.html", nil) Expect(err).ToNot(HaveOccurred()) - req2, err = http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil) + req2, err = http.NewRequest("GET", "https://quic-go.net/file2.html", nil) Expect(err).ToNot(HaveOccurred()) Expect(req1.URL).ToNot(Equal(req2.URL)) }) It("reuses existing clients", func() { + cl := NewMockSingleRoundTripper(mockCtrl) + clientChan <- cl + conn := mockquic.NewMockEarlyConnection(mockCtrl) + handshakeChan := make(chan struct{}) + close(handshakeChan) + conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2) + + cl.EXPECT().RoundTripOpt(req1, gomock.Any()).Return(&http.Response{Request: req1}, nil) + cl.EXPECT().RoundTripOpt(req2, gomock.Any()).Return(&http.Response{Request: req2}, nil) var count int - rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { + rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { count++ - cl := NewMockRoundTripCloser(mockCtrl) - cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { - return &http.Response{Request: req}, nil - }).Times(2) - cl.EXPECT().HandshakeComplete().Return(true) - return cl, nil + return conn, nil } - rsp1, err := rt.RoundTrip(req1) + rsp, err := rt.RoundTrip(req1) Expect(err).ToNot(HaveOccurred()) - Expect(rsp1.Request.URL).To(Equal(req1.URL)) - rsp2, err := rt.RoundTrip(req2) + Expect(rsp.Request).To(Equal(req1)) + rsp, err = rt.RoundTrip(req2) Expect(err).ToNot(HaveOccurred()) - Expect(rsp2.Request.URL).To(Equal(req2.URL)) + Expect(rsp.Request).To(Equal(req2)) Expect(count).To(Equal(1)) }) It("immediately removes a clients when a request errored", func() { - testErr := errors.New("test err") + cl1 := NewMockSingleRoundTripper(mockCtrl) + clientChan <- cl1 + cl2 := NewMockSingleRoundTripper(mockCtrl) + clientChan <- cl2 + req1, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + req2, err := http.NewRequest("GET", "https://quic-go.net/bar.html", nil) + Expect(err).ToNot(HaveOccurred()) + + conn := mockquic.NewMockEarlyConnection(mockCtrl) + handshakeChan := make(chan struct{}) + close(handshakeChan) var count int - rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { + rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { count++ - cl := NewMockRoundTripCloser(mockCtrl) - cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr) - return cl, nil + return conn, nil } - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError(testErr)) - _, err = rt.RoundTrip(req2) + testErr := errors.New("test err") + cl1.EXPECT().RoundTripOpt(req1, gomock.Any()).Return(nil, testErr) + cl2.EXPECT().RoundTripOpt(req2, gomock.Any()).Return(&http.Response{Request: req2}, nil) + _, err = rt.RoundTrip(req1) Expect(err).To(MatchError(testErr)) + rsp, err := rt.RoundTrip(req2) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.Request).To(Equal(req2)) Expect(count).To(Equal(2)) }) It("recreates a client when a request times out", func() { var reqCount int - cl1 := NewMockRoundTripCloser(mockCtrl) + cl1 := NewMockSingleRoundTripper(mockCtrl) cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { reqCount++ if reqCount == 1 { // the first request is successful... @@ -187,19 +308,21 @@ var _ = Describe("RoundTripper", func() { Expect(req.URL).To(Equal(req2.URL)) return nil, &qerr.IdleTimeoutError{} }).Times(2) - cl1.EXPECT().HandshakeComplete().Return(true) - cl2 := NewMockRoundTripCloser(mockCtrl) + cl2 := NewMockSingleRoundTripper(mockCtrl) cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { return &http.Response{Request: req}, nil }) + clientChan <- cl1 + clientChan <- cl2 + conn := mockquic.NewMockEarlyConnection(mockCtrl) + handshakeChan := make(chan struct{}) + close(handshakeChan) + conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2) var count int - rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { + rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { count++ - if count == 1 { - return cl1, nil - } - return cl2, nil + return conn, nil } rsp1, err := rt.RoundTrip(req1) Expect(err).ToNot(HaveOccurred()) @@ -211,11 +334,14 @@ var _ = Describe("RoundTripper", func() { It("only issues a request once, even if a timeout error occurs", func() { var count int - rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { + rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { count++ - cl := NewMockRoundTripCloser(mockCtrl) + return mockquic.NewMockEarlyConnection(mockCtrl), nil + } + rt.newClient = func(quic.EarlyConnection) singleRoundTripper { + cl := NewMockSingleRoundTripper(mockCtrl) cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{}) - return cl, nil + return cl } _, err := rt.RoundTrip(req1) Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) @@ -225,18 +351,23 @@ var _ = Describe("RoundTripper", func() { It("handles a burst of requests", func() { wait := make(chan struct{}) reqs := make(chan struct{}, 2) + + cl := NewMockSingleRoundTripper(mockCtrl) + cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { + reqs <- struct{}{} + <-wait + return nil, &qerr.IdleTimeoutError{} + }).Times(2) + clientChan <- cl + + conn := mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().HandshakeComplete().Return(wait).AnyTimes() var count int - rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { + rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { count++ - cl := NewMockRoundTripCloser(mockCtrl) - cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { - reqs <- struct{}{} - <-wait - return nil, &qerr.IdleTimeoutError{} - }).Times(2) - cl.EXPECT().HandshakeComplete() - return cl, nil + return conn, nil } + done := make(chan struct{}, 2) go func() { defer GinkgoRecover() @@ -244,14 +375,14 @@ var _ = Describe("RoundTripper", func() { _, err := rt.RoundTrip(req1) Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) }() + // wait for the first requests to be issued + Eventually(reqs).Should(Receive()) go func() { defer GinkgoRecover() defer func() { done <- struct{}{} }() _, err := rt.RoundTrip(req2) Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) }() - // wait for both requests to be issued - Eventually(reqs).Should(Receive()) Eventually(reqs).Should(Receive()) close(wait) // now return the requests Eventually(done).Should(Receive()) @@ -260,7 +391,7 @@ var _ = Describe("RoundTripper", func() { }) It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() { - req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) + req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) _, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) Expect(err).To(MatchError(ErrNoCachedConn)) @@ -268,6 +399,8 @@ var _ = Describe("RoundTripper", func() { }) Context("validating request", func() { + var rt RoundTripper + It("rejects plain HTTP requests", func() { req, err := http.NewRequest("GET", "http://www.example.org/", nil) req.Body = &mockBody{} @@ -331,24 +464,41 @@ var _ = Describe("RoundTripper", func() { Context("closing", func() { It("closes", func() { - rt.clients = make(map[string]*roundTripCloserWithCount) - cl := NewMockRoundTripCloser(mockCtrl) - cl.EXPECT().Close() - rt.clients["foo.bar"] = &roundTripCloserWithCount{cl, atomic.Int64{}} - err := rt.Close() + conn := mockquic.NewMockEarlyConnection(mockCtrl) + rt := &RoundTripper{ + Dial: func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return conn, nil + }, + newClient: func(quic.EarlyConnection) singleRoundTripper { + cl := NewMockSingleRoundTripper(mockCtrl) + cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(&http.Response{}, nil) + return cl + }, + } + req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) - Expect(len(rt.clients)).To(BeZero()) - }) - - It("closes a RoundTripper that has never been used", func() { - Expect(len(rt.clients)).To(BeZero()) - err := rt.Close() + _, err = rt.RoundTrip(req) Expect(err).ToNot(HaveOccurred()) - Expect(len(rt.clients)).To(BeZero()) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(0), "") + Expect(rt.Close()).To(Succeed()) }) It("closes idle connections", func() { - Expect(len(rt.clients)).To(Equal(0)) + conn1 := mockquic.NewMockEarlyConnection(mockCtrl) + conn2 := mockquic.NewMockEarlyConnection(mockCtrl) + rt := &RoundTripper{ + Dial: func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { + switch hostname { + case "site1.com:443": + return conn1, nil + case "site2.com:443": + return conn2, nil + default: + Fail("unexpected hostname") + return nil, errors.New("unexpected hostname") + } + }, + } req1, err := http.NewRequest("GET", "https://site1.com", nil) Expect(err).ToNot(HaveOccurred()) req2, err := http.NewRequest("GET", "https://site2.com", nil) @@ -360,15 +510,14 @@ var _ = Describe("RoundTripper", func() { req2 = req2.WithContext(ctx2) roundTripCalled := make(chan struct{}) reqFinished := make(chan struct{}) - rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { - cl := NewMockRoundTripCloser(mockCtrl) - cl.EXPECT().Close() + rt.newClient = func(quic.EarlyConnection) singleRoundTripper { + cl := NewMockSingleRoundTripper(mockCtrl) cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(r *http.Request, _ RoundTripOpt) (*http.Response, error) { roundTripCalled <- struct{}{} <-r.Context().Done() return nil, nil }) - return cl, nil + return cl } go func() { rt.RoundTrip(req1) @@ -381,17 +530,16 @@ var _ = Describe("RoundTripper", func() { <-roundTripCalled <-roundTripCalled // Both two requests are started. - Expect(len(rt.clients)).To(Equal(2)) cancel1() <-reqFinished // req1 is finished + conn1.EXPECT().CloseWithError(gomock.Any(), gomock.Any()) rt.CloseIdleConnections() - Expect(len(rt.clients)).To(Equal(1)) cancel2() <-reqFinished // all requests are finished + conn2.EXPECT().CloseWithError(gomock.Any(), gomock.Any()) rt.CloseIdleConnections() - Expect(len(rt.clients)).To(Equal(0)) }) }) }) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 9042678c..327d61d7 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -183,6 +183,7 @@ var _ = Describe("HTTP tests", func() { group, ctx := errgroup.WithContext(context.Background()) for i := 0; i < 2; i++ { group.Go(func() error { + defer GinkgoRecover() req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/hello", port), nil) Expect(err).ToNot(HaveOccurred()) resp, err := client.Do(req) @@ -431,7 +432,18 @@ var _ = Describe("HTTP tests", func() { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/httpstreamer", port), nil) Expect(err).ToNot(HaveOccurred()) - str, err := client.Transport.(*http3.RoundTripper).OpenStream(context.Background(), req.URL) + tlsConf := getTLSClientConfigWithoutServerName() + tlsConf.NextProtos = []string{http3.NextProtoH3} + conn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", port), + tlsConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") + rt := http3.SingleDestinationRoundTripper{Connection: conn} + str, err := rt.OpenRequestStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.SendRequestHeader(req)).To(Succeed()) rsp, err := str.ReadResponse() @@ -452,10 +464,10 @@ var _ = Describe("HTTP tests", func() { Expect(repl).To(Equal(data)) }) - It("serves other QUIC connections", func() { + It("serves QUIC connections", func() { tlsConf := getTLSConfig() tlsConf.NextProtos = []string{http3.NextProtoH3} - ln, err := quic.ListenAddr("localhost:0", tlsConf, nil) + ln, err := quic.ListenAddr("localhost:0", tlsConf, getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) defer ln.Close() done := make(chan struct{}) @@ -464,7 +476,7 @@ var _ = Describe("HTTP tests", func() { defer close(done) conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(server.ServeQUICConn(conn)).To(Succeed()) + server.ServeQUICConn(conn) // returns once the client closes }() resp, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", ln.Addr().(*net.UDPAddr).Port))