diff --git a/internal/wire/streams_blocked_frame.go b/internal/wire/streams_blocked_frame.go index c46e87b47..42b455bfe 100644 --- a/internal/wire/streams_blocked_frame.go +++ b/internal/wire/streams_blocked_frame.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -31,7 +32,9 @@ func parseStreamsBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*Strea return nil, err } f.StreamLimit = protocol.StreamNum(streamLimit) - + if f.StreamLimit > protocol.MaxStreamCount { + return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit) + } return f, nil } diff --git a/internal/wire/streams_blocked_frame_test.go b/internal/wire/streams_blocked_frame_test.go index 9c942fd69..97820a2f1 100644 --- a/internal/wire/streams_blocked_frame_test.go +++ b/internal/wire/streams_blocked_frame_test.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "fmt" "io" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -44,6 +45,33 @@ var _ = Describe("STREAMS_BLOCKED frame", func() { Expect(err).To(MatchError(io.EOF)) } }) + + for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { + streamType := t + + It("accepts a frame containing the maximum stream count", func() { + f := &StreamsBlockedFrame{ + Type: streamType, + StreamLimit: protocol.MaxStreamCount, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + frame, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when receiving a too large stream count", func() { + f := &StreamsBlockedFrame{ + Type: streamType, + StreamLimit: protocol.MaxStreamCount + 1, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + _, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) + }) + } }) Context("writing", func() {