implement a h2quic.RoundTripOpt that allow to only use cached QUIC conns

This commit is contained in:
Marten Seemann
2017-07-24 14:29:37 +07:00
parent 1060582a18
commit 6bd6594003
2 changed files with 36 additions and 5 deletions

View File

@@ -46,10 +46,22 @@ type RoundTripper struct {
clients map[string]roundTripCloser
}
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may
// create a new QUIC connection. If set true and
// no cached connection is available, RoundTrip
// will return ErrNoCachedConn.
OnlyCachedConn bool
}
var _ roundTripCloser = &RoundTripper{}
// RoundTrip does a round trip
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
// RoundTripOpt is like RoundTrip, but takes options.
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil {
closeRequestBody(req)
return nil, errors.New("quic: nil Request.URL")
@@ -85,10 +97,19 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
}
hostname := authorityAddr("https", hostnameFromRequest(req))
return r.getClient(hostname).RoundTrip(req)
cl, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil {
return nil, err
}
return cl.RoundTrip(req)
}
func (r *RoundTripper) getClient(hostname string) http.RoundTripper {
// RoundTrip does a round trip.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
@@ -99,10 +120,13 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper {
client, ok := r.clients[hostname]
if !ok {
if onlyCached {
return nil, ErrNoCachedConn
}
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
r.clients[hostname] = client
}
return client
return client, nil
}
func finalizer(r *RoundTripper) {

View File

@@ -116,6 +116,13 @@ var _ = Describe("RoundTripper", func() {
Expect(err).To(MatchError(streamOpenErr))
Expect(rt.clients).To(HaveLen(1))
})
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
Expect(err).To(MatchError(ErrNoCachedConn))
})
})
Context("validating request", func() {