implement a Dial function for the h2quic.RoundTripper

If Dial is set, it will be used for dialing new QUIC connections. If it
is nil, quic.DialAddr will be used.
This commit is contained in:
Marten Seemann
2018-01-15 12:18:47 +07:00
parent 15bcc2579f
commit ef56cae9dc
4 changed files with 78 additions and 15 deletions

View File

@@ -37,6 +37,7 @@ type client struct {
hostname string hostname string
handshakeErr error handshakeErr error
dialOnce sync.Once dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
session quic.Session session quic.Session
headerStream quic.Stream headerStream quic.Stream
@@ -60,6 +61,7 @@ func newClient(
tlsConfig *tls.Config, tlsConfig *tls.Config,
opts *roundTripperOpts, opts *roundTripperOpts,
quicConfig *quic.Config, quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
) *client { ) *client {
config := defaultQuicConfig config := defaultQuicConfig
if quicConfig != nil { if quicConfig != nil {
@@ -72,13 +74,18 @@ func newClient(
config: config, config: config,
opts: opts, opts: opts,
headerErrored: make(chan struct{}), headerErrored: make(chan struct{}),
dialer: dialer,
} }
} }
// dial dials the connection // dial dials the connection
func (c *client) dial() error { func (c *client) dial() error {
var err error var err error
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config) c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
}
if err != nil { if err != nil {
return err return err
} }

View File

@@ -34,7 +34,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() { BeforeEach(func() {
origDialAddr = dialAddr origDialAddr = dialAddr
hostname := "quic.clemente.io:1337" hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{}, nil) client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
Expect(client.hostname).To(Equal(hostname)) Expect(client.hostname).To(Equal(hostname))
session = &mockSession{} session = &mockSession{}
session.ctx, session.ctxCancel = context.WithCancel(context.Background()) session.ctx, session.ctxCancel = context.WithCancel(context.Background())
@@ -54,42 +54,50 @@ var _ = Describe("Client", func() {
It("saves the TLS config", func() { It("saves the TLS config", func() {
tlsConf := &tls.Config{InsecureSkipVerify: true} tlsConf := &tls.Config{InsecureSkipVerify: true}
client = newClient("", tlsConf, &roundTripperOpts{}, nil) client = newClient("", tlsConf, &roundTripperOpts{}, nil, nil)
Expect(client.tlsConf).To(Equal(tlsConf)) Expect(client.tlsConf).To(Equal(tlsConf))
}) })
It("saves the QUIC config", func() { It("saves the QUIC config", func() {
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond} quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf) client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf, nil)
Expect(client.config).To(Equal(quicConf)) Expect(client.config).To(Equal(quicConf))
}) })
It("uses the default QUIC config if none is give", func() { It("uses the default QUIC config if none is give", func() {
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil) client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil, nil)
Expect(client.config).ToNot(BeNil()) Expect(client.config).ToNot(BeNil())
Expect(client.config).To(Equal(defaultQuicConfig)) Expect(client.config).To(Equal(defaultQuicConfig))
}) })
It("adds the port to the hostname, if none is given", func() { It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil) client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
Expect(client.hostname).To(Equal("quic.clemente.io:443")) Expect(client.hostname).To(Equal("quic.clemente.io:443"))
}) })
It("dials", func(done Done) { It("dials", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)} session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil return session, nil
} }
close(headerStream.unblockRead) close(headerStream.unblockRead)
go client.RoundTrip(req) done := make(chan struct{})
Eventually(func() quic.Session { return client.session }).Should(Equal(session)) go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done) close(done)
}, 2) }()
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
// make the go routine return
client.responses[5] <- &http.Response{}
Eventually(done).Should(BeClosed())
})
It("errors when dialing fails", func() { It("errors when dialing fails", func() {
testErr := errors.New("handshake error") testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr return nil, testErr
} }
@@ -97,9 +105,35 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
It("uses the custom dialer, if provided", func() {
var tlsCfg *tls.Config
var qCfg *quic.Config
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
tlsCfg = tlsCfgP
qCfg = cfg
return session, nil
}
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, dialer)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
Expect(qCfg).To(Equal(client.config))
Expect(tlsCfg).To(Equal(client.tlsConf))
// make the go routine return
client.responses[5] <- &http.Response{}
Eventually(done).Should(BeClosed())
})
It("errors if it can't open a stream", func() { It("errors if it can't open a stream", func() {
testErr := errors.New("you shall not pass") testErr := errors.New("you shall not pass")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session.streamOpenErr = testErr session.streamOpenErr = testErr
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil return session, nil
@@ -247,7 +281,7 @@ var _ = Describe("Client", func() {
It("adds the port for request URLs without one", func(done Done) { It("adds the port for request URLs without one", func(done Done) {
var err error var err error
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil) client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil) req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View File

@@ -41,6 +41,11 @@ type RoundTripper struct {
// If nil, reasonable default values will be used. // If nil, reasonable default values will be used.
QuicConfig *quic.Config QuicConfig *quic.Config
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddr will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
clients map[string]roundTripCloser clients map[string]roundTripCloser
} }
@@ -120,7 +125,13 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
if onlyCached { if onlyCached {
return nil, ErrNoCachedConn return nil, ErrNoCachedConn
} }
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig) client = newClient(
hostname,
r.TLSClientConfig,
&roundTripperOpts{DisableCompression: r.DisableCompression},
r.QuicConfig,
r.Dial,
)
r.clients[hostname] = client r.clients[hostname] = client
} }
return client, nil return client, nil

View File

@@ -103,6 +103,17 @@ var _ = Describe("RoundTripper", func() {
Expect(receivedConfig).To(Equal(config)) Expect(receivedConfig).To(Equal(config))
}) })
It("uses the custom dialer, if provided", func() {
var dialed bool
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
dialed = true
return nil, errors.New("err")
}
rt.Dial = dialer
rt.RoundTrip(req1)
Expect(dialed).To(BeTrue())
})
It("reuses existing clients", func() { It("reuses existing clients", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())