Merge pull request #736 from lucas-clemente/fix-735

implement a function to close the h2quic.RoundTripper and run it as a finalizer
This commit is contained in:
Marten Seemann
2017-07-12 20:44:59 +07:00
committed by GitHub
3 changed files with 82 additions and 10 deletions

View File

@@ -166,7 +166,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
responseChan := make(chan *http.Response) responseChan := make(chan *http.Response)
dataStream, err := c.session.OpenStreamSync() dataStream, err := c.session.OpenStreamSync()
if err != nil { if err != nil {
c.Close(err) _ = c.CloseWithError(err)
return nil, err return nil, err
} }
c.mutex.Lock() c.mutex.Lock()
@@ -181,7 +181,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
endStream := !hasBody endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil { if err != nil {
c.Close(err) _ = c.CloseWithError(err)
return nil, err return nil, err
} }
@@ -215,7 +215,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
} }
case <-c.headerErrored: case <-c.headerErrored:
// an error occured on the header stream // an error occured on the header stream
c.Close(c.headerErr) _ = c.CloseWithError(c.headerErr)
return nil, c.headerErr return nil, c.headerErr
} }
} }
@@ -261,8 +261,15 @@ func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
} }
// Close closes the client // Close closes the client
func (c *client) Close(e error) { func (c *client) CloseWithError(e error) error {
_ = c.session.Close(e) if c.session == nil {
return nil
}
return c.session.Close(e)
}
func (c *client) Close() error {
return c.CloseWithError(nil)
} }
// copied from net/transport.go // copied from net/transport.go

View File

@@ -4,15 +4,23 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"runtime"
"golang.org/x/net/lex/httplex" "golang.org/x/net/lex/httplex"
) )
type roundTripCloser interface {
http.RoundTripper
io.Closer
}
// RoundTripper implements the http.RoundTripper interface // RoundTripper implements the http.RoundTripper interface
type RoundTripper struct { type RoundTripper struct {
mutex sync.Mutex mutex sync.Mutex
@@ -35,10 +43,10 @@ type RoundTripper struct {
// If nil, reasonable default values will be used. // If nil, reasonable default values will be used.
QuicConfig *quic.Config QuicConfig *quic.Config
clients map[string]http.RoundTripper clients map[string]roundTripCloser
} }
var _ http.RoundTripper = &RoundTripper{} var _ roundTripCloser = &RoundTripper{}
// RoundTrip does a round trip // RoundTrip does a round trip
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -85,7 +93,8 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper {
defer r.mutex.Unlock() defer r.mutex.Unlock()
if r.clients == nil { if r.clients == nil {
r.clients = make(map[string]http.RoundTripper) runtime.SetFinalizer(r, finalizer)
r.clients = make(map[string]roundTripCloser)
} }
client, ok := r.clients[hostname] client, ok := r.clients[hostname]
@@ -96,6 +105,23 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper {
return client return client
} }
func finalizer(r *RoundTripper) {
_ = r.Close()
}
// Close closes the QUIC connections that this RoundTripper has used
func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.Close(); err != nil {
return err
}
}
r.clients = nil
return nil
}
func closeRequestBody(req *http.Request) { func closeRequestBody(req *http.Request) {
if req.Body != nil { if req.Body != nil {
req.Body.Close() req.Body.Close()

View File

@@ -6,6 +6,7 @@ import (
"errors" "errors"
"io" "io"
"net/http" "net/http"
"runtime"
"time" "time"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
@@ -13,11 +14,19 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
type mockRoundTripper struct{} type mockClient struct {
closed bool
}
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
return &http.Response{Request: req}, nil return &http.Response{Request: req}, nil
} }
func (m *mockClient) Close() error {
m.closed = true
return nil
}
var _ roundTripCloser = &mockClient{}
type mockBody struct { type mockBody struct {
reader bytes.Reader reader bytes.Reader
@@ -170,4 +179,34 @@ var _ = Describe("RoundTripper", func() {
Expect(req1.Body.(*mockBody).closed).To(BeTrue()) Expect(req1.Body.(*mockBody).closed).To(BeTrue())
}) })
}) })
Context("closing", func() {
It("closes", func() {
rt.clients = make(map[string]roundTripCloser)
cl := &mockClient{}
rt.clients["foo.bar"] = cl
err := rt.Close()
Expect(err).ToNot(HaveOccurred())
Expect(len(rt.clients)).To(BeZero())
Expect(cl.closed).To(BeTrue())
})
It("closes a RoundTripper that has never been used", func() {
Expect(len(rt.clients)).To(BeZero())
err := rt.Close()
Expect(err).ToNot(HaveOccurred())
Expect(len(rt.clients)).To(BeZero())
})
It("runs Close when the RoundTripper is garbage collected", func() {
// this is set by getClient, but we can't do that while at the same time injecting the mockClient
runtime.SetFinalizer(rt, finalizer)
rt.clients = make(map[string]roundTripCloser)
cl := &mockClient{}
rt.clients["foo.bar"] = cl
rt = nil // lose the references to the RoundTripper, such that it can be garbage collected
runtime.GC()
Eventually(func() bool { return cl.closed }).Should(BeTrue())
})
})
}) })