forked from quic-go/quic-go
implement a Dial function for the h2quic.RoundTripper
If Dial is set, it will be used for dialing new QUIC connections. If it is nil, quic.DialAddr will be used.
This commit is contained in:
@@ -37,6 +37,7 @@ type client struct {
|
||||
hostname string
|
||||
handshakeErr error
|
||||
dialOnce sync.Once
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
|
||||
session quic.Session
|
||||
headerStream quic.Stream
|
||||
@@ -60,6 +61,7 @@ func newClient(
|
||||
tlsConfig *tls.Config,
|
||||
opts *roundTripperOpts,
|
||||
quicConfig *quic.Config,
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
||||
) *client {
|
||||
config := defaultQuicConfig
|
||||
if quicConfig != nil {
|
||||
@@ -72,13 +74,18 @@ func newClient(
|
||||
config: config,
|
||||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
dialer: dialer,
|
||||
}
|
||||
}
|
||||
|
||||
// dial dials the connection
|
||||
func (c *client) dial() error {
|
||||
var err error
|
||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||
if c.dialer != nil {
|
||||
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
||||
} else {
|
||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ var _ = Describe("Client", func() {
|
||||
BeforeEach(func() {
|
||||
origDialAddr = dialAddr
|
||||
hostname := "quic.clemente.io:1337"
|
||||
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
|
||||
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
|
||||
Expect(client.hostname).To(Equal(hostname))
|
||||
session = &mockSession{}
|
||||
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
||||
@@ -54,42 +54,50 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("saves the TLS config", func() {
|
||||
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
||||
client = newClient("", tlsConf, &roundTripperOpts{}, nil)
|
||||
client = newClient("", tlsConf, &roundTripperOpts{}, nil, nil)
|
||||
Expect(client.tlsConf).To(Equal(tlsConf))
|
||||
})
|
||||
|
||||
It("saves the QUIC config", func() {
|
||||
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf)
|
||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf, nil)
|
||||
Expect(client.config).To(Equal(quicConf))
|
||||
})
|
||||
|
||||
It("uses the default QUIC config if none is give", func() {
|
||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil)
|
||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil, nil)
|
||||
Expect(client.config).ToNot(BeNil())
|
||||
Expect(client.config).To(Equal(defaultQuicConfig))
|
||||
})
|
||||
|
||||
It("adds the port to the hostname, if none is given", func() {
|
||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
|
||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
||||
})
|
||||
|
||||
It("dials", func(done Done) {
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
||||
It("dials", func() {
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
close(headerStream.unblockRead)
|
||||
go client.RoundTrip(req)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := client.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
||||
close(done)
|
||||
}, 2)
|
||||
// make the go routine return
|
||||
client.responses[5] <- &http.Response{}
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors when dialing fails", func() {
|
||||
testErr := errors.New("handshake error")
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
@@ -97,9 +105,35 @@ var _ = Describe("Client", func() {
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("uses the custom dialer, if provided", func() {
|
||||
var tlsCfg *tls.Config
|
||||
var qCfg *quic.Config
|
||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
||||
tlsCfg = tlsCfgP
|
||||
qCfg = cfg
|
||||
return session, nil
|
||||
}
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, dialer)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := client.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
||||
Expect(qCfg).To(Equal(client.config))
|
||||
Expect(tlsCfg).To(Equal(client.tlsConf))
|
||||
// make the go routine return
|
||||
client.responses[5] <- &http.Response{}
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
||||
})
|
||||
|
||||
It("errors if it can't open a stream", func() {
|
||||
testErr := errors.New("you shall not pass")
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||
session.streamOpenErr = testErr
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
@@ -247,7 +281,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("adds the port for request URLs without one", func(done Done) {
|
||||
var err error
|
||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
|
||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
|
||||
@@ -41,6 +41,11 @@ type RoundTripper struct {
|
||||
// If nil, reasonable default values will be used.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
// Dial specifies an optional dial function for creating QUIC
|
||||
// connections for requests.
|
||||
// If Dial is nil, quic.DialAddr will be used.
|
||||
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
|
||||
clients map[string]roundTripCloser
|
||||
}
|
||||
|
||||
@@ -120,7 +125,13 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
|
||||
if onlyCached {
|
||||
return nil, ErrNoCachedConn
|
||||
}
|
||||
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
|
||||
client = newClient(
|
||||
hostname,
|
||||
r.TLSClientConfig,
|
||||
&roundTripperOpts{DisableCompression: r.DisableCompression},
|
||||
r.QuicConfig,
|
||||
r.Dial,
|
||||
)
|
||||
r.clients[hostname] = client
|
||||
}
|
||||
return client, nil
|
||||
|
||||
@@ -103,6 +103,17 @@ var _ = Describe("RoundTripper", func() {
|
||||
Expect(receivedConfig).To(Equal(config))
|
||||
})
|
||||
|
||||
It("uses the custom dialer, if provided", func() {
|
||||
var dialed bool
|
||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
||||
dialed = true
|
||||
return nil, errors.New("err")
|
||||
}
|
||||
rt.Dial = dialer
|
||||
rt.RoundTrip(req1)
|
||||
Expect(dialed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("reuses existing clients", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Reference in New Issue
Block a user