diff --git a/http3/transport.go b/http3/transport.go index 245d7157..4a863913 100644 --- a/http3/transport.go +++ b/http3/transport.go @@ -161,27 +161,37 @@ func (t *Transport) init() error { // RoundTripOpt is like RoundTrip, but takes options. func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + rsp, err := t.roundTripOpt(req, opt) + if err != nil { + if req.Body != nil { + req.Body.Close() + } + return nil, err + } + return rsp, nil +} + +func (t *Transport) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { t.initOnce.Do(func() { t.initErr = t.init() }) if t.initErr != nil { return nil, t.initErr } if req.URL == nil { - closeRequestBody(req) return nil, errors.New("http3: nil Request.URL") } if req.URL.Scheme != "https" { - closeRequestBody(req) return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) } if req.URL.Host == "" { - closeRequestBody(req) return nil, errors.New("http3: no Host in request URL") } if req.Header == nil { - closeRequestBody(req) return nil, errors.New("http3: nil Request.Header") } + if req.Method != "" && !validMethod(req.Method) { + return nil, fmt.Errorf("http3: invalid method %q", req.Method) + } for k, vv := range req.Header { if !httpguts.ValidHeaderFieldName(k) { return nil, fmt.Errorf("http3: invalid http header field name %q", k) @@ -193,11 +203,6 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res } } - if req.Method != "" && !validMethod(req.Method) { - closeRequestBody(req) - return nil, fmt.Errorf("http3: invalid method %q", req.Method) - } - trace := httptrace.ContextClientTrace(req.Context()) hostname := authorityAddr(hostnameFromURL(req.URL)) traceGetConn(trace, hostname) @@ -415,12 +420,6 @@ func (t *Transport) Close() error { return nil } -func closeRequestBody(req *http.Request) { - if req.Body != nil { - req.Body.Close() - } -} - func validMethod(method string) bool { /* Method = "OPTIONS" ; Section 9.2 diff --git a/http3/transport_test.go b/http3/transport_test.go index 9c4e06b0..bf136617 100644 --- a/http3/transport_test.go +++ b/http3/transport_test.go @@ -512,14 +512,18 @@ var _ = Describe("Transport", func() { It("rejects requests with invalid header name fields", func() { req.Header.Add("foobär", "value") + req.Body = &mockBody{} _, err := tr.RoundTrip(req) Expect(err).To(MatchError("http3: invalid http header field name \"foobär\"")) + Expect(req.Body.(*mockBody).closed).To(BeTrue()) }) It("rejects requests with invalid header name values", func() { req.Header.Add("foo", string([]byte{0x7})) + req.Body = &mockBody{} _, err := tr.RoundTrip(req) Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value")) + Expect(req.Body.(*mockBody).closed).To(BeTrue()) }) It("rejects requests with an invalid request method", func() {