From feec325083977708028fe1f128c60909779cf318 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 22 Dec 2016 14:13:21 +0700 Subject: [PATCH] validate HTTP headers and request scheme in RoundTripper --- h2quic/roundtrip.go | 19 ++++++++++ h2quic/roundtrip_test.go | 82 ++++++++++++++++++++++++++-------------- 2 files changed, 73 insertions(+), 28 deletions(-) diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index 51d5b4230..c8ddbd344 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -2,8 +2,11 @@ package h2quic import ( "errors" + "fmt" "net/http" "sync" + + "golang.org/x/net/lex/httplex" ) type h2quicClient interface { @@ -44,6 +47,22 @@ func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return nil, errors.New("quic: nil Request.Header") } + if req.URL.Scheme == "https" { + for k, vv := range req.Header { + if !httplex.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("quic: invalid http header field name %q", k) + } + for _, v := range vv { + if !httplex.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k) + } + } + } + } else { + closeRequestBody(req) + return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme) + } + hostname := authorityAddr("https", hostnameFromRequest(req)) client, err := r.getClient(hostname) if err != nil { diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index 6ff4aca54..7d038d5fc 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -39,34 +39,6 @@ var _ = Describe("RoundTripper", func() { Expect(err).ToNot(HaveOccurred()) }) - It("rejects requests without a URL", func() { - req1.URL = nil - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("quic: nil Request.URL")) - }) - - It("rejects request without a URL Host", func() { - req1.URL.Host = "" - req1.Body = &mockBody{} - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("quic: no Host in request URL")) - Expect(req1.Body.(*mockBody).closed).To(BeTrue()) - }) - - It("closes the body for rejected requests", func() { - req1.URL = nil - req1.Body = &mockBody{} - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("quic: nil Request.URL")) - Expect(req1.Body.(*mockBody).closed).To(BeTrue()) - }) - - It("rejects requests without a header", func() { - req1.Header = nil - _, err := rt.RoundTrip(req1) - Expect(err).To(MatchError("quic: nil Request.Header")) - }) - It("reuses existing clients", func() { rt.clients = make(map[string]h2quicClient) rt.clients["www.example.org:443"] = &mockQuicRoundTripper{} @@ -80,4 +52,58 @@ var _ = Describe("RoundTripper", func() { rt.DisableCompression = true Expect(rt.disableCompression()).To(BeTrue()) }) + + Context("validating request", func() { + It("rejects plain HTTP requests", func() { + req, err := http.NewRequest("GET", "http://www.example.org/", nil) + req.Body = &mockBody{} + Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTrip(req) + Expect(err).To(MatchError("quic: unsupported protocol scheme: http")) + Expect(req.Body.(*mockBody).closed).To(BeTrue()) + }) + + It("rejects requests without a URL", func() { + req1.URL = nil + req1.Body = &mockBody{} + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("quic: nil Request.URL")) + Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + }) + + It("rejects request without a URL Host", func() { + req1.URL.Host = "" + req1.Body = &mockBody{} + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("quic: no Host in request URL")) + Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + }) + + It("doesn't try to close the body if the request doesn't have one", func() { + req1.URL = nil + Expect(req1.Body).To(BeNil()) + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("quic: nil Request.URL")) + }) + + It("rejects requests without a header", func() { + req1.Header = nil + req1.Body = &mockBody{} + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("quic: nil Request.Header")) + Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + }) + + It("rejects requests with invalid header name fields", func() { + req1.Header.Add("foobär", "value") + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError("quic: invalid http header field name \"foobär\"")) + }) + + It("rejects requests with invalid header name values", func() { + req1.Header.Add("foo", string([]byte{0x7})) + _, err := rt.RoundTrip(req1) + Expect(err.Error()).To(ContainSubstring("quic: invalid http header field value")) + }) + }) })