From 0d4dd8869d26bdfb56b0a05465e756ecd5140ab8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 16 May 2016 18:15:40 +0700 Subject: [PATCH] detect stream flow control violations fixes #97 --- session.go | 1 + stream.go | 9 ++++- stream_test.go | 101 +++++++++++++++++++++++++++++++++++++------------ 3 files changed, 86 insertions(+), 25 deletions(-) diff --git a/session.go b/session.go index b086e378..de64655d 100644 --- a/session.go +++ b/session.go @@ -212,6 +212,7 @@ func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, da case *frames.StreamFrame: utils.Debugf("\t<- &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", frame.StreamID, frame.FinBit, frame.Offset, len(frame.Data), frame.Offset+protocol.ByteCount(len(frame.Data))) err = s.handleStreamFrame(frame) + // TODO: send error for flow control violation // TODO: send RstStreamFrame case *frames.AckFrame: err = s.handleAckFrame(frame) diff --git a/stream.go b/stream.go index 6106767f..975a133e 100644 --- a/stream.go +++ b/stream.go @@ -1,6 +1,7 @@ package quic import ( + "errors" "io" "sync" "sync/atomic" @@ -16,6 +17,8 @@ type streamHandler interface { updateReceiveFlowControlWindow(streamID protocol.StreamID, byteOffset protocol.ByteCount) error } +var errFlowControlViolation = errors.New("flow control violation") + // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface type stream struct { streamID protocol.StreamID @@ -208,7 +211,11 @@ func (s *stream) Close() error { // AddStreamFrame adds a new stream frame func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { - // TODO: return flow control window violation here + maxOffset := frame.Offset + protocol.ByteCount(len(frame.Data)) + if maxOffset > s.receiveFlowControlWindow { + return errFlowControlViolation + } + s.mutex.Lock() s.frameQueue.Push(frame) s.mutex.Unlock() diff --git a/stream_test.go b/stream_test.go index b466e5c2..0fcc4f25 100644 --- a/stream_test.go +++ b/stream_test.go @@ -48,7 +48,8 @@ var _ = Describe("Stream", func() { Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } - str.AddStreamFrame(&frame) + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 4) n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -61,7 +62,8 @@ var _ = Describe("Stream", func() { Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } - str.AddStreamFrame(&frame) + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 2) n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -78,7 +80,8 @@ var _ = Describe("Stream", func() { Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } - str.AddStreamFrame(&frame) + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) b, err := str.ReadByte() Expect(err).ToNot(HaveOccurred()) Expect(b).To(Equal(byte(0xDE))) @@ -102,8 +105,10 @@ var _ = Describe("Stream", func() { Offset: 2, Data: []byte{0xBE, 0xEF}, } - str.AddStreamFrame(&frame1) - str.AddStreamFrame(&frame2) + err := str.AddStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.AddStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 6) n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -120,8 +125,10 @@ var _ = Describe("Stream", func() { Offset: 2, Data: []byte{0xBE, 0xEF}, } - str.AddStreamFrame(&frame1) - str.AddStreamFrame(&frame2) + err := str.AddStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.AddStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 4) n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -136,7 +143,8 @@ var _ = Describe("Stream", func() { Data: []byte{0xDE, 0xAD}, } time.Sleep(time.Millisecond) - str.AddStreamFrame(&frame) + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) }() b := make([]byte, 2) n, err := str.Read(b) @@ -153,8 +161,10 @@ var _ = Describe("Stream", func() { Offset: 0, Data: []byte{0xDE, 0xAD}, } - str.AddStreamFrame(&frame1) - str.AddStreamFrame(&frame2) + err := str.AddStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.AddStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 4) n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -175,9 +185,12 @@ var _ = Describe("Stream", func() { Offset: 2, Data: []byte{0xBE, 0xEF}, } - str.AddStreamFrame(&frame1) - str.AddStreamFrame(&frame2) - str.AddStreamFrame(&frame3) + err := str.AddStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.AddStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + err = str.AddStreamFrame(&frame3) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 4) n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -198,9 +211,12 @@ var _ = Describe("Stream", func() { Offset: 2, Data: []byte("cd"), } - str.AddStreamFrame(&frame1) - str.AddStreamFrame(&frame2) - str.AddStreamFrame(&frame3) + err := str.AddStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.AddStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + err = str.AddStreamFrame(&frame3) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 4) n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -358,7 +374,8 @@ var _ = Describe("Stream", func() { Offset: 0, Data: bytes.Repeat([]byte{'f'}, len), } - str.AddStreamFrame(&frame) + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, len) n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -373,13 +390,43 @@ var _ = Describe("Stream", func() { Offset: 0, Data: bytes.Repeat([]byte{'f'}, len), } - str.AddStreamFrame(&frame) + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, len) n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(len)) Expect(str.receiveFlowControlWindow).To(Equal(receiveFlowControlWindow)) }) + + It("accepts frames that completely fill the flow control window", func() { + len := int(protocol.ReceiveStreamFlowControlWindow) + frame := frames.StreamFrame{ + Offset: 0, + Data: bytes.Repeat([]byte{'f'}, len), + } + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) + }) + + It("rejects too large frames that would violate the flow control window", func() { + len := int(protocol.ReceiveStreamFlowControlWindow) + 1 + frame := frames.StreamFrame{ + Offset: 0, + Data: bytes.Repeat([]byte{'f'}, len), + } + err := str.AddStreamFrame(&frame) + Expect(err).To(MatchError(errFlowControlViolation)) + }) + + It("rejects a small frames that would violate the flow control window", func() { + frame := frames.StreamFrame{ + Offset: protocol.ReceiveStreamFlowControlWindow - 1, + Data: []byte{0x13, 0x37}, + } + err := str.AddStreamFrame(&frame) + Expect(err).To(MatchError(errFlowControlViolation)) + }) }) Context("closing", func() { @@ -411,8 +458,10 @@ var _ = Describe("Stream", func() { Offset: 0, Data: []byte{0xDE, 0xAD}, } - str.AddStreamFrame(&frame1) - str.AddStreamFrame(&frame2) + err := str.AddStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.AddStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 4) n, err := str.Read(b) Expect(err).To(Equal(io.EOF)) @@ -429,7 +478,8 @@ var _ = Describe("Stream", func() { Data: []byte{0xDE, 0xAD}, FinBit: true, } - str.AddStreamFrame(&frame) + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 4) n, err := str.Read(b) Expect(err).To(Equal(io.EOF)) @@ -443,7 +493,8 @@ var _ = Describe("Stream", func() { Data: []byte{}, FinBit: true, } - str.AddStreamFrame(&frame) + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) b := make([]byte, 4) n, err := str.Read(b) Expect(n).To(BeZero()) @@ -460,7 +511,8 @@ var _ = Describe("Stream", func() { Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, FinBit: true, } - str.AddStreamFrame(&frame) + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) str.RegisterError(testErr) b := make([]byte, 4) n, err := str.Read(b) @@ -477,7 +529,8 @@ var _ = Describe("Stream", func() { Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } - str.AddStreamFrame(&frame) + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) str.RegisterError(testErr) b := make([]byte, 4) n, err := str.Read(b)