diff --git a/packet_packer.go b/packet_packer.go index 8e2e4235..d0faed8f 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -233,3 +233,7 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra return payloadFrames, nil } + +func (p *packetPacker) QueueControlFrameForNextPacket(f frames.Frame) { + p.controlFrames = append(p.controlFrames, f) +} diff --git a/protocol/server_parameters.go b/protocol/server_parameters.go index 9451fdc3..1e87d6cc 100644 --- a/protocol/server_parameters.go +++ b/protocol/server_parameters.go @@ -42,9 +42,6 @@ const MaxStreamsMultiplier = 1.1 // TODO: set a reasonable value here const MaxIdleConnectionStateLifetime = 60 * time.Second -// WindowUpdateNumRepetitions is the number of times the same WindowUpdate frame will be sent to the client -const WindowUpdateNumRepetitions uint8 = 2 - // MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. const MaxSessionUnprocessedPackets = 128 diff --git a/session.go b/session.go index d06c53df..4fb960ad 100644 --- a/session.go +++ b/session.go @@ -56,7 +56,6 @@ type Session struct { sentPacketHandler ackhandlerlegacy.SentPacketHandler receivedPacketHandler ackhandlerlegacy.ReceivedPacketHandler stopWaitingManager ackhandlerlegacy.StopWaitingManager - windowUpdateManager *windowUpdateManager streamFramer *streamFramer flowControlManager flowcontrol.FlowControlManager @@ -108,7 +107,6 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol receivedPacketHandler: ackhandlerlegacy.NewReceivedPacketHandler(), stopWaitingManager: stopWaitingManager, flowControlManager: flowControlManager, - windowUpdateManager: newWindowUpdateManager(), receivedPackets: make(chan receivedPacket, protocol.MaxSessionUnprocessedPackets), closeChan: make(chan *qerr.QuicError, 1), sendingScheduled: make(chan struct{}, 1), @@ -510,7 +508,10 @@ func (s *Session) sendPacket() error { } } - windowUpdateFrames := s.windowUpdateManager.GetWindowUpdateFrames() + windowUpdateFrames, err := s.getWindowUpdateFrames() + if err != nil { + return err + } for _, wuf := range windowUpdateFrames { controlFrames = append(controlFrames, wuf) @@ -534,6 +535,10 @@ func (s *Session) sendPacket() error { return nil } + for _, f := range windowUpdateFrames { + s.packer.QueueControlFrameForNextPacket(f) + } + err = s.sentPacketHandler.SentPacket(&ackhandlerlegacy.Packet{ PacketNumber: packet.number, Frames: packet.frames, @@ -585,12 +590,6 @@ func (s *Session) logPacket(packet *packedPacket) { } } -// updateReceiveFlowControlWindow updates the flow control window for a stream -func (s *Session) updateReceiveFlowControlWindow(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - s.windowUpdateManager.SetStreamOffset(streamID, byteOffset) - return nil -} - // OpenStream creates a new stream open for reading and writing func (s *Session) OpenStream(id protocol.StreamID) (utils.Stream, error) { s.streamsMutex.Lock() @@ -644,9 +643,6 @@ func (s *Session) garbageCollectStreams() { if v == nil { continue } - if v.finishedReading() { - s.windowUpdateManager.RemoveStream(k) - } if v.finished() { utils.Debugf("Garbage-collecting stream %d", k) atomic.AddUint32(&s.openStreamsCount, ^uint32(0)) // decrement @@ -682,3 +678,31 @@ func (s *Session) tryDecryptingQueuedPackets() { } s.undecryptablePackets = s.undecryptablePackets[:0] } + +func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) { + s.streamsMutex.RLock() + defer s.streamsMutex.RUnlock() + + var res []*frames.WindowUpdateFrame + + for id, str := range s.streams { + if str == nil { + continue + } + + doUpdate, offset, err := s.flowControlManager.MaybeTriggerStreamWindowUpdate(id) + if err != nil { + return nil, err + } + if doUpdate { + res = append(res, &frames.WindowUpdateFrame{StreamID: id, ByteOffset: offset}) + } + } + + doUpdate, offset := s.flowControlManager.MaybeTriggerConnectionWindowUpdate() + if doUpdate { + res = append(res, &frames.WindowUpdateFrame{StreamID: 0, ByteOffset: offset}) + } + + return res, nil +} diff --git a/session_test.go b/session_test.go index d2d52171..fd0a5af1 100644 --- a/session_test.go +++ b/session_test.go @@ -229,17 +229,6 @@ var _ = Describe("Session", func() { Expect(session.streams[5]).To(BeNil()) }) - It("removes closed streams from WindowUpdateManager", func() { - session.handleStreamFrame(&frames.StreamFrame{ - StreamID: 5, - Data: []byte{0xde, 0xca, 0xfb, 0xad}, - }) - session.updateReceiveFlowControlWindow(5, 0x1337) - session.streams[5].eof = 1 - session.garbageCollectStreams() - Expect(session.windowUpdateManager.streamOffsets).ToNot(HaveKey(protocol.StreamID(5))) - }) - It("closes empty streams with error", func() { testErr := errors.New("test") session.newStreamImpl(5) @@ -445,31 +434,19 @@ var _ = Describe("Session", func() { Expect(conn.written[0]).To(ContainSubstring(string([]byte{byte(entropy), 0x35, 0x01}))) }) - It("sends a WindowUpdate frame", func() { + It("sends two WindowUpdate frames", func() { _, err := session.OpenStream(5) Expect(err).ToNot(HaveOccurred()) - err = session.updateReceiveFlowControlWindow(5, 0xDECAFBAD) - Expect(err).ToNot(HaveOccurred()) + session.flowControlManager.AddBytesRead(5, protocol.ReceiveStreamFlowControlWindow) err = session.sendPacket() Expect(err).NotTo(HaveOccurred()) - Expect(conn.written).To(HaveLen(1)) + err = session.sendPacket() + Expect(err).NotTo(HaveOccurred()) + err = session.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(conn.written).To(HaveLen(2)) Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0}))) - }) - - It("repeats a WindowUpdate frame in WindowUpdateNumRepetitions packets", func() { - _, err := session.OpenStream(5) - Expect(err).ToNot(HaveOccurred()) - err = session.updateReceiveFlowControlWindow(5, 0xDECAFBAD) - Expect(err).ToNot(HaveOccurred()) - for i := uint8(0); i < protocol.WindowUpdateNumRepetitions; i++ { - err = session.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(conn.written[i]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0}))) - } - Expect(conn.written).To(HaveLen(int(protocol.WindowUpdateNumRepetitions))) - err = session.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(conn.written).To(HaveLen(int(protocol.WindowUpdateNumRepetitions))) // no packet was sent + Expect(conn.written[1]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0}))) }) It("sends public reset", func() { @@ -759,4 +736,28 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) }) }) + + Context("window updates", func() { + It("gets stream level window updates", func() { + err := session.flowControlManager.AddBytesRead(1, protocol.ReceiveStreamFlowControlWindow) + Expect(err).NotTo(HaveOccurred()) + frames, err := session.getWindowUpdateFrames() + Expect(err).NotTo(HaveOccurred()) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].StreamID).To(Equal(protocol.StreamID(1))) + Expect(frames[0].ByteOffset).To(Equal(protocol.ReceiveStreamFlowControlWindow * 2)) + }) + + It("gets connection level window updates", func() { + _, err := session.OpenStream(5) + Expect(err).NotTo(HaveOccurred()) + err = session.flowControlManager.AddBytesRead(5, protocol.ReceiveConnectionFlowControlWindow) + Expect(err).NotTo(HaveOccurred()) + frames, err := session.getWindowUpdateFrames() + Expect(err).NotTo(HaveOccurred()) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].StreamID).To(Equal(protocol.StreamID(0))) + Expect(frames[0].ByteOffset).To(Equal(protocol.ReceiveConnectionFlowControlWindow * 2)) + }) + }) }) diff --git a/stream.go b/stream.go index 39743b58..71398658 100644 --- a/stream.go +++ b/stream.go @@ -14,7 +14,6 @@ import ( ) type streamHandler interface { - updateReceiveFlowControlWindow(streamID protocol.StreamID, byteOffset protocol.ByteCount) error scheduleSending() } @@ -132,7 +131,6 @@ func (s *stream) Read(p []byte) (int, error) { s.readOffset += protocol.ByteCount(m) s.flowControlManager.AddBytesRead(s.streamID, protocol.ByteCount(m)) - s.maybeTriggerWindowUpdate() if s.readPosInFrame >= int(frame.DataLen()) { fin := frame.FinBit @@ -250,25 +248,6 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) { s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset}) } -func (s *stream) maybeTriggerWindowUpdate() error { - // check for stream level window updates - doUpdate, byteOffset, err := s.flowControlManager.MaybeTriggerStreamWindowUpdate(s.streamID) - if err != nil { - return err - } - if doUpdate { - s.session.updateReceiveFlowControlWindow(s.streamID, byteOffset) - } - - // check for connection level window updates - doUpdate, byteOffset = s.flowControlManager.MaybeTriggerConnectionWindowUpdate() - if doUpdate { - s.session.updateReceiveFlowControlWindow(0, byteOffset) - } - - return nil -} - // RegisterError is called by session to indicate that an error occurred and the // stream should be closed. func (s *stream) RegisterError(err error) { diff --git a/stream_test.go b/stream_test.go index fdcaef62..b5cb42ce 100644 --- a/stream_test.go +++ b/stream_test.go @@ -14,17 +14,9 @@ import ( ) type mockStreamHandler struct { - receiveFlowControlWindowCalled bool - receiveFlowControlWindowCalledForStream protocol.StreamID - scheduledSending bool } -func (m *mockStreamHandler) updateReceiveFlowControlWindow(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - m.receiveFlowControlWindowCalled = true - m.receiveFlowControlWindowCalledForStream = streamID - return nil -} func (m *mockStreamHandler) scheduleSending() { m.scheduledSending = true } type mockFlowControlHandler struct { @@ -416,40 +408,6 @@ var _ = Describe("Stream", func() { Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceivedForStream).To(Equal(str.streamID)) Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceived).To(Equal(protocol.ByteCount(2 + 6))) }) - - It("updates the flow control window", func() { - str.flowControlManager.(*mockFlowControlHandler).triggerStreamWindowUpdate = true - frame := frames.StreamFrame{ - Offset: 0, - Data: []byte("foobar"), - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - n, err := str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(handler.receiveFlowControlWindowCalled).To(BeTrue()) - Expect(handler.receiveFlowControlWindowCalledForStream).To(Equal(str.streamID)) - }) - - It("updates the connection level flow control window", func() { - str.flowControlManager.(*mockFlowControlHandler).triggerConnectionWindowUpdate = true - frame := frames.StreamFrame{ - Offset: 0, - Data: []byte("foobar"), - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - n, err := str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(handler.receiveFlowControlWindowCalled).To(BeTrue()) - Expect(handler.receiveFlowControlWindowCalledForStream).To(Equal(protocol.StreamID(0))) - }) - - // TODO: think about flow control violation }) Context("closing", func() { diff --git a/window_update_manager.go b/window_update_manager.go deleted file mode 100644 index 70e3bde3..00000000 --- a/window_update_manager.go +++ /dev/null @@ -1,74 +0,0 @@ -package quic - -import ( - "sync" - - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" -) - -type windowUpdateItem struct { - Offset protocol.ByteCount - Counter uint8 -} - -// windowUpdateManager manages window update frames for receiving data -type windowUpdateManager struct { - streamOffsets map[protocol.StreamID]*windowUpdateItem - mutex sync.RWMutex -} - -// newWindowUpdateManager returns a new windowUpdateManager -func newWindowUpdateManager() *windowUpdateManager { - return &windowUpdateManager{ - streamOffsets: make(map[protocol.StreamID]*windowUpdateItem), - } -} - -// SetStreamOffset sets an offset for a stream -func (m *windowUpdateManager) SetStreamOffset(streamID protocol.StreamID, n protocol.ByteCount) { - m.mutex.Lock() - defer m.mutex.Unlock() - - entry, ok := m.streamOffsets[streamID] - if !ok { - m.streamOffsets[streamID] = &windowUpdateItem{Offset: n} - return - } - - if n > entry.Offset { - entry.Offset = n - entry.Counter = 0 - } -} - -// GetWindowUpdateFrames gets all the WindowUpdate frames that need to be sent -func (m *windowUpdateManager) GetWindowUpdateFrames() []*frames.WindowUpdateFrame { - m.mutex.RLock() - defer m.mutex.RUnlock() - - var wuf []*frames.WindowUpdateFrame - - for key, value := range m.streamOffsets { - if value.Counter >= protocol.WindowUpdateNumRepetitions { - continue - } - - frame := frames.WindowUpdateFrame{ - StreamID: key, - ByteOffset: value.Offset, - } - value.Counter++ - wuf = append(wuf, &frame) - } - - return wuf -} - -// RemoveStream should be called when a stream is closed for receiving -func (m *windowUpdateManager) RemoveStream(streamID protocol.StreamID) { - m.mutex.Lock() - defer m.mutex.Unlock() - - delete(m.streamOffsets, streamID) -} diff --git a/window_update_manager_test.go b/window_update_manager_test.go deleted file mode 100644 index 0e6607b2..00000000 --- a/window_update_manager_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package quic - -import ( - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("WindowUpdateManager", func() { - var wum *windowUpdateManager - - BeforeEach(func() { - wum = newWindowUpdateManager() - }) - - Context("queueing new window updates", func() { - It("queues a window update for a new stream", func() { - wum.SetStreamOffset(5, 0x1000) - Expect(wum.streamOffsets).To(HaveKey(protocol.StreamID(5))) - Expect(wum.streamOffsets[5].Offset).To(Equal(protocol.ByteCount(0x1000))) - }) - - It("updates the offset for an existing stream", func() { - wum.SetStreamOffset(5, 0x1000) - wum.SetStreamOffset(5, 0x2000) - Expect(wum.streamOffsets).To(HaveKey(protocol.StreamID(5))) - Expect(wum.streamOffsets[5].Offset).To(Equal(protocol.ByteCount(0x2000))) - }) - - It("does not decrease the offset for an existing stream", func() { - wum.SetStreamOffset(5, 0x1000) - wum.SetStreamOffset(5, 0x500) - Expect(wum.streamOffsets).To(HaveKey(protocol.StreamID(5))) - Expect(wum.streamOffsets[5].Offset).To(Equal(protocol.ByteCount(0x1000))) - }) - - It("resets the counter after increasing the offset", func() { - wum.streamOffsets[5] = &windowUpdateItem{ - Offset: 0x1000, - Counter: 1, - } - wum.SetStreamOffset(5, 0x2000) - Expect(wum.streamOffsets[5].Offset).To(Equal(protocol.ByteCount(0x2000))) - Expect(wum.streamOffsets[5].Counter).To(Equal(uint8(0))) - }) - }) - - Context("dequeueing window updates", func() { - BeforeEach(func() { - wum.SetStreamOffset(7, 0x1000) - wum.SetStreamOffset(9, 0x500) - }) - - It("gets the window update frames", func() { - f := wum.GetWindowUpdateFrames() - Expect(f).To(HaveLen(2)) - Expect(f).To(ContainElement(&frames.WindowUpdateFrame{StreamID: 7, ByteOffset: 0x1000})) - Expect(f).To(ContainElement(&frames.WindowUpdateFrame{StreamID: 9, ByteOffset: 0x500})) - }) - - It("increases the counter", func() { - _ = wum.GetWindowUpdateFrames() - Expect(wum.streamOffsets[7].Counter).To(Equal(uint8(1))) - Expect(wum.streamOffsets[9].Counter).To(Equal(uint8(1))) - }) - - It("only sends out a window update frame WindowUpdateNumRepetitions times", func() { - for i := uint8(0); i < protocol.WindowUpdateNumRepetitions; i++ { - frames := wum.GetWindowUpdateFrames() - Expect(frames).To(HaveLen(2)) - } - frames := wum.GetWindowUpdateFrames() - Expect(frames).To(BeEmpty()) - }) - }) - - Context("removing streams", func() { - It("deletes the map entry", func() { - wum.SetStreamOffset(7, 0x1000) - Expect(wum.streamOffsets).To(HaveKey(protocol.StreamID(7))) - wum.RemoveStream(7) - Expect(wum.streamOffsets).ToNot(HaveKey(protocol.StreamID(7))) - }) - }) -})