From ba1bcf6e0c6b75ba166d1eb1a9a961112e0be5d2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 4 Sep 2019 16:51:58 +0700 Subject: [PATCH] use the STREAM frame buffer for receiving data --- crypto_stream.go | 4 +- frame_sorter.go | 59 +++-- frame_sorter_test.go | 543 +++++++++++++++++++++++++++---------------- receive_stream.go | 9 +- 4 files changed, 394 insertions(+), 221 deletions(-) diff --git a/crypto_stream.go b/crypto_stream.go index 095feb1d..7f43a2a2 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -80,11 +80,11 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { return nil } s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset) - if err := s.queue.Push(f.Data, f.Offset); err != nil { + if err := s.queue.Push(f.Data, f.Offset, nil); err != nil { return err } for { - _, data := s.queue.Pop() + _, data, _ := s.queue.Pop() if data == nil { return nil } diff --git a/frame_sorter.go b/frame_sorter.go index fa5ee884..dee0a445 100644 --- a/frame_sorter.go +++ b/frame_sorter.go @@ -7,8 +7,13 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) +type frameSorterEntry struct { + Data []byte + DoneCb func() +} + type frameSorter struct { - queue map[protocol.ByteCount][]byte + queue map[protocol.ByteCount]frameSorterEntry readPos protocol.ByteCount gaps *utils.ByteIntervalList } @@ -18,30 +23,38 @@ var errDuplicateStreamData = errors.New("Duplicate Stream Data") func newFrameSorter() *frameSorter { s := frameSorter{ gaps: utils.NewByteIntervalList(), - queue: make(map[protocol.ByteCount][]byte), + queue: make(map[protocol.ByteCount]frameSorterEntry), } s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount}) return &s } -func (s *frameSorter) Push(data []byte, offset protocol.ByteCount) error { - err := s.push(data, offset) +func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error { + err := s.push(data, offset, doneCb) if err == errDuplicateStreamData { + if doneCb != nil { + doneCb() + } return nil } return err } -func (s *frameSorter) push(data []byte, offset protocol.ByteCount) error { +func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error { if len(data) == 0 { - return nil + return errDuplicateStreamData } - if oldData, ok := s.queue[offset]; ok { - if len(data) <= len(oldData) { + if oldEntry, ok := s.queue[offset]; ok { + if len(data) <= len(oldEntry.Data) { return errDuplicateStreamData } - s.queue[offset] = data + // The data we currently have is shorter than the new data. + // Replace it. + if oldEntry.DoneCb != nil { + oldEntry.DoneCb() + } + s.queue[offset] = frameSorterEntry{Data: data, DoneCb: doneCb} } start := offset @@ -82,11 +95,17 @@ func (s *frameSorter) push(data []byte, offset protocol.ByteCount) error { if endGap != gap { s.gaps.Remove(endGap) } - if end <= nextEndGap.Value.Start { + if end < nextEndGap.Value.Start { break } // delete queued frames completely covered by the current frame - delete(s.queue, endGap.Value.End) + end := endGap.Value.End + if end != offset { + if cb := s.queue[end].DoneCb; cb != nil { + cb() + } + delete(s.queue, end) + } endGap = nextEndGap } @@ -130,25 +149,29 @@ func (s *frameSorter) push(data []byte, offset protocol.ByteCount) error { return errors.New("Too many gaps in received data") } - if wasCut { + if wasCut && len(data) < protocol.MinStreamFrameBufferSize { newData := make([]byte, len(data)) copy(newData, data) data = newData + if doneCb != nil { + doneCb() + doneCb = nil + } } - s.queue[offset] = data + s.queue[offset] = frameSorterEntry{Data: data, DoneCb: doneCb} return nil } -func (s *frameSorter) Pop() (protocol.ByteCount, []byte) { - data, ok := s.queue[s.readPos] +func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) { + entry, ok := s.queue[s.readPos] if !ok { - return s.readPos, nil + return s.readPos, nil, nil } delete(s.queue, s.readPos) offset := s.readPos - s.readPos += protocol.ByteCount(len(data)) - return offset, data + s.readPos += protocol.ByteCount(len(entry.Data)) + return offset, entry.Data, entry.DoneCb } // HasMoreData says if there is any more data queued at *any* offset. diff --git a/frame_sorter_test.go b/frame_sorter_test.go index 2dfcf770..4a2aa1bc 100644 --- a/frame_sorter_test.go +++ b/frame_sorter_test.go @@ -21,58 +21,81 @@ var _ = Describe("frame sorter", func() { } } + getCallback := func() (func(), *bool) { + var called bool + return func() { called = true }, &called + } + + checkCallback := func(cb func(), called *bool) { + ExpectWithOffset(1, cb).ToNot(BeNil()) + ExpectWithOffset(1, *called).To(BeFalse()) + cb() + ExpectWithOffset(1, *called).To(BeTrue()) + } + BeforeEach(func() { s = newFrameSorter() + _ = checkGaps }) It("returns nil when empty", func() { - _, data := s.Pop() + _, data, doneCb := s.Pop() Expect(data).To(BeNil()) + Expect(doneCb).To(BeNil()) }) Context("Push", func() { It("inserts and pops a single frame", func() { - Expect(s.Push([]byte("foobar"), 0)).To(Succeed()) - offset, data := s.Pop() + cb, called := getCallback() + Expect(s.Push([]byte("foobar"), 0, cb)).To(Succeed()) + offset, data, doneCb := s.Pop() Expect(offset).To(BeZero()) Expect(data).To(Equal([]byte("foobar"))) - offset, data = s.Pop() + checkCallback(doneCb, called) + offset, data, doneCb = s.Pop() Expect(offset).To(Equal(protocol.ByteCount(6))) Expect(data).To(BeNil()) + Expect(doneCb).To(BeNil()) }) It("inserts and pops two consecutive frame", func() { - Expect(s.Push([]byte("foo"), 0)).To(Succeed()) - Expect(s.Push([]byte("bar"), 3)).To(Succeed()) - offset, data := s.Pop() + cb1, called1 := getCallback() + cb2, called2 := getCallback() + Expect(s.Push([]byte("bar"), 3, cb2)).To(Succeed()) + Expect(s.Push([]byte("foo"), 0, cb1)).To(Succeed()) + offset, data, doneCb := s.Pop() Expect(offset).To(BeZero()) Expect(data).To(Equal([]byte("foo"))) - offset, data = s.Pop() + checkCallback(doneCb, called1) + offset, data, doneCb = s.Pop() Expect(offset).To(Equal(protocol.ByteCount(3))) Expect(data).To(Equal([]byte("bar"))) - offset, data = s.Pop() + checkCallback(doneCb, called2) + offset, data, doneCb = s.Pop() Expect(offset).To(Equal(protocol.ByteCount(6))) Expect(data).To(BeNil()) + Expect(doneCb).To(BeNil()) }) It("ignores empty frames", func() { - Expect(s.Push(nil, 0)).To(Succeed()) - _, data := s.Pop() + Expect(s.Push(nil, 0, nil)).To(Succeed()) + _, data, doneCb := s.Pop() Expect(data).To(BeNil()) + Expect(doneCb).To(BeNil()) }) It("says if has more data", func() { Expect(s.HasMoreData()).To(BeFalse()) - Expect(s.Push([]byte("foo"), 0)).To(Succeed()) + Expect(s.Push([]byte("foo"), 0, nil)).To(Succeed()) Expect(s.HasMoreData()).To(BeTrue()) - _, data := s.Pop() + _, data, _ := s.Pop() Expect(data).To(Equal([]byte("foo"))) Expect(s.HasMoreData()).To(BeFalse()) }) Context("Gap handling", func() { It("finds the first gap", func() { - Expect(s.Push([]byte("foobar"), 10)).To(Succeed()) + Expect(s.Push([]byte("foobar"), 10, nil)).To(Succeed()) checkGaps([]utils.ByteInterval{ {Start: 0, End: 10}, {Start: 16, End: protocol.MaxByteCount}, @@ -80,15 +103,15 @@ var _ = Describe("frame sorter", func() { }) It("correctly sets the first gap for a frame with offset 0", func() { - Expect(s.Push([]byte("foobar"), 0)).To(Succeed()) + Expect(s.Push([]byte("foobar"), 0, nil)).To(Succeed()) checkGaps([]utils.ByteInterval{ {Start: 6, End: protocol.MaxByteCount}, }) }) It("finds the two gaps", func() { - Expect(s.Push([]byte("foobar"), 10)).To(Succeed()) - Expect(s.Push([]byte("foobar"), 20)).To(Succeed()) + Expect(s.Push([]byte("foobar"), 10, nil)).To(Succeed()) + Expect(s.Push([]byte("foobar"), 20, nil)).To(Succeed()) checkGaps([]utils.ByteInterval{ {Start: 0, End: 10}, {Start: 16, End: 20}, @@ -97,8 +120,8 @@ var _ = Describe("frame sorter", func() { }) It("finds the two gaps in reverse order", func() { - Expect(s.Push([]byte("foobar"), 20)).To(Succeed()) - Expect(s.Push([]byte("foobar"), 10)).To(Succeed()) + Expect(s.Push([]byte("foobar"), 20, nil)).To(Succeed()) + Expect(s.Push([]byte("foobar"), 10, nil)).To(Succeed()) checkGaps([]utils.ByteInterval{ {Start: 0, End: 10}, {Start: 16, End: 20}, @@ -107,8 +130,8 @@ var _ = Describe("frame sorter", func() { }) It("shrinks a gap when it is partially filled", func() { - Expect(s.Push([]byte("test"), 10)).To(Succeed()) - Expect(s.Push([]byte("foobar"), 4)).To(Succeed()) + Expect(s.Push([]byte("test"), 10, nil)).To(Succeed()) + Expect(s.Push([]byte("foobar"), 4, nil)).To(Succeed()) checkGaps([]utils.ByteInterval{ {Start: 0, End: 4}, {Start: 14, End: protocol.MaxByteCount}, @@ -116,17 +139,17 @@ var _ = Describe("frame sorter", func() { }) It("deletes a gap at the beginning, when it is filled", func() { - Expect(s.Push([]byte("test"), 6)).To(Succeed()) - Expect(s.Push([]byte("foobar"), 0)).To(Succeed()) + Expect(s.Push([]byte("test"), 6, nil)).To(Succeed()) + Expect(s.Push([]byte("foobar"), 0, nil)).To(Succeed()) checkGaps([]utils.ByteInterval{ {Start: 10, End: protocol.MaxByteCount}, }) }) It("deletes a gap in the middle, when it is filled", func() { - Expect(s.Push([]byte("test"), 0)).To(Succeed()) - Expect(s.Push([]byte("test2"), 10)).To(Succeed()) - Expect(s.Push([]byte("foobar"), 4)).To(Succeed()) + Expect(s.Push([]byte("test"), 0, nil)).To(Succeed()) + Expect(s.Push([]byte("test2"), 10, nil)).To(Succeed()) + Expect(s.Push([]byte("foobar"), 4, nil)).To(Succeed()) Expect(s.queue).To(HaveLen(3)) checkGaps([]utils.ByteInterval{ {Start: 15, End: protocol.MaxByteCount}, @@ -134,8 +157,8 @@ var _ = Describe("frame sorter", func() { }) It("splits a gap into two", func() { - Expect(s.Push([]byte("test"), 100)).To(Succeed()) - Expect(s.Push([]byte("foobar"), 50)).To(Succeed()) + Expect(s.Push([]byte("test"), 100, nil)).To(Succeed()) + Expect(s.Push([]byte("foobar"), 50, nil)).To(Succeed()) Expect(s.queue).To(HaveLen(2)) checkGaps([]utils.ByteInterval{ {Start: 0, End: 50}, @@ -145,271 +168,393 @@ var _ = Describe("frame sorter", func() { }) Context("Overlapping Stream Data detection", func() { - // create gaps: 0-5, 10-15, 20-25, 30-inf + var initialCb1, initialCb2, initialCb3 func() + var initialCb1Called, initialCb2Called, initialCb3Called *bool + + // create gaps: 0-500, 1000-1500, 2000-2500, 3000-inf BeforeEach(func() { - Expect(s.Push([]byte("12345"), 5)).To(Succeed()) - Expect(s.Push([]byte("12345"), 15)).To(Succeed()) - Expect(s.Push([]byte("12345"), 25)).To(Succeed()) + // make sure frames are not cut when we overlap a little bit + Expect(protocol.MinStreamFrameBufferSize).To(BeNumerically("<", 500/2)) + initialCb1, initialCb1Called = getCallback() + initialCb2, initialCb2Called = getCallback() + initialCb3, initialCb3Called = getCallback() + Expect(s.Push(bytes.Repeat([]byte{1}, 500), 500, initialCb1)).To(Succeed()) + Expect(s.Push(bytes.Repeat([]byte{2}, 500), 1500, initialCb2)).To(Succeed()) + Expect(s.Push(bytes.Repeat([]byte{3}, 500), 2500, initialCb3)).To(Succeed()) checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 10, End: 15}, - {Start: 20, End: 25}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 0, End: 500}, + {Start: 1000, End: 1500}, + {Start: 2000, End: 2500}, + {Start: 3000, End: protocol.MaxByteCount}, }) }) It("cuts a frame with offset 0 that overlaps at the end", func() { - Expect(s.Push([]byte("foobar"), 0)).To(Succeed()) + cb, called := getCallback() + // 0 - 505 + Expect(s.Push(bytes.Repeat([]byte{9}, 505), 0, cb)).To(Succeed()) Expect(s.queue).To(HaveKey(protocol.ByteCount(0))) - Expect(s.queue[0]).To(Equal([]byte("fooba"))) - Expect(s.queue[0]).To(HaveCap(5)) + Expect(s.queue[0].Data).To(Equal(bytes.Repeat([]byte{9}, 500))) // 0 to 500 checkGaps([]utils.ByteInterval{ - {Start: 10, End: 15}, - {Start: 20, End: 25}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 1000, End: 1500}, + {Start: 2000, End: 2500}, + {Start: 3000, End: protocol.MaxByteCount}, }) + checkCallback(cb, called) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) + checkCallback(initialCb3, initialCb3Called) }) It("cuts a frame that overlaps at the end", func() { - // 4 to 7 - Expect(s.Push([]byte("foo"), 4)).To(Succeed()) - Expect(s.queue).To(HaveKey(protocol.ByteCount(4))) - Expect(s.queue[4]).To(Equal([]byte("f"))) - Expect(s.queue[4]).To(HaveCap(1)) + cb, called := getCallback() + // 100 to 600 + Expect(s.Push(bytes.Repeat([]byte{9}, 500), 100, cb)).To(Succeed()) + Expect(s.queue).To(HaveKey(protocol.ByteCount(100))) + Expect(s.queue[100].Data).To(Equal(bytes.Repeat([]byte{9}, 400))) // 100 to 500 checkGaps([]utils.ByteInterval{ - {Start: 0, End: 4}, - {Start: 10, End: 15}, - {Start: 20, End: 25}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 0, End: 100}, + {Start: 1000, End: 1500}, + {Start: 2000, End: 2500}, + {Start: 3000, End: protocol.MaxByteCount}, }) + checkCallback(cb, called) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) + checkCallback(initialCb3, initialCb3Called) }) It("cuts a frame that completely fills a gap, but overlaps at the end", func() { - // 10 to 16 - Expect(s.Push([]byte("foobar"), 10)).To(Succeed()) - Expect(s.queue).To(HaveKey(protocol.ByteCount(10))) - Expect(s.queue[10]).To(Equal([]byte("fooba"))) - Expect(s.queue[10]).To(HaveCap(5)) + // 1000 to 1600 + cb, called := getCallback() + Expect(s.Push(bytes.Repeat([]byte{9}, 600), 1000, cb)).To(Succeed()) + Expect(s.queue).To(HaveKey(protocol.ByteCount(1000))) + Expect(s.queue[1000].Data).To(Equal(bytes.Repeat([]byte{9}, 500))) // 1000 to 15000 checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 20, End: 25}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 0, End: 500}, + {Start: 2000, End: 2500}, + {Start: 3000, End: protocol.MaxByteCount}, }) + checkCallback(cb, called) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) + checkCallback(initialCb3, initialCb3Called) }) It("cuts a frame that overlaps at the beginning", func() { - // 8 to 14 - Expect(s.Push([]byte("foobar"), 8)).To(Succeed()) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(8))) - Expect(s.queue).To(HaveKey(protocol.ByteCount(10))) - Expect(s.queue[10]).To(Equal([]byte("obar"))) - Expect(s.queue[10]).To(HaveCap(4)) + cb, called := getCallback() + // 900 to 1400 + Expect(s.Push(bytes.Repeat([]byte{9}, 500), 900, cb)).To(Succeed()) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(900))) + Expect(s.queue).To(HaveKey(protocol.ByteCount(1000))) + Expect(s.queue[1000].Data).To(Equal(bytes.Repeat([]byte{9}, 400))) // 1000 to 1400 checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 14, End: 15}, - {Start: 20, End: 25}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 0, End: 500}, + {Start: 1400, End: 1500}, + {Start: 2000, End: 2500}, + {Start: 3000, End: protocol.MaxByteCount}, }) + checkCallback(cb, called) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) + checkCallback(initialCb3, initialCb3Called) }) It("processes a frame that overlaps at the beginning and at the end, starting in a gap", func() { - // 2 to 12 - Expect(s.Push([]byte("1234567890"), 2)).To(Succeed()) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5))) - Expect(s.queue).To(HaveKey(protocol.ByteCount(2))) - Expect(s.queue[2]).To(Equal([]byte("1234567890"))) + cb, called := getCallback() + // 300 to 1100 + Expect(s.Push(bytes.Repeat([]byte{9}, 800), 300, cb)).To(Succeed()) + Expect(s.queue).To(HaveKey(protocol.ByteCount(300))) + Expect(s.queue[300].Data).To(Equal(bytes.Repeat([]byte{9}, 800))) // 300 to 1100 checkGaps([]utils.ByteInterval{ - {Start: 0, End: 2}, - {Start: 12, End: 15}, - {Start: 20, End: 25}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 0, End: 300}, + {Start: 1100, End: 1500}, + {Start: 2000, End: 2500}, + {Start: 3000, End: protocol.MaxByteCount}, }) + checkCallback(cb, called) + // initial1 spanned from 500 - 1000, and should have been deleted + Expect(*initialCb1Called).To(BeTrue()) + checkCallback(initialCb2, initialCb2Called) + checkCallback(initialCb3, initialCb3Called) }) It("processes a frame that overlaps at the beginning and at the end, starting in a gap, ending in data", func() { - // 2 to 17 - Expect(s.Push([]byte("123456789012345"), 2)).To(Succeed()) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5))) - Expect(s.queue).To(HaveKey(protocol.ByteCount(2))) - Expect(s.queue).To(HaveKey(protocol.ByteCount(15))) - Expect(s.queue[2]).To(Equal([]byte("1234567890123"))) - Expect(s.queue[2]).To(HaveCap(13)) + cb, called := getCallback() + // 400 to 1600 + Expect(s.Push(bytes.Repeat([]byte{9}, 1200), 400, cb)).To(Succeed()) + Expect(s.queue).To(HaveKey(protocol.ByteCount(400))) + Expect(s.queue).To(HaveKey(protocol.ByteCount(1500))) + Expect(s.queue[400].Data).To(Equal(bytes.Repeat([]byte{9}, 1100))) // 400 to 1500 checkGaps([]utils.ByteInterval{ - {Start: 0, End: 2}, - {Start: 20, End: 25}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 0, End: 400}, + {Start: 2000, End: 2500}, + {Start: 3000, End: protocol.MaxByteCount}, }) + checkCallback(cb, called) + // initial1 spans from 500 - 1000, and should have been deleted + Expect(*initialCb1Called).To(BeTrue()) + checkCallback(initialCb2, initialCb2Called) + checkCallback(initialCb3, initialCb3Called) }) It("processes a frame that overlaps at the beginning and at the end, starting in a gap, ending in data", func() { - // 5 to 22 - Expect(s.Push([]byte("12345678901234567"), 5)).To(Succeed()) - Expect(s.queue).To(HaveKey(protocol.ByteCount(5))) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(10))) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15))) + cb, called := getCallback() + // 500 to 2100 + Expect(s.Push(bytes.Repeat([]byte{9}, 1600), 500, cb)).To(Succeed()) + Expect(s.queue).To(HaveKey(protocol.ByteCount(500))) + Expect(s.queue[500].Data).To(Equal(bytes.Repeat([]byte{9}, 1600))) // 500 to 2100 + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1000))) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1500))) checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 22, End: 25}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 0, End: 500}, + {Start: 2100, End: 2500}, + {Start: 3000, End: protocol.MaxByteCount}, }) + checkCallback(cb, called) + // initial1 spans from 500 - 1000, and should have been deleted + Expect(*initialCb1Called).To(BeTrue()) + // initial2 spans from 1500 - 2000, and should have been deleted + Expect(*initialCb2Called).To(BeTrue()) + checkCallback(initialCb3, initialCb3Called) }) - It("processes a frame that closes multiple gaps", func() { - // 2 to 27 - Expect(s.Push(bytes.Repeat([]byte{'e'}, 25), 2)).To(Succeed()) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5))) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15))) - Expect(s.queue).To(HaveKey(protocol.ByteCount(25))) - Expect(s.queue).To(HaveKey(protocol.ByteCount(2))) - Expect(s.queue[2]).To(Equal(bytes.Repeat([]byte{'e'}, 23))) - Expect(s.queue[2]).To(HaveCap(23)) + It("processes a frame that closes multiple gaps, beginning in a gap", func() { + cb, called := getCallback() + // 400 to 3100 + Expect(s.Push(bytes.Repeat([]byte{9}, 2700), 400, cb)).To(Succeed()) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(500))) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1500))) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(2500))) + Expect(s.queue).To(HaveKey(protocol.ByteCount(400))) + Expect(s.queue[400].Data).To(Equal(bytes.Repeat([]byte{9}, 2700))) // 400 to 3100 checkGaps([]utils.ByteInterval{ - {Start: 0, End: 2}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 0, End: 400}, + {Start: 3100, End: protocol.MaxByteCount}, }) + checkCallback(cb, called) + // initial1 spans from 500 - 1000, and should have been deleted + Expect(*initialCb1Called).To(BeTrue()) + // initial2 spans from 1500 - 2000, and should have been deleted + Expect(*initialCb2Called).To(BeTrue()) + // initial3 spans from 2500 - 3100, and should have been deleted + Expect(*initialCb3Called).To(BeTrue()) }) - It("processes a frame that closes multiple gaps", func() { - // 5 to 27 - Expect(s.Push(bytes.Repeat([]byte{'d'}, 22), 5)).To(Succeed()) - Expect(s.queue).To(HaveKey(protocol.ByteCount(5))) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(10))) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15))) - Expect(s.queue).To(HaveKey(protocol.ByteCount(25))) - Expect(s.queue[5]).To(Equal(bytes.Repeat([]byte{'d'}, 20))) - Expect(s.queue[5]).To(HaveCap(20)) + It("processes a frame that closes multiple gaps, beginning at the end of a gap", func() { + cb, called := getCallback() + // 500 to 2600 + Expect(s.Push(bytes.Repeat([]byte{9}, 2100), 500, cb)).To(Succeed()) + Expect(s.queue).To(HaveKey(protocol.ByteCount(500))) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1000))) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1500))) + Expect(s.queue).To(HaveKey(protocol.ByteCount(2500))) + Expect(s.queue[500].Data).To(Equal(bytes.Repeat([]byte{9}, 2000))) // 500 to 2500 checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 0, End: 500}, + {Start: 3000, End: protocol.MaxByteCount}, }) + checkCallback(cb, called) + // initial1 spans from 500 - 1000, and should have been deleted + Expect(*initialCb1Called).To(BeTrue()) + // initial2 spans from 1500 - 2000, and should have been deleted + Expect(*initialCb2Called).To(BeTrue()) + checkCallback(initialCb3, initialCb3Called) }) It("processes a frame that covers multiple gaps and ends at the end of a gap", func() { - data := bytes.Repeat([]byte{'e'}, 14) - // 1 to 15 - Expect(s.Push(data, 1)).To(Succeed()) - Expect(s.queue).To(HaveKey(protocol.ByteCount(1))) - Expect(s.queue).To(HaveKey(protocol.ByteCount(15))) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5))) - Expect(s.queue[1]).To(Equal(data)) + cb, called := getCallback() + // 100 to 1500 + Expect(s.Push(bytes.Repeat([]byte{9}, 1400), 100, cb)).To(Succeed()) + Expect(s.queue).To(HaveKey(protocol.ByteCount(100))) + Expect(s.queue).To(HaveKey(protocol.ByteCount(1500))) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(500))) + Expect(s.queue[100].Data).To(Equal(bytes.Repeat([]byte{9}, 1400))) // 100 to 1500 checkGaps([]utils.ByteInterval{ - {Start: 0, End: 1}, - {Start: 20, End: 25}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 0, End: 100}, + {Start: 2000, End: 2500}, + {Start: 3000, End: protocol.MaxByteCount}, }) + checkCallback(cb, called) + // initial1 spans from 500 - 1000, and should have been deleted + Expect(*initialCb1Called).To(BeTrue()) + checkCallback(initialCb2, initialCb2Called) + checkCallback(initialCb3, initialCb3Called) }) It("processes a frame that closes all gaps (except for the last one)", func() { - data := bytes.Repeat([]byte{'f'}, 32) - // 0 to 32 - Expect(s.Push(data, 0)).To(Succeed()) + cb, called := getCallback() + // 0 to 3100 + Expect(s.Push(bytes.Repeat([]byte{9}, 3100), 0, cb)).To(Succeed()) Expect(s.queue).To(HaveLen(1)) Expect(s.queue).To(HaveKey(protocol.ByteCount(0))) - Expect(s.queue[0]).To(Equal(data)) + Expect(s.queue[0].Data).To(Equal(bytes.Repeat([]byte{9}, 3100))) // 0 to 3100 checkGaps([]utils.ByteInterval{ - {Start: 32, End: protocol.MaxByteCount}, - }) - }) - - It("cuts a frame that overlaps at the beginning and at the end, starting in data already received", func() { - // 8 to 17 - Expect(s.Push([]byte("123456789"), 8)).To(Succeed()) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(8))) - Expect(s.queue).To(HaveKey(protocol.ByteCount(10))) - Expect(s.queue[10]).To(Equal([]byte("34567"))) - Expect(s.queue[10]).To(HaveCap(5)) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 20, End: 25}, - {Start: 30, End: protocol.MaxByteCount}, - }) - }) - - It("cuts a frame that completely covers two gaps", func() { - // 10 to 20 - Expect(s.Push([]byte("1234567890"), 10)).To(Succeed()) - Expect(s.queue).To(HaveKey(protocol.ByteCount(10))) - Expect(s.queue[10]).To(Equal([]byte("12345"))) - Expect(s.queue[10]).To(HaveCap(5)) - checkGaps([]utils.ByteInterval{ - {Start: 0, End: 5}, - {Start: 20, End: 25}, - {Start: 30, End: protocol.MaxByteCount}, + {Start: 3100, End: protocol.MaxByteCount}, }) + checkCallback(cb, called) + // initial1 spans from 500 - 1000, and should have been deleted + Expect(*initialCb1Called).To(BeTrue()) + Expect(*initialCb2Called).To(BeTrue()) + Expect(*initialCb3Called).To(BeTrue()) }) }) Context("duplicate data", func() { - expectedGaps := []utils.ByteInterval{ - {Start: 5, End: 10}, - {Start: 15, End: protocol.MaxByteCount}, - } + var initialCb1, initialCb2 func() + var initialCb1Called, initialCb2Called *bool BeforeEach(func() { - // create gaps: 5-10, 15-inf - Expect(s.Push([]byte("12345"), 0)).To(Succeed()) - Expect(s.Push([]byte("12345"), 10)).To(Succeed()) - checkGaps(expectedGaps) + // make sure frames are not cut when we overlap a little bit + Expect(protocol.MinStreamFrameBufferSize).To(BeNumerically("<", 500/2)) + initialCb1, initialCb1Called = getCallback() + initialCb2, initialCb2Called = getCallback() + // create gaps: 500 - 1000, 1500 - inf + Expect(s.Push(bytes.Repeat([]byte{1}, 500), 0, initialCb1)).To(Succeed()) + Expect(s.Push(bytes.Repeat([]byte{2}, 500), 1000, initialCb1)).To(Succeed()) + checkGaps([]utils.ByteInterval{ + {Start: 500, End: 1000}, + {Start: 1500, End: protocol.MaxByteCount}, + }) }) AfterEach(func() { // check that the gaps were not modified - checkGaps(expectedGaps) + checkGaps([]utils.ByteInterval{ + {Start: 500, End: 1000}, + {Start: 1500, End: protocol.MaxByteCount}, + }) }) It("does not modify data when receiving a duplicate", func() { - err := s.push([]byte("fffff"), 0) - Expect(err).To(MatchError(errDuplicateStreamData)) - Expect(s.queue[0]).ToNot(Equal([]byte("fffff"))) + cb, called := getCallback() + // 0 to 500 + Expect(s.Push(bytes.Repeat([]byte{9}, 500), 0, cb)).To(Succeed()) + Expect(s.queue).To(HaveKey(protocol.ByteCount(0))) + Expect(s.queue[0].Data).ToNot(Equal(bytes.Repeat([]byte{9}, 500))) + Expect(*called).To(BeTrue()) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) }) It("detects a duplicate frame that is smaller than the original, starting at the beginning", func() { - // 10 to 12 - err := s.push([]byte("12"), 10) - Expect(err).To(MatchError(errDuplicateStreamData)) - Expect(s.queue[10]).To(HaveLen(5)) + cb, called := getCallback() + // 1000 to 1200 + Expect(s.Push(bytes.Repeat([]byte{9}, 200), 1000, cb)).To(Succeed()) + Expect(s.queue).To(HaveKey(protocol.ByteCount(1000))) + Expect(s.queue[1000].Data).ToNot(Equal(bytes.Repeat([]byte{9}, 200))) + Expect(*called).To(BeTrue()) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) }) It("detects a duplicate frame that is smaller than the original, somewhere in the middle", func() { - // 1 to 4 - err := s.push([]byte("123"), 1) - Expect(err).To(MatchError(errDuplicateStreamData)) - Expect(s.queue[0]).To(HaveLen(5)) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1))) + cb, called := getCallback() + // 100 to 400 + Expect(s.Push(bytes.Repeat([]byte{9}, 300), 100, cb)).To(Succeed()) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(100))) + Expect(s.queue[0].Data).To(Equal(bytes.Repeat([]byte{1}, 500))) + Expect(*called).To(BeTrue()) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) }) It("detects a duplicate frame that is smaller than the original, somewhere in the middle in the last block", func() { - // 11 to 14 - err := s.push([]byte("123"), 11) - Expect(err).To(MatchError(errDuplicateStreamData)) - Expect(s.queue[10]).To(HaveLen(5)) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(11))) + cb, called := getCallback() + // 1100 to 1400 + Expect(s.Push(bytes.Repeat([]byte{9}, 300), 1100, cb)).To(Succeed()) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1100))) + Expect(s.queue[1000].Data).To(Equal(bytes.Repeat([]byte{2}, 500))) + Expect(*called).To(BeTrue()) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) }) - It("detects a duplicate frame that is smaller than the original, with aligned end in the last block", func() { - // 11 to 15 - err := s.push([]byte("1234"), 1) - Expect(err).To(MatchError(errDuplicateStreamData)) - Expect(s.queue[10]).To(HaveLen(5)) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(11))) + It("detects a duplicate frame that is smaller than the original, somewhere in the middle in the last block", func() { + cb, called := getCallback() + // 1100 to 1500 + Expect(s.Push(bytes.Repeat([]byte{9}, 400), 1100, cb)).To(Succeed()) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1100))) + Expect(s.queue[1000].Data).To(Equal(bytes.Repeat([]byte{2}, 500))) + Expect(*called).To(BeTrue()) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) }) It("detects a duplicate frame that is smaller than the original, with aligned end", func() { - // 3 to 5 - err := s.push([]byte("12"), 3) - Expect(err).To(MatchError(errDuplicateStreamData)) - Expect(s.queue[0]).To(HaveLen(5)) - Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(3))) + cb, called := getCallback() + // 300 to 500 + Expect(s.Push(bytes.Repeat([]byte{9}, 200), 300, cb)).To(Succeed()) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(300))) + Expect(s.queue[0].Data).To(Equal(bytes.Repeat([]byte{1}, 500))) + Expect(*called).To(BeTrue()) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) + }) + }) + + Context("cutting short frames", func() { + var initialCb1, initialCb2 func() + var initialCb1Called, initialCb2Called *bool + + // create gaps: 0-5, 10-15, 2000-inf + BeforeEach(func() { + // make sure frames are not cut when we overlap a little bit + Expect(protocol.MinStreamFrameBufferSize).To(BeNumerically(">", 10)) + initialCb1, initialCb1Called = getCallback() + initialCb2, initialCb2Called = getCallback() + Expect(s.Push(bytes.Repeat([]byte{1}, 5), 5, initialCb1)).To(Succeed()) + Expect(s.Push(bytes.Repeat([]byte{2}, 5), 15, initialCb2)).To(Succeed()) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 5}, + {Start: 10, End: 15}, + {Start: 20, End: protocol.MaxByteCount}, + }) + }) + + It("cuts a frame that overlaps with received data at the beginning", func() { + cb, called := getCallback() + // 9 to 12 + Expect(s.Push(bytes.Repeat([]byte{9}, 3), 9, cb)).To(Succeed()) + Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(9))) + Expect(s.queue).To(HaveKey(protocol.ByteCount(10))) + Expect(s.queue[10].Data).To(Equal(bytes.Repeat([]byte{9}, 2))) // 10 to 12 + Expect(s.queue[10].Data).To(HaveCap(2)) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 5}, + {Start: 12, End: 15}, + {Start: 20, End: protocol.MaxByteCount}, + }) + Expect(*called).To(BeTrue()) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) + }) + + It("cuts a frame that overlaps with received data at the end", func() { + cb, called := getCallback() + // 12 to 19 + Expect(s.Push(bytes.Repeat([]byte{9}, 7), 12, cb)).To(Succeed()) + Expect(s.queue).To(HaveKey(protocol.ByteCount(12))) + Expect(s.queue[12].Data).To(Equal(bytes.Repeat([]byte{9}, 3))) // 12 to 15 + Expect(s.queue[12].Data).To(HaveCap(3)) + checkGaps([]utils.ByteInterval{ + {Start: 0, End: 5}, + {Start: 10, End: 12}, + {Start: 20, End: protocol.MaxByteCount}, + }) + Expect(*called).To(BeTrue()) + checkCallback(initialCb1, initialCb1Called) + checkCallback(initialCb2, initialCb2Called) }) }) Context("DoS protection", func() { It("errors when too many gaps are created", func() { for i := 0; i < protocol.MaxStreamFrameSorterGaps; i++ { - Expect(s.Push([]byte("foobar"), protocol.ByteCount(i*7))).To(Succeed()) + Expect(s.Push([]byte("foobar"), protocol.ByteCount(i*7), nil)).To(Succeed()) } Expect(s.gaps.Len()).To(Equal(protocol.MaxStreamFrameSorterGaps)) - err := s.Push([]byte("foobar"), protocol.ByteCount(protocol.MaxStreamFrameSorterGaps*7)+100) + err := s.Push([]byte("foobar"), protocol.ByteCount(protocol.MaxStreamFrameSorterGaps*7)+100, nil) Expect(err).To(MatchError("Too many gaps in received data")) }) }) diff --git a/receive_stream.go b/receive_stream.go index de76335e..04338f65 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -33,6 +33,7 @@ type receiveStream struct { finalOffset protocol.ByteCount currentFrame []byte + currentFrameDone func() currentFrameIsLast bool // is the currentFrame the last frame on this stream readPosInFrame int @@ -185,7 +186,11 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err func (s *receiveStream) dequeueNextFrame() { var offset protocol.ByteCount - offset, s.currentFrame = s.frameQueue.Pop() + // We're done with the last frame. Release the buffer. + if s.currentFrameDone != nil { + s.currentFrameDone() + } + offset, s.currentFrame, s.currentFrameDone = s.frameQueue.Pop() s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset s.readPosInFrame = 0 } @@ -237,7 +242,7 @@ func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* if s.canceledRead { return frame.FinBit, nil } - if err := s.frameQueue.Push(frame.Data, frame.Offset); err != nil { + if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil { return false, err } s.signalRead()