diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index 43c9f23a..51d5b423 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -1,6 +1,7 @@ package h2quic import ( + "errors" "net/http" "sync" ) @@ -30,6 +31,19 @@ var _ http.RoundTripper = &QuicRoundTripper{} // RoundTrip does a round trip func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if req.URL == nil { + closeRequestBody(req) + return nil, errors.New("quic: nil Request.URL") + } + if req.URL.Host == "" { + closeRequestBody(req) + return nil, errors.New("quic: no Host in request URL") + } + if req.Header == nil { + closeRequestBody(req) + return nil, errors.New("quic: nil Request.Header") + } + hostname := authorityAddr("https", hostnameFromRequest(req)) client, err := r.getClient(hostname) if err != nil { @@ -61,3 +75,9 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { func (r *QuicRoundTripper) disableCompression() bool { return r.DisableCompression } + +func closeRequestBody(req *http.Request) { + if req.Body != nil { + req.Body.Close() + } +} diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index cb038b97..6ff4aca5 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -13,6 +13,19 @@ func (m *mockQuicRoundTripper) Do(req *http.Request) (*http.Response, error) { return &http.Response{Request: req}, nil } +type mockBody struct { + closed bool +} + +func (m *mockBody) Read([]byte) (int, error) { + panic("not implemented") +} + +func (m *mockBody) Close() error { + m.closed = true + return nil +} + var _ = Describe("RoundTripper", func() { var ( rt *QuicRoundTripper @@ -26,6 +39,34 @@ var _ = Describe("RoundTripper", func() { Expect(err).ToNot(HaveOccurred()) }) + It("rejects requests without a URL", func() { + req1.URL = nil + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("quic: nil Request.URL")) + }) + + It("rejects request without a URL Host", func() { + req1.URL.Host = "" + req1.Body = &mockBody{} + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("quic: no Host in request URL")) + Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + }) + + It("closes the body for rejected requests", func() { + req1.URL = nil + req1.Body = &mockBody{} + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("quic: nil Request.URL")) + Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + }) + + It("rejects requests without a header", func() { + req1.Header = nil + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("quic: nil Request.Header")) + }) + It("reuses existing clients", func() { rt.clients = make(map[string]h2quicClient) rt.clients["www.example.org:443"] = &mockQuicRoundTripper{}