From cb81a95ceb6b2d3b1d93eb0b4ee24446038c4c26 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Jul 2017 16:40:53 +0800 Subject: [PATCH] make the dependency-injected dialAddr in h2quic.client a global variable It's only used for testing, so there's no need to have in each h2quic.client instance. --- h2quic/client.go | 12 +++++----- h2quic/client_test.go | 18 ++++++++++----- h2quic/roundtrip_test.go | 47 ++++++++++++++++++++++++++++++++++------ 3 files changed, 58 insertions(+), 19 deletions(-) diff --git a/h2quic/client.go b/h2quic/client.go index 813f466c..79c005bb 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -24,14 +24,15 @@ type roundTripperOpts struct { DisableCompression bool } +var dialAddr = quic.DialAddr + // client is a HTTP2 client doing QUIC requests type client struct { mutex sync.RWMutex - dialAddr func(hostname string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) - tlsConf *tls.Config - config *quic.Config - opts *roundTripperOpts + tlsConf *tls.Config + config *quic.Config + opts *roundTripperOpts hostname string encryptionLevel protocol.EncryptionLevel @@ -52,7 +53,6 @@ var _ http.RoundTripper = &client{} // newClient creates a new client func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *client { return &client{ - dialAddr: quic.DialAddr, hostname: authorityAddr("https", hostname), responses: make(map[protocol.StreamID]chan *http.Response), encryptionLevel: protocol.EncryptionUnencrypted, @@ -69,7 +69,7 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) * // dial dials the connection func (c *client) dial() error { var err error - c.session, err = c.dialAddr(c.hostname, c.tlsConf, c.config) + c.session, err = dialAddr(c.hostname, c.tlsConf, c.config) if err != nil { return err } diff --git a/h2quic/client_test.go b/h2quic/client_test.go index fa61d142..890be12d 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -25,9 +25,11 @@ var _ = Describe("Client", func() { session *mockSession headerStream *mockStream req *http.Request + origDialAddr = dialAddr ) BeforeEach(func() { + origDialAddr = dialAddr hostname := "quic.clemente.io:1337" client = newClient(nil, hostname, &roundTripperOpts{}) Expect(client.hostname).To(Equal(hostname)) @@ -42,6 +44,10 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) }) + AfterEach(func() { + dialAddr = origDialAddr + }) + It("saves the TLS config", func() { tlsConf := &tls.Config{InsecureSkipVerify: true} client = newClient(tlsConf, "", &roundTripperOpts{}) @@ -56,7 +62,7 @@ var _ = Describe("Client", func() { It("dials", func(done Done) { client = newClient(nil, "localhost:1337", &roundTripperOpts{}) session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)} - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } close(headerStream.unblockRead) @@ -68,7 +74,7 @@ var _ = Describe("Client", func() { It("errors when dialing fails", func() { testErr := errors.New("handshake error") client = newClient(nil, "localhost:1337", &roundTripperOpts{}) - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return nil, testErr } _, err := client.RoundTrip(req) @@ -78,7 +84,7 @@ var _ = Describe("Client", func() { It("errors if the header stream has the wrong stream ID", func() { client = newClient(nil, "localhost:1337", &roundTripperOpts{}) session.streamsToOpen = []quic.Stream{&mockStream{id: 2}} - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } _, err := client.RoundTrip(req) @@ -89,7 +95,7 @@ var _ = Describe("Client", func() { testErr := errors.New("you shall not pass") client = newClient(nil, "localhost:1337", &roundTripperOpts{}) session.streamOpenErr = testErr - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } _, err := client.RoundTrip(req) @@ -98,7 +104,7 @@ var _ = Describe("Client", func() { It("returns a request when dial fails", func() { testErr := errors.New("dial error") - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return nil, testErr } request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) @@ -140,7 +146,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { var err error client.encryptionLevel = protocol.EncryptionForwardSecure - client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } dataStream = newMockStream(5) diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index 3b91a3bf..fb88ac8a 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -2,9 +2,12 @@ package h2quic import ( "bytes" + "crypto/tls" + "errors" "io" "net/http" + quic "github.com/lucas-clemente/quic-go" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -54,13 +57,43 @@ var _ = Describe("RoundTripper", func() { Expect(err).ToNot(HaveOccurred()) }) - It("reuses existing clients", func() { - rt.clients = make(map[string]http.RoundTripper) - rt.clients["www.example.org:443"] = &mockRoundTripper{} - rsp, err := rt.RoundTrip(req1) - Expect(err).ToNot(HaveOccurred()) - Expect(rsp.Request).To(Equal(req1)) - Expect(rt.clients).To(HaveLen(1)) + Context("dialing hosts", func() { + origDialAddr := dialAddr + streamOpenErr := errors.New("error opening stream") + + BeforeEach(func() { + origDialAddr = dialAddr + dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { + // return an error when trying to open a stream + // we don't want to test all the dial logic here, just that dialing happens at all + return &mockSession{streamOpenErr: streamOpenErr}, nil + } + }) + + AfterEach(func() { + dialAddr = origDialAddr + }) + + It("creates new clients", func() { + req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTrip(req) + Expect(err).To(MatchError(streamOpenErr)) + Expect(rt.clients).To(HaveLen(1)) + }) + + It("reuses existing clients", func() { + req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTrip(req) + Expect(err).To(MatchError(streamOpenErr)) + Expect(rt.clients).To(HaveLen(1)) + req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTrip(req2) + Expect(err).To(MatchError(streamOpenErr)) + Expect(rt.clients).To(HaveLen(1)) + }) }) Context("validating request", func() {