From 79c7ed4ed1cec5f2f36cba1bcf9501b4df6ff399 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Jul 2017 16:02:34 +0800 Subject: [PATCH 1/5] dependency-inject quic.Listen and quic.ListenAddr in h2quic.Server --- h2quic/server.go | 10 +++++++-- h2quic/server_test.go | 50 ++++++++++++++++--------------------------- 2 files changed, 26 insertions(+), 34 deletions(-) diff --git a/h2quic/server.go b/h2quic/server.go index a2f13bf7..57b68e56 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -29,6 +29,12 @@ type remoteCloser interface { CloseRemote(protocol.ByteCount) } +// allows mocking of quic.Listen and quic.ListenAddr +var ( + quicListen = quic.Listen + quicListenAddr = quic.ListenAddr +) + // Server is a HTTP2 server listening for QUIC connections. type Server struct { *http.Server @@ -90,9 +96,9 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { var ln quic.Listener var err error if conn == nil { - ln, err = quic.ListenAddr(s.Addr, tlsConfig, &config) + ln, err = quicListenAddr(s.Addr, tlsConfig, &config) } else { - ln, err = quic.Listen(conn, tlsConfig, &config) + ln, err = quicListen(conn, tlsConfig, &config) } if err != nil { s.listenerMutex.Unlock() diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 527ee9ba..08d3a218 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -2,13 +2,12 @@ package h2quic import ( "bytes" + "crypto/tls" + "errors" "io" "net" "net/http" - "os" - "runtime" "sync" - "syscall" "time" "golang.org/x/net/http2" @@ -66,9 +65,10 @@ func (s *mockSession) WaitUntilClosed() { panic("not implemented") } var _ = Describe("H2 server", func() { var ( - s *Server - session *mockSession - dataStream *mockStream + s *Server + session *mockSession + dataStream *mockStream + origQuicListenAddr = quicListenAddr ) BeforeEach(func() { @@ -80,6 +80,11 @@ var _ = Describe("H2 server", func() { dataStream = newMockStream(0) close(dataStream.unblockRead) session = &mockSession{dataStream: dataStream} + origQuicListenAddr = quicListenAddr + }) + + AfterEach(func() { + quicListenAddr = origQuicListenAddr }) Context("handling requests", func() { @@ -399,7 +404,6 @@ var _ = Describe("H2 server", func() { Expect(err).To(MatchError("ListenAndServe may only be called once")) err = s.Close() Expect(err).NotTo(HaveOccurred()) - }, 0.5) }) @@ -436,31 +440,13 @@ var _ = Describe("H2 server", func() { Expect(err).NotTo(HaveOccurred()) }) - It("at least errors in global ListenAndServeQUIC", func() { - // It's quite hard to test this, since we cannot properly shutdown the server - // once it's started. So, we open a socket on the same port before the test, - // so that ListenAndServeQUIC definitely fails. This way we know it at least - // created a socket on the proper address :) - const addr = "127.0.0.1:4826" - udpAddr, err := net.ResolveUDPAddr("udp", addr) - Expect(err).NotTo(HaveOccurred()) - c, err := net.ListenUDP("udp", udpAddr) - Expect(err).NotTo(HaveOccurred()) - defer c.Close() - fullpem, privkey := testdata.GetCertificatePaths() - err = ListenAndServeQUIC(addr, fullpem, privkey, nil) - // Check that it's an EADDRINUSE - Expect(err).ToNot(BeNil()) - opErr, ok := err.(*net.OpError) - Expect(ok).To(BeTrue()) - syscallErr, ok := opErr.Err.(*os.SyscallError) - Expect(ok).To(BeTrue()) - if runtime.GOOS == "windows" { - // for some reason, Windows return a different error number, corresponding to an WSAEADDRINUSE error - // see https://msdn.microsoft.com/en-us/library/windows/desktop/ms681391(v=vs.85).aspx - Expect(syscallErr.Err).To(Equal(syscall.Errno(0x2740))) - } else { - Expect(syscallErr.Err).To(MatchError(syscall.EADDRINUSE)) + It("errors when listening fails", func() { + testErr := errors.New("listen error") + quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) { + return nil, testErr } + fullpem, privkey := testdata.GetCertificatePaths() + err := ListenAndServeQUIC("", fullpem, privkey, nil) + Expect(err).To(MatchError(testErr)) }) }) From d94b57fe296aece6bcc7ff8865d9556df45c13e8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Jul 2017 16:21:20 +0800 Subject: [PATCH 2/5] expose the quic.Config in the h2quic.Server --- h2quic/server.go | 12 ++++++------ h2quic/server_test.go | 15 +++++++++++++-- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/h2quic/server.go b/h2quic/server.go index 57b68e56..ebce5dae 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -39,6 +39,10 @@ var ( type Server struct { *http.Server + // By providing a quic.Config, it is possible to set parameters of the QUIC connection. + // If nil, it uses reasonable default values. + QuicConfig *quic.Config + // Private flag for demo, do not use CloseAfterFirstRequest bool @@ -89,16 +93,12 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { return errors.New("ListenAndServe may only be called once") } - config := quic.Config{ - Versions: protocol.SupportedVersions, - } - var ln quic.Listener var err error if conn == nil { - ln, err = quicListenAddr(s.Addr, tlsConfig, &config) + ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig) } else { - ln, err = quicListen(conn, tlsConfig, &config) + ln, err = quicListen(conn, tlsConfig, s.QuicConfig) } if err != nil { s.listenerMutex.Unlock() diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 08d3a218..799ab44d 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -385,8 +385,7 @@ var _ = Describe("H2 server", func() { }) AfterEach(func() { - err := s.Close() - Expect(err).NotTo(HaveOccurred()) + Expect(s.Close()).To(Succeed()) }) It("may only be called once", func() { @@ -405,6 +404,18 @@ var _ = Describe("H2 server", func() { err = s.Close() Expect(err).NotTo(HaveOccurred()) }, 0.5) + + It("uses the quic.Config to start the quic server", func() { + conf := &quic.Config{HandshakeTimeout: time.Nanosecond} + var receivedConf *quic.Config + quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) { + receivedConf = config + return nil, errors.New("listen err") + } + s.QuicConfig = conf + go s.ListenAndServe() + Eventually(func() *quic.Config { return receivedConf }).Should(Equal(conf)) + }) }) Context("ListenAndServeTLS", func() { From cb81a95ceb6b2d3b1d93eb0b4ee24446038c4c26 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Jul 2017 16:40:53 +0800 Subject: [PATCH 3/5] make the dependency-injected dialAddr in h2quic.client a global variable It's only used for testing, so there's no need to have in each h2quic.client instance. --- h2quic/client.go | 12 +++++----- h2quic/client_test.go | 18 ++++++++++----- h2quic/roundtrip_test.go | 47 ++++++++++++++++++++++++++++++++++------ 3 files changed, 58 insertions(+), 19 deletions(-) diff --git a/h2quic/client.go b/h2quic/client.go index 813f466c..79c005bb 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -24,14 +24,15 @@ type roundTripperOpts struct { DisableCompression bool } +var dialAddr = quic.DialAddr + // client is a HTTP2 client doing QUIC requests type client struct { mutex sync.RWMutex - dialAddr func(hostname string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) - tlsConf *tls.Config - config *quic.Config - opts *roundTripperOpts + tlsConf *tls.Config + config *quic.Config + opts *roundTripperOpts hostname string encryptionLevel protocol.EncryptionLevel @@ -52,7 +53,6 @@ var _ http.RoundTripper = &client{} // newClient creates a new client func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *client { return &client{ - dialAddr: quic.DialAddr, hostname: authorityAddr("https", hostname), responses: make(map[protocol.StreamID]chan *http.Response), encryptionLevel: protocol.EncryptionUnencrypted, @@ -69,7 +69,7 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) * // dial dials the connection func (c *client) dial() error { var err error - c.session, err = c.dialAddr(c.hostname, c.tlsConf, c.config) + c.session, err = dialAddr(c.hostname, c.tlsConf, c.config) if err != nil { return err } diff --git a/h2quic/client_test.go b/h2quic/client_test.go index fa61d142..890be12d 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -25,9 +25,11 @@ var _ = Describe("Client", func() { session *mockSession headerStream *mockStream req *http.Request + origDialAddr = dialAddr ) BeforeEach(func() { + origDialAddr = dialAddr hostname := "quic.clemente.io:1337" client = newClient(nil, hostname, &roundTripperOpts{}) Expect(client.hostname).To(Equal(hostname)) @@ -42,6 +44,10 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) }) + AfterEach(func() { + dialAddr = origDialAddr + }) + It("saves the TLS config", func() { tlsConf := &tls.Config{InsecureSkipVerify: true} client = newClient(tlsConf, "", &roundTripperOpts{}) @@ -56,7 +62,7 @@ var _ = Describe("Client", func() { It("dials", func(done Done) { client = newClient(nil, "localhost:1337", &roundTripperOpts{}) session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)} - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } close(headerStream.unblockRead) @@ -68,7 +74,7 @@ var _ = Describe("Client", func() { It("errors when dialing fails", func() { testErr := errors.New("handshake error") client = newClient(nil, "localhost:1337", &roundTripperOpts{}) - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return nil, testErr } _, err := client.RoundTrip(req) @@ -78,7 +84,7 @@ var _ = Describe("Client", func() { It("errors if the header stream has the wrong stream ID", func() { client = newClient(nil, "localhost:1337", &roundTripperOpts{}) session.streamsToOpen = []quic.Stream{&mockStream{id: 2}} - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } _, err := client.RoundTrip(req) @@ -89,7 +95,7 @@ var _ = Describe("Client", func() { testErr := errors.New("you shall not pass") client = newClient(nil, "localhost:1337", &roundTripperOpts{}) session.streamOpenErr = testErr - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } _, err := client.RoundTrip(req) @@ -98,7 +104,7 @@ var _ = Describe("Client", func() { It("returns a request when dial fails", func() { testErr := errors.New("dial error") - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return nil, testErr } request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) @@ -140,7 +146,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { var err error client.encryptionLevel = protocol.EncryptionForwardSecure - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } dataStream = newMockStream(5) diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index 3b91a3bf..fb88ac8a 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -2,9 +2,12 @@ package h2quic import ( "bytes" + "crypto/tls" + "errors" "io" "net/http" + quic "github.com/lucas-clemente/quic-go" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -54,13 +57,43 @@ var _ = Describe("RoundTripper", func() { Expect(err).ToNot(HaveOccurred()) }) - It("reuses existing clients", func() { - rt.clients = make(map[string]http.RoundTripper) - rt.clients["www.example.org:443"] = &mockRoundTripper{} - rsp, err := rt.RoundTrip(req1) - Expect(err).ToNot(HaveOccurred()) - Expect(rsp.Request).To(Equal(req1)) - Expect(rt.clients).To(HaveLen(1)) + Context("dialing hosts", func() { + origDialAddr := dialAddr + streamOpenErr := errors.New("error opening stream") + + BeforeEach(func() { + origDialAddr = dialAddr + dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { + // return an error when trying to open a stream + // we don't want to test all the dial logic here, just that dialing happens at all + return &mockSession{streamOpenErr: streamOpenErr}, nil + } + }) + + AfterEach(func() { + dialAddr = origDialAddr + }) + + It("creates new clients", func() { + req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTrip(req) + Expect(err).To(MatchError(streamOpenErr)) + Expect(rt.clients).To(HaveLen(1)) + }) + + It("reuses existing clients", func() { + req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTrip(req) + Expect(err).To(MatchError(streamOpenErr)) + Expect(rt.clients).To(HaveLen(1)) + req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTrip(req2) + Expect(err).To(MatchError(streamOpenErr)) + Expect(rt.clients).To(HaveLen(1)) + }) }) Context("validating request", func() { From abb9594af8ee50e2db579bf307315e09b6d235b5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Jul 2017 16:51:31 +0800 Subject: [PATCH 4/5] change order of function parameters for the h2quic.client constructor --- h2quic/client.go | 2 +- h2quic/client_test.go | 16 ++++++++-------- h2quic/roundtrip.go | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/h2quic/client.go b/h2quic/client.go index 79c005bb..8b577850 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -51,7 +51,7 @@ type client struct { var _ http.RoundTripper = &client{} // newClient creates a new client -func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *client { +func newClient(hostname string, tlsConfig *tls.Config, opts *roundTripperOpts) *client { return &client{ hostname: authorityAddr("https", hostname), responses: make(map[protocol.StreamID]chan *http.Response), diff --git a/h2quic/client_test.go b/h2quic/client_test.go index 890be12d..f8ced8cc 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -31,7 +31,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { origDialAddr = dialAddr hostname := "quic.clemente.io:1337" - client = newClient(nil, hostname, &roundTripperOpts{}) + client = newClient(hostname, nil, &roundTripperOpts{}) Expect(client.hostname).To(Equal(hostname)) session = &mockSession{} client.session = session @@ -50,17 +50,17 @@ var _ = Describe("Client", func() { It("saves the TLS config", func() { tlsConf := &tls.Config{InsecureSkipVerify: true} - client = newClient(tlsConf, "", &roundTripperOpts{}) + client = newClient("", tlsConf, &roundTripperOpts{}) Expect(client.tlsConf).To(Equal(tlsConf)) }) It("adds the port to the hostname, if none is given", func() { - client = newClient(nil, "quic.clemente.io", &roundTripperOpts{}) + client = newClient("quic.clemente.io", nil, &roundTripperOpts{}) Expect(client.hostname).To(Equal("quic.clemente.io:443")) }) It("dials", func(done Done) { - client = newClient(nil, "localhost:1337", &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}) session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)} dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil @@ -73,7 +73,7 @@ var _ = Describe("Client", func() { It("errors when dialing fails", func() { testErr := errors.New("handshake error") - client = newClient(nil, "localhost:1337", &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}) dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return nil, testErr } @@ -82,7 +82,7 @@ var _ = Describe("Client", func() { }) It("errors if the header stream has the wrong stream ID", func() { - client = newClient(nil, "localhost:1337", &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}) session.streamsToOpen = []quic.Stream{&mockStream{id: 2}} dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil @@ -93,7 +93,7 @@ var _ = Describe("Client", func() { It("errors if it can't open a stream", func() { testErr := errors.New("you shall not pass") - client = newClient(nil, "localhost:1337", &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}) session.streamOpenErr = testErr dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil @@ -252,7 +252,7 @@ var _ = Describe("Client", func() { It("adds the port for request URLs without one", func(done Done) { var err error - client = newClient(nil, "quic.clemente.io", &roundTripperOpts{}) + client = newClient("quic.clemente.io", nil, &roundTripperOpts{}) req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index 332411c3..b8ab70be 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -84,7 +84,7 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper { client, ok := r.clients[hostname] if !ok { - client = newClient(r.TLSClientConfig, hostname, &roundTripperOpts{DisableCompression: r.DisableCompression}) + client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}) r.clients[hostname] = client } return client From ee6ca8dfb4784599e8db754d5654f3b003147b60 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Jul 2017 20:15:09 +0800 Subject: [PATCH 5/5] expose the quic.Config in the h2quic.RoundTripper --- h2quic/client.go | 25 ++++++++++++++++++------- h2quic/client_test.go | 30 ++++++++++++++++++++++-------- h2quic/roundtrip.go | 8 +++++++- h2quic/roundtrip_test.go | 13 +++++++++++++ 4 files changed, 60 insertions(+), 16 deletions(-) diff --git a/h2quic/client.go b/h2quic/client.go index 8b577850..bba706bf 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -50,19 +50,30 @@ type client struct { var _ http.RoundTripper = &client{} +var defaultQuicConfig = &quic.Config{ + RequestConnectionIDTruncation: true, + KeepAlive: true, +} + // newClient creates a new client -func newClient(hostname string, tlsConfig *tls.Config, opts *roundTripperOpts) *client { +func newClient( + hostname string, + tlsConfig *tls.Config, + opts *roundTripperOpts, + quicConfig *quic.Config, +) *client { + config := defaultQuicConfig + if quicConfig != nil { + config = quicConfig + } return &client{ hostname: authorityAddr("https", hostname), responses: make(map[protocol.StreamID]chan *http.Response), encryptionLevel: protocol.EncryptionUnencrypted, tlsConf: tlsConfig, - config: &quic.Config{ - RequestConnectionIDTruncation: true, - KeepAlive: true, - }, - opts: opts, - headerErrored: make(chan struct{}), + config: config, + opts: opts, + headerErrored: make(chan struct{}), } } diff --git a/h2quic/client_test.go b/h2quic/client_test.go index f8ced8cc..5be4b2f4 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -15,6 +15,8 @@ import ( "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" + "time" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -31,7 +33,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { origDialAddr = dialAddr hostname := "quic.clemente.io:1337" - client = newClient(hostname, nil, &roundTripperOpts{}) + client = newClient(hostname, nil, &roundTripperOpts{}, nil) Expect(client.hostname).To(Equal(hostname)) session = &mockSession{} client.session = session @@ -50,17 +52,29 @@ var _ = Describe("Client", func() { It("saves the TLS config", func() { tlsConf := &tls.Config{InsecureSkipVerify: true} - client = newClient("", tlsConf, &roundTripperOpts{}) + client = newClient("", tlsConf, &roundTripperOpts{}, nil) Expect(client.tlsConf).To(Equal(tlsConf)) }) + It("saves the QUIC config", func() { + quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond} + client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf) + Expect(client.config).To(Equal(quicConf)) + }) + + It("uses the default QUIC config if none is give", func() { + client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil) + Expect(client.config).ToNot(BeNil()) + Expect(client.config).To(Equal(defaultQuicConfig)) + }) + It("adds the port to the hostname, if none is given", func() { - client = newClient("quic.clemente.io", nil, &roundTripperOpts{}) + client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil) Expect(client.hostname).To(Equal("quic.clemente.io:443")) }) It("dials", func(done Done) { - client = newClient("localhost:1337", nil, &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)} dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil @@ -73,7 +87,7 @@ var _ = Describe("Client", func() { It("errors when dialing fails", func() { testErr := errors.New("handshake error") - client = newClient("localhost:1337", nil, &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return nil, testErr } @@ -82,7 +96,7 @@ var _ = Describe("Client", func() { }) It("errors if the header stream has the wrong stream ID", func() { - client = newClient("localhost:1337", nil, &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) session.streamsToOpen = []quic.Stream{&mockStream{id: 2}} dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil @@ -93,7 +107,7 @@ var _ = Describe("Client", func() { It("errors if it can't open a stream", func() { testErr := errors.New("you shall not pass") - client = newClient("localhost:1337", nil, &roundTripperOpts{}) + client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) session.streamOpenErr = testErr dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil @@ -252,7 +266,7 @@ var _ = Describe("Client", func() { It("adds the port for request URLs without one", func(done Done) { var err error - client = newClient("quic.clemente.io", nil, &roundTripperOpts{}) + client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil) req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index b8ab70be..ce2bbe96 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -8,6 +8,8 @@ import ( "strings" "sync" + quic "github.com/lucas-clemente/quic-go" + "golang.org/x/net/lex/httplex" ) @@ -29,6 +31,10 @@ type RoundTripper struct { // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config + // QuicConfig is the quic.Config used for dialing new connections. + // If nil, reasonable default values will be used. + QuicConfig *quic.Config + clients map[string]http.RoundTripper } @@ -84,7 +90,7 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper { client, ok := r.clients[hostname] if !ok { - client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}) + client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig) r.clients[hostname] = client } return client diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index fb88ac8a..635a50a3 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -6,6 +6,7 @@ import ( "errors" "io" "net/http" + "time" quic "github.com/lucas-clemente/quic-go" . "github.com/onsi/ginkgo" @@ -82,6 +83,18 @@ var _ = Describe("RoundTripper", func() { Expect(rt.clients).To(HaveLen(1)) }) + It("uses the quic.Config, if provided", func() { + config := &quic.Config{HandshakeTimeout: time.Millisecond} + var receivedConfig *quic.Config + dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { + receivedConfig = config + return nil, errors.New("err") + } + rt.QuicConfig = config + rt.RoundTrip(req1) + Expect(receivedConfig).To(Equal(config)) + }) + It("reuses existing clients", func() { req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) Expect(err).ToNot(HaveOccurred())