diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index e725ea7a..6b195b6f 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -11,6 +11,8 @@ import ( quic "github.com/lucas-clemente/quic-go" + "runtime" + "golang.org/x/net/lex/httplex" ) @@ -91,6 +93,7 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper { defer r.mutex.Unlock() if r.clients == nil { + runtime.SetFinalizer(r, finalizer) r.clients = make(map[string]roundTripCloser) } @@ -102,6 +105,10 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper { return client } +func finalizer(r *RoundTripper) { + _ = r.Close() +} + // Close closes the QUIC connections that this RoundTripper has used func (r *RoundTripper) Close() error { r.mutex.Lock() diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index b612d8bb..66e3eb8a 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -6,6 +6,7 @@ import ( "errors" "io" "net/http" + "runtime" "time" quic "github.com/lucas-clemente/quic-go" @@ -196,5 +197,16 @@ var _ = Describe("RoundTripper", func() { Expect(err).ToNot(HaveOccurred()) Expect(len(rt.clients)).To(BeZero()) }) + + It("runs Close when the RoundTripper is garbage collected", func() { + // this is set by getClient, but we can't do that while at the same time injecting the mockClient + runtime.SetFinalizer(rt, finalizer) + rt.clients = make(map[string]roundTripCloser) + cl := &mockClient{} + rt.clients["foo.bar"] = cl + rt = nil // lose the references to the RoundTripper, such that it can be garbage collected + runtime.GC() + Eventually(func() bool { return cl.closed }).Should(BeTrue()) + }) }) })