forked from quic-go/quic-go
http3: don't send more than http.Request.ContentLength bytes (#3960)
This commit is contained in:
@@ -327,31 +327,43 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
|
||||
return rsp, rerr.err
|
||||
}
|
||||
|
||||
func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error {
|
||||
defer body.Close()
|
||||
b := make([]byte, bodyCopyBufferSize)
|
||||
for {
|
||||
n, rerr := body.Read(b)
|
||||
if n == 0 {
|
||||
if rerr == nil {
|
||||
continue
|
||||
}
|
||||
if rerr == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
if _, err := str.Write(b[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
if rerr != nil {
|
||||
if rerr == io.EOF {
|
||||
break
|
||||
}
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
return rerr
|
||||
}
|
||||
// cancelingReader reads from the io.Reader.
|
||||
// It cancels writing on the stream if any error other than io.EOF occurs.
|
||||
type cancelingReader struct {
|
||||
r io.Reader
|
||||
str Stream
|
||||
}
|
||||
|
||||
func (r *cancelingReader) Read(b []byte) (int, error) {
|
||||
n, err := r.r.Read(b)
|
||||
if err != nil && err != io.EOF {
|
||||
r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
}
|
||||
return nil
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
|
||||
defer body.Close()
|
||||
buf := make([]byte, bodyCopyBufferSize)
|
||||
sr := &cancelingReader{str: str, r: body}
|
||||
if contentLength == -1 {
|
||||
_, err := io.CopyBuffer(str, sr, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// make sure we don't send more bytes than the content length
|
||||
n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var extra int64
|
||||
extra, err = io.CopyBuffer(io.Discard, sr, buf)
|
||||
n += extra
|
||||
if n > contentLength {
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
|
||||
@@ -371,7 +383,13 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui
|
||||
if req.Body != nil {
|
||||
// send the request body asynchronously
|
||||
go func() {
|
||||
if err := c.sendRequestBody(hstr, req.Body); err != nil {
|
||||
contentLength := int64(-1)
|
||||
// According to the documentation for http.Request.ContentLength,
|
||||
// a value of 0 with a non-nil Body is also treated as unknown content length.
|
||||
if req.ContentLength > 0 {
|
||||
contentLength = req.ContentLength
|
||||
}
|
||||
if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil {
|
||||
c.logger.Errorf("Error writing request: %s", err)
|
||||
}
|
||||
if !opt.DontCloseRequestStream {
|
||||
|
||||
Reference in New Issue
Block a user