forked from quic-go/quic-go
add option to disable compresson to QuicRoundTripper
This commit is contained in:
@@ -29,6 +29,8 @@ type Client struct {
|
||||
mutex sync.RWMutex
|
||||
cryptoChangedCond sync.Cond
|
||||
|
||||
t *QuicRoundTripper
|
||||
|
||||
hostname string
|
||||
encryptionLevel protocol.EncryptionLevel
|
||||
|
||||
@@ -44,8 +46,9 @@ type Client struct {
|
||||
var _ h2quicClient = &Client{}
|
||||
|
||||
// NewClient creates a new client
|
||||
func NewClient(hostname string) (*Client, error) {
|
||||
func NewClient(t *QuicRoundTripper, hostname string) (*Client, error) {
|
||||
c := &Client{
|
||||
t: t,
|
||||
hostname: authorityAddr("https", hostname),
|
||||
highestOpenedStream: 3,
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
@@ -170,7 +173,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
|
||||
var requestedGzip bool
|
||||
if req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
||||
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
||||
requestedGzip = true
|
||||
}
|
||||
err = c.requestWriter.WriteRequest(req, dataStreamID, requestedGzip)
|
||||
|
||||
@@ -43,15 +43,17 @@ var _ quicClient = &mockQuicClient{}
|
||||
|
||||
var _ = Describe("Client", func() {
|
||||
var (
|
||||
client *Client
|
||||
qClient *mockQuicClient
|
||||
headerStream *mockStream
|
||||
client *Client
|
||||
qClient *mockQuicClient
|
||||
headerStream *mockStream
|
||||
quicTransport *QuicRoundTripper
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
quicTransport = &QuicRoundTripper{}
|
||||
hostname := "quic.clemente.io:1337"
|
||||
client, err = NewClient(hostname)
|
||||
client, err = NewClient(quicTransport, hostname)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.hostname).To(Equal(hostname))
|
||||
qClient = newMockQuicClient()
|
||||
@@ -65,7 +67,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("adds the port to the hostname, if none is given", func() {
|
||||
var err error
|
||||
client, err = NewClient("quic.clemente.io")
|
||||
client, err = NewClient(quicTransport, "quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
||||
})
|
||||
@@ -183,7 +185,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("adds the port for request URLs without one", func(done Done) {
|
||||
var err error
|
||||
client, err = NewClient("quic.clemente.io")
|
||||
client, err = NewClient(quicTransport, "quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -239,6 +241,18 @@ var _ = Describe("Client", func() {
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("doesn't add gzip if the header disable it", func() {
|
||||
quicTransport.DisableCompression = true
|
||||
var doRsp *http.Response
|
||||
var doErr error
|
||||
go func() { doRsp, doErr = client.Do(request) }()
|
||||
|
||||
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
|
||||
Expect(doErr).ToNot(HaveOccurred())
|
||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
||||
Expect(headers).ToNot(HaveKey("accept-encoding"))
|
||||
})
|
||||
|
||||
It("only decompresses the response if the response contains the right content-encoding header", func() {
|
||||
var doRsp *http.Response
|
||||
var doErr error
|
||||
|
||||
@@ -13,6 +13,16 @@ type h2quicClient interface {
|
||||
type QuicRoundTripper struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
// DisableCompression, if true, prevents the Transport from
|
||||
// requesting compression with an "Accept-Encoding: gzip"
|
||||
// request header when the Request contains no existing
|
||||
// Accept-Encoding value. If the Transport requests gzip on
|
||||
// its own and gets a gzipped response, it's transparently
|
||||
// decoded in the Response.Body. However, if the user
|
||||
// explicitly requested gzip it is not automatically
|
||||
// uncompressed.
|
||||
DisableCompression bool
|
||||
|
||||
clients map[string]h2quicClient
|
||||
}
|
||||
|
||||
@@ -39,7 +49,7 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
|
||||
client, ok := r.clients[hostname]
|
||||
if !ok {
|
||||
var err error
|
||||
client, err = NewClient(hostname)
|
||||
client, err = NewClient(r, hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -47,3 +57,7 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (r *QuicRoundTripper) disableCompression() bool {
|
||||
return r.DisableCompression
|
||||
}
|
||||
|
||||
@@ -33,4 +33,10 @@ var _ = Describe("RoundTripper", func() {
|
||||
Expect(rsp.Request).To(Equal(req1))
|
||||
Expect(rt.clients).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("disable compression", func() {
|
||||
Expect(rt.disableCompression()).To(BeFalse())
|
||||
rt.DisableCompression = true
|
||||
Expect(rt.disableCompression()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user