From bf6030a855e708f1334f6693d700bbb6d03e23e5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 12 Jul 2017 18:09:30 +0700 Subject: [PATCH 1/2] implement a function to close the h2quic.RoundTripper h2quic.RoundTripper.Close() closes all QUIC connections that this roundtripper has used. --- h2quic/client.go | 17 ++++++++++++----- h2quic/roundtrip.go | 25 ++++++++++++++++++++++--- h2quic/roundtrip_test.go | 31 +++++++++++++++++++++++++++++-- 3 files changed, 63 insertions(+), 10 deletions(-) diff --git a/h2quic/client.go b/h2quic/client.go index bba706bf..866b11ab 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -166,7 +166,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { responseChan := make(chan *http.Response) dataStream, err := c.session.OpenStreamSync() if err != nil { - c.Close(err) + _ = c.CloseWithError(err) return nil, err } c.mutex.Lock() @@ -181,7 +181,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { endStream := !hasBody err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) if err != nil { - c.Close(err) + _ = c.CloseWithError(err) return nil, err } @@ -215,7 +215,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { } case <-c.headerErrored: // an error occured on the header stream - c.Close(c.headerErr) + _ = c.CloseWithError(c.headerErr) return nil, c.headerErr } } @@ -261,8 +261,15 @@ func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e } // Close closes the client -func (c *client) Close(e error) { - _ = c.session.Close(e) +func (c *client) CloseWithError(e error) error { + if c.session == nil { + return nil + } + return c.session.Close(e) +} + +func (c *client) Close() error { + return c.CloseWithError(nil) } // copied from net/transport.go diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index ce2bbe96..e725ea7a 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "net/http" "strings" "sync" @@ -13,6 +14,11 @@ import ( "golang.org/x/net/lex/httplex" ) +type roundTripCloser interface { + http.RoundTripper + io.Closer +} + // RoundTripper implements the http.RoundTripper interface type RoundTripper struct { mutex sync.Mutex @@ -35,10 +41,10 @@ type RoundTripper struct { // If nil, reasonable default values will be used. QuicConfig *quic.Config - clients map[string]http.RoundTripper + clients map[string]roundTripCloser } -var _ http.RoundTripper = &RoundTripper{} +var _ roundTripCloser = &RoundTripper{} // RoundTrip does a round trip func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -85,7 +91,7 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper { defer r.mutex.Unlock() if r.clients == nil { - r.clients = make(map[string]http.RoundTripper) + r.clients = make(map[string]roundTripCloser) } client, ok := r.clients[hostname] @@ -96,6 +102,19 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper { return client } +// Close closes the QUIC connections that this RoundTripper has used +func (r *RoundTripper) Close() error { + r.mutex.Lock() + defer r.mutex.Unlock() + for _, client := range r.clients { + if err := client.Close(); err != nil { + return err + } + } + r.clients = nil + return nil +} + func closeRequestBody(req *http.Request) { if req.Body != nil { req.Body.Close() diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index 635a50a3..b612d8bb 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -13,11 +13,19 @@ import ( . "github.com/onsi/gomega" ) -type mockRoundTripper struct{} +type mockClient struct { + closed bool +} -func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { +func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) { return &http.Response{Request: req}, nil } +func (m *mockClient) Close() error { + m.closed = true + return nil +} + +var _ roundTripCloser = &mockClient{} type mockBody struct { reader bytes.Reader @@ -170,4 +178,23 @@ var _ = Describe("RoundTripper", func() { Expect(req1.Body.(*mockBody).closed).To(BeTrue()) }) }) + + Context("closing", func() { + It("closes", func() { + rt.clients = make(map[string]roundTripCloser) + cl := &mockClient{} + rt.clients["foo.bar"] = cl + err := rt.Close() + Expect(err).ToNot(HaveOccurred()) + Expect(len(rt.clients)).To(BeZero()) + Expect(cl.closed).To(BeTrue()) + }) + + It("closes a RoundTripper that has never been used", func() { + Expect(len(rt.clients)).To(BeZero()) + err := rt.Close() + Expect(err).ToNot(HaveOccurred()) + Expect(len(rt.clients)).To(BeZero()) + }) + }) }) From 65cea185bd1762e73a4111e440b9b27d895a4994 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 12 Jul 2017 19:20:21 +0700 Subject: [PATCH 2/2] use a finalizer to close the h2quic.RoundTripper The finalizer is executed when the RoundTripper is garbage collected. This is not a perfect solution, since there are situations when an unneeded RoundTripper is not garbage collected, e.g. when the program exits before the GC ran. In those cases, the server will run into the idle timeout and eventually close the connection on its side. --- h2quic/roundtrip.go | 7 +++++++ h2quic/roundtrip_test.go | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index e725ea7a..6b195b6f 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -11,6 +11,8 @@ import ( quic "github.com/lucas-clemente/quic-go" + "runtime" + "golang.org/x/net/lex/httplex" ) @@ -91,6 +93,7 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper { defer r.mutex.Unlock() if r.clients == nil { + runtime.SetFinalizer(r, finalizer) r.clients = make(map[string]roundTripCloser) } @@ -102,6 +105,10 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper { return client } +func finalizer(r *RoundTripper) { + _ = r.Close() +} + // Close closes the QUIC connections that this RoundTripper has used func (r *RoundTripper) Close() error { r.mutex.Lock() diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index b612d8bb..66e3eb8a 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -6,6 +6,7 @@ import ( "errors" "io" "net/http" + "runtime" "time" quic "github.com/lucas-clemente/quic-go" @@ -196,5 +197,16 @@ var _ = Describe("RoundTripper", func() { Expect(err).ToNot(HaveOccurred()) Expect(len(rt.clients)).To(BeZero()) }) + + It("runs Close when the RoundTripper is garbage collected", func() { + // this is set by getClient, but we can't do that while at the same time injecting the mockClient + runtime.SetFinalizer(rt, finalizer) + rt.clients = make(map[string]roundTripCloser) + cl := &mockClient{} + rt.clients["foo.bar"] = cl + rt = nil // lose the references to the RoundTripper, such that it can be garbage collected + runtime.GC() + Eventually(func() bool { return cl.closed }).Should(BeTrue()) + }) }) })