From ee6ca8dfb4784599e8db754d5654f3b003147b60 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Jul 2017 20:15:09 +0800 Subject: [PATCH] expose the quic.Config in the h2quic.RoundTripper --- h2quic/client.go | 25 ++++++++++++++++++------- h2quic/client_test.go | 30 ++++++++++++++++++++++-------- h2quic/roundtrip.go | 8 +++++++- h2quic/roundtrip_test.go | 13 +++++++++++++ 4 files changed, 60 insertions(+), 16 deletions(-) diff --git a/h2quic/client.go b/h2quic/client.go index 8b577850..bba706bf 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -50,19 +50,30 @@ type client struct { var _ http.RoundTripper = &client{} +var defaultQuicConfig = &quic.Config{ + RequestConnectionIDTruncation: true, + KeepAlive: true, +} + // newClient creates a new client -func newClient(hostname string, tlsConfig *tls.Config, opts *roundTripperOpts) *client { +func newClient( + hostname string, + tlsConfig *tls.Config, + opts *roundTripperOpts, + quicConfig *quic.Config, +) *client { + config := defaultQuicConfig + if quicConfig != nil { + config = quicConfig + } return &client{ hostname: authorityAddr("https", hostname), responses: make(map[protocol.StreamID]chan *http.Response), encryptionLevel: protocol.EncryptionUnencrypted, tlsConf: tlsConfig, - config: &quic.Config{ - RequestConnectionIDTruncation: true, - KeepAlive: true, - }, - opts: opts, - headerErrored: make(chan struct{}), + config: config, + opts: opts, + headerErrored: make(chan struct{}), } } diff --git a/h2quic/client_test.go b/h2quic/client_test.go index f8ced8cc..5be4b2f4 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -15,6 +15,8 @@ import ( "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" + "time" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -31,7 +33,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { origDialAddr = dialAddr hostname := "quic.clemente.io:1337" - client = newClient(hostname, nil, &roundTripperOpts{}) + client = newClient(hostname, nil, &roundTripperOpts{}, nil) Expect(client.hostname).To(Equal(hostname)) session = &mockSession{} client.session = session @@ -50,17 +52,29 @@ var _ = Describe("Client", func() { It("saves the TLS config", func() { tlsConf := &tls.Config{InsecureSkipVerify: true} - client = newClient("", tlsConf, &roundTripperOpts{}) + client = newClient("", tlsConf, &roundTripperOpts{}, nil) Expect(client.tlsConf).To(Equal(tlsConf)) }) + It("saves the QUIC config", func() { + quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond} + client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf) + Expect(client.config).To(Equal(quicConf)) + }) + + It("uses the default QUIC config if none is give", func() { + client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil) + Expect(client.config).ToNot(BeNil()) + Expect(client.config).To(Equal(defaultQuicConfig)) + }) + It("adds the port to the hostname, if none is given", func() { - client = newClient("quic.clemente.io", nil, &roundTripperOpts{}) + client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil) Expect(client.hostname).To(Equal("quic.clemente.io:443")) }) It("dials", func(done Done) { - client = newClient("localhost:1337", nil, &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)} dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil @@ -73,7 +87,7 @@ var _ = Describe("Client", func() { It("errors when dialing fails", func() { testErr := errors.New("handshake error") - client = newClient("localhost:1337", nil, &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return nil, testErr } @@ -82,7 +96,7 @@ var _ = Describe("Client", func() { }) It("errors if the header stream has the wrong stream ID", func() { - client = newClient("localhost:1337", nil, &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) session.streamsToOpen = []quic.Stream{&mockStream{id: 2}} dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil @@ -93,7 +107,7 @@ var _ = Describe("Client", func() { It("errors if it can't open a stream", func() { testErr := errors.New("you shall not pass") - client = newClient("localhost:1337", nil, &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) session.streamOpenErr = testErr dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil @@ -252,7 +266,7 @@ var _ = Describe("Client", func() { It("adds the port for request URLs without one", func(done Done) { var err error - client = newClient("quic.clemente.io", nil, &roundTripperOpts{}) + client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil) req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index b8ab70be..ce2bbe96 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -8,6 +8,8 @@ import ( "strings" "sync" + quic "github.com/lucas-clemente/quic-go" + "golang.org/x/net/lex/httplex" ) @@ -29,6 +31,10 @@ type RoundTripper struct { // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config + // QuicConfig is the quic.Config used for dialing new connections. + // If nil, reasonable default values will be used. + QuicConfig *quic.Config + clients map[string]http.RoundTripper } @@ -84,7 +90,7 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper { client, ok := r.clients[hostname] if !ok { - client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}) + client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig) r.clients[hostname] = client } return client diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index fb88ac8a..635a50a3 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -6,6 +6,7 @@ import ( "errors" "io" "net/http" + "time" quic "github.com/lucas-clemente/quic-go" . "github.com/onsi/ginkgo" @@ -82,6 +83,18 @@ var _ = Describe("RoundTripper", func() { Expect(rt.clients).To(HaveLen(1)) }) + It("uses the quic.Config, if provided", func() { + config := &quic.Config{HandshakeTimeout: time.Millisecond} + var receivedConfig *quic.Config + dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { + receivedConfig = config + return nil, errors.New("err") + } + rt.QuicConfig = config + rt.RoundTrip(req1) + Expect(receivedConfig).To(Equal(config)) + }) + It("reuses existing clients", func() { req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) Expect(err).ToNot(HaveOccurred())