From 137491916b9ea63a4e67c84a462c082aaf309830 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 25 Mar 2022 09:23:48 +0100 Subject: [PATCH] respect the request context when dialing --- http3/client.go | 38 +++++++++++++++++--------------------- http3/client_test.go | 13 ++++++++----- http3/roundtrip.go | 7 ++++--- http3/roundtrip_test.go | 2 +- 4 files changed, 30 insertions(+), 30 deletions(-) diff --git a/http3/client.go b/http3/client.go index 861eaf0ab..5ff2edc44 100644 --- a/http3/client.go +++ b/http3/client.go @@ -34,6 +34,8 @@ var defaultQuicConfig = &quic.Config{ Versions: []protocol.VersionNumber{protocol.VersionTLS}, } +type dialFunc func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) + var dialAddr = quic.DialAddrEarly type roundTripperOpts struct { @@ -49,7 +51,7 @@ type client struct { opts *roundTripperOpts dialOnce sync.Once - dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) + dialer dialFunc handshakeErr error requestWriter *requestWriter @@ -62,24 +64,18 @@ type client struct { logger utils.Logger } -func newClient( - hostname string, - tlsConf *tls.Config, - opts *roundTripperOpts, - quicConfig *quic.Config, - dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error), -) (*client, error) { - if quicConfig == nil { - quicConfig = defaultQuicConfig.Clone() - } else if len(quicConfig.Versions) == 0 { - quicConfig = quicConfig.Clone() - quicConfig.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} +func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) { + if conf == nil { + conf = defaultQuicConfig.Clone() + } else if len(conf.Versions) == 0 { + conf = conf.Clone() + conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} } - if len(quicConfig.Versions) != 1 { + if len(conf.Versions) != 1 { return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") } - quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams - quicConfig.EnableDatagrams = opts.EnableDatagram + conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams + conf.EnableDatagrams = opts.EnableDatagram logger := utils.DefaultLogger.WithPrefix("h3 client") if tlsConf == nil { @@ -88,24 +84,24 @@ func newClient( tlsConf = tlsConf.Clone() } // Replace existing ALPNs by H3 - tlsConf.NextProtos = []string{versionToALPN(quicConfig.Versions[0])} + tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} return &client{ hostname: authorityAddr("https", hostname), tlsConf: tlsConf, requestWriter: newRequestWriter(logger), decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), - config: quicConfig, + config: conf, opts: opts, dialer: dialer, logger: logger, }, nil } -func (c *client) dial() error { +func (c *client) dial(ctx context.Context) error { var err error if c.dialer != nil { - c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config) + c.session, err = c.dialer(ctx, "udp", c.hostname, c.tlsConf, c.config) } else { c.session, err = dialAddr(c.hostname, c.tlsConf, c.config) } @@ -212,7 +208,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { } c.dialOnce.Do(func() { - c.handshakeErr = c.dial() + c.handshakeErr = c.dial(req.Context()) }) if c.handshakeErr != nil { diff --git a/http3/client_test.go b/http3/client_test.go index 3f3b4a4b1..157d0c2f1 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -12,13 +12,13 @@ import ( "net/http" "time" - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go" mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" - "github.com/lucas-clemente/quic-go/quicvarint" - "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/quicvarint" + + "github.com/golang/mock/gomock" "github.com/marten-seemann/qpack" . "github.com/onsi/ginkgo" @@ -122,8 +122,11 @@ var _ = Describe("Client", func() { testErr := errors.New("test done") tlsConf := &tls.Config{ServerName: "foo.bar"} quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) + defer cancel() var dialerCalled bool - dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) { + dialer := func(ctxP context.Context, network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) { + Expect(ctxP).To(Equal(ctx)) Expect(network).To(Equal("udp")) Expect(address).To(Equal("localhost:1337")) Expect(tlsConfP.ServerName).To(Equal("foo.bar")) @@ -133,7 +136,7 @@ var _ = Describe("Client", func() { } client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) Expect(err).ToNot(HaveOccurred()) - _, err = client.RoundTrip(req) + _, err = client.RoundTrip(req.WithContext(ctx)) Expect(err).To(MatchError(testErr)) Expect(dialerCalled).To(BeTrue()) }) diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 8e6f943e9..d301045a3 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -1,6 +1,7 @@ package http3 import ( + "context" "crypto/tls" "errors" "fmt" @@ -9,7 +10,7 @@ import ( "strings" "sync" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" "golang.org/x/net/http/httpguts" ) @@ -48,8 +49,8 @@ type RoundTripper struct { // Dial specifies an optional dial function for creating QUIC // connections for requests. - // If Dial is nil, quic.DialAddrEarly will be used. - Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) + // If Dial is nil, quic.DialAddrEarlyContext will be used. + Dial func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) // MaxResponseHeaderBytes specifies a limit on how many response bytes are // allowed in the server's response header. diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index 184889f1c..201e9190b 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -127,7 +127,7 @@ var _ = Describe("RoundTripper", func() { It("uses the custom dialer, if provided", func() { var dialed bool - dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) { + dialer := func(_ context.Context, _, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) { dialed = true return nil, errors.New("handshake error") }