forked from quic-go/quic-go
implement a function to close the h2quic.RoundTripper
h2quic.RoundTripper.Close() closes all QUIC connections that this roundtripper has used.
This commit is contained in:
@@ -166,7 +166,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
responseChan := make(chan *http.Response)
|
||||
dataStream, err := c.session.OpenStreamSync()
|
||||
if err != nil {
|
||||
c.Close(err)
|
||||
_ = c.CloseWithError(err)
|
||||
return nil, err
|
||||
}
|
||||
c.mutex.Lock()
|
||||
@@ -181,7 +181,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
endStream := !hasBody
|
||||
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
||||
if err != nil {
|
||||
c.Close(err)
|
||||
_ = c.CloseWithError(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -215,7 +215,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
case <-c.headerErrored:
|
||||
// an error occured on the header stream
|
||||
c.Close(c.headerErr)
|
||||
_ = c.CloseWithError(c.headerErr)
|
||||
return nil, c.headerErr
|
||||
}
|
||||
}
|
||||
@@ -261,8 +261,15 @@ func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
|
||||
}
|
||||
|
||||
// Close closes the client
|
||||
func (c *client) Close(e error) {
|
||||
_ = c.session.Close(e)
|
||||
func (c *client) CloseWithError(e error) error {
|
||||
if c.session == nil {
|
||||
return nil
|
||||
}
|
||||
return c.session.Close(e)
|
||||
}
|
||||
|
||||
func (c *client) Close() error {
|
||||
return c.CloseWithError(nil)
|
||||
}
|
||||
|
||||
// copied from net/transport.go
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -13,6 +14,11 @@ import (
|
||||
"golang.org/x/net/lex/httplex"
|
||||
)
|
||||
|
||||
type roundTripCloser interface {
|
||||
http.RoundTripper
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// RoundTripper implements the http.RoundTripper interface
|
||||
type RoundTripper struct {
|
||||
mutex sync.Mutex
|
||||
@@ -35,10 +41,10 @@ type RoundTripper struct {
|
||||
// If nil, reasonable default values will be used.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
clients map[string]http.RoundTripper
|
||||
clients map[string]roundTripCloser
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = &RoundTripper{}
|
||||
var _ roundTripCloser = &RoundTripper{}
|
||||
|
||||
// RoundTrip does a round trip
|
||||
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
@@ -85,7 +91,7 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper {
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if r.clients == nil {
|
||||
r.clients = make(map[string]http.RoundTripper)
|
||||
r.clients = make(map[string]roundTripCloser)
|
||||
}
|
||||
|
||||
client, ok := r.clients[hostname]
|
||||
@@ -96,6 +102,19 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper {
|
||||
return client
|
||||
}
|
||||
|
||||
// Close closes the QUIC connections that this RoundTripper has used
|
||||
func (r *RoundTripper) Close() error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for _, client := range r.clients {
|
||||
if err := client.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.clients = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeRequestBody(req *http.Request) {
|
||||
if req.Body != nil {
|
||||
req.Body.Close()
|
||||
|
||||
@@ -13,11 +13,19 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockRoundTripper struct{}
|
||||
type mockClient struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{Request: req}, nil
|
||||
}
|
||||
func (m *mockClient) Close() error {
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ roundTripCloser = &mockClient{}
|
||||
|
||||
type mockBody struct {
|
||||
reader bytes.Reader
|
||||
@@ -170,4 +178,23 @@ var _ = Describe("RoundTripper", func() {
|
||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
It("closes", func() {
|
||||
rt.clients = make(map[string]roundTripCloser)
|
||||
cl := &mockClient{}
|
||||
rt.clients["foo.bar"] = cl
|
||||
err := rt.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(rt.clients)).To(BeZero())
|
||||
Expect(cl.closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("closes a RoundTripper that has never been used", func() {
|
||||
Expect(len(rt.clients)).To(BeZero())
|
||||
err := rt.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(rt.clients)).To(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user