forked from quic-go/quic-go
The finalizer is executed when the RoundTripper is garbage collected. This is not a perfect solution, since there are situations when an unneeded RoundTripper is not garbage collected, e.g. when the program exits before the GC ran. In those cases, the server will run into the idle timeout and eventually close the connection on its side.
213 lines
5.9 KiB
Go
213 lines
5.9 KiB
Go
package h2quic
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"runtime"
|
|
"time"
|
|
|
|
quic "github.com/lucas-clemente/quic-go"
|
|
. "github.com/onsi/ginkgo"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
type mockClient struct {
|
|
closed bool
|
|
}
|
|
|
|
func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return &http.Response{Request: req}, nil
|
|
}
|
|
func (m *mockClient) Close() error {
|
|
m.closed = true
|
|
return nil
|
|
}
|
|
|
|
var _ roundTripCloser = &mockClient{}
|
|
|
|
type mockBody struct {
|
|
reader bytes.Reader
|
|
readErr error
|
|
closeErr error
|
|
closed bool
|
|
}
|
|
|
|
func (m *mockBody) Read(p []byte) (int, error) {
|
|
if m.readErr != nil {
|
|
return 0, m.readErr
|
|
}
|
|
return m.reader.Read(p)
|
|
}
|
|
|
|
func (m *mockBody) SetData(data []byte) {
|
|
m.reader = *bytes.NewReader(data)
|
|
}
|
|
|
|
func (m *mockBody) Close() error {
|
|
m.closed = true
|
|
return m.closeErr
|
|
}
|
|
|
|
// make sure the mockBody can be used as a http.Request.Body
|
|
var _ io.ReadCloser = &mockBody{}
|
|
|
|
var _ = Describe("RoundTripper", func() {
|
|
var (
|
|
rt *RoundTripper
|
|
req1 *http.Request
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
rt = &RoundTripper{}
|
|
var err error
|
|
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
Context("dialing hosts", func() {
|
|
origDialAddr := dialAddr
|
|
streamOpenErr := errors.New("error opening stream")
|
|
|
|
BeforeEach(func() {
|
|
origDialAddr = dialAddr
|
|
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
|
// return an error when trying to open a stream
|
|
// we don't want to test all the dial logic here, just that dialing happens at all
|
|
return &mockSession{streamOpenErr: streamOpenErr}, nil
|
|
}
|
|
})
|
|
|
|
AfterEach(func() {
|
|
dialAddr = origDialAddr
|
|
})
|
|
|
|
It("creates new clients", func() {
|
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = rt.RoundTrip(req)
|
|
Expect(err).To(MatchError(streamOpenErr))
|
|
Expect(rt.clients).To(HaveLen(1))
|
|
})
|
|
|
|
It("uses the quic.Config, if provided", func() {
|
|
config := &quic.Config{HandshakeTimeout: time.Millisecond}
|
|
var receivedConfig *quic.Config
|
|
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
|
receivedConfig = config
|
|
return nil, errors.New("err")
|
|
}
|
|
rt.QuicConfig = config
|
|
rt.RoundTrip(req1)
|
|
Expect(receivedConfig).To(Equal(config))
|
|
})
|
|
|
|
It("reuses existing clients", func() {
|
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = rt.RoundTrip(req)
|
|
Expect(err).To(MatchError(streamOpenErr))
|
|
Expect(rt.clients).To(HaveLen(1))
|
|
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = rt.RoundTrip(req2)
|
|
Expect(err).To(MatchError(streamOpenErr))
|
|
Expect(rt.clients).To(HaveLen(1))
|
|
})
|
|
})
|
|
|
|
Context("validating request", func() {
|
|
It("rejects plain HTTP requests", func() {
|
|
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
|
|
req.Body = &mockBody{}
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = rt.RoundTrip(req)
|
|
Expect(err).To(MatchError("quic: unsupported protocol scheme: http"))
|
|
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
|
|
It("rejects requests without a URL", func() {
|
|
req1.URL = nil
|
|
req1.Body = &mockBody{}
|
|
_, err := rt.RoundTrip(req1)
|
|
Expect(err).To(MatchError("quic: nil Request.URL"))
|
|
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
|
|
It("rejects request without a URL Host", func() {
|
|
req1.URL.Host = ""
|
|
req1.Body = &mockBody{}
|
|
_, err := rt.RoundTrip(req1)
|
|
Expect(err).To(MatchError("quic: no Host in request URL"))
|
|
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
|
|
It("doesn't try to close the body if the request doesn't have one", func() {
|
|
req1.URL = nil
|
|
Expect(req1.Body).To(BeNil())
|
|
_, err := rt.RoundTrip(req1)
|
|
Expect(err).To(MatchError("quic: nil Request.URL"))
|
|
})
|
|
|
|
It("rejects requests without a header", func() {
|
|
req1.Header = nil
|
|
req1.Body = &mockBody{}
|
|
_, err := rt.RoundTrip(req1)
|
|
Expect(err).To(MatchError("quic: nil Request.Header"))
|
|
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
|
|
It("rejects requests with invalid header name fields", func() {
|
|
req1.Header.Add("foobär", "value")
|
|
_, err := rt.RoundTrip(req1)
|
|
Expect(err).To(MatchError("quic: invalid http header field name \"foobär\""))
|
|
})
|
|
|
|
It("rejects requests with invalid header name values", func() {
|
|
req1.Header.Add("foo", string([]byte{0x7}))
|
|
_, err := rt.RoundTrip(req1)
|
|
Expect(err.Error()).To(ContainSubstring("quic: invalid http header field value"))
|
|
})
|
|
|
|
It("rejects requests with an invalid request method", func() {
|
|
req1.Method = "foobär"
|
|
req1.Body = &mockBody{}
|
|
_, err := rt.RoundTrip(req1)
|
|
Expect(err).To(MatchError("quic: invalid method \"foobär\""))
|
|
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
})
|
|
|
|
Context("closing", func() {
|
|
It("closes", func() {
|
|
rt.clients = make(map[string]roundTripCloser)
|
|
cl := &mockClient{}
|
|
rt.clients["foo.bar"] = cl
|
|
err := rt.Close()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(rt.clients)).To(BeZero())
|
|
Expect(cl.closed).To(BeTrue())
|
|
})
|
|
|
|
It("closes a RoundTripper that has never been used", func() {
|
|
Expect(len(rt.clients)).To(BeZero())
|
|
err := rt.Close()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(rt.clients)).To(BeZero())
|
|
})
|
|
|
|
It("runs Close when the RoundTripper is garbage collected", func() {
|
|
// this is set by getClient, but we can't do that while at the same time injecting the mockClient
|
|
runtime.SetFinalizer(rt, finalizer)
|
|
rt.clients = make(map[string]roundTripCloser)
|
|
cl := &mockClient{}
|
|
rt.clients["foo.bar"] = cl
|
|
rt = nil // lose the references to the RoundTripper, such that it can be garbage collected
|
|
runtime.GC()
|
|
Eventually(func() bool { return cl.closed }).Should(BeTrue())
|
|
})
|
|
})
|
|
})
|