don't pass the roundtripper to the h2quic client

This commit is contained in:
Marten Seemann
2017-06-02 22:35:16 +02:00
parent 4c3d4960bb
commit 9054e5205f
4 changed files with 21 additions and 30 deletions

View File

@@ -20,14 +20,17 @@ import (
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/utils"
) )
type roundTripperOpts struct {
DisableCompression bool
}
// client is a HTTP2 client doing QUIC requests // client is a HTTP2 client doing QUIC requests
type client struct { type client struct {
mutex sync.RWMutex mutex sync.RWMutex
dialAddr func(hostname string, config *quic.Config) (quic.Session, error) dialAddr func(hostname string, config *quic.Config) (quic.Session, error)
config *quic.Config config *quic.Config
opts *roundTripperOpts
t *QuicRoundTripper
hostname string hostname string
encryptionLevel protocol.EncryptionLevel encryptionLevel protocol.EncryptionLevel
@@ -45,9 +48,8 @@ type client struct {
var _ h2quicClient = &client{} var _ h2quicClient = &client{}
// newClient creates a new client // newClient creates a new client
func newClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *client { func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *client {
return &client{ return &client{
t: t,
dialAddr: quic.DialAddr, dialAddr: quic.DialAddr,
hostname: authorityAddr("https", hostname), hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response), responses: make(map[protocol.StreamID]chan *http.Response),
@@ -56,6 +58,7 @@ func newClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *cli
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
RequestConnectionIDTruncation: true, RequestConnectionIDTruncation: true,
}, },
opts: opts,
dialChan: make(chan struct{}), dialChan: make(chan struct{}),
} }
} }
@@ -164,7 +167,7 @@ func (c *client) Do(req *http.Request) (*http.Response, error) {
c.mutex.Unlock() c.mutex.Unlock()
var requestedGzip bool var requestedGzip bool
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
requestedGzip = true requestedGzip = true
} }
// TODO: add support for trailers // TODO: add support for trailers

View File

@@ -23,13 +23,11 @@ var _ = Describe("Client", func() {
client *client client *client
session *mockSession session *mockSession
headerStream *mockStream headerStream *mockStream
quicTransport *QuicRoundTripper
) )
BeforeEach(func() { BeforeEach(func() {
quicTransport = &QuicRoundTripper{}
hostname := "quic.clemente.io:1337" hostname := "quic.clemente.io:1337"
client = newClient(quicTransport, nil, hostname) client = newClient(nil, hostname, &roundTripperOpts{})
Expect(client.hostname).To(Equal(hostname)) Expect(client.hostname).To(Equal(hostname))
session = &mockSession{} session = &mockSession{}
client.session = session client.session = session
@@ -41,17 +39,17 @@ var _ = Describe("Client", func() {
It("saves the TLS config", func() { It("saves the TLS config", func() {
tlsConf := &tls.Config{InsecureSkipVerify: true} tlsConf := &tls.Config{InsecureSkipVerify: true}
client = newClient(&QuicRoundTripper{}, tlsConf, "") client = newClient(tlsConf, "", &roundTripperOpts{})
Expect(client.config.TLSConfig).To(Equal(tlsConf)) Expect(client.config.TLSConfig).To(Equal(tlsConf))
}) })
It("adds the port to the hostname, if none is given", func() { It("adds the port to the hostname, if none is given", func() {
client = newClient(quicTransport, nil, "quic.clemente.io") client = newClient(nil, "quic.clemente.io", &roundTripperOpts{})
Expect(client.hostname).To(Equal("quic.clemente.io:443")) Expect(client.hostname).To(Equal("quic.clemente.io:443"))
}) })
It("dials", func() { It("dials", func() {
client = newClient(quicTransport, nil, "localhost") client = newClient(nil, "localhost", &roundTripperOpts{})
session.streamToOpen = &mockStream{id: 3} session.streamToOpen = &mockStream{id: 3}
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
return session, nil return session, nil
@@ -63,7 +61,7 @@ var _ = Describe("Client", func() {
It("errors when dialing fails", func() { It("errors when dialing fails", func() {
testErr := errors.New("handshake error") testErr := errors.New("handshake error")
client = newClient(quicTransport, nil, "localhost") client = newClient(nil, "localhost", &roundTripperOpts{})
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
return nil, testErr return nil, testErr
} }
@@ -72,7 +70,7 @@ var _ = Describe("Client", func() {
}) })
It("errors if the header stream has the wrong stream ID", func() { It("errors if the header stream has the wrong stream ID", func() {
client = newClient(quicTransport, nil, "localhost") client = newClient(nil, "localhost", &roundTripperOpts{})
session.streamToOpen = &mockStream{id: 2} session.streamToOpen = &mockStream{id: 2}
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
return session, nil return session, nil
@@ -83,7 +81,7 @@ var _ = Describe("Client", func() {
It("errors if it can't open a stream", func() { It("errors if it can't open a stream", func() {
testErr := errors.New("you shall not pass") testErr := errors.New("you shall not pass")
client = newClient(quicTransport, nil, "localhost") client = newClient(nil, "localhost", &roundTripperOpts{})
session.streamOpenErr = testErr session.streamOpenErr = testErr
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
return session, nil return session, nil
@@ -226,7 +224,7 @@ var _ = Describe("Client", func() {
It("adds the port for request URLs without one", func(done Done) { It("adds the port for request URLs without one", func(done Done) {
var err error var err error
client = newClient(quicTransport, nil, "quic.clemente.io") client = newClient(nil, "quic.clemente.io", &roundTripperOpts{})
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil) req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@@ -365,7 +363,7 @@ var _ = Describe("Client", func() {
}) })
It("doesn't add gzip if the header disable it", func() { It("doesn't add gzip if the header disable it", func() {
quicTransport.DisableCompression = true client.opts.DisableCompression = true
var doErr error var doErr error
go func() { _, doErr = client.Do(request) }() go func() { _, doErr = client.Do(request) }()

View File

@@ -93,7 +93,7 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
client, ok := r.clients[hostname] client, ok := r.clients[hostname]
if !ok { if !ok {
client = newClient(r, r.TLSClientConfig, hostname) client = newClient(r.TLSClientConfig, hostname, &roundTripperOpts{DisableCompression: r.DisableCompression})
err := client.Dial() err := client.Dial()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -103,10 +103,6 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
return client, nil return client, nil
} }
func (r *QuicRoundTripper) disableCompression() bool {
return r.DisableCompression
}
func closeRequestBody(req *http.Request) { func closeRequestBody(req *http.Request) {
if req.Body != nil { if req.Body != nil {
req.Body.Close() req.Body.Close()

View File

@@ -66,12 +66,6 @@ var _ = Describe("RoundTripper", func() {
Expect(rt.clients).To(HaveLen(1)) Expect(rt.clients).To(HaveLen(1))
}) })
It("disable compression", func() {
Expect(rt.disableCompression()).To(BeFalse())
rt.DisableCompression = true
Expect(rt.disableCompression()).To(BeTrue())
})
Context("validating request", func() { Context("validating request", func() {
It("rejects plain HTTP requests", func() { It("rejects plain HTTP requests", func() {
req, err := http.NewRequest("GET", "http://www.example.org/", nil) req, err := http.NewRequest("GET", "http://www.example.org/", nil)