http3: simplify tracking of content length (#4438)

This commit is contained in:
Marten Seemann
2024-04-16 06:35:31 -07:00
committed by GitHub
parent afd0b061d5
commit 29785b29bf

View File

@@ -237,22 +237,21 @@ var errTooMuchData = errors.New("peer sent too much data")
type lengthLimitedStream struct { type lengthLimitedStream struct {
*stream *stream
contentLength int64 remainingContentLength int64
read int64 resetStream bool
resetStream bool
} }
var _ Stream = &lengthLimitedStream{} var _ Stream = &lengthLimitedStream{}
func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream { func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream {
return &lengthLimitedStream{ return &lengthLimitedStream{
stream: str, stream: str,
contentLength: contentLength, remainingContentLength: contentLength,
} }
} }
func (s *lengthLimitedStream) checkContentLengthViolation() error { func (s *lengthLimitedStream) checkContentLengthViolation() error {
if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() { if s.remainingContentLength < 0 || s.remainingContentLength == 0 && s.hasMoreData() {
if !s.resetStream { if !s.resetStream {
s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
@@ -267,8 +266,8 @@ func (s *lengthLimitedStream) Read(b []byte) (int, error) {
if err := s.checkContentLengthViolation(); err != nil { if err := s.checkContentLengthViolation(); err != nil {
return 0, err return 0, err
} }
n, err := s.stream.Read(b[:min(int64(len(b)), s.contentLength-s.read)]) n, err := s.stream.Read(b[:min(int64(len(b)), s.remainingContentLength)])
s.read += int64(n) s.remainingContentLength -= int64(n)
if err := s.checkContentLengthViolation(); err != nil { if err := s.checkContentLengthViolation(); err != nil {
return n, err return n, err
} }