refactor frame packing to logic to not access the streams map (#4596)

* avoid accessing the streams map when packing stream data

* avoid accessing the streams map when packing flow control frames

* remove streamGetter interface
This commit is contained in:
Marten Seemann
2024-07-28 12:32:54 -07:00
committed by GitHub
parent fc79a77ffe
commit 42f04d4e02
20 changed files with 224 additions and 390 deletions

View File

@@ -28,11 +28,6 @@ type unpacker interface {
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
}
type streamGetter interface {
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
}
type streamManager interface {
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
@@ -464,7 +459,6 @@ func (s *connection) preSetup() {
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
protocol.ByteCount(s.config.MaxConnectionReceiveWindow),
s.onHasConnectionWindowUpdate,
func(size protocol.ByteCount) bool {
if s.config.AllowConnectionWindowIncrease == nil {
return true
@@ -483,7 +477,7 @@ func (s *connection) preSetup() {
uint64(s.config.MaxIncomingUniStreams),
s.perspective,
)
s.framer = newFramer(s.streamsMap)
s.framer = newFramer()
s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets)
s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1)
@@ -493,7 +487,7 @@ func (s *connection) preSetup() {
s.lastPacketReceivedTime = now
s.creationTime = now
s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame)
s.windowUpdateQueue = newWindowUpdateQueue(s.connFlowController, s.framer.QueueControlFrame)
s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger)
s.connState.Version = s.version
}
@@ -2214,7 +2208,6 @@ func (s *connection) newFlowController(id protocol.StreamID) flowcontrol.StreamF
protocol.ByteCount(s.config.InitialStreamReceiveWindow),
protocol.ByteCount(s.config.MaxStreamReceiveWindow),
initialSendWindow,
s.onHasStreamWindowUpdate,
s.rttStats,
s.logger,
)
@@ -2253,18 +2246,13 @@ func (s *connection) queueControlFrame(f wire.Frame) {
s.scheduleSending()
}
func (s *connection) onHasStreamWindowUpdate(id protocol.StreamID) {
s.windowUpdateQueue.AddStream(id)
func (s *connection) onHasStreamWindowUpdate(id protocol.StreamID, str receiveStreamI) {
s.windowUpdateQueue.AddStream(id, str)
s.scheduleSending()
}
func (s *connection) onHasConnectionWindowUpdate() {
s.windowUpdateQueue.AddConnection()
s.scheduleSending()
}
func (s *connection) onHasStreamData(id protocol.StreamID) {
s.framer.AddActiveStream(id)
func (s *connection) onHasStreamData(id protocol.StreamID, str sendStreamI) {
s.framer.AddActiveStream(id, str)
s.scheduleSending()
}
@@ -2272,6 +2260,8 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) {
if err := s.streamsMap.DeleteStream(id); err != nil {
s.closeLocal(err)
}
s.framer.RemoveActiveStream(id)
s.windowUpdateQueue.RemoveStream(id)
}
func (s *connection) onMTUIncreased(mtu protocol.ByteCount) {

View File

@@ -19,9 +19,7 @@ const (
type framer struct {
mutex sync.Mutex
streamGetter streamGetter
activeStreams map[protocol.StreamID]struct{}
activeStreams map[protocol.StreamID]sendStreamI
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
controlFrameMutex sync.Mutex
@@ -30,11 +28,8 @@ type framer struct {
queuedTooManyControlFrames bool
}
func newFramer(streamGetter streamGetter) *framer {
return &framer{
streamGetter: streamGetter,
activeStreams: make(map[protocol.StreamID]struct{}),
}
func newFramer() *framer {
return &framer{activeStreams: make(map[protocol.StreamID]sendStreamI)}
}
func (f *framer) HasData() bool {
@@ -109,15 +104,25 @@ func (f *framer) QueuedTooManyControlFrames() bool {
return f.queuedTooManyControlFrames
}
func (f *framer) AddActiveStream(id protocol.StreamID) {
func (f *framer) AddActiveStream(id protocol.StreamID, str sendStreamI) {
f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue.PushBack(id)
f.activeStreams[id] = struct{}{}
f.activeStreams[id] = str
}
f.mutex.Unlock()
}
// RemoveActiveStream is called when a stream completes.
func (f *framer) RemoveActiveStream(id protocol.StreamID) {
f.mutex.Lock()
delete(f.activeStreams, id)
// We don't delete the stream from the streamQueue,
// since we'd have to iterate over the ringbuffer.
// Instead, we check if the stream is still in activeStreams in AppendStreamFrames.
f.mutex.Unlock()
}
func (f *framer) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) {
startLen := len(frames)
var length protocol.ByteCount
@@ -131,10 +136,9 @@ func (f *framer) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen prot
id := f.streamQueue.PopFront()
// This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there.
str, err := f.streamGetter.GetOrOpenSendStream(id)
// The stream can be nil if it completed after it said it had data.
if str == nil || err != nil {
delete(f.activeStreams, id)
str, ok := f.activeStreams[id]
// The stream might have been removed after being enqueued.
if !ok {
continue
}
remainingLen := maxLen - length

View File

@@ -23,17 +23,15 @@ var _ = Describe("Framer", func() {
var (
framer *framer
stream1, stream2 *MockSendStreamI
streamGetter *MockStreamGetter
version protocol.Version
)
BeforeEach(func() {
streamGetter = NewMockStreamGetter(mockCtrl)
stream1 = NewMockSendStreamI(mockCtrl)
stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes()
stream2 = NewMockSendStreamI(mockCtrl)
stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes()
framer = newFramer(streamGetter)
framer = newFramer()
})
Context("handling control frames", func() {
@@ -183,7 +181,6 @@ var _ = Describe("Framer", func() {
})
It("returns STREAM frames", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
f := &wire.StreamFrame{
StreamID: id1,
Data: []byte("foobar"),
@@ -191,7 +188,7 @@ var _ = Describe("Framer", func() {
DataLenPresent: true,
}
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id1, stream1)
fs, length := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
Expect(fs).To(HaveLen(1))
Expect(fs[0].Frame.DataLenPresent).To(BeFalse())
@@ -199,9 +196,8 @@ var _ = Describe("Framer", func() {
})
It("says if it has data", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
Expect(framer.HasData()).To(BeFalse())
framer.AddActiveStream(id1)
framer.AddActiveStream(id1, stream1)
Expect(framer.HasData()).To(BeTrue())
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foo")}
f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("bar")}
@@ -218,14 +214,13 @@ var _ = Describe("Framer", func() {
})
It("appends to a frame slice", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
f := &wire.StreamFrame{
StreamID: id1,
Data: []byte("foobar"),
DataLenPresent: true,
}
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id1, stream1)
f0 := ackhandler.StreamFrame{Frame: &wire.StreamFrame{StreamID: 9999}}
frames := []ackhandler.StreamFrame{f0}
fs, length := framer.AppendStreamFrames(frames, 1000, protocol.Version1)
@@ -237,24 +232,21 @@ var _ = Describe("Framer", func() {
})
It("skips a stream that was reported active, but was completed shortly after", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(nil, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f := &wire.StreamFrame{
StreamID: id2,
Data: []byte("foobar"),
DataLenPresent: true,
}
stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
framer.AddActiveStream(id1, stream1)
framer.AddActiveStream(id2, stream2)
framer.RemoveActiveStream(id1)
frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
Expect(frames).To(HaveLen(1))
Expect(frames[0].Frame).To(Equal(f))
})
It("skips a stream that was reported active, but doesn't have any data", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f := &wire.StreamFrame{
StreamID: id2,
Data: []byte("foobar"),
@@ -262,20 +254,19 @@ var _ = Describe("Framer", func() {
}
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{}, false, false)
stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
framer.AddActiveStream(id1, stream1)
framer.AddActiveStream(id2, stream2)
frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
Expect(frames).To(HaveLen(1))
Expect(frames[0].Frame).To(Equal(f))
})
It("pops from a stream multiple times, if it has enough data", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, true, true)
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, false)
framer.AddActiveStream(id1) // only add it once
framer.AddActiveStream(id1, stream1) // only add it once
frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize, protocol.Version1)
Expect(frames).To(HaveLen(1))
Expect(frames[0].Frame).To(Equal(f1))
@@ -288,16 +279,14 @@ var _ = Describe("Framer", func() {
})
It("re-queues a stream at the end, if it has enough data", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f11}, true, true)
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f12}, true, false)
stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, false)
framer.AddActiveStream(id1) // only add it once
framer.AddActiveStream(id2)
framer.AddActiveStream(id1, stream1) // only add it once
framer.AddActiveStream(id2, stream2)
// first a frame from stream 1
frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize, protocol.Version1)
Expect(frames).To(HaveLen(1))
@@ -313,15 +302,13 @@ var _ = Describe("Framer", func() {
})
It("only dequeues data from each stream once per packet", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
// both streams have more data, and will be re-queued
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, true, true)
stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, true)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
framer.AddActiveStream(id1, stream1)
framer.AddActiveStream(id2, stream2)
frames, length := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
Expect(frames).To(HaveLen(2))
Expect(frames[0].Frame).To(Equal(f1))
@@ -330,14 +317,12 @@ var _ = Describe("Framer", func() {
})
It("returns multiple normal frames in the order they were reported active", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f1 := &wire.StreamFrame{Data: []byte("foobar")}
f2 := &wire.StreamFrame{Data: []byte("foobaz")}
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, true, false)
stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, false)
framer.AddActiveStream(id2)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2, stream2)
framer.AddActiveStream(id1, stream1)
frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
Expect(frames).To(HaveLen(2))
Expect(frames[0].Frame).To(Equal(f2))
@@ -345,11 +330,10 @@ var _ = Describe("Framer", func() {
})
It("only asks a stream for data once, even if it was reported active multiple times", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
f := &wire.StreamFrame{Data: []byte("foobar")}
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false) // only one call to this function
framer.AddActiveStream(id1)
framer.AddActiveStream(id1)
framer.AddActiveStream(id1, stream1)
framer.AddActiveStream(id1, stream1)
frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
Expect(frames).To(HaveLen(1))
})
@@ -362,7 +346,6 @@ var _ = Describe("Framer", func() {
It("pops maximum size STREAM frames", func() {
for i := protocol.MinStreamFrameSize; i < 2000; i++ {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool) {
f := &wire.StreamFrame{
StreamID: id1,
@@ -372,7 +355,7 @@ var _ = Describe("Framer", func() {
Expect(f.Length(version)).To(Equal(size))
return ackhandler.StreamFrame{Frame: f}, true, false
})
framer.AddActiveStream(id1)
framer.AddActiveStream(id1, stream1)
frames, _ := framer.AppendStreamFrames(nil, i, protocol.Version1)
Expect(frames).To(HaveLen(1))
f := frames[0].Frame
@@ -383,8 +366,6 @@ var _ = Describe("Framer", func() {
It("pops multiple STREAM frames", func() {
for i := 2 * protocol.MinStreamFrameSize; i < 2000; i++ {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool) {
f := &wire.StreamFrame{
StreamID: id2,
@@ -402,8 +383,8 @@ var _ = Describe("Framer", func() {
Expect(f.Length(version)).To(Equal(size))
return ackhandler.StreamFrame{Frame: f}, true, false
})
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
framer.AddActiveStream(id1, stream1)
framer.AddActiveStream(id2, stream2)
frames, _ := framer.AppendStreamFrames(nil, i, protocol.Version1)
Expect(frames).To(HaveLen(2))
f1 := frames[0].Frame
@@ -415,10 +396,9 @@ var _ = Describe("Framer", func() {
})
It("pops frames that when asked for the the minimum STREAM frame size", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
f := &wire.StreamFrame{Data: []byte("foobar")}
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id1, stream1)
framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize, protocol.Version1)
})
@@ -428,7 +408,6 @@ var _ = Describe("Framer", func() {
})
It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
// pop a frame such that the remaining size is one byte less than the minimum STREAM frame size
f := &wire.StreamFrame{
StreamID: id1,
@@ -436,7 +415,7 @@ var _ = Describe("Framer", func() {
DataLenPresent: true,
}
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id1, stream1)
fs, length := framer.AppendStreamFrames(nil, 500, protocol.Version1)
Expect(fs).To(HaveLen(1))
Expect(fs[0].Frame).To(Equal(f))
@@ -444,7 +423,7 @@ var _ = Describe("Framer", func() {
})
It("drops all STREAM frames when 0-RTT is rejected", func() {
framer.AddActiveStream(id1)
framer.AddActiveStream(id1, stream1)
Expect(framer.Handle0RTTRejection()).To(Succeed())
fs, length := framer.AppendStreamFrames(nil, protocol.MaxByteCount, protocol.Version1)
Expect(fs).To(BeEmpty())

View File

@@ -12,8 +12,6 @@ import (
type connectionFlowController struct {
baseFlowController
queueWindowUpdate func()
}
var _ ConnectionFlowController = &connectionFlowController{}
@@ -23,7 +21,6 @@ var _ ConnectionFlowController = &connectionFlowController{}
func NewConnectionFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
queueWindowUpdate func(),
allowWindowIncrease func(size protocol.ByteCount) bool,
rttStats *utils.RTTStats,
logger utils.Logger,
@@ -37,7 +34,6 @@ func NewConnectionFlowController(
allowWindowIncrease: allowWindowIncrease,
logger: logger,
},
queueWindowUpdate: queueWindowUpdate,
}
}
@@ -63,18 +59,14 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) {
c.mutex.Lock()
c.baseFlowController.addBytesRead(n)
shouldQueueWindowUpdate := c.hasWindowUpdate()
c.mutex.Unlock()
if shouldQueueWindowUpdate {
c.queueWindowUpdate()
}
}
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
c.mutex.Lock()
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if oldWindowSize < c.receiveWindowSize {
if c.logger.Debug() && oldWindowSize < c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
}
c.mutex.Unlock()

View File

@@ -11,10 +11,7 @@ import (
)
var _ = Describe("Connection Flow controller", func() {
var (
controller *connectionFlowController
queuedWindowUpdate bool
)
var controller *connectionFlowController
// update the congestion such that it returns a given value for the smoothed RTT
setRtt := func(t time.Duration) {
@@ -23,11 +20,9 @@ var _ = Describe("Connection Flow controller", func() {
}
BeforeEach(func() {
queuedWindowUpdate = false
controller = &connectionFlowController{}
controller.rttStats = &utils.RTTStats{}
controller.logger = utils.DefaultLogger
controller.queueWindowUpdate = func() { queuedWindowUpdate = true }
controller.allowWindowIncrease = func(protocol.ByteCount) bool { return true }
})
@@ -41,7 +36,6 @@ var _ = Describe("Connection Flow controller", func() {
fc := NewConnectionFlowController(
receiveWindow,
maxReceiveWindow,
nil,
func(protocol.ByteCount) bool { return true },
rttStats,
utils.DefaultLogger).(*connectionFlowController)
@@ -67,13 +61,11 @@ var _ = Describe("Connection Flow controller", func() {
It("queues window updates", func() {
controller.AddBytesRead(1)
Expect(queuedWindowUpdate).To(BeFalse())
Expect(controller.GetWindowUpdate()).To(BeZero())
controller.AddBytesRead(29)
Expect(queuedWindowUpdate).To(BeTrue())
Expect(controller.GetWindowUpdate()).ToNot(BeZero())
queuedWindowUpdate = false
controller.AddBytesRead(1)
Expect(queuedWindowUpdate).To(BeFalse())
Expect(controller.GetWindowUpdate()).To(BeZero())
})
It("gets a window update", func() {

View File

@@ -8,7 +8,6 @@ type flowController interface {
UpdateSendWindow(protocol.ByteCount) (updated bool)
AddBytesSent(protocol.ByteCount)
// for receiving
AddBytesRead(protocol.ByteCount)
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
IsNewlyBlocked() (bool, protocol.ByteCount)
}
@@ -16,6 +15,7 @@ type flowController interface {
// A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface {
flowController
AddBytesRead(protocol.ByteCount) (shouldQueueWindowUpdate bool)
// UpdateHighestReceived is called when a new highest offset is received
// final has to be to true if this is the final offset of the stream,
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
@@ -28,6 +28,7 @@ type StreamFlowController interface {
// The ConnectionFlowController is the flow controller for the connection.
type ConnectionFlowController interface {
flowController
AddBytesRead(protocol.ByteCount)
Reset() error
}

View File

@@ -13,8 +13,6 @@ type streamFlowController struct {
streamID protocol.StreamID
queueWindowUpdate func()
connection connectionFlowControllerI
receivedFinalOffset bool
@@ -29,14 +27,12 @@ func NewStreamFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount,
queueWindowUpdate func(protocol.StreamID),
rttStats *utils.RTTStats,
logger utils.Logger,
) StreamFlowController {
return &streamFlowController{
streamID: streamID,
connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
streamID: streamID,
connection: cfc.(connectionFlowControllerI),
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
@@ -97,15 +93,13 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount,
return c.connection.IncrementHighestReceived(increment)
}
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (shouldQueueWindowUpdate bool) {
c.mutex.Lock()
c.baseFlowController.addBytesRead(n)
shouldQueueWindowUpdate := c.shouldQueueWindowUpdate()
shouldQueueWindowUpdate = c.shouldQueueWindowUpdate()
c.mutex.Unlock()
if shouldQueueWindowUpdate {
c.queueWindowUpdate()
}
c.connection.AddBytesRead(n)
return
}
func (c *streamFlowController) Abandon() {

View File

@@ -12,20 +12,15 @@ import (
)
var _ = Describe("Stream Flow controller", func() {
var (
controller *streamFlowController
queuedWindowUpdate bool
)
var controller *streamFlowController
BeforeEach(func() {
queuedWindowUpdate = false
rttStats := &utils.RTTStats{}
controller = &streamFlowController{
streamID: 10,
connection: NewConnectionFlowController(
1000,
1000,
func() {},
func(protocol.ByteCount) bool { return true },
rttStats,
utils.DefaultLogger,
@@ -34,7 +29,6 @@ var _ = Describe("Stream Flow controller", func() {
controller.maxReceiveWindowSize = 10000
controller.rttStats = rttStats
controller.logger = utils.DefaultLogger
controller.queueWindowUpdate = func() { queuedWindowUpdate = true }
})
Context("Constructor", func() {
@@ -44,25 +38,18 @@ var _ = Describe("Stream Flow controller", func() {
const sendWindow protocol.ByteCount = 4000
It("sets the send and receive windows", func() {
cc := NewConnectionFlowController(0, 0, nil, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger)
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController)
cc := NewConnectionFlowController(0, 0, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger)
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats, utils.DefaultLogger).(*streamFlowController)
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
Expect(fc.receiveWindow).To(Equal(receiveWindow))
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
Expect(fc.sendWindow).To(Equal(sendWindow))
})
It("queues window updates with the correct stream ID", func() {
var queued bool
queueWindowUpdate := func(id protocol.StreamID) {
Expect(id).To(Equal(protocol.StreamID(5)))
queued = true
}
cc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, func() {}, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger)
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController)
fc.AddBytesRead(receiveWindow)
Expect(queued).To(BeTrue())
It("queues window updates", func() {
cc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger)
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats, utils.DefaultLogger).(*streamFlowController)
Expect(fc.AddBytesRead(receiveWindow)).To(BeTrue())
})
})
@@ -195,14 +182,10 @@ var _ = Describe("Stream Flow controller", func() {
})
It("queues window updates", func() {
controller.AddBytesRead(1)
Expect(queuedWindowUpdate).To(BeFalse())
controller.AddBytesRead(29)
Expect(queuedWindowUpdate).To(BeTrue())
Expect(controller.AddBytesRead(1)).To(BeFalse())
Expect(controller.AddBytesRead(29)).To(BeTrue())
Expect(controller.GetWindowUpdate()).ToNot(BeZero())
queuedWindowUpdate = false
controller.AddBytesRead(1)
Expect(queuedWindowUpdate).To(BeFalse())
Expect(controller.AddBytesRead(1)).To(BeFalse())
})
It("tells the connection flow controller when the window was auto-tuned", func() {
@@ -246,10 +229,8 @@ var _ = Describe("Stream Flow controller", func() {
It("doesn't increase the window after a final offset was already received", func() {
Expect(controller.UpdateHighestReceived(90, true)).To(Succeed())
controller.AddBytesRead(30)
Expect(queuedWindowUpdate).To(BeFalse())
offset := controller.GetWindowUpdate()
Expect(offset).To(BeZero())
Expect(controller.AddBytesRead(30)).To(BeFalse())
Expect(controller.GetWindowUpdate()).To(BeZero())
})
})
})

View File

@@ -76,9 +76,11 @@ func (c *MockStreamFlowControllerAbandonCall) DoAndReturn(f func()) *MockStreamF
}
// AddBytesRead mocks base method.
func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) {
func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) bool {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddBytesRead", arg0)
ret := m.ctrl.Call(m, "AddBytesRead", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// AddBytesRead indicates an expected call of AddBytesRead.
@@ -94,19 +96,19 @@ type MockStreamFlowControllerAddBytesReadCall struct {
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamFlowControllerAddBytesReadCall) Return() *MockStreamFlowControllerAddBytesReadCall {
c.Call = c.Call.Return()
func (c *MockStreamFlowControllerAddBytesReadCall) Return(arg0 bool) *MockStreamFlowControllerAddBytesReadCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamFlowControllerAddBytesReadCall) Do(f func(protocol.ByteCount)) *MockStreamFlowControllerAddBytesReadCall {
func (c *MockStreamFlowControllerAddBytesReadCall) Do(f func(protocol.ByteCount) bool) *MockStreamFlowControllerAddBytesReadCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamFlowControllerAddBytesReadCall) DoAndReturn(f func(protocol.ByteCount)) *MockStreamFlowControllerAddBytesReadCall {
func (c *MockStreamFlowControllerAddBytesReadCall) DoAndReturn(f func(protocol.ByteCount) bool) *MockStreamFlowControllerAddBytesReadCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

View File

@@ -1,118 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/quic-go/quic-go (interfaces: StreamGetter)
//
// Generated by this command:
//
// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter
//
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
protocol "github.com/quic-go/quic-go/internal/protocol"
gomock "go.uber.org/mock/gomock"
)
// MockStreamGetter is a mock of StreamGetter interface.
type MockStreamGetter struct {
ctrl *gomock.Controller
recorder *MockStreamGetterMockRecorder
}
// MockStreamGetterMockRecorder is the mock recorder for MockStreamGetter.
type MockStreamGetterMockRecorder struct {
mock *MockStreamGetter
}
// NewMockStreamGetter creates a new mock instance.
func NewMockStreamGetter(ctrl *gomock.Controller) *MockStreamGetter {
mock := &MockStreamGetter{ctrl: ctrl}
mock.recorder = &MockStreamGetterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStreamGetter) EXPECT() *MockStreamGetterMockRecorder {
return m.recorder
}
// GetOrOpenReceiveStream mocks base method.
func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0)
ret0, _ := ret[0].(receiveStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream.
func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 any) *MockStreamGetterGetOrOpenReceiveStreamCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0)
return &MockStreamGetterGetOrOpenReceiveStreamCall{Call: call}
}
// MockStreamGetterGetOrOpenReceiveStreamCall wrap *gomock.Call
type MockStreamGetterGetOrOpenReceiveStreamCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamGetterGetOrOpenReceiveStreamCall) Return(arg0 receiveStreamI, arg1 error) *MockStreamGetterGetOrOpenReceiveStreamCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamGetterGetOrOpenReceiveStreamCall) Do(f func(protocol.StreamID) (receiveStreamI, error)) *MockStreamGetterGetOrOpenReceiveStreamCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamGetterGetOrOpenReceiveStreamCall) DoAndReturn(f func(protocol.StreamID) (receiveStreamI, error)) *MockStreamGetterGetOrOpenReceiveStreamCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// GetOrOpenSendStream mocks base method.
func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
ret0, _ := ret[0].(sendStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream.
func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 any) *MockStreamGetterGetOrOpenSendStreamCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0)
return &MockStreamGetterGetOrOpenSendStreamCall{Call: call}
}
// MockStreamGetterGetOrOpenSendStreamCall wrap *gomock.Call
type MockStreamGetterGetOrOpenSendStreamCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamGetterGetOrOpenSendStreamCall) Return(arg0 sendStreamI, arg1 error) *MockStreamGetterGetOrOpenSendStreamCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamGetterGetOrOpenSendStreamCall) Do(f func(protocol.StreamID) (sendStreamI, error)) *MockStreamGetterGetOrOpenSendStreamCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamGetterGetOrOpenSendStreamCall) DoAndReturn(f func(protocol.StreamID) (sendStreamI, error)) *MockStreamGetterGetOrOpenSendStreamCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

View File

@@ -41,15 +41,15 @@ func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder {
}
// onHasStreamData mocks base method.
func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID) {
func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID, arg1 sendStreamI) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "onHasStreamData", arg0)
m.ctrl.Call(m, "onHasStreamData", arg0, arg1)
}
// onHasStreamData indicates an expected call of onHasStreamData.
func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 any) *MockStreamSenderonHasStreamDataCall {
func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0, arg1 any) *MockStreamSenderonHasStreamDataCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0, arg1)
return &MockStreamSenderonHasStreamDataCall{Call: call}
}
@@ -65,13 +65,49 @@ func (c *MockStreamSenderonHasStreamDataCall) Return() *MockStreamSenderonHasStr
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamSenderonHasStreamDataCall) Do(f func(protocol.StreamID)) *MockStreamSenderonHasStreamDataCall {
func (c *MockStreamSenderonHasStreamDataCall) Do(f func(protocol.StreamID, sendStreamI)) *MockStreamSenderonHasStreamDataCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamSenderonHasStreamDataCall) DoAndReturn(f func(protocol.StreamID)) *MockStreamSenderonHasStreamDataCall {
func (c *MockStreamSenderonHasStreamDataCall) DoAndReturn(f func(protocol.StreamID, sendStreamI)) *MockStreamSenderonHasStreamDataCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// onHasStreamWindowUpdate mocks base method.
func (m *MockStreamSender) onHasStreamWindowUpdate(arg0 protocol.StreamID, arg1 receiveStreamI) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "onHasStreamWindowUpdate", arg0, arg1)
}
// onHasStreamWindowUpdate indicates an expected call of onHasStreamWindowUpdate.
func (mr *MockStreamSenderMockRecorder) onHasStreamWindowUpdate(arg0, arg1 any) *MockStreamSenderonHasStreamWindowUpdateCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamWindowUpdate), arg0, arg1)
return &MockStreamSenderonHasStreamWindowUpdateCall{Call: call}
}
// MockStreamSenderonHasStreamWindowUpdateCall wrap *gomock.Call
type MockStreamSenderonHasStreamWindowUpdateCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockStreamSenderonHasStreamWindowUpdateCall) Return() *MockStreamSenderonHasStreamWindowUpdateCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockStreamSenderonHasStreamWindowUpdateCall) Do(f func(protocol.StreamID, receiveStreamI)) *MockStreamSenderonHasStreamWindowUpdateCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockStreamSenderonHasStreamWindowUpdateCall) DoAndReturn(f func(protocol.StreamID, receiveStreamI)) *MockStreamSenderonHasStreamWindowUpdateCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

View File

@@ -23,9 +23,6 @@ type ReceiveStreamI = receiveStreamI
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI"
type SendStreamI = sendStreamI
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter"
type StreamGetter = streamGetter
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender"
type StreamSender = streamSender

View File

@@ -85,7 +85,7 @@ var _ = Describe("Packet packer", func() {
rand.Seed(uint64(GinkgoRandomSeed()))
retransmissionQueue = newRetransmissionQueue()
mockSender := NewMockStreamSender(mockCtrl)
mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes()
mockSender.EXPECT().onHasStreamData(gomock.Any(), gomock.Any()).AnyTimes()
initialStream = NewMockCryptoStream(mockCtrl)
handshakeStream = NewMockCryptoStream(mockCtrl)
framer = NewMockFrameSource(mockCtrl)

View File

@@ -197,7 +197,9 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
// when a RESET_STREAM was received, the flow controller was already
// informed about the final byteOffset for this stream
if !s.cancelledRemotely {
s.flowController.AddBytesRead(protocol.ByteCount(m))
if queueWindowUpdate := s.flowController.AddBytesRead(protocol.ByteCount(m)); queueWindowUpdate {
s.sender.onHasStreamWindowUpdate(s.streamID, s)
}
}
if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {

View File

@@ -414,7 +414,10 @@ var _ = Describe("Receive Stream", func() {
It("handles concurrent reads", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any()).AnyTimes()
var bytesRead protocol.ByteCount
mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) { bytesRead += n }).AnyTimes()
mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) bool {
bytesRead += n
return false
}).AnyTimes()
var numCompleted int32
mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) {

View File

@@ -172,7 +172,7 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
s.mutex.Unlock()
if !notifiedSender {
s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
s.sender.onHasStreamData(s.streamID, s) // must be called without holding the mutex
notifiedSender = true
}
if copied {
@@ -407,7 +407,7 @@ func (s *sendStream) Close() error {
if cancelWriteErr != nil {
return fmt.Errorf("close called for canceled stream %d", s.streamID)
}
s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex
s.ctxCancel(nil)
return nil
@@ -453,7 +453,7 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
s.mutex.Unlock()
if hasStreamData {
s.sender.onHasStreamData(s.streamID)
s.sender.onHasStreamData(s.streamID, s)
}
}
@@ -530,5 +530,5 @@ func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
}
s.mutex.Unlock()
s.sender.onHasStreamData(s.streamID)
s.sender.onHasStreamData(s.streamID, (*sendStream)(s))
}

View File

@@ -80,7 +80,7 @@ var _ = Describe("Send Stream", func() {
go func() {
defer GinkgoRecover()
defer close(done)
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
@@ -104,7 +104,7 @@ var _ = Describe("Send Stream", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
@@ -136,7 +136,7 @@ var _ = Describe("Send Stream", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
mockSender.EXPECT().onHasStreamData(streamID).Times(2)
mockSender.EXPECT().onHasStreamData(streamID, str).Times(2)
n, err := strWithTimeout.Write([]byte("foo"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
@@ -163,7 +163,7 @@ var _ = Describe("Send Stream", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
n, err := strWithTimeout.Write(getData(5000))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(5000))
@@ -188,7 +188,7 @@ var _ = Describe("Send Stream", func() {
go func() {
defer GinkgoRecover()
defer close(done)
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
_, err := strWithTimeout.Write(getData(protocol.MaxPacketBufferSize + 3))
Expect(err).ToNot(HaveOccurred())
}()
@@ -211,14 +211,14 @@ var _ = Describe("Send Stream", func() {
})
It("only unblocks Write once a previously buffered STREAM frame has been fully dequeued", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
_, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
_, err := str.Write(getData(protocol.MaxPacketBufferSize))
Expect(err).ToNot(HaveOccurred())
}()
@@ -251,7 +251,7 @@ var _ = Describe("Send Stream", func() {
go func() {
defer GinkgoRecover()
defer close(done)
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(100))
@@ -282,7 +282,7 @@ var _ = Describe("Send Stream", func() {
go func() {
defer GinkgoRecover()
defer close(done)
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
n, err := strWithTimeout.Write(s)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
@@ -315,7 +315,7 @@ var _ = Describe("Send Stream", func() {
})
It("cancels the context when Close is called", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
Expect(str.Context().Done()).ToNot(BeClosed())
Expect(str.Close()).To(Succeed())
Expect(str.Context().Done()).To(BeClosed())
@@ -334,7 +334,7 @@ var _ = Describe("Send Stream", func() {
go func() {
defer GinkgoRecover()
defer close(done)
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
}()
@@ -352,7 +352,7 @@ var _ = Describe("Send Stream", func() {
go func() {
defer GinkgoRecover()
defer close(done)
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
}()
@@ -392,7 +392,7 @@ var _ = Describe("Send Stream", func() {
})
It("unblocks after the deadline", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
str.SetWriteDeadline(deadline)
n, err := strWithTimeout.Write(getData(5000))
@@ -402,7 +402,7 @@ var _ = Describe("Send Stream", func() {
})
It("unblocks when the deadline is changed to the past", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
str.SetWriteDeadline(time.Now().Add(time.Hour))
done := make(chan struct{})
go func() {
@@ -426,7 +426,7 @@ var _ = Describe("Send Stream", func() {
go func() {
defer GinkgoRecover()
defer close(writeReturned)
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
var err error
n, err = strWithTimeout.Write(getData(5000))
Expect(err).To(MatchError(errDeadline))
@@ -450,7 +450,7 @@ var _ = Describe("Send Stream", func() {
go func() {
defer GinkgoRecover()
defer close(writeReturned)
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
_, err := strWithTimeout.Write(getData(5000))
Expect(err).To(MatchError(errDeadline))
}()
@@ -466,7 +466,7 @@ var _ = Describe("Send Stream", func() {
})
It("doesn't unblock if the deadline is changed before the first one expires", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond))
deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond))
str.SetWriteDeadline(deadline1)
@@ -488,7 +488,7 @@ var _ = Describe("Send Stream", func() {
})
It("unblocks earlier, when a new deadline is set", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond))
deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond))
done := make(chan struct{})
@@ -509,7 +509,7 @@ var _ = Describe("Send Stream", func() {
})
It("doesn't unblock if the deadline is removed", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
str.SetWriteDeadline(deadline)
deadlineUnset := make(chan struct{})
@@ -539,14 +539,14 @@ var _ = Describe("Send Stream", func() {
Context("closing", func() {
It("doesn't allow writes after it has been closed", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
str.Close()
_, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError("write on closed stream 1337"))
})
It("allows FIN", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
str.Close()
frame, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1)
Expect(ok).To(BeTrue())
@@ -560,7 +560,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't send a FIN when there's still data", func() {
const frameHeaderLen protocol.ByteCount = 4
mockSender.EXPECT().onHasStreamData(streamID).Times(2)
mockSender.EXPECT().onHasStreamData(streamID, str).Times(2)
_, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
@@ -584,10 +584,10 @@ var _ = Describe("Send Stream", func() {
go func() {
defer GinkgoRecover()
defer close(done)
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
_, err := strWithTimeout.Write(getData(5000))
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
Expect(str.Close()).To(Succeed())
}()
waitForWrite()
@@ -619,7 +619,7 @@ var _ = Describe("Send Stream", func() {
})
It("doesn't allow FIN twice", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
str.Close()
frame, ok, _ := str.popStreamFrame(1000, protocol.Version1)
Expect(ok).To(BeTrue())
@@ -646,7 +646,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't get data for writing if an error occurred", func() {
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
mockFC.EXPECT().AddBytesSent(gomock.Any())
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@@ -676,7 +676,7 @@ var _ = Describe("Send Stream", func() {
It("says when it has data for sending", func() {
mockFC.EXPECT().UpdateSendWindow(gomock.Any()).Return(true)
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@@ -685,7 +685,7 @@ var _ = Describe("Send Stream", func() {
close(done)
}()
waitForWrite()
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
str.updateSendWindow(42)
// make sure the Write go routine returns
str.closeForShutdown(nil)
@@ -694,7 +694,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't say it has data for sending if the MAX_STREAM_DATA frame was reordered", func() {
mockFC.EXPECT().UpdateSendWindow(gomock.Any()).Return(false) // reordered frame
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@@ -731,7 +731,7 @@ var _ = Describe("Send Stream", func() {
// for reliable results it has to be run many times.
It("returns a nil error when the whole slice has been sent out", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any()).MaxTimes(1)
mockSender.EXPECT().onHasStreamData(streamID).MaxTimes(1)
mockSender.EXPECT().onHasStreamData(streamID, str).MaxTimes(1)
mockSender.EXPECT().onStreamCompleted(streamID).MaxTimes(1)
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).MaxTimes(1)
mockFC.EXPECT().AddBytesSent(gomock.Any()).MaxTimes(1)
@@ -754,7 +754,7 @@ var _ = Describe("Send Stream", func() {
It("unblocks Write", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
mockFC.EXPECT().AddBytesSent(gomock.Any())
writeReturned := make(chan struct{})
@@ -782,7 +782,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't pop STREAM frames after being canceled", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
mockFC.EXPECT().AddBytesSent(gomock.Any())
writeReturned := make(chan struct{})
@@ -806,7 +806,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't pop STREAM frames after being canceled, for large writes", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
mockFC.EXPECT().AddBytesSent(gomock.Any())
writeReturned := make(chan struct{})
@@ -835,7 +835,7 @@ var _ = Describe("Send Stream", func() {
It("ignores acknowledgements for STREAM frames after it was cancelled", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
mockFC.EXPECT().AddBytesSent(gomock.Any())
writeReturned := make(chan struct{})
@@ -884,7 +884,7 @@ var _ = Describe("Send Stream", func() {
})
It("queues a RESET_STREAM frame, even if the stream was already closed", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f).To(BeAssignableToTypeOf(&wire.ResetStreamFrame{}))
})
@@ -910,7 +910,7 @@ var _ = Describe("Send Stream", func() {
})
It("unblocks Write", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
mockSender.EXPECT().queueControlFrame(gomock.Any())
done := make(chan struct{})
go func() {
@@ -958,7 +958,7 @@ var _ = Describe("Send Stream", func() {
})
It("handles STOP_SENDING after sending the FIN", func() {
mockSender.EXPECT().onHasStreamData(gomock.Any())
mockSender.EXPECT().onHasStreamData(gomock.Any(), gomock.Any())
str.Close()
_, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
Expect(ok).To(BeTrue())
@@ -973,7 +973,7 @@ var _ = Describe("Send Stream", func() {
})
It("handles STOP_SENDING after Close, but before sending the FIN", func() {
mockSender.EXPECT().onHasStreamData(gomock.Any())
mockSender.EXPECT().onHasStreamData(gomock.Any(), gomock.Any())
str.Close()
gomock.InOrder(
mockSender.EXPECT().queueControlFrame(gomock.Any()),
@@ -995,7 +995,7 @@ var _ = Describe("Send Stream", func() {
Offset: 0x42,
DataLenPresent: false,
}
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
(*sendStreamAckHandler)(str).OnLost(f)
frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
Expect(ok).To(BeTrue())
@@ -1013,7 +1013,7 @@ var _ = Describe("Send Stream", func() {
Offset: 0x42,
DataLenPresent: false,
}
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
(*sendStreamAckHandler)(str).OnLost(sf)
frame, ok, hasMoreData := str.popStreamFrame(sf.Length(protocol.Version1)-3, protocol.Version1)
Expect(ok).To(BeTrue())
@@ -1039,7 +1039,7 @@ var _ = Describe("Send Stream", func() {
Offset: 0x42,
DataLenPresent: false,
}
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
(*sendStreamAckHandler)(str).OnLost(f)
_, ok, hasMoreData := str.popStreamFrame(2, protocol.Version1)
Expect(ok).To(BeFalse())
@@ -1047,7 +1047,7 @@ var _ = Describe("Send Stream", func() {
})
It("queues lost STREAM frames", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
done := make(chan struct{})
@@ -1065,7 +1065,7 @@ var _ = Describe("Send Stream", func() {
Expect(frame.Frame.Data).To(Equal([]byte("foobar")))
// now lose the frame
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
frame.Handler.OnLost(frame.Frame)
newFrame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
Expect(ok).To(BeTrue())
@@ -1074,7 +1074,7 @@ var _ = Describe("Send Stream", func() {
})
It("doesn't queue retransmissions for a stream that was canceled", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
done := make(chan struct{})
@@ -1107,7 +1107,7 @@ var _ = Describe("Send Stream", func() {
})
It("says when a stream is completed", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@@ -1138,7 +1138,7 @@ var _ = Describe("Send Stream", func() {
}
// Now close the stream and acknowledge the FIN.
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
Expect(str.Close()).To(Succeed())
frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
Expect(ok).To(BeTrue())
@@ -1148,7 +1148,7 @@ var _ = Describe("Send Stream", func() {
})
It("says when a stream is completed, if Close() is called before popping the frame", func() {
mockSender.EXPECT().onHasStreamData(streamID).Times(2)
mockSender.EXPECT().onHasStreamData(streamID, str).Times(2)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@@ -1171,13 +1171,13 @@ var _ = Describe("Send Stream", func() {
})
It("doesn't say it's completed when there are frames waiting to be retransmitted", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := strWithTimeout.Write(getData(100))
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
Expect(str.Close()).To(Succeed())
close(done)
}()
@@ -1201,7 +1201,7 @@ var _ = Describe("Send Stream", func() {
for _, f := range frames[1:] {
f.Handler.OnAcked(f.Frame)
}
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onHasStreamData(streamID, str)
frames[0].Handler.OnLost(frames[0].Frame)
// get the retransmission and acknowledge it
@@ -1218,7 +1218,7 @@ var _ = Describe("Send Stream", func() {
// and has to be retransmitted.
It("retransmits data until everything has been acknowledged", func() {
const dataLen = 1 << 22 // 4 MB
mockSender.EXPECT().onHasStreamData(streamID).AnyTimes()
mockSender.EXPECT().onHasStreamData(streamID, str).AnyTimes()
mockFC.EXPECT().SendWindowSize().DoAndReturn(func() protocol.ByteCount {
return protocol.ByteCount(mrand.Intn(500)) + 50
}).AnyTimes()

View File

@@ -25,7 +25,8 @@ var errDeadline net.Error = &deadlineError{}
// The streamSender is notified by the stream about various events.
type streamSender interface {
queueControlFrame(wire.Frame)
onHasStreamData(protocol.StreamID)
onHasStreamData(protocol.StreamID, sendStreamI)
onHasStreamWindowUpdate(protocol.StreamID, receiveStreamI)
// must be called without holding the mutex that is acquired by closeForShutdown
onStreamCompleted(protocol.StreamID)
}
@@ -37,17 +38,11 @@ type uniStreamSender struct {
onStreamCompletedImpl func()
}
func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
s.streamSender.queueControlFrame(f)
}
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
s.streamSender.onHasStreamData(id)
}
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) {
s.onStreamCompletedImpl()
func (s *uniStreamSender) queueControlFrame(f wire.Frame) { s.streamSender.queueControlFrame(f) }
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID, str sendStreamI) {
s.streamSender.onHasStreamData(id, str)
}
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { s.onStreamCompletedImpl() }
var _ streamSender = &uniStreamSender{}

View File

@@ -11,53 +11,44 @@ import (
type windowUpdateQueue struct {
mutex sync.Mutex
queue map[protocol.StreamID]struct{} // used as a set
queuedConn bool // connection-level window update
queue map[protocol.StreamID]receiveStreamI
streamGetter streamGetter
connFlowController flowcontrol.ConnectionFlowController
callback func(wire.Frame)
}
func newWindowUpdateQueue(
streamGetter streamGetter,
connFC flowcontrol.ConnectionFlowController,
cb func(wire.Frame),
) *windowUpdateQueue {
return &windowUpdateQueue{
queue: make(map[protocol.StreamID]struct{}),
streamGetter: streamGetter,
queue: make(map[protocol.StreamID]receiveStreamI),
connFlowController: connFC,
callback: cb,
}
}
func (q *windowUpdateQueue) AddStream(id protocol.StreamID) {
func (q *windowUpdateQueue) AddStream(id protocol.StreamID, str receiveStreamI) {
q.mutex.Lock()
q.queue[id] = struct{}{}
q.queue[id] = str
q.mutex.Unlock()
}
func (q *windowUpdateQueue) AddConnection() {
func (q *windowUpdateQueue) RemoveStream(id protocol.StreamID) {
q.mutex.Lock()
q.queuedConn = true
delete(q.queue, id)
q.mutex.Unlock()
}
func (q *windowUpdateQueue) QueueAll() {
q.mutex.Lock()
// queue a connection-level window update
if q.queuedConn {
q.callback(&wire.MaxDataFrame{MaximumData: q.connFlowController.GetWindowUpdate()})
q.queuedConn = false
if offset := q.connFlowController.GetWindowUpdate(); offset > 0 {
q.callback(&wire.MaxDataFrame{MaximumData: offset})
}
// queue all stream-level window updates
for id := range q.queue {
for id, str := range q.queue {
delete(q.queue, id)
str, err := q.streamGetter.GetOrOpenReceiveStream(id)
if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update
continue
}
offset := str.getWindowUpdate()
if offset == 0 { // can happen if we received a final offset, right after queueing the window update
continue

View File

@@ -12,39 +12,36 @@ import (
var _ = Describe("Window Update Queue", func() {
var (
q *windowUpdateQueue
streamGetter *MockStreamGetter
connFC *mocks.MockConnectionFlowController
queuedFrames []wire.Frame
)
BeforeEach(func() {
streamGetter = NewMockStreamGetter(mockCtrl)
connFC = mocks.NewMockConnectionFlowController(mockCtrl)
queuedFrames = queuedFrames[:0]
q = newWindowUpdateQueue(streamGetter, connFC, func(f wire.Frame) {
q = newWindowUpdateQueue(connFC, func(f wire.Frame) {
queuedFrames = append(queuedFrames, f)
})
})
It("adds stream offsets and gets MAX_STREAM_DATA frames", func() {
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes()
stream1 := NewMockStreamI(mockCtrl)
stream1.EXPECT().getWindowUpdate().Return(protocol.ByteCount(10))
stream3 := NewMockStreamI(mockCtrl)
stream3.EXPECT().getWindowUpdate().Return(protocol.ByteCount(30))
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(stream3, nil)
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(1)).Return(stream1, nil)
q.AddStream(3)
q.AddStream(1)
q.AddStream(3, stream3)
q.AddStream(1, stream1)
q.QueueAll()
Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 1, MaximumStreamData: 10}))
Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 3, MaximumStreamData: 30}))
})
It("deletes the entry after getting the MAX_STREAM_DATA frame", func() {
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes()
stream10 := NewMockStreamI(mockCtrl)
stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(100))
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil)
q.AddStream(10)
q.AddStream(10, stream10)
q.QueueAll()
Expect(queuedFrames).To(HaveLen(1))
q.QueueAll()
@@ -52,36 +49,37 @@ var _ = Describe("Window Update Queue", func() {
})
It("doesn't queue a MAX_STREAM_DATA for a closed stream", func() {
q.AddStream(12)
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(12)).Return(nil, nil)
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes()
stream12 := NewMockStreamI(mockCtrl)
q.AddStream(12, stream12)
q.RemoveStream(12)
q.QueueAll()
Expect(queuedFrames).To(BeEmpty())
})
It("removes closed streams from the queue", func() {
q.AddStream(12)
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(12)).Return(nil, nil)
q.QueueAll()
Expect(queuedFrames).To(BeEmpty())
// don't EXPECT any further calls to GetOrOpenReceiveStream
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes()
stream12 := NewMockStreamI(mockCtrl)
q.AddStream(12, stream12)
q.RemoveStream(12)
q.QueueAll()
Expect(queuedFrames).To(BeEmpty())
})
It("doesn't queue a MAX_STREAM_DATA if the flow controller returns an offset of 0", func() {
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0))
stream5 := NewMockStreamI(mockCtrl)
stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0))
q.AddStream(5)
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(stream5, nil)
q.AddStream(5, stream5)
q.QueueAll()
Expect(queuedFrames).To(BeEmpty())
})
It("removes streams for which the flow controller returns an offset of 0 from the queue", func() {
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes()
stream5 := NewMockStreamI(mockCtrl)
stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0))
q.AddStream(5)
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(stream5, nil)
q.AddStream(5, stream5)
q.QueueAll()
Expect(queuedFrames).To(BeEmpty())
// don't EXPECT any further calls to GetOrOpenReveiveStream and to getWindowUpdate
@@ -91,22 +89,17 @@ var _ = Describe("Window Update Queue", func() {
It("queues MAX_DATA frames", func() {
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1337))
q.AddConnection()
q.QueueAll()
Expect(queuedFrames).To(Equal([]wire.Frame{
&wire.MaxDataFrame{MaximumData: 0x1337},
}))
Expect(queuedFrames).To(Equal([]wire.Frame{&wire.MaxDataFrame{MaximumData: 0x1337}}))
})
It("deduplicates", func() {
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0))
stream10 := NewMockStreamI(mockCtrl)
stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(200))
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil)
q.AddStream(10)
q.AddStream(10)
q.AddStream(10, stream10)
q.AddStream(10, stream10)
q.QueueAll()
Expect(queuedFrames).To(Equal([]wire.Frame{
&wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 200},
}))
Expect(queuedFrames).To(Equal([]wire.Frame{&wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 200}}))
})
})