forked from quic-go/quic-go
implement a Dial function for the h2quic.RoundTripper
If Dial is set, it will be used for dialing new QUIC connections. If it is nil, quic.DialAddr will be used.
This commit is contained in:
@@ -37,6 +37,7 @@ type client struct {
|
|||||||
hostname string
|
hostname string
|
||||||
handshakeErr error
|
handshakeErr error
|
||||||
dialOnce sync.Once
|
dialOnce sync.Once
|
||||||
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||||
|
|
||||||
session quic.Session
|
session quic.Session
|
||||||
headerStream quic.Stream
|
headerStream quic.Stream
|
||||||
@@ -60,6 +61,7 @@ func newClient(
|
|||||||
tlsConfig *tls.Config,
|
tlsConfig *tls.Config,
|
||||||
opts *roundTripperOpts,
|
opts *roundTripperOpts,
|
||||||
quicConfig *quic.Config,
|
quicConfig *quic.Config,
|
||||||
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
||||||
) *client {
|
) *client {
|
||||||
config := defaultQuicConfig
|
config := defaultQuicConfig
|
||||||
if quicConfig != nil {
|
if quicConfig != nil {
|
||||||
@@ -72,13 +74,18 @@ func newClient(
|
|||||||
config: config,
|
config: config,
|
||||||
opts: opts,
|
opts: opts,
|
||||||
headerErrored: make(chan struct{}),
|
headerErrored: make(chan struct{}),
|
||||||
|
dialer: dialer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// dial dials the connection
|
// dial dials the connection
|
||||||
func (c *client) dial() error {
|
func (c *client) dial() error {
|
||||||
var err error
|
var err error
|
||||||
|
if c.dialer != nil {
|
||||||
|
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
||||||
|
} else {
|
||||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ var _ = Describe("Client", func() {
|
|||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
origDialAddr = dialAddr
|
origDialAddr = dialAddr
|
||||||
hostname := "quic.clemente.io:1337"
|
hostname := "quic.clemente.io:1337"
|
||||||
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
|
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
|
||||||
Expect(client.hostname).To(Equal(hostname))
|
Expect(client.hostname).To(Equal(hostname))
|
||||||
session = &mockSession{}
|
session = &mockSession{}
|
||||||
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
||||||
@@ -54,42 +54,50 @@ var _ = Describe("Client", func() {
|
|||||||
|
|
||||||
It("saves the TLS config", func() {
|
It("saves the TLS config", func() {
|
||||||
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
||||||
client = newClient("", tlsConf, &roundTripperOpts{}, nil)
|
client = newClient("", tlsConf, &roundTripperOpts{}, nil, nil)
|
||||||
Expect(client.tlsConf).To(Equal(tlsConf))
|
Expect(client.tlsConf).To(Equal(tlsConf))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("saves the QUIC config", func() {
|
It("saves the QUIC config", func() {
|
||||||
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
||||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf)
|
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf, nil)
|
||||||
Expect(client.config).To(Equal(quicConf))
|
Expect(client.config).To(Equal(quicConf))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("uses the default QUIC config if none is give", func() {
|
It("uses the default QUIC config if none is give", func() {
|
||||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil)
|
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil, nil)
|
||||||
Expect(client.config).ToNot(BeNil())
|
Expect(client.config).ToNot(BeNil())
|
||||||
Expect(client.config).To(Equal(defaultQuicConfig))
|
Expect(client.config).To(Equal(defaultQuicConfig))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("adds the port to the hostname, if none is given", func() {
|
It("adds the port to the hostname, if none is given", func() {
|
||||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
|
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("dials", func(done Done) {
|
It("dials", func() {
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
close(headerStream.unblockRead)
|
close(headerStream.unblockRead)
|
||||||
go client.RoundTrip(req)
|
done := make(chan struct{})
|
||||||
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
_, err := client.RoundTrip(req)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
close(done)
|
close(done)
|
||||||
}, 2)
|
}()
|
||||||
|
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
||||||
|
// make the go routine return
|
||||||
|
client.responses[5] <- &http.Response{}
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
It("errors when dialing fails", func() {
|
It("errors when dialing fails", func() {
|
||||||
testErr := errors.New("handshake error")
|
testErr := errors.New("handshake error")
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
@@ -97,9 +105,35 @@ var _ = Describe("Client", func() {
|
|||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("uses the custom dialer, if provided", func() {
|
||||||
|
var tlsCfg *tls.Config
|
||||||
|
var qCfg *quic.Config
|
||||||
|
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
||||||
|
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
||||||
|
tlsCfg = tlsCfgP
|
||||||
|
qCfg = cfg
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, dialer)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
_, err := client.RoundTrip(req)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
||||||
|
Expect(qCfg).To(Equal(client.config))
|
||||||
|
Expect(tlsCfg).To(Equal(client.tlsConf))
|
||||||
|
// make the go routine return
|
||||||
|
client.responses[5] <- &http.Response{}
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
It("errors if it can't open a stream", func() {
|
It("errors if it can't open a stream", func() {
|
||||||
testErr := errors.New("you shall not pass")
|
testErr := errors.New("you shall not pass")
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
session.streamOpenErr = testErr
|
session.streamOpenErr = testErr
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
return session, nil
|
return session, nil
|
||||||
@@ -247,7 +281,7 @@ var _ = Describe("Client", func() {
|
|||||||
|
|
||||||
It("adds the port for request URLs without one", func(done Done) {
|
It("adds the port for request URLs without one", func(done Done) {
|
||||||
var err error
|
var err error
|
||||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
|
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,11 @@ type RoundTripper struct {
|
|||||||
// If nil, reasonable default values will be used.
|
// If nil, reasonable default values will be used.
|
||||||
QuicConfig *quic.Config
|
QuicConfig *quic.Config
|
||||||
|
|
||||||
|
// Dial specifies an optional dial function for creating QUIC
|
||||||
|
// connections for requests.
|
||||||
|
// If Dial is nil, quic.DialAddr will be used.
|
||||||
|
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||||
|
|
||||||
clients map[string]roundTripCloser
|
clients map[string]roundTripCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,7 +125,13 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
|
|||||||
if onlyCached {
|
if onlyCached {
|
||||||
return nil, ErrNoCachedConn
|
return nil, ErrNoCachedConn
|
||||||
}
|
}
|
||||||
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
|
client = newClient(
|
||||||
|
hostname,
|
||||||
|
r.TLSClientConfig,
|
||||||
|
&roundTripperOpts{DisableCompression: r.DisableCompression},
|
||||||
|
r.QuicConfig,
|
||||||
|
r.Dial,
|
||||||
|
)
|
||||||
r.clients[hostname] = client
|
r.clients[hostname] = client
|
||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
|
|||||||
@@ -103,6 +103,17 @@ var _ = Describe("RoundTripper", func() {
|
|||||||
Expect(receivedConfig).To(Equal(config))
|
Expect(receivedConfig).To(Equal(config))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("uses the custom dialer, if provided", func() {
|
||||||
|
var dialed bool
|
||||||
|
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
||||||
|
dialed = true
|
||||||
|
return nil, errors.New("err")
|
||||||
|
}
|
||||||
|
rt.Dial = dialer
|
||||||
|
rt.RoundTrip(req1)
|
||||||
|
Expect(dialed).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
It("reuses existing clients", func() {
|
It("reuses existing clients", func() {
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|||||||
Reference in New Issue
Block a user