forked from quic-go/quic-go
http3: move length limiting logic to the body (#4439)
The length limit is a property of the (request and response) body. As such, it's better implemented there than by wrapping the HTTP stream.
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
@@ -100,6 +101,10 @@ func (s *stream) Write(b []byte) (int, error) {
|
||||
return s.Stream.Write(b)
|
||||
}
|
||||
|
||||
func (s *stream) StreamID() protocol.StreamID {
|
||||
return s.Stream.StreamID()
|
||||
}
|
||||
|
||||
// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly
|
||||
// from the QUIC stream, it writes to and reads from the HTTP stream.
|
||||
type requestStream struct {
|
||||
@@ -197,13 +202,11 @@ func (s *requestStream) ReadResponse() (*http.Response, error) {
|
||||
|
||||
// Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
|
||||
// See section 4.1.2 of RFC 9114.
|
||||
var httpStr Stream
|
||||
contentLength := int64(-1)
|
||||
if _, ok := res.Header["Content-Length"]; ok && res.ContentLength >= 0 {
|
||||
httpStr = newLengthLimitedStream(s.stream, res.ContentLength)
|
||||
} else {
|
||||
httpStr = s.stream
|
||||
contentLength = res.ContentLength
|
||||
}
|
||||
respBody := newResponseBody(httpStr, s.reqDone)
|
||||
respBody := newResponseBody(s.stream, contentLength, s.reqDone)
|
||||
|
||||
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
|
||||
_, hasTransferEncoding := res.Header["Transfer-Encoding"]
|
||||
@@ -231,45 +234,3 @@ func (s *requestStream) ReadResponse() (*http.Response, error) {
|
||||
res.Body = s.responseBody
|
||||
return res, nil
|
||||
}
|
||||
|
||||
var errTooMuchData = errors.New("peer sent too much data")
|
||||
|
||||
type lengthLimitedStream struct {
|
||||
*stream
|
||||
|
||||
remainingContentLength int64
|
||||
resetStream bool
|
||||
}
|
||||
|
||||
var _ Stream = &lengthLimitedStream{}
|
||||
|
||||
func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream {
|
||||
return &lengthLimitedStream{
|
||||
stream: str,
|
||||
remainingContentLength: contentLength,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *lengthLimitedStream) checkContentLengthViolation() error {
|
||||
if s.remainingContentLength < 0 || s.remainingContentLength == 0 && s.hasMoreData() {
|
||||
if !s.resetStream {
|
||||
s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
s.resetStream = true
|
||||
}
|
||||
return errTooMuchData
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *lengthLimitedStream) Read(b []byte) (int, error) {
|
||||
if err := s.checkContentLengthViolation(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err := s.stream.Read(b[:min(int64(len(b)), s.remainingContentLength)])
|
||||
s.remainingContentLength -= int64(n)
|
||||
if err := s.checkContentLengthViolation(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user