forked from quic-go/quic-go
Merge pull request #736 from lucas-clemente/fix-735
implement a function to close the h2quic.RoundTripper and run it as a finalizer
This commit is contained in:
@@ -166,7 +166,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
responseChan := make(chan *http.Response)
|
responseChan := make(chan *http.Response)
|
||||||
dataStream, err := c.session.OpenStreamSync()
|
dataStream, err := c.session.OpenStreamSync()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Close(err)
|
_ = c.CloseWithError(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
@@ -181,7 +181,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
endStream := !hasBody
|
endStream := !hasBody
|
||||||
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Close(err)
|
_ = c.CloseWithError(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,7 +215,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
}
|
}
|
||||||
case <-c.headerErrored:
|
case <-c.headerErrored:
|
||||||
// an error occured on the header stream
|
// an error occured on the header stream
|
||||||
c.Close(c.headerErr)
|
_ = c.CloseWithError(c.headerErr)
|
||||||
return nil, c.headerErr
|
return nil, c.headerErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -261,8 +261,15 @@ func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the client
|
// Close closes the client
|
||||||
func (c *client) Close(e error) {
|
func (c *client) CloseWithError(e error) error {
|
||||||
_ = c.session.Close(e)
|
if c.session == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.session.Close(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) Close() error {
|
||||||
|
return c.CloseWithError(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// copied from net/transport.go
|
// copied from net/transport.go
|
||||||
|
|||||||
@@ -4,15 +4,23 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
|
||||||
|
"runtime"
|
||||||
|
|
||||||
"golang.org/x/net/lex/httplex"
|
"golang.org/x/net/lex/httplex"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type roundTripCloser interface {
|
||||||
|
http.RoundTripper
|
||||||
|
io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
// RoundTripper implements the http.RoundTripper interface
|
// RoundTripper implements the http.RoundTripper interface
|
||||||
type RoundTripper struct {
|
type RoundTripper struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
@@ -35,10 +43,10 @@ type RoundTripper struct {
|
|||||||
// If nil, reasonable default values will be used.
|
// If nil, reasonable default values will be used.
|
||||||
QuicConfig *quic.Config
|
QuicConfig *quic.Config
|
||||||
|
|
||||||
clients map[string]http.RoundTripper
|
clients map[string]roundTripCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ http.RoundTripper = &RoundTripper{}
|
var _ roundTripCloser = &RoundTripper{}
|
||||||
|
|
||||||
// RoundTrip does a round trip
|
// RoundTrip does a round trip
|
||||||
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
@@ -85,7 +93,8 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper {
|
|||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
if r.clients == nil {
|
if r.clients == nil {
|
||||||
r.clients = make(map[string]http.RoundTripper)
|
runtime.SetFinalizer(r, finalizer)
|
||||||
|
r.clients = make(map[string]roundTripCloser)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, ok := r.clients[hostname]
|
client, ok := r.clients[hostname]
|
||||||
@@ -96,6 +105,23 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper {
|
|||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func finalizer(r *RoundTripper) {
|
||||||
|
_ = r.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the QUIC connections that this RoundTripper has used
|
||||||
|
func (r *RoundTripper) Close() error {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
for _, client := range r.clients {
|
||||||
|
if err := client.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.clients = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func closeRequestBody(req *http.Request) {
|
func closeRequestBody(req *http.Request) {
|
||||||
if req.Body != nil {
|
if req.Body != nil {
|
||||||
req.Body.Close()
|
req.Body.Close()
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
@@ -13,11 +14,19 @@ import (
|
|||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockRoundTripper struct{}
|
type mockClient struct {
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return &http.Response{Request: req}, nil
|
return &http.Response{Request: req}, nil
|
||||||
}
|
}
|
||||||
|
func (m *mockClient) Close() error {
|
||||||
|
m.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ roundTripCloser = &mockClient{}
|
||||||
|
|
||||||
type mockBody struct {
|
type mockBody struct {
|
||||||
reader bytes.Reader
|
reader bytes.Reader
|
||||||
@@ -170,4 +179,34 @@ var _ = Describe("RoundTripper", func() {
|
|||||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("closing", func() {
|
||||||
|
It("closes", func() {
|
||||||
|
rt.clients = make(map[string]roundTripCloser)
|
||||||
|
cl := &mockClient{}
|
||||||
|
rt.clients["foo.bar"] = cl
|
||||||
|
err := rt.Close()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(len(rt.clients)).To(BeZero())
|
||||||
|
Expect(cl.closed).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("closes a RoundTripper that has never been used", func() {
|
||||||
|
Expect(len(rt.clients)).To(BeZero())
|
||||||
|
err := rt.Close()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(len(rt.clients)).To(BeZero())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("runs Close when the RoundTripper is garbage collected", func() {
|
||||||
|
// this is set by getClient, but we can't do that while at the same time injecting the mockClient
|
||||||
|
runtime.SetFinalizer(rt, finalizer)
|
||||||
|
rt.clients = make(map[string]roundTripCloser)
|
||||||
|
cl := &mockClient{}
|
||||||
|
rt.clients["foo.bar"] = cl
|
||||||
|
rt = nil // lose the references to the RoundTripper, such that it can be garbage collected
|
||||||
|
runtime.GC()
|
||||||
|
Eventually(func() bool { return cl.closed }).Should(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user