From 9054e5205f1c8d8861b7c6876785a5c20053f7d5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 2 Jun 2017 22:35:16 +0200 Subject: [PATCH] don't pass the roundtripper to the h2quic client --- h2quic/client.go | 13 ++++++++----- h2quic/client_test.go | 26 ++++++++++++-------------- h2quic/roundtrip.go | 6 +----- h2quic/roundtrip_test.go | 6 ------ 4 files changed, 21 insertions(+), 30 deletions(-) diff --git a/h2quic/client.go b/h2quic/client.go index 5118fdd6..91e33791 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -20,14 +20,17 @@ import ( "github.com/lucas-clemente/quic-go/utils" ) +type roundTripperOpts struct { + DisableCompression bool +} + // client is a HTTP2 client doing QUIC requests type client struct { mutex sync.RWMutex dialAddr func(hostname string, config *quic.Config) (quic.Session, error) config *quic.Config - - t *QuicRoundTripper + opts *roundTripperOpts hostname string encryptionLevel protocol.EncryptionLevel @@ -45,9 +48,8 @@ type client struct { var _ h2quicClient = &client{} // newClient creates a new client -func newClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *client { +func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *client { return &client{ - t: t, dialAddr: quic.DialAddr, hostname: authorityAddr("https", hostname), responses: make(map[protocol.StreamID]chan *http.Response), @@ -56,6 +58,7 @@ func newClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *cli TLSConfig: tlsConfig, RequestConnectionIDTruncation: true, }, + opts: opts, dialChan: make(chan struct{}), } } @@ -164,7 +167,7 @@ func (c *client) Do(req *http.Request) (*http.Response, error) { c.mutex.Unlock() var requestedGzip bool - if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { + if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { requestedGzip = true } // TODO: add support for trailers diff --git a/h2quic/client_test.go b/h2quic/client_test.go index 5cf4e54c..a9451f5a 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -20,16 +20,14 @@ import ( var _ = Describe("Client", func() { var ( - client *client - session *mockSession - headerStream *mockStream - quicTransport *QuicRoundTripper + client *client + session *mockSession + headerStream *mockStream ) BeforeEach(func() { - quicTransport = &QuicRoundTripper{} hostname := "quic.clemente.io:1337" - client = newClient(quicTransport, nil, hostname) + client = newClient(nil, hostname, &roundTripperOpts{}) Expect(client.hostname).To(Equal(hostname)) session = &mockSession{} client.session = session @@ -41,17 +39,17 @@ var _ = Describe("Client", func() { It("saves the TLS config", func() { tlsConf := &tls.Config{InsecureSkipVerify: true} - client = newClient(&QuicRoundTripper{}, tlsConf, "") + client = newClient(tlsConf, "", &roundTripperOpts{}) Expect(client.config.TLSConfig).To(Equal(tlsConf)) }) It("adds the port to the hostname, if none is given", func() { - client = newClient(quicTransport, nil, "quic.clemente.io") + client = newClient(nil, "quic.clemente.io", &roundTripperOpts{}) Expect(client.hostname).To(Equal("quic.clemente.io:443")) }) It("dials", func() { - client = newClient(quicTransport, nil, "localhost") + client = newClient(nil, "localhost", &roundTripperOpts{}) session.streamToOpen = &mockStream{id: 3} client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { return session, nil @@ -63,7 +61,7 @@ var _ = Describe("Client", func() { It("errors when dialing fails", func() { testErr := errors.New("handshake error") - client = newClient(quicTransport, nil, "localhost") + client = newClient(nil, "localhost", &roundTripperOpts{}) client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { return nil, testErr } @@ -72,7 +70,7 @@ var _ = Describe("Client", func() { }) It("errors if the header stream has the wrong stream ID", func() { - client = newClient(quicTransport, nil, "localhost") + client = newClient(nil, "localhost", &roundTripperOpts{}) session.streamToOpen = &mockStream{id: 2} client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { return session, nil @@ -83,7 +81,7 @@ var _ = Describe("Client", func() { It("errors if it can't open a stream", func() { testErr := errors.New("you shall not pass") - client = newClient(quicTransport, nil, "localhost") + client = newClient(nil, "localhost", &roundTripperOpts{}) session.streamOpenErr = testErr client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { return session, nil @@ -226,7 +224,7 @@ var _ = Describe("Client", func() { It("adds the port for request URLs without one", func(done Done) { var err error - client = newClient(quicTransport, nil, "quic.clemente.io") + client = newClient(nil, "quic.clemente.io", &roundTripperOpts{}) req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) @@ -365,7 +363,7 @@ var _ = Describe("Client", func() { }) It("doesn't add gzip if the header disable it", func() { - quicTransport.DisableCompression = true + client.opts.DisableCompression = true var doErr error go func() { _, doErr = client.Do(request) }() diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index 5a60fcde..15201bed 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -93,7 +93,7 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { client, ok := r.clients[hostname] if !ok { - client = newClient(r, r.TLSClientConfig, hostname) + client = newClient(r.TLSClientConfig, hostname, &roundTripperOpts{DisableCompression: r.DisableCompression}) err := client.Dial() if err != nil { return nil, err @@ -103,10 +103,6 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { return client, nil } -func (r *QuicRoundTripper) disableCompression() bool { - return r.DisableCompression -} - func closeRequestBody(req *http.Request) { if req.Body != nil { req.Body.Close() diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index 3a48fa2f..bc212a01 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -66,12 +66,6 @@ var _ = Describe("RoundTripper", func() { Expect(rt.clients).To(HaveLen(1)) }) - It("disable compression", func() { - Expect(rt.disableCompression()).To(BeFalse()) - rt.DisableCompression = true - Expect(rt.disableCompression()).To(BeTrue()) - }) - Context("validating request", func() { It("rejects plain HTTP requests", func() { req, err := http.NewRequest("GET", "http://www.example.org/", nil)