forked from quic-go/quic-go
http3: simplify tracking of content length (#4438)
This commit is contained in:
@@ -237,22 +237,21 @@ var errTooMuchData = errors.New("peer sent too much data")
|
|||||||
type lengthLimitedStream struct {
|
type lengthLimitedStream struct {
|
||||||
*stream
|
*stream
|
||||||
|
|
||||||
contentLength int64
|
remainingContentLength int64
|
||||||
read int64
|
resetStream bool
|
||||||
resetStream bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Stream = &lengthLimitedStream{}
|
var _ Stream = &lengthLimitedStream{}
|
||||||
|
|
||||||
func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream {
|
func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream {
|
||||||
return &lengthLimitedStream{
|
return &lengthLimitedStream{
|
||||||
stream: str,
|
stream: str,
|
||||||
contentLength: contentLength,
|
remainingContentLength: contentLength,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *lengthLimitedStream) checkContentLengthViolation() error {
|
func (s *lengthLimitedStream) checkContentLengthViolation() error {
|
||||||
if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() {
|
if s.remainingContentLength < 0 || s.remainingContentLength == 0 && s.hasMoreData() {
|
||||||
if !s.resetStream {
|
if !s.resetStream {
|
||||||
s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
|
s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
|
||||||
s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
|
s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
|
||||||
@@ -267,8 +266,8 @@ func (s *lengthLimitedStream) Read(b []byte) (int, error) {
|
|||||||
if err := s.checkContentLengthViolation(); err != nil {
|
if err := s.checkContentLengthViolation(); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
n, err := s.stream.Read(b[:min(int64(len(b)), s.contentLength-s.read)])
|
n, err := s.stream.Read(b[:min(int64(len(b)), s.remainingContentLength)])
|
||||||
s.read += int64(n)
|
s.remainingContentLength -= int64(n)
|
||||||
if err := s.checkContentLengthViolation(); err != nil {
|
if err := s.checkContentLengthViolation(); err != nil {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user