diff --git a/packet_packer.go b/packet_packer.go index 85b310b07..a2ca327d5 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -48,7 +48,7 @@ func (p *packetPacker) AddWindowUpdateFrame(f *frames.WindowUpdateFrame) { func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, includeStreamFrames bool) (*packedPacket, error) { // don't send out packets that only contain a StopWaitingFrame - if len(controlFrames) == 0 && (p.streamFrameQueue.Len() == 0 || !includeStreamFrames) { + if len(p.windowUpdateFrames) == 0 && len(controlFrames) == 0 && (p.streamFrameQueue.Len() == 0 || !includeStreamFrames) { return nil, nil } diff --git a/protocol/server_parameters.go b/protocol/server_parameters.go index 6ed22cda7..42206beec 100644 --- a/protocol/server_parameters.go +++ b/protocol/server_parameters.go @@ -28,9 +28,13 @@ const SmallPacketSendDelay = 500 * time.Microsecond // TODO: set a reasonable value here const ReceiveStreamFlowControlWindow ByteCount = (1 << 20) // 1 MB +// ReceiveStreamFlowControlWindowIncrement is the amount that the stream-level flow control window is increased when sending a WindowUpdate +const ReceiveStreamFlowControlWindowIncrement = ReceiveStreamFlowControlWindow + // ReceiveConnectionFlowControlWindow is the stream-level flow control window for receiving data +// temporarily set this to a very high value, until proper connection-level flow control is implemented // TODO: set a reasonable value here -const ReceiveConnectionFlowControlWindow ByteCount = (1 << 20) // 1 MB +const ReceiveConnectionFlowControlWindow ByteCount = (1 << 20) * 1024 * 2 // 2 GB // MaxStreamsPerConnection is the maximum value accepted for the number of streams per connection // TODO: set a reasonable value here @@ -39,3 +43,6 @@ const MaxStreamsPerConnection uint32 = 100 // MaxIdleConnectionStateLifetime is the maximum value accepted for the idle connection state lifetime // TODO: set a reasonable value here const MaxIdleConnectionStateLifetime = 60 * time.Second + +// WindowUpdateThreshold is the size of the receive flow control window for which we send out a WindowUpdate frame +const WindowUpdateThreshold = ReceiveStreamFlowControlWindow / 2 diff --git a/session.go b/session.go index 9bfc175e3..cede862dd 100644 --- a/session.go +++ b/session.go @@ -290,7 +290,7 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error return errWindowUpdateOnInvalidStream } - stream.UpdateFlowControlWindow(frame.ByteOffset) + stream.UpdateSendFlowControlWindow(frame.ByteOffset) return nil } @@ -454,7 +454,6 @@ func (s *Session) sendPacket() error { } } - stopWaitingFrame := s.stopWaitingManager.GetStopWaitingFrame() ack, err := s.receivedPacketHandler.GetAckFrame(true) if err != nil { @@ -463,6 +462,8 @@ func (s *Session) sendPacket() error { if ack != nil { controlFrames = append(controlFrames, ack) } + + stopWaitingFrame := s.stopWaitingManager.GetStopWaitingFrame() packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, true) if err != nil { @@ -522,6 +523,16 @@ func (s *Session) QueueStreamFrame(frame *frames.StreamFrame) error { return nil } +// UpdateReceiveFlowControlWindow updates the flow control window for a stream +func (s *Session) UpdateReceiveFlowControlWindow(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { + wuf := frames.WindowUpdateFrame{ + StreamID: streamID, + ByteOffset: byteOffset, + } + s.packer.AddWindowUpdateFrame(&wuf) + return nil +} + // NewStream creates a new stream open for reading and writing func (s *Session) NewStream(id protocol.StreamID) (utils.Stream, error) { s.streamsMutex.Lock() diff --git a/session_test.go b/session_test.go index 37309ad7c..747f01c13 100644 --- a/session_test.go +++ b/session_test.go @@ -342,6 +342,17 @@ var _ = Describe("Session", func() { Expect(conn.written[0]).To(ContainSubstring(string("foobar"))) }) + It("sends a WindowUpdate frame", func() { + _, err := session.NewStream(5) + Expect(err).ToNot(HaveOccurred()) + err = session.UpdateReceiveFlowControlWindow(5, 0xDECAFBAD) + Expect(err).ToNot(HaveOccurred()) + err = session.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(conn.written).To(HaveLen(1)) + Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0}))) + }) + It("sends public reset", func() { err := session.sendPublicReset(1) Expect(err).NotTo(HaveOccurred()) diff --git a/stream.go b/stream.go index bd20ee741..96be577b8 100644 --- a/stream.go +++ b/stream.go @@ -13,6 +13,7 @@ import ( type streamHandler interface { QueueStreamFrame(*frames.StreamFrame) error + UpdateReceiveFlowControlWindow(streamID protocol.StreamID, byteOffset protocol.ByteCount) error } // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface @@ -33,8 +34,10 @@ type stream struct { frameQueue streamFrameSorter newFrameOrErrCond sync.Cond - flowControlWindow protocol.ByteCount - windowUpdateOrErrCond sync.Cond + sendFlowControlWindow protocol.ByteCount + receiveFlowControlWindow protocol.ByteCount + receiveFlowControlWindowIncrement protocol.ByteCount + windowUpdateOrErrCond sync.Cond } // newStream creates a new Stream @@ -47,7 +50,9 @@ func newStream(session streamHandler, connectionParameterManager *handshake.Conn s.newFrameOrErrCond.L = &s.mutex s.windowUpdateOrErrCond.L = &s.mutex - s.flowControlWindow = connectionParameterManager.GetSendStreamFlowControlWindow() + s.sendFlowControlWindow = connectionParameterManager.GetSendStreamFlowControlWindow() + s.receiveFlowControlWindow = connectionParameterManager.GetReceiveStreamFlowControlWindow() + s.receiveFlowControlWindowIncrement = protocol.ReceiveStreamFlowControlWindowIncrement return s, nil } @@ -104,6 +109,9 @@ func (s *stream) Read(p []byte) (int, error) { s.readPosInFrame += m bytesRead += m s.readOffset += protocol.ByteCount(m) + + s.maybeTriggerWindowUpdate() + if s.readPosInFrame >= len(frame.Data) { fin := frame.FinBit s.mutex.Lock() @@ -126,11 +134,17 @@ func (s *stream) ReadByte() (byte, error) { return p[0], err } -func (s *stream) UpdateFlowControlWindow(n protocol.ByteCount) { +func (s *stream) updateReceiveFlowControlWindow() { + n := s.receiveFlowControlWindow + s.receiveFlowControlWindowIncrement + s.receiveFlowControlWindow = n + s.session.UpdateReceiveFlowControlWindow(s.streamID, n) +} + +func (s *stream) UpdateSendFlowControlWindow(n protocol.ByteCount) { s.mutex.Lock() defer s.mutex.Unlock() - if n > s.flowControlWindow { - s.flowControlWindow = n + if n > s.sendFlowControlWindow { + s.sendFlowControlWindow = n s.windowUpdateOrErrCond.Broadcast() } } @@ -148,10 +162,10 @@ func (s *stream) Write(p []byte) (int, error) { for dataWritten < len(p) { s.mutex.Lock() - remainingBytesInWindow := int64(s.flowControlWindow) - int64(s.writeOffset) + remainingBytesInWindow := int64(s.sendFlowControlWindow) - int64(s.writeOffset) for remainingBytesInWindow == 0 && s.err == nil { s.windowUpdateOrErrCond.Wait() - remainingBytesInWindow = int64(s.flowControlWindow) - int64(s.writeOffset) + remainingBytesInWindow = int64(s.sendFlowControlWindow) - int64(s.writeOffset) } s.mutex.Unlock() @@ -190,6 +204,7 @@ func (s *stream) Close() error { // AddStreamFrame adds a new stream frame func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { + // TODO: return flow control window violation here s.mutex.Lock() s.frameQueue.Push(frame) s.mutex.Unlock() @@ -197,6 +212,13 @@ func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { return nil } +func (s *stream) maybeTriggerWindowUpdate() { + diff := s.receiveFlowControlWindow - s.readOffset + if diff < protocol.WindowUpdateThreshold { + s.updateReceiveFlowControlWindow() + } +} + // RegisterError is called by session to indicate that an error occurred and the // stream should be closed. func (s *stream) RegisterError(err error) { diff --git a/stream_test.go b/stream_test.go index 2e71553cb..ef901960e 100644 --- a/stream_test.go +++ b/stream_test.go @@ -1,6 +1,7 @@ package quic import ( + "bytes" "errors" "io" "time" @@ -21,6 +22,10 @@ func (m *mockStreamHandler) QueueStreamFrame(f *frames.StreamFrame) error { return nil } +func (m *mockStreamHandler) UpdateReceiveFlowControlWindow(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { + return nil +} + var _ = Describe("Stream", func() { var ( str *stream @@ -256,7 +261,7 @@ var _ = Describe("Stream", func() { Context("flow control", func() { It("writes everything if the flow control window is big enough", func() { - str.flowControlWindow = 4 + str.sendFlowControlWindow = 4 n, err := str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD}) Expect(n).To(Equal(4)) Expect(err).ToNot(HaveOccurred()) @@ -264,14 +269,14 @@ var _ = Describe("Stream", func() { It("waits for a flow control window update", func() { var b bool - str.flowControlWindow = 1 + str.sendFlowControlWindow = 1 _, err := str.Write([]byte{0x42}) Expect(err).ToNot(HaveOccurred()) go func() { time.Sleep(2 * time.Millisecond) b = true - str.UpdateFlowControlWindow(3) + str.UpdateSendFlowControlWindow(3) }() n, err := str.Write([]byte{0x13, 0x37}) Expect(b).To(BeTrue()) @@ -280,13 +285,13 @@ var _ = Describe("Stream", func() { }) It("splits writing of frames when given more data than the flow control windows size", func() { - str.flowControlWindow = 2 + str.sendFlowControlWindow = 2 var b bool go func() { time.Sleep(time.Millisecond) b = true - str.UpdateFlowControlWindow(4) + str.UpdateSendFlowControlWindow(4) }() n, err := str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD}) @@ -298,14 +303,14 @@ var _ = Describe("Stream", func() { It("writes after a flow control window update", func() { var b bool - str.flowControlWindow = 1 + str.sendFlowControlWindow = 1 _, err := str.Write([]byte{0x42}) Expect(err).ToNot(HaveOccurred()) go func() { time.Sleep(time.Millisecond) b = true - str.UpdateFlowControlWindow(3) + str.UpdateSendFlowControlWindow(3) }() n, err := str.Write([]byte{0xDE, 0xAD}) Expect(b).To(BeTrue()) @@ -315,7 +320,7 @@ var _ = Describe("Stream", func() { It("immediately returns on remote errors", func() { var b bool - str.flowControlWindow = 1 + str.sendFlowControlWindow = 1 testErr := errors.New("test error") @@ -332,17 +337,49 @@ var _ = Describe("Stream", func() { }) }) - Context("flow control window updating", func() { + Context("flow control window updating, for sending", func() { It("updates the flow control window", func() { - str.flowControlWindow = 3 - str.UpdateFlowControlWindow(4) - Expect(str.flowControlWindow).To(Equal(protocol.ByteCount(4))) + str.sendFlowControlWindow = 3 + str.UpdateSendFlowControlWindow(4) + Expect(str.sendFlowControlWindow).To(Equal(protocol.ByteCount(4))) }) It("never shrinks the flow control window", func() { - str.flowControlWindow = 100 - str.UpdateFlowControlWindow(50) - Expect(str.flowControlWindow).To(Equal(protocol.ByteCount(100))) + str.sendFlowControlWindow = 100 + str.UpdateSendFlowControlWindow(50) + Expect(str.sendFlowControlWindow).To(Equal(protocol.ByteCount(100))) + }) + }) + + Context("flow control window updating, for receiving", func() { + It("updates the flow control window", func() { + len := int(protocol.WindowUpdateThreshold) + 1 + receiveFlowControlWindow := str.receiveFlowControlWindow + frame := frames.StreamFrame{ + Offset: 0, + Data: bytes.Repeat([]byte{'f'}, len), + } + str.AddStreamFrame(&frame) + b := make([]byte, len) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len)) + Expect(str.receiveFlowControlWindow).To(Equal(receiveFlowControlWindow + str.receiveFlowControlWindowIncrement)) + }) + + It("does not update the flow control window when not enough data was received", func() { + len := int(protocol.WindowUpdateThreshold) - 1 + receiveFlowControlWindow := str.receiveFlowControlWindow + frame := frames.StreamFrame{ + Offset: 0, + Data: bytes.Repeat([]byte{'f'}, len), + } + str.AddStreamFrame(&frame) + b := make([]byte, len) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len)) + Expect(str.receiveFlowControlWindow).To(Equal(receiveFlowControlWindow)) }) })