diff --git a/flowcontrol/flow_control_manager.go b/flowcontrol/flow_control_manager.go index a48c8b13..d9f4af21 100644 --- a/flowcontrol/flow_control_manager.go +++ b/flowcontrol/flow_control_manager.go @@ -17,6 +17,7 @@ type flowControlManager struct { rttStats *congestion.RTTStats streamFlowController map[protocol.StreamID]*flowController + connFlowController *flowController mutex sync.RWMutex } @@ -26,14 +27,12 @@ var errMapAccess = errors.New("Error accessing the flowController map.") // NewFlowControlManager creates a new flow control manager func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager { - fcm := flowControlManager{ + return &flowControlManager{ connectionParameters: connectionParameters, rttStats: rttStats, streamFlowController: make(map[protocol.StreamID]*flowController), + connFlowController: newFlowController(0, false, connectionParameters, rttStats), } - // initialize connection level flow controller - fcm.streamFlowController[0] = newFlowController(0, false, connectionParameters, rttStats) - return &fcm } // NewStream creates new flow controllers for a stream @@ -77,10 +76,9 @@ func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset } if streamFlowController.ContributesToConnection() { - connFlowController := f.streamFlowController[0] - connFlowController.IncrementHighestReceived(increment) - if connFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connFlowController.receiveWindow)) + f.connFlowController.IncrementHighestReceived(increment) + if f.connFlowController.CheckFlowControlViolation() { + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, f.connFlowController.receiveWindow)) } } @@ -107,10 +105,9 @@ func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, b } if streamFlowController.ContributesToConnection() { - connFlowController := f.streamFlowController[0] - connFlowController.IncrementHighestReceived(increment) - if connFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connFlowController.receiveWindow)) + f.connFlowController.IncrementHighestReceived(increment) + if f.connFlowController.CheckFlowControlViolation() { + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, f.connFlowController.receiveWindow)) } } @@ -129,7 +126,7 @@ func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol fc.AddBytesRead(n) if fc.ContributesToConnection() { - f.streamFlowController[0].AddBytesRead(n) + f.connFlowController.AddBytesRead(n) } return nil @@ -139,22 +136,17 @@ func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) { f.mutex.Lock() defer f.mutex.Unlock() - connFlowController := f.streamFlowController[0] - // get WindowUpdates for streams for id, fc := range f.streamFlowController { - if id == 0 { // connection-level updates are dealt with later - continue - } if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary { res = append(res, WindowUpdate{StreamID: id, Offset: offset}) if fc.ContributesToConnection() && newIncrement != 0 { - connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier)) + f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier)) } } } // get a WindowUpdate for the connection - if necessary, _, offset := connFlowController.MaybeUpdateWindow(); necessary { + if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary { res = append(res, WindowUpdate{StreamID: 0, Offset: offset}) } @@ -184,7 +176,7 @@ func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol fc.AddBytesSent(n) if fc.ContributesToConnection() { - f.streamFlowController[0].AddBytesSent(n) + f.connFlowController.AddBytesSent(n) } return nil @@ -202,7 +194,7 @@ func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protoco res := fc.SendWindowSize() if fc.ContributesToConnection() { - res = utils.MinByteCount(res, f.streamFlowController[0].SendWindowSize()) + res = utils.MinByteCount(res, f.connFlowController.SendWindowSize()) } return res, nil @@ -212,7 +204,7 @@ func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount f.mutex.RLock() defer f.mutex.RUnlock() - return f.streamFlowController[0].SendWindowSize() + return f.connFlowController.SendWindowSize() } // streamID may be 0 here @@ -220,9 +212,15 @@ func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset pro f.mutex.Lock() defer f.mutex.Unlock() - fc, err := f.getFlowController(streamID) - if err != nil { - return false, err + var fc *flowController + if streamID == 0 { + fc = f.connFlowController + } else { + var err error + fc, err = f.getFlowController(streamID) + if err != nil { + return false, err + } } return fc.UpdateSendWindow(offset), nil diff --git a/flowcontrol/flow_control_manager_test.go b/flowcontrol/flow_control_manager_test.go index c05e4105..92fde3b2 100644 --- a/flowcontrol/flow_control_manager_test.go +++ b/flowcontrol/flow_control_manager_test.go @@ -26,8 +26,8 @@ var _ = Describe("Flow Control Manager", func() { }) It("creates a connection level flow controller", func() { - Expect(fcm.streamFlowController).To(HaveKey(protocol.StreamID(0))) - Expect(fcm.streamFlowController[0].ContributesToConnection()).To(BeFalse()) + Expect(fcm.streamFlowController).ToNot(HaveKey(protocol.StreamID(0))) + Expect(fcm.connFlowController.ContributesToConnection()).To(BeFalse()) }) Context("creating new streams", func() { @@ -67,7 +67,7 @@ var _ = Describe("Flow Control Manager", func() { It("updates the connection level flow controller if the stream contributes", func() { err := fcm.UpdateHighestReceived(4, 100) Expect(err).ToNot(HaveOccurred()) - Expect(fcm.streamFlowController[0].highestReceived).To(Equal(protocol.ByteCount(100))) + Expect(fcm.connFlowController.highestReceived).To(Equal(protocol.ByteCount(100))) Expect(fcm.streamFlowController[4].highestReceived).To(Equal(protocol.ByteCount(100))) }) @@ -76,14 +76,14 @@ var _ = Describe("Flow Control Manager", func() { Expect(err).ToNot(HaveOccurred()) err = fcm.UpdateHighestReceived(6, 50) Expect(err).ToNot(HaveOccurred()) - Expect(fcm.streamFlowController[0].highestReceived).To(Equal(protocol.ByteCount(100 + 50))) + Expect(fcm.connFlowController.highestReceived).To(Equal(protocol.ByteCount(100 + 50))) }) It("does not update the connection level flow controller if the stream does not contribute", func() { err := fcm.UpdateHighestReceived(1, 100) // fcm.streamFlowController[4].receiveWindow = 0x1000 Expect(err).ToNot(HaveOccurred()) - Expect(fcm.streamFlowController[0].highestReceived).To(BeZero()) + Expect(fcm.connFlowController.highestReceived).To(BeZero()) Expect(fcm.streamFlowController[1].highestReceived).To(Equal(protocol.ByteCount(100))) }) @@ -195,14 +195,14 @@ var _ = Describe("Flow Control Manager", func() { It("updates the connection level flow controller if the stream contributes", func() { err := fcm.ResetStream(4, 100) Expect(err).ToNot(HaveOccurred()) - Expect(fcm.streamFlowController[0].highestReceived).To(Equal(protocol.ByteCount(100))) + Expect(fcm.connFlowController.highestReceived).To(Equal(protocol.ByteCount(100))) Expect(fcm.streamFlowController[4].highestReceived).To(Equal(protocol.ByteCount(100))) }) It("does not update the connection level flow controller if the stream does not contribute", func() { err := fcm.ResetStream(1, 100) Expect(err).ToNot(HaveOccurred()) - Expect(fcm.streamFlowController[0].highestReceived).To(BeZero()) + Expect(fcm.connFlowController.highestReceived).To(BeZero()) Expect(fcm.streamFlowController[1].highestReceived).To(Equal(protocol.ByteCount(100))) }) @@ -243,7 +243,7 @@ var _ = Describe("Flow Control Manager", func() { Expect(err).ToNot(HaveOccurred()) err = fcm.AddBytesSent(5, 500) Expect(err).ToNot(HaveOccurred()) - Expect(fcm.streamFlowController[0].bytesSent).To(Equal(protocol.ByteCount(200 + 500))) + Expect(fcm.connFlowController.bytesSent).To(Equal(protocol.ByteCount(200 + 500))) }) It("errors when called for a stream doesn't exist", func() {