From 5d02033f0fa488ca1dfc6a17fe80c1ae01a7a6d8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 15 Jun 2016 19:39:50 +0700 Subject: [PATCH] use FlowControlManager in Stream for Reading data --- flowcontrol/flow_controller.go | 7 -- flowcontrol/interface.go | 1 - session.go | 13 ++- session_test.go | 9 ++ stream.go | 37 ++++--- stream_test.go | 171 +++++++++++++-------------------- 6 files changed, 110 insertions(+), 128 deletions(-) diff --git a/flowcontrol/flow_controller.go b/flowcontrol/flow_controller.go index 8c3981b0..d5de2b4f 100644 --- a/flowcontrol/flow_controller.go +++ b/flowcontrol/flow_controller.go @@ -161,10 +161,3 @@ func (c *flowController) CheckFlowControlViolation() bool { } return false } - -func (c *flowController) GetHighestReceived() protocol.ByteCount { - c.mutex.RLock() - defer c.mutex.RUnlock() - - return c.highestReceived -} diff --git a/flowcontrol/interface.go b/flowcontrol/interface.go index 8aa662c4..2591e008 100644 --- a/flowcontrol/interface.go +++ b/flowcontrol/interface.go @@ -13,7 +13,6 @@ type FlowController interface { MaybeTriggerBlocked() bool MaybeTriggerWindowUpdate() (bool, protocol.ByteCount) CheckFlowControlViolation() bool - GetHighestReceived() protocol.ByteCount } // A FlowControlManager manages the flow control diff --git a/session.go b/session.go index c5cc3254..3be10b11 100644 --- a/session.go +++ b/session.go @@ -55,6 +55,8 @@ type Session struct { blockedManager *blockedManager streamFrameQueue *streamFrameQueue + flowControlManager flowcontrol.FlowControlManager + // TODO: remove flowController flowcontrol.FlowController // connection level flow controller unpacker *packetUnpacker @@ -99,6 +101,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol sentPacketHandler: ackhandler.NewSentPacketHandler(stopWaitingManager), receivedPacketHandler: ackhandler.NewReceivedPacketHandler(), stopWaitingManager: stopWaitingManager, + flowControlManager: flowcontrol.NewFlowControlManager(connectionParametersManager), flowController: flowcontrol.NewFlowController(0, connectionParametersManager), windowUpdateManager: newWindowUpdateManager(), blockedManager: newBlockedManager(), @@ -626,10 +629,18 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) { if _, ok := s.streams[id]; ok { return nil, fmt.Errorf("Session: stream with ID %d already exists", id) } - stream, err := newStream(s, s.connectionParametersManager, s.flowController, id) + stream, err := newStream(s, s.connectionParametersManager, s.flowController, s.flowControlManager, id) if err != nil { return nil, err } + + // TODO: find a better solution for determining which streams contribute to connection level flow control + if id == 1 || id == 3 { + s.flowControlManager.NewStream(id, false) + } else { + s.flowControlManager.NewStream(id, true) + } + atomic.AddUint32(&s.openStreamsCount, 1) s.streams[id] = stream return stream, nil diff --git a/session_test.go b/session_test.go index bffeefc6..5e7a1a8c 100644 --- a/session_test.go +++ b/session_test.go @@ -241,6 +241,15 @@ var _ = Describe("Session", func() { Expect(session.streams[5]).To(BeNil()) }) + It("informs the FlowControlManager about new streams", func() { + // since the stream doesn't yet exist, this will throw an error + err := session.flowControlManager.UpdateHighestReceived(5, 1000) + Expect(err).To(HaveOccurred()) + session.newStreamImpl(5) + err = session.flowControlManager.UpdateHighestReceived(5, 2000) + Expect(err).ToNot(HaveOccurred()) + }) + It("ignores streams that existed previously", func() { session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, diff --git a/stream.go b/stream.go index 7cd0cd42..17073b93 100644 --- a/stream.go +++ b/stream.go @@ -47,6 +47,8 @@ type stream struct { frameQueue *streamFrameSorter newFrameOrErrCond sync.Cond + flowControlManager flowcontrol.FlowControlManager + // TODO: remove those flowController flowcontrol.FlowController connectionFlowController flowcontrol.FlowController contributesToConnectionFlowControl bool @@ -55,10 +57,11 @@ type stream struct { } // newStream creates a new Stream -func newStream(session streamHandler, connectionParameterManager *handshake.ConnectionParametersManager, connectionFlowController flowcontrol.FlowController, StreamID protocol.StreamID) (*stream, error) { +func newStream(session streamHandler, connectionParameterManager *handshake.ConnectionParametersManager, connectionFlowController flowcontrol.FlowController, flowControlManager flowcontrol.FlowControlManager, StreamID protocol.StreamID) (*stream, error) { s := &stream{ session: session, streamID: StreamID, + flowControlManager: flowControlManager, connectionFlowController: connectionFlowController, contributesToConnectionFlowControl: true, flowController: flowcontrol.NewFlowController(StreamID, connectionParameterManager), @@ -131,11 +134,7 @@ func (s *stream) Read(p []byte) (int, error) { bytesRead += m s.readOffset += protocol.ByteCount(m) - s.flowController.AddBytesRead(protocol.ByteCount(m)) - if s.contributesToConnectionFlowControl { - s.connectionFlowController.AddBytesRead(protocol.ByteCount(m)) - } - + s.flowControlManager.AddBytesRead(s.streamID, protocol.ByteCount(m)) s.maybeTriggerWindowUpdate() if s.readPosInFrame >= int(frame.DataLen()) { @@ -242,20 +241,21 @@ func (s *stream) Close() error { // AddStreamFrame adds a new stream frame func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { maxOffset := frame.Offset + frame.DataLen() - increment := s.flowController.UpdateHighestReceived(maxOffset) - if s.contributesToConnectionFlowControl { - s.connectionFlowController.IncrementHighestReceived(increment) - } - if s.flowController.CheckFlowControlViolation() { + err := s.flowControlManager.UpdateHighestReceived(s.streamID, maxOffset) + + if err == flowcontrol.ErrStreamFlowControlViolation { return errFlowControlViolation } - if s.connectionFlowController.CheckFlowControlViolation() { + if err == flowcontrol.ErrConnectionFlowControlViolation { return errConnectionFlowControlViolation } + if err != nil { + return err + } s.mutex.Lock() defer s.mutex.Unlock() - err := s.frameQueue.Push(frame) + err = s.frameQueue.Push(frame) if err != nil && err != errDuplicateStreamData { return err } @@ -268,18 +268,23 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) { s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset}) } -func (s *stream) maybeTriggerWindowUpdate() { +func (s *stream) maybeTriggerWindowUpdate() error { // check for stream level window updates - doUpdate, byteOffset := s.flowController.MaybeTriggerWindowUpdate() + doUpdate, byteOffset, err := s.flowControlManager.MaybeTriggerStreamWindowUpdate(s.streamID) + if err != nil { + return err + } if doUpdate { s.session.updateReceiveFlowControlWindow(s.streamID, byteOffset) } // check for connection level window updates - doUpdate, byteOffset = s.connectionFlowController.MaybeTriggerWindowUpdate() + doUpdate, byteOffset = s.flowControlManager.MaybeTriggerConnectionWindowUpdate() if doUpdate { s.session.updateReceiveFlowControlWindow(0, byteOffset) } + + return nil } func (s *stream) maybeTriggerBlocked() { diff --git a/stream_test.go b/stream_test.go index 1c208d76..55a3675a 100644 --- a/stream_test.go +++ b/stream_test.go @@ -1,12 +1,9 @@ package quic import ( - "bytes" "errors" "io" - "reflect" "time" - "unsafe" "github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/frames" @@ -26,11 +23,6 @@ type mockStreamHandler struct { receiveFlowControlWindowCalledForStream protocol.StreamID } -func (m *mockStreamHandler) queueStreamFrame(f *frames.StreamFrame) error { - m.frames = append(m.frames, f) - return nil -} - func (m *mockStreamHandler) streamBlocked(streamID protocol.StreamID, byteOffset protocol.ByteCount) { m.receivedBlockedCalled = true m.receivedBlockedForStream = streamID @@ -42,6 +34,46 @@ func (m *mockStreamHandler) updateReceiveFlowControlWindow(streamID protocol.Str return nil } +func (m *mockStreamHandler) queueStreamFrame(f *frames.StreamFrame) error { + m.frames = append(m.frames, f) + return nil +} + +type mockFlowControlHandler struct { + bytesReadForStream protocol.StreamID + bytesRead protocol.ByteCount + + highestReceivedForStream protocol.StreamID + highestReceived protocol.ByteCount + + triggerStreamWindowUpdate bool + triggerConnectionWindowUpdate bool +} + +func (m *mockFlowControlHandler) NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) { + panic("not implemented") +} + +func (m *mockFlowControlHandler) MaybeTriggerStreamWindowUpdate(streamID protocol.StreamID) (bool, protocol.ByteCount, error) { + return m.triggerStreamWindowUpdate, 0x1337, nil +} + +func (m *mockFlowControlHandler) MaybeTriggerConnectionWindowUpdate() (bool, protocol.ByteCount) { + return m.triggerConnectionWindowUpdate, 0x1337 +} + +func (m *mockFlowControlHandler) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error { + m.bytesReadForStream = streamID + m.bytesRead = n + return nil +} + +func (m *mockFlowControlHandler) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { + m.highestReceivedForStream = streamID + m.highestReceived = byteOffset + return nil +} + var _ = Describe("Stream", func() { var ( str *stream @@ -53,7 +85,9 @@ var _ = Describe("Stream", func() { handler = &mockStreamHandler{} cpm := handshake.NewConnectionParamatersManager() flowController := flowcontrol.NewFlowController(streamID, cpm) - str, _ = newStream(handler, cpm, flowController, streamID) + flowControlManager := flowcontrol.NewFlowControlManager(cpm) + flowControlManager.NewStream(streamID, true) + str, _ = newStream(handler, cpm, flowController, flowControlManager, streamID) }) It("gets stream id", func() { @@ -231,44 +265,6 @@ var _ = Describe("Stream", func() { Expect(err).To(MatchError(errOverlappingStreamData)) }) - Context("flow control", func() { - It("consumes bytes in the flow control window", func() { - str.contributesToConnectionFlowControl = false - frame := frames.StreamFrame{ - Offset: 2, - Data: []byte("foobar"), - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - Expect(str.flowController.GetHighestReceived()).To(Equal(protocol.ByteCount(8))) - }) - - It("updates the connection level flow controller", func() { - str.contributesToConnectionFlowControl = true - newVal := str.connectionFlowController.UpdateHighestReceived(10) - Expect(newVal).To(Equal(protocol.ByteCount(10))) - frame := frames.StreamFrame{ - Offset: 2, - Data: []byte("foobar"), - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - Expect(str.connectionFlowController.GetHighestReceived()).To(Equal(protocol.ByteCount(10 + 8))) - }) - - It("doesn't update the connection level flow controller if the stream doesn't contribute", func() { - str.contributesToConnectionFlowControl = false - newVal := str.connectionFlowController.UpdateHighestReceived(10) - Expect(newVal).To(Equal(protocol.ByteCount(10))) - frame := frames.StreamFrame{ - Offset: 2, - Data: []byte("foobar"), - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - Expect(str.connectionFlowController.GetHighestReceived()).To(Equal(protocol.ByteCount(10))) - }) - }) }) Context("writing", func() { @@ -529,87 +525,56 @@ var _ = Describe("Stream", func() { }) }) - Context("flow control window updating, for receiving", func() { - var receiveFlowControlWindow protocol.ByteCount = 1000 - var receiveFlowControlWindowIncrement protocol.ByteCount = 1000 + Context("flow control, for receiving", func() { BeforeEach(func() { - // set receiveFlowControlWindow and receiveFlowControlWindowIncrement in the stream-level flow controller - *(*protocol.ByteCount)(unsafe.Pointer(reflect.ValueOf(str.flowController).Elem().FieldByName("receiveFlowControlWindow").UnsafeAddr())) = receiveFlowControlWindow - *(*protocol.ByteCount)(unsafe.Pointer(reflect.ValueOf(str.flowController).Elem().FieldByName("receiveFlowControlWindowIncrement").UnsafeAddr())) = receiveFlowControlWindowIncrement + str.flowControlManager = &mockFlowControlHandler{} }) - It("updates the flow control window", func() { - len := int(receiveFlowControlWindow)/2 + 1 + It("updates the highestReceived value in the flow controller", func() { frame := frames.StreamFrame{ - Offset: 0, - Data: bytes.Repeat([]byte{'f'}, len), + Offset: 2, + Data: []byte("foobar"), } err := str.AddStreamFrame(&frame) Expect(err).ToNot(HaveOccurred()) - b := make([]byte, len) + Expect(err).ToNot(HaveOccurred()) + Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceivedForStream).To(Equal(str.streamID)) + Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceived).To(Equal(protocol.ByteCount(2 + 6))) + }) + + It("updates the flow control window", func() { + str.flowControlManager.(*mockFlowControlHandler).triggerStreamWindowUpdate = true + frame := frames.StreamFrame{ + Offset: 0, + Data: []byte("foobar"), + } + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 6) n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(len)) + Expect(n).To(Equal(6)) Expect(handler.receiveFlowControlWindowCalled).To(BeTrue()) Expect(handler.receiveFlowControlWindowCalledForStream).To(Equal(str.streamID)) }) It("updates the connection level flow control window", func() { - var connectionReceiveFlowControlWindow protocol.ByteCount = 100 - var connectionReceiveFlowControlWindowIncrement protocol.ByteCount = 100 - // set receiveFlowControlWindow and receiveFlowControlWindowIncrement in the connection-level flow controller - *(*protocol.ByteCount)(unsafe.Pointer(reflect.ValueOf(str.connectionFlowController).Elem().FieldByName("receiveFlowControlWindow").UnsafeAddr())) = connectionReceiveFlowControlWindow - *(*protocol.ByteCount)(unsafe.Pointer(reflect.ValueOf(str.connectionFlowController).Elem().FieldByName("receiveFlowControlWindowIncrement").UnsafeAddr())) = connectionReceiveFlowControlWindowIncrement - - len := 100/2 + 1 + str.flowControlManager.(*mockFlowControlHandler).triggerConnectionWindowUpdate = true frame := frames.StreamFrame{ Offset: 0, - Data: bytes.Repeat([]byte{'f'}, len), + Data: []byte("foobar"), } err := str.AddStreamFrame(&frame) Expect(err).ToNot(HaveOccurred()) - b := make([]byte, len) + b := make([]byte, 6) n, err := str.Read(b) Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(len)) + Expect(n).To(Equal(6)) Expect(handler.receiveFlowControlWindowCalled).To(BeTrue()) Expect(handler.receiveFlowControlWindowCalledForStream).To(Equal(protocol.StreamID(0))) }) - It("does not update the flow control window when not enough data was received", func() { - len := int(receiveFlowControlWindow)/2 - 1 - frame := frames.StreamFrame{ - Offset: 0, - Data: bytes.Repeat([]byte{'f'}, len), - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, len) - n, err := str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(len)) - Expect(handler.receiveFlowControlWindowCalled).To(BeFalse()) - }) - - It("accepts frames that completely fill the flow control window", func() { - len := int(receiveFlowControlWindow) - frame := frames.StreamFrame{ - Offset: 0, - Data: bytes.Repeat([]byte{'f'}, len), - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - }) - - It("rejects too large frames that would violate the flow control window", func() { - len := int(protocol.ReceiveStreamFlowControlWindow) + 1 - frame := frames.StreamFrame{ - Offset: 0, - Data: bytes.Repeat([]byte{'f'}, len), - } - err := str.AddStreamFrame(&frame) - Expect(err).To(MatchError(errFlowControlViolation)) - }) + // TODO: think about flow control violation }) Context("closing", func() {