implement a Dial function for the h2quic.RoundTripper

If Dial is set, it will be used for dialing new QUIC connections. If it
is nil, quic.DialAddr will be used.
This commit is contained in:
Marten Seemann
2018-01-15 12:18:47 +07:00
parent 15bcc2579f
commit ef56cae9dc
4 changed files with 78 additions and 15 deletions

View File

@@ -37,6 +37,7 @@ type client struct {
hostname string
handshakeErr error
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
session quic.Session
headerStream quic.Stream
@@ -60,6 +61,7 @@ func newClient(
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
) *client {
config := defaultQuicConfig
if quicConfig != nil {
@@ -72,13 +74,18 @@ func newClient(
config: config,
opts: opts,
headerErrored: make(chan struct{}),
dialer: dialer,
}
}
// dial dials the connection
func (c *client) dial() error {
var err error
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}

View File

@@ -34,7 +34,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
Expect(client.hostname).To(Equal(hostname))
session = &mockSession{}
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
@@ -54,42 +54,50 @@ var _ = Describe("Client", func() {
It("saves the TLS config", func() {
tlsConf := &tls.Config{InsecureSkipVerify: true}
client = newClient("", tlsConf, &roundTripperOpts{}, nil)
client = newClient("", tlsConf, &roundTripperOpts{}, nil, nil)
Expect(client.tlsConf).To(Equal(tlsConf))
})
It("saves the QUIC config", func() {
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf)
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf, nil)
Expect(client.config).To(Equal(quicConf))
})
It("uses the default QUIC config if none is give", func() {
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil)
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil, nil)
Expect(client.config).ToNot(BeNil())
Expect(client.config).To(Equal(defaultQuicConfig))
})
It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
})
It("dials", func(done Done) {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
It("dials", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
close(headerStream.unblockRead)
go client.RoundTrip(req)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
close(done)
}, 2)
// make the go routine return
client.responses[5] <- &http.Response{}
Eventually(done).Should(BeClosed())
})
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
@@ -97,9 +105,35 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError(testErr))
})
It("uses the custom dialer, if provided", func() {
var tlsCfg *tls.Config
var qCfg *quic.Config
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
tlsCfg = tlsCfgP
qCfg = cfg
return session, nil
}
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, dialer)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
Expect(qCfg).To(Equal(client.config))
Expect(tlsCfg).To(Equal(client.tlsConf))
// make the go routine return
client.responses[5] <- &http.Response{}
Eventually(done).Should(BeClosed())
})
It("errors if it can't open a stream", func() {
testErr := errors.New("you shall not pass")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session.streamOpenErr = testErr
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
@@ -247,7 +281,7 @@ var _ = Describe("Client", func() {
It("adds the port for request URLs without one", func(done Done) {
var err error
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())

View File

@@ -41,6 +41,11 @@ type RoundTripper struct {
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddr will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
clients map[string]roundTripCloser
}
@@ -120,7 +125,13 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
if onlyCached {
return nil, ErrNoCachedConn
}
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
client = newClient(
hostname,
r.TLSClientConfig,
&roundTripperOpts{DisableCompression: r.DisableCompression},
r.QuicConfig,
r.Dial,
)
r.clients[hostname] = client
}
return client, nil

View File

@@ -103,6 +103,17 @@ var _ = Describe("RoundTripper", func() {
Expect(receivedConfig).To(Equal(config))
})
It("uses the custom dialer, if provided", func() {
var dialed bool
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
dialed = true
return nil, errors.New("err")
}
rt.Dial = dialer
rt.RoundTrip(req1)
Expect(dialed).To(BeTrue())
})
It("reuses existing clients", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred())