forked from quic-go/quic-go
don't pass the roundtripper to the h2quic client
This commit is contained in:
@@ -20,14 +20,17 @@ import (
|
|||||||
"github.com/lucas-clemente/quic-go/utils"
|
"github.com/lucas-clemente/quic-go/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type roundTripperOpts struct {
|
||||||
|
DisableCompression bool
|
||||||
|
}
|
||||||
|
|
||||||
// client is a HTTP2 client doing QUIC requests
|
// client is a HTTP2 client doing QUIC requests
|
||||||
type client struct {
|
type client struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
dialAddr func(hostname string, config *quic.Config) (quic.Session, error)
|
dialAddr func(hostname string, config *quic.Config) (quic.Session, error)
|
||||||
config *quic.Config
|
config *quic.Config
|
||||||
|
opts *roundTripperOpts
|
||||||
t *QuicRoundTripper
|
|
||||||
|
|
||||||
hostname string
|
hostname string
|
||||||
encryptionLevel protocol.EncryptionLevel
|
encryptionLevel protocol.EncryptionLevel
|
||||||
@@ -45,9 +48,8 @@ type client struct {
|
|||||||
var _ h2quicClient = &client{}
|
var _ h2quicClient = &client{}
|
||||||
|
|
||||||
// newClient creates a new client
|
// newClient creates a new client
|
||||||
func newClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *client {
|
func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *client {
|
||||||
return &client{
|
return &client{
|
||||||
t: t,
|
|
||||||
dialAddr: quic.DialAddr,
|
dialAddr: quic.DialAddr,
|
||||||
hostname: authorityAddr("https", hostname),
|
hostname: authorityAddr("https", hostname),
|
||||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||||
@@ -56,6 +58,7 @@ func newClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *cli
|
|||||||
TLSConfig: tlsConfig,
|
TLSConfig: tlsConfig,
|
||||||
RequestConnectionIDTruncation: true,
|
RequestConnectionIDTruncation: true,
|
||||||
},
|
},
|
||||||
|
opts: opts,
|
||||||
dialChan: make(chan struct{}),
|
dialChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -164,7 +167,7 @@ func (c *client) Do(req *http.Request) (*http.Response, error) {
|
|||||||
c.mutex.Unlock()
|
c.mutex.Unlock()
|
||||||
|
|
||||||
var requestedGzip bool
|
var requestedGzip bool
|
||||||
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
||||||
requestedGzip = true
|
requestedGzip = true
|
||||||
}
|
}
|
||||||
// TODO: add support for trailers
|
// TODO: add support for trailers
|
||||||
|
|||||||
@@ -23,13 +23,11 @@ var _ = Describe("Client", func() {
|
|||||||
client *client
|
client *client
|
||||||
session *mockSession
|
session *mockSession
|
||||||
headerStream *mockStream
|
headerStream *mockStream
|
||||||
quicTransport *QuicRoundTripper
|
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
quicTransport = &QuicRoundTripper{}
|
|
||||||
hostname := "quic.clemente.io:1337"
|
hostname := "quic.clemente.io:1337"
|
||||||
client = newClient(quicTransport, nil, hostname)
|
client = newClient(nil, hostname, &roundTripperOpts{})
|
||||||
Expect(client.hostname).To(Equal(hostname))
|
Expect(client.hostname).To(Equal(hostname))
|
||||||
session = &mockSession{}
|
session = &mockSession{}
|
||||||
client.session = session
|
client.session = session
|
||||||
@@ -41,17 +39,17 @@ var _ = Describe("Client", func() {
|
|||||||
|
|
||||||
It("saves the TLS config", func() {
|
It("saves the TLS config", func() {
|
||||||
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
||||||
client = newClient(&QuicRoundTripper{}, tlsConf, "")
|
client = newClient(tlsConf, "", &roundTripperOpts{})
|
||||||
Expect(client.config.TLSConfig).To(Equal(tlsConf))
|
Expect(client.config.TLSConfig).To(Equal(tlsConf))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("adds the port to the hostname, if none is given", func() {
|
It("adds the port to the hostname, if none is given", func() {
|
||||||
client = newClient(quicTransport, nil, "quic.clemente.io")
|
client = newClient(nil, "quic.clemente.io", &roundTripperOpts{})
|
||||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("dials", func() {
|
It("dials", func() {
|
||||||
client = newClient(quicTransport, nil, "localhost")
|
client = newClient(nil, "localhost", &roundTripperOpts{})
|
||||||
session.streamToOpen = &mockStream{id: 3}
|
session.streamToOpen = &mockStream{id: 3}
|
||||||
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
||||||
return session, nil
|
return session, nil
|
||||||
@@ -63,7 +61,7 @@ var _ = Describe("Client", func() {
|
|||||||
|
|
||||||
It("errors when dialing fails", func() {
|
It("errors when dialing fails", func() {
|
||||||
testErr := errors.New("handshake error")
|
testErr := errors.New("handshake error")
|
||||||
client = newClient(quicTransport, nil, "localhost")
|
client = newClient(nil, "localhost", &roundTripperOpts{})
|
||||||
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
@@ -72,7 +70,7 @@ var _ = Describe("Client", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("errors if the header stream has the wrong stream ID", func() {
|
It("errors if the header stream has the wrong stream ID", func() {
|
||||||
client = newClient(quicTransport, nil, "localhost")
|
client = newClient(nil, "localhost", &roundTripperOpts{})
|
||||||
session.streamToOpen = &mockStream{id: 2}
|
session.streamToOpen = &mockStream{id: 2}
|
||||||
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
||||||
return session, nil
|
return session, nil
|
||||||
@@ -83,7 +81,7 @@ var _ = Describe("Client", func() {
|
|||||||
|
|
||||||
It("errors if it can't open a stream", func() {
|
It("errors if it can't open a stream", func() {
|
||||||
testErr := errors.New("you shall not pass")
|
testErr := errors.New("you shall not pass")
|
||||||
client = newClient(quicTransport, nil, "localhost")
|
client = newClient(nil, "localhost", &roundTripperOpts{})
|
||||||
session.streamOpenErr = testErr
|
session.streamOpenErr = testErr
|
||||||
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
||||||
return session, nil
|
return session, nil
|
||||||
@@ -226,7 +224,7 @@ var _ = Describe("Client", func() {
|
|||||||
|
|
||||||
It("adds the port for request URLs without one", func(done Done) {
|
It("adds the port for request URLs without one", func(done Done) {
|
||||||
var err error
|
var err error
|
||||||
client = newClient(quicTransport, nil, "quic.clemente.io")
|
client = newClient(nil, "quic.clemente.io", &roundTripperOpts{})
|
||||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
@@ -365,7 +363,7 @@ var _ = Describe("Client", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't add gzip if the header disable it", func() {
|
It("doesn't add gzip if the header disable it", func() {
|
||||||
quicTransport.DisableCompression = true
|
client.opts.DisableCompression = true
|
||||||
var doErr error
|
var doErr error
|
||||||
go func() { _, doErr = client.Do(request) }()
|
go func() { _, doErr = client.Do(request) }()
|
||||||
|
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
|
|||||||
|
|
||||||
client, ok := r.clients[hostname]
|
client, ok := r.clients[hostname]
|
||||||
if !ok {
|
if !ok {
|
||||||
client = newClient(r, r.TLSClientConfig, hostname)
|
client = newClient(r.TLSClientConfig, hostname, &roundTripperOpts{DisableCompression: r.DisableCompression})
|
||||||
err := client.Dial()
|
err := client.Dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -103,10 +103,6 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
|
|||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *QuicRoundTripper) disableCompression() bool {
|
|
||||||
return r.DisableCompression
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeRequestBody(req *http.Request) {
|
func closeRequestBody(req *http.Request) {
|
||||||
if req.Body != nil {
|
if req.Body != nil {
|
||||||
req.Body.Close()
|
req.Body.Close()
|
||||||
|
|||||||
@@ -66,12 +66,6 @@ var _ = Describe("RoundTripper", func() {
|
|||||||
Expect(rt.clients).To(HaveLen(1))
|
Expect(rt.clients).To(HaveLen(1))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("disable compression", func() {
|
|
||||||
Expect(rt.disableCompression()).To(BeFalse())
|
|
||||||
rt.DisableCompression = true
|
|
||||||
Expect(rt.disableCompression()).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("validating request", func() {
|
Context("validating request", func() {
|
||||||
It("rejects plain HTTP requests", func() {
|
It("rejects plain HTTP requests", func() {
|
||||||
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
|
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
|
||||||
|
|||||||
Reference in New Issue
Block a user