From 6bd6594003f3d7609773d94717c42789b0243bcf Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 24 Jul 2017 14:29:37 +0700 Subject: [PATCH] implement a h2quic.RoundTripOpt that allow to only use cached QUIC conns --- h2quic/roundtrip.go | 34 +++++++++++++++++++++++++++++----- h2quic/roundtrip_test.go | 7 +++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index 6b195b6f..7570acdf 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -46,10 +46,22 @@ type RoundTripper struct { clients map[string]roundTripCloser } +// RoundTripOpt are options for the Transport.RoundTripOpt method. +type RoundTripOpt struct { + // OnlyCachedConn controls whether the RoundTripper may + // create a new QUIC connection. If set true and + // no cached connection is available, RoundTrip + // will return ErrNoCachedConn. + OnlyCachedConn bool +} + var _ roundTripCloser = &RoundTripper{} -// RoundTrip does a round trip -func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { +// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set +var ErrNoCachedConn = errors.New("h2quic: no cached connection was available") + +// RoundTripOpt is like RoundTrip, but takes options. +func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if req.URL == nil { closeRequestBody(req) return nil, errors.New("quic: nil Request.URL") @@ -85,10 +97,19 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { } hostname := authorityAddr("https", hostnameFromRequest(req)) - return r.getClient(hostname).RoundTrip(req) + cl, err := r.getClient(hostname, opt.OnlyCachedConn) + if err != nil { + return nil, err + } + return cl.RoundTrip(req) } -func (r *RoundTripper) getClient(hostname string) http.RoundTripper { +// RoundTrip does a round trip. +func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return r.RoundTripOpt(req, RoundTripOpt{}) +} + +func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) { r.mutex.Lock() defer r.mutex.Unlock() @@ -99,10 +120,13 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper { client, ok := r.clients[hostname] if !ok { + if onlyCached { + return nil, ErrNoCachedConn + } client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig) r.clients[hostname] = client } - return client + return client, nil } func finalizer(r *RoundTripper) { diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index 66e3eb8a..76bb4bc3 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -116,6 +116,13 @@ var _ = Describe("RoundTripper", func() { Expect(err).To(MatchError(streamOpenErr)) Expect(rt.clients).To(HaveLen(1)) }) + + It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() { + req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) + Expect(err).To(MatchError(ErrNoCachedConn)) + }) }) Context("validating request", func() {