forked from quic-go/quic-go
refactor frame packing to logic to not access the streams map (#4596)
* avoid accessing the streams map when packing stream data * avoid accessing the streams map when packing flow control frames * remove streamGetter interface
This commit is contained in:
@@ -28,11 +28,6 @@ type unpacker interface {
|
||||
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
|
||||
}
|
||||
|
||||
type streamGetter interface {
|
||||
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
|
||||
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
|
||||
}
|
||||
|
||||
type streamManager interface {
|
||||
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
|
||||
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
|
||||
@@ -464,7 +459,6 @@ func (s *connection) preSetup() {
|
||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||
protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
|
||||
protocol.ByteCount(s.config.MaxConnectionReceiveWindow),
|
||||
s.onHasConnectionWindowUpdate,
|
||||
func(size protocol.ByteCount) bool {
|
||||
if s.config.AllowConnectionWindowIncrease == nil {
|
||||
return true
|
||||
@@ -483,7 +477,7 @@ func (s *connection) preSetup() {
|
||||
uint64(s.config.MaxIncomingUniStreams),
|
||||
s.perspective,
|
||||
)
|
||||
s.framer = newFramer(s.streamsMap)
|
||||
s.framer = newFramer()
|
||||
s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets)
|
||||
s.closeChan = make(chan closeError, 1)
|
||||
s.sendingScheduled = make(chan struct{}, 1)
|
||||
@@ -493,7 +487,7 @@ func (s *connection) preSetup() {
|
||||
s.lastPacketReceivedTime = now
|
||||
s.creationTime = now
|
||||
|
||||
s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame)
|
||||
s.windowUpdateQueue = newWindowUpdateQueue(s.connFlowController, s.framer.QueueControlFrame)
|
||||
s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger)
|
||||
s.connState.Version = s.version
|
||||
}
|
||||
@@ -2214,7 +2208,6 @@ func (s *connection) newFlowController(id protocol.StreamID) flowcontrol.StreamF
|
||||
protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
||||
protocol.ByteCount(s.config.MaxStreamReceiveWindow),
|
||||
initialSendWindow,
|
||||
s.onHasStreamWindowUpdate,
|
||||
s.rttStats,
|
||||
s.logger,
|
||||
)
|
||||
@@ -2253,18 +2246,13 @@ func (s *connection) queueControlFrame(f wire.Frame) {
|
||||
s.scheduleSending()
|
||||
}
|
||||
|
||||
func (s *connection) onHasStreamWindowUpdate(id protocol.StreamID) {
|
||||
s.windowUpdateQueue.AddStream(id)
|
||||
func (s *connection) onHasStreamWindowUpdate(id protocol.StreamID, str receiveStreamI) {
|
||||
s.windowUpdateQueue.AddStream(id, str)
|
||||
s.scheduleSending()
|
||||
}
|
||||
|
||||
func (s *connection) onHasConnectionWindowUpdate() {
|
||||
s.windowUpdateQueue.AddConnection()
|
||||
s.scheduleSending()
|
||||
}
|
||||
|
||||
func (s *connection) onHasStreamData(id protocol.StreamID) {
|
||||
s.framer.AddActiveStream(id)
|
||||
func (s *connection) onHasStreamData(id protocol.StreamID, str sendStreamI) {
|
||||
s.framer.AddActiveStream(id, str)
|
||||
s.scheduleSending()
|
||||
}
|
||||
|
||||
@@ -2272,6 +2260,8 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) {
|
||||
if err := s.streamsMap.DeleteStream(id); err != nil {
|
||||
s.closeLocal(err)
|
||||
}
|
||||
s.framer.RemoveActiveStream(id)
|
||||
s.windowUpdateQueue.RemoveStream(id)
|
||||
}
|
||||
|
||||
func (s *connection) onMTUIncreased(mtu protocol.ByteCount) {
|
||||
|
||||
32
framer.go
32
framer.go
@@ -19,9 +19,7 @@ const (
|
||||
type framer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
streamGetter streamGetter
|
||||
|
||||
activeStreams map[protocol.StreamID]struct{}
|
||||
activeStreams map[protocol.StreamID]sendStreamI
|
||||
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
|
||||
|
||||
controlFrameMutex sync.Mutex
|
||||
@@ -30,11 +28,8 @@ type framer struct {
|
||||
queuedTooManyControlFrames bool
|
||||
}
|
||||
|
||||
func newFramer(streamGetter streamGetter) *framer {
|
||||
return &framer{
|
||||
streamGetter: streamGetter,
|
||||
activeStreams: make(map[protocol.StreamID]struct{}),
|
||||
}
|
||||
func newFramer() *framer {
|
||||
return &framer{activeStreams: make(map[protocol.StreamID]sendStreamI)}
|
||||
}
|
||||
|
||||
func (f *framer) HasData() bool {
|
||||
@@ -109,15 +104,25 @@ func (f *framer) QueuedTooManyControlFrames() bool {
|
||||
return f.queuedTooManyControlFrames
|
||||
}
|
||||
|
||||
func (f *framer) AddActiveStream(id protocol.StreamID) {
|
||||
func (f *framer) AddActiveStream(id protocol.StreamID, str sendStreamI) {
|
||||
f.mutex.Lock()
|
||||
if _, ok := f.activeStreams[id]; !ok {
|
||||
f.streamQueue.PushBack(id)
|
||||
f.activeStreams[id] = struct{}{}
|
||||
f.activeStreams[id] = str
|
||||
}
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// RemoveActiveStream is called when a stream completes.
|
||||
func (f *framer) RemoveActiveStream(id protocol.StreamID) {
|
||||
f.mutex.Lock()
|
||||
delete(f.activeStreams, id)
|
||||
// We don't delete the stream from the streamQueue,
|
||||
// since we'd have to iterate over the ringbuffer.
|
||||
// Instead, we check if the stream is still in activeStreams in AppendStreamFrames.
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (f *framer) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) {
|
||||
startLen := len(frames)
|
||||
var length protocol.ByteCount
|
||||
@@ -131,10 +136,9 @@ func (f *framer) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen prot
|
||||
id := f.streamQueue.PopFront()
|
||||
// This should never return an error. Better check it anyway.
|
||||
// The stream will only be in the streamQueue, if it enqueued itself there.
|
||||
str, err := f.streamGetter.GetOrOpenSendStream(id)
|
||||
// The stream can be nil if it completed after it said it had data.
|
||||
if str == nil || err != nil {
|
||||
delete(f.activeStreams, id)
|
||||
str, ok := f.activeStreams[id]
|
||||
// The stream might have been removed after being enqueued.
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
remainingLen := maxLen - length
|
||||
|
||||
@@ -23,17 +23,15 @@ var _ = Describe("Framer", func() {
|
||||
var (
|
||||
framer *framer
|
||||
stream1, stream2 *MockSendStreamI
|
||||
streamGetter *MockStreamGetter
|
||||
version protocol.Version
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
streamGetter = NewMockStreamGetter(mockCtrl)
|
||||
stream1 = NewMockSendStreamI(mockCtrl)
|
||||
stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes()
|
||||
stream2 = NewMockSendStreamI(mockCtrl)
|
||||
stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes()
|
||||
framer = newFramer(streamGetter)
|
||||
framer = newFramer()
|
||||
})
|
||||
|
||||
Context("handling control frames", func() {
|
||||
@@ -183,7 +181,6 @@ var _ = Describe("Framer", func() {
|
||||
})
|
||||
|
||||
It("returns STREAM frames", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: id1,
|
||||
Data: []byte("foobar"),
|
||||
@@ -191,7 +188,7 @@ var _ = Describe("Framer", func() {
|
||||
DataLenPresent: true,
|
||||
}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
fs, length := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
Expect(fs[0].Frame.DataLenPresent).To(BeFalse())
|
||||
@@ -199,9 +196,8 @@ var _ = Describe("Framer", func() {
|
||||
})
|
||||
|
||||
It("says if it has data", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
|
||||
Expect(framer.HasData()).To(BeFalse())
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
Expect(framer.HasData()).To(BeTrue())
|
||||
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foo")}
|
||||
f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("bar")}
|
||||
@@ -218,14 +214,13 @@ var _ = Describe("Framer", func() {
|
||||
})
|
||||
|
||||
It("appends to a frame slice", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: id1,
|
||||
Data: []byte("foobar"),
|
||||
DataLenPresent: true,
|
||||
}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
f0 := ackhandler.StreamFrame{Frame: &wire.StreamFrame{StreamID: 9999}}
|
||||
frames := []ackhandler.StreamFrame{f0}
|
||||
fs, length := framer.AppendStreamFrames(frames, 1000, protocol.Version1)
|
||||
@@ -237,24 +232,21 @@ var _ = Describe("Framer", func() {
|
||||
})
|
||||
|
||||
It("skips a stream that was reported active, but was completed shortly after", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(nil, nil)
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: id2,
|
||||
Data: []byte("foobar"),
|
||||
DataLenPresent: true,
|
||||
}
|
||||
stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id2)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
framer.AddActiveStream(id2, stream2)
|
||||
framer.RemoveActiveStream(id1)
|
||||
frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
|
||||
Expect(frames).To(HaveLen(1))
|
||||
Expect(frames[0].Frame).To(Equal(f))
|
||||
})
|
||||
|
||||
It("skips a stream that was reported active, but doesn't have any data", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: id2,
|
||||
Data: []byte("foobar"),
|
||||
@@ -262,20 +254,19 @@ var _ = Describe("Framer", func() {
|
||||
}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{}, false, false)
|
||||
stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id2)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
framer.AddActiveStream(id2, stream2)
|
||||
frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
|
||||
Expect(frames).To(HaveLen(1))
|
||||
Expect(frames[0].Frame).To(Equal(f))
|
||||
})
|
||||
|
||||
It("pops from a stream multiple times, if it has enough data", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
|
||||
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
|
||||
f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, true, true)
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, false)
|
||||
framer.AddActiveStream(id1) // only add it once
|
||||
framer.AddActiveStream(id1, stream1) // only add it once
|
||||
frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize, protocol.Version1)
|
||||
Expect(frames).To(HaveLen(1))
|
||||
Expect(frames[0].Frame).To(Equal(f1))
|
||||
@@ -288,16 +279,14 @@ var _ = Describe("Framer", func() {
|
||||
})
|
||||
|
||||
It("re-queues a stream at the end, if it has enough data", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
|
||||
f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
|
||||
f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
|
||||
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f11}, true, true)
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f12}, true, false)
|
||||
stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, false)
|
||||
framer.AddActiveStream(id1) // only add it once
|
||||
framer.AddActiveStream(id2)
|
||||
framer.AddActiveStream(id1, stream1) // only add it once
|
||||
framer.AddActiveStream(id2, stream2)
|
||||
// first a frame from stream 1
|
||||
frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize, protocol.Version1)
|
||||
Expect(frames).To(HaveLen(1))
|
||||
@@ -313,15 +302,13 @@ var _ = Describe("Framer", func() {
|
||||
})
|
||||
|
||||
It("only dequeues data from each stream once per packet", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
|
||||
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
|
||||
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
|
||||
// both streams have more data, and will be re-queued
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, true, true)
|
||||
stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, true)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id2)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
framer.AddActiveStream(id2, stream2)
|
||||
frames, length := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
|
||||
Expect(frames).To(HaveLen(2))
|
||||
Expect(frames[0].Frame).To(Equal(f1))
|
||||
@@ -330,14 +317,12 @@ var _ = Describe("Framer", func() {
|
||||
})
|
||||
|
||||
It("returns multiple normal frames in the order they were reported active", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
|
||||
f1 := &wire.StreamFrame{Data: []byte("foobar")}
|
||||
f2 := &wire.StreamFrame{Data: []byte("foobaz")}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, true, false)
|
||||
stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, false)
|
||||
framer.AddActiveStream(id2)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id2, stream2)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
|
||||
Expect(frames).To(HaveLen(2))
|
||||
Expect(frames[0].Frame).To(Equal(f2))
|
||||
@@ -345,11 +330,10 @@ var _ = Describe("Framer", func() {
|
||||
})
|
||||
|
||||
It("only asks a stream for data once, even if it was reported active multiple times", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
f := &wire.StreamFrame{Data: []byte("foobar")}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false) // only one call to this function
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1)
|
||||
Expect(frames).To(HaveLen(1))
|
||||
})
|
||||
@@ -362,7 +346,6 @@ var _ = Describe("Framer", func() {
|
||||
|
||||
It("pops maximum size STREAM frames", func() {
|
||||
for i := protocol.MinStreamFrameSize; i < 2000; i++ {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool) {
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: id1,
|
||||
@@ -372,7 +355,7 @@ var _ = Describe("Framer", func() {
|
||||
Expect(f.Length(version)).To(Equal(size))
|
||||
return ackhandler.StreamFrame{Frame: f}, true, false
|
||||
})
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
frames, _ := framer.AppendStreamFrames(nil, i, protocol.Version1)
|
||||
Expect(frames).To(HaveLen(1))
|
||||
f := frames[0].Frame
|
||||
@@ -383,8 +366,6 @@ var _ = Describe("Framer", func() {
|
||||
|
||||
It("pops multiple STREAM frames", func() {
|
||||
for i := 2 * protocol.MinStreamFrameSize; i < 2000; i++ {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool) {
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: id2,
|
||||
@@ -402,8 +383,8 @@ var _ = Describe("Framer", func() {
|
||||
Expect(f.Length(version)).To(Equal(size))
|
||||
return ackhandler.StreamFrame{Frame: f}, true, false
|
||||
})
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id2)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
framer.AddActiveStream(id2, stream2)
|
||||
frames, _ := framer.AppendStreamFrames(nil, i, protocol.Version1)
|
||||
Expect(frames).To(HaveLen(2))
|
||||
f1 := frames[0].Frame
|
||||
@@ -415,10 +396,9 @@ var _ = Describe("Framer", func() {
|
||||
})
|
||||
|
||||
It("pops frames that when asked for the the minimum STREAM frame size", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
f := &wire.StreamFrame{Data: []byte("foobar")}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize, protocol.Version1)
|
||||
})
|
||||
|
||||
@@ -428,7 +408,6 @@ var _ = Describe("Framer", func() {
|
||||
})
|
||||
|
||||
It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
// pop a frame such that the remaining size is one byte less than the minimum STREAM frame size
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: id1,
|
||||
@@ -436,7 +415,7 @@ var _ = Describe("Framer", func() {
|
||||
DataLenPresent: true,
|
||||
}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
fs, length := framer.AppendStreamFrames(nil, 500, protocol.Version1)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
Expect(fs[0].Frame).To(Equal(f))
|
||||
@@ -444,7 +423,7 @@ var _ = Describe("Framer", func() {
|
||||
})
|
||||
|
||||
It("drops all STREAM frames when 0-RTT is rejected", func() {
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id1, stream1)
|
||||
Expect(framer.Handle0RTTRejection()).To(Succeed())
|
||||
fs, length := framer.AppendStreamFrames(nil, protocol.MaxByteCount, protocol.Version1)
|
||||
Expect(fs).To(BeEmpty())
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
|
||||
type connectionFlowController struct {
|
||||
baseFlowController
|
||||
|
||||
queueWindowUpdate func()
|
||||
}
|
||||
|
||||
var _ ConnectionFlowController = &connectionFlowController{}
|
||||
@@ -23,7 +21,6 @@ var _ ConnectionFlowController = &connectionFlowController{}
|
||||
func NewConnectionFlowController(
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
queueWindowUpdate func(),
|
||||
allowWindowIncrease func(size protocol.ByteCount) bool,
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
@@ -37,7 +34,6 @@ func NewConnectionFlowController(
|
||||
allowWindowIncrease: allowWindowIncrease,
|
||||
logger: logger,
|
||||
},
|
||||
queueWindowUpdate: queueWindowUpdate,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,18 +59,14 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
|
||||
func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
c.baseFlowController.addBytesRead(n)
|
||||
shouldQueueWindowUpdate := c.hasWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
if shouldQueueWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if oldWindowSize < c.receiveWindowSize {
|
||||
if c.logger.Debug() && oldWindowSize < c.receiveWindowSize {
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
|
||||
@@ -11,10 +11,7 @@ import (
|
||||
)
|
||||
|
||||
var _ = Describe("Connection Flow controller", func() {
|
||||
var (
|
||||
controller *connectionFlowController
|
||||
queuedWindowUpdate bool
|
||||
)
|
||||
var controller *connectionFlowController
|
||||
|
||||
// update the congestion such that it returns a given value for the smoothed RTT
|
||||
setRtt := func(t time.Duration) {
|
||||
@@ -23,11 +20,9 @@ var _ = Describe("Connection Flow controller", func() {
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
queuedWindowUpdate = false
|
||||
controller = &connectionFlowController{}
|
||||
controller.rttStats = &utils.RTTStats{}
|
||||
controller.logger = utils.DefaultLogger
|
||||
controller.queueWindowUpdate = func() { queuedWindowUpdate = true }
|
||||
controller.allowWindowIncrease = func(protocol.ByteCount) bool { return true }
|
||||
})
|
||||
|
||||
@@ -41,7 +36,6 @@ var _ = Describe("Connection Flow controller", func() {
|
||||
fc := NewConnectionFlowController(
|
||||
receiveWindow,
|
||||
maxReceiveWindow,
|
||||
nil,
|
||||
func(protocol.ByteCount) bool { return true },
|
||||
rttStats,
|
||||
utils.DefaultLogger).(*connectionFlowController)
|
||||
@@ -67,13 +61,11 @@ var _ = Describe("Connection Flow controller", func() {
|
||||
|
||||
It("queues window updates", func() {
|
||||
controller.AddBytesRead(1)
|
||||
Expect(queuedWindowUpdate).To(BeFalse())
|
||||
Expect(controller.GetWindowUpdate()).To(BeZero())
|
||||
controller.AddBytesRead(29)
|
||||
Expect(queuedWindowUpdate).To(BeTrue())
|
||||
Expect(controller.GetWindowUpdate()).ToNot(BeZero())
|
||||
queuedWindowUpdate = false
|
||||
controller.AddBytesRead(1)
|
||||
Expect(queuedWindowUpdate).To(BeFalse())
|
||||
Expect(controller.GetWindowUpdate()).To(BeZero())
|
||||
})
|
||||
|
||||
It("gets a window update", func() {
|
||||
|
||||
@@ -8,7 +8,6 @@ type flowController interface {
|
||||
UpdateSendWindow(protocol.ByteCount) (updated bool)
|
||||
AddBytesSent(protocol.ByteCount)
|
||||
// for receiving
|
||||
AddBytesRead(protocol.ByteCount)
|
||||
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
|
||||
IsNewlyBlocked() (bool, protocol.ByteCount)
|
||||
}
|
||||
@@ -16,6 +15,7 @@ type flowController interface {
|
||||
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||
type StreamFlowController interface {
|
||||
flowController
|
||||
AddBytesRead(protocol.ByteCount) (shouldQueueWindowUpdate bool)
|
||||
// UpdateHighestReceived is called when a new highest offset is received
|
||||
// final has to be to true if this is the final offset of the stream,
|
||||
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
|
||||
@@ -28,6 +28,7 @@ type StreamFlowController interface {
|
||||
// The ConnectionFlowController is the flow controller for the connection.
|
||||
type ConnectionFlowController interface {
|
||||
flowController
|
||||
AddBytesRead(protocol.ByteCount)
|
||||
Reset() error
|
||||
}
|
||||
|
||||
|
||||
@@ -13,8 +13,6 @@ type streamFlowController struct {
|
||||
|
||||
streamID protocol.StreamID
|
||||
|
||||
queueWindowUpdate func()
|
||||
|
||||
connection connectionFlowControllerI
|
||||
|
||||
receivedFinalOffset bool
|
||||
@@ -29,14 +27,12 @@ func NewStreamFlowController(
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
initialSendWindow protocol.ByteCount,
|
||||
queueWindowUpdate func(protocol.StreamID),
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) StreamFlowController {
|
||||
return &streamFlowController{
|
||||
streamID: streamID,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
||||
streamID: streamID,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
@@ -97,15 +93,13 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount,
|
||||
return c.connection.IncrementHighestReceived(increment)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (shouldQueueWindowUpdate bool) {
|
||||
c.mutex.Lock()
|
||||
c.baseFlowController.addBytesRead(n)
|
||||
shouldQueueWindowUpdate := c.shouldQueueWindowUpdate()
|
||||
shouldQueueWindowUpdate = c.shouldQueueWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
if shouldQueueWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
c.connection.AddBytesRead(n)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *streamFlowController) Abandon() {
|
||||
|
||||
@@ -12,20 +12,15 @@ import (
|
||||
)
|
||||
|
||||
var _ = Describe("Stream Flow controller", func() {
|
||||
var (
|
||||
controller *streamFlowController
|
||||
queuedWindowUpdate bool
|
||||
)
|
||||
var controller *streamFlowController
|
||||
|
||||
BeforeEach(func() {
|
||||
queuedWindowUpdate = false
|
||||
rttStats := &utils.RTTStats{}
|
||||
controller = &streamFlowController{
|
||||
streamID: 10,
|
||||
connection: NewConnectionFlowController(
|
||||
1000,
|
||||
1000,
|
||||
func() {},
|
||||
func(protocol.ByteCount) bool { return true },
|
||||
rttStats,
|
||||
utils.DefaultLogger,
|
||||
@@ -34,7 +29,6 @@ var _ = Describe("Stream Flow controller", func() {
|
||||
controller.maxReceiveWindowSize = 10000
|
||||
controller.rttStats = rttStats
|
||||
controller.logger = utils.DefaultLogger
|
||||
controller.queueWindowUpdate = func() { queuedWindowUpdate = true }
|
||||
})
|
||||
|
||||
Context("Constructor", func() {
|
||||
@@ -44,25 +38,18 @@ var _ = Describe("Stream Flow controller", func() {
|
||||
const sendWindow protocol.ByteCount = 4000
|
||||
|
||||
It("sets the send and receive windows", func() {
|
||||
cc := NewConnectionFlowController(0, 0, nil, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger)
|
||||
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController)
|
||||
cc := NewConnectionFlowController(0, 0, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger)
|
||||
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats, utils.DefaultLogger).(*streamFlowController)
|
||||
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
|
||||
Expect(fc.receiveWindow).To(Equal(receiveWindow))
|
||||
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
|
||||
Expect(fc.sendWindow).To(Equal(sendWindow))
|
||||
})
|
||||
|
||||
It("queues window updates with the correct stream ID", func() {
|
||||
var queued bool
|
||||
queueWindowUpdate := func(id protocol.StreamID) {
|
||||
Expect(id).To(Equal(protocol.StreamID(5)))
|
||||
queued = true
|
||||
}
|
||||
|
||||
cc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, func() {}, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger)
|
||||
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController)
|
||||
fc.AddBytesRead(receiveWindow)
|
||||
Expect(queued).To(BeTrue())
|
||||
It("queues window updates", func() {
|
||||
cc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, func(protocol.ByteCount) bool { return true }, nil, utils.DefaultLogger)
|
||||
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats, utils.DefaultLogger).(*streamFlowController)
|
||||
Expect(fc.AddBytesRead(receiveWindow)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -195,14 +182,10 @@ var _ = Describe("Stream Flow controller", func() {
|
||||
})
|
||||
|
||||
It("queues window updates", func() {
|
||||
controller.AddBytesRead(1)
|
||||
Expect(queuedWindowUpdate).To(BeFalse())
|
||||
controller.AddBytesRead(29)
|
||||
Expect(queuedWindowUpdate).To(BeTrue())
|
||||
Expect(controller.AddBytesRead(1)).To(BeFalse())
|
||||
Expect(controller.AddBytesRead(29)).To(BeTrue())
|
||||
Expect(controller.GetWindowUpdate()).ToNot(BeZero())
|
||||
queuedWindowUpdate = false
|
||||
controller.AddBytesRead(1)
|
||||
Expect(queuedWindowUpdate).To(BeFalse())
|
||||
Expect(controller.AddBytesRead(1)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("tells the connection flow controller when the window was auto-tuned", func() {
|
||||
@@ -246,10 +229,8 @@ var _ = Describe("Stream Flow controller", func() {
|
||||
|
||||
It("doesn't increase the window after a final offset was already received", func() {
|
||||
Expect(controller.UpdateHighestReceived(90, true)).To(Succeed())
|
||||
controller.AddBytesRead(30)
|
||||
Expect(queuedWindowUpdate).To(BeFalse())
|
||||
offset := controller.GetWindowUpdate()
|
||||
Expect(offset).To(BeZero())
|
||||
Expect(controller.AddBytesRead(30)).To(BeFalse())
|
||||
Expect(controller.GetWindowUpdate()).To(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -76,9 +76,11 @@ func (c *MockStreamFlowControllerAbandonCall) DoAndReturn(f func()) *MockStreamF
|
||||
}
|
||||
|
||||
// AddBytesRead mocks base method.
|
||||
func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) {
|
||||
func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) bool {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AddBytesRead", arg0)
|
||||
ret := m.ctrl.Call(m, "AddBytesRead", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddBytesRead indicates an expected call of AddBytesRead.
|
||||
@@ -94,19 +96,19 @@ type MockStreamFlowControllerAddBytesReadCall struct {
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamFlowControllerAddBytesReadCall) Return() *MockStreamFlowControllerAddBytesReadCall {
|
||||
c.Call = c.Call.Return()
|
||||
func (c *MockStreamFlowControllerAddBytesReadCall) Return(arg0 bool) *MockStreamFlowControllerAddBytesReadCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamFlowControllerAddBytesReadCall) Do(f func(protocol.ByteCount)) *MockStreamFlowControllerAddBytesReadCall {
|
||||
func (c *MockStreamFlowControllerAddBytesReadCall) Do(f func(protocol.ByteCount) bool) *MockStreamFlowControllerAddBytesReadCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamFlowControllerAddBytesReadCall) DoAndReturn(f func(protocol.ByteCount)) *MockStreamFlowControllerAddBytesReadCall {
|
||||
func (c *MockStreamFlowControllerAddBytesReadCall) DoAndReturn(f func(protocol.ByteCount) bool) *MockStreamFlowControllerAddBytesReadCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/quic-go/quic-go (interfaces: StreamGetter)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter
|
||||
//
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
protocol "github.com/quic-go/quic-go/internal/protocol"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockStreamGetter is a mock of StreamGetter interface.
|
||||
type MockStreamGetter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStreamGetterMockRecorder
|
||||
}
|
||||
|
||||
// MockStreamGetterMockRecorder is the mock recorder for MockStreamGetter.
|
||||
type MockStreamGetterMockRecorder struct {
|
||||
mock *MockStreamGetter
|
||||
}
|
||||
|
||||
// NewMockStreamGetter creates a new mock instance.
|
||||
func NewMockStreamGetter(ctrl *gomock.Controller) *MockStreamGetter {
|
||||
mock := &MockStreamGetter{ctrl: ctrl}
|
||||
mock.recorder = &MockStreamGetterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStreamGetter) EXPECT() *MockStreamGetterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetOrOpenReceiveStream mocks base method.
|
||||
func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0)
|
||||
ret0, _ := ret[0].(receiveStreamI)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream.
|
||||
func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 any) *MockStreamGetterGetOrOpenReceiveStreamCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0)
|
||||
return &MockStreamGetterGetOrOpenReceiveStreamCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamGetterGetOrOpenReceiveStreamCall wrap *gomock.Call
|
||||
type MockStreamGetterGetOrOpenReceiveStreamCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamGetterGetOrOpenReceiveStreamCall) Return(arg0 receiveStreamI, arg1 error) *MockStreamGetterGetOrOpenReceiveStreamCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamGetterGetOrOpenReceiveStreamCall) Do(f func(protocol.StreamID) (receiveStreamI, error)) *MockStreamGetterGetOrOpenReceiveStreamCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamGetterGetOrOpenReceiveStreamCall) DoAndReturn(f func(protocol.StreamID) (receiveStreamI, error)) *MockStreamGetterGetOrOpenReceiveStreamCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// GetOrOpenSendStream mocks base method.
|
||||
func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
|
||||
ret0, _ := ret[0].(sendStreamI)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream.
|
||||
func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 any) *MockStreamGetterGetOrOpenSendStreamCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0)
|
||||
return &MockStreamGetterGetOrOpenSendStreamCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamGetterGetOrOpenSendStreamCall wrap *gomock.Call
|
||||
type MockStreamGetterGetOrOpenSendStreamCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamGetterGetOrOpenSendStreamCall) Return(arg0 sendStreamI, arg1 error) *MockStreamGetterGetOrOpenSendStreamCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamGetterGetOrOpenSendStreamCall) Do(f func(protocol.StreamID) (sendStreamI, error)) *MockStreamGetterGetOrOpenSendStreamCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamGetterGetOrOpenSendStreamCall) DoAndReturn(f func(protocol.StreamID) (sendStreamI, error)) *MockStreamGetterGetOrOpenSendStreamCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
@@ -41,15 +41,15 @@ func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder {
|
||||
}
|
||||
|
||||
// onHasStreamData mocks base method.
|
||||
func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID) {
|
||||
func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID, arg1 sendStreamI) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "onHasStreamData", arg0)
|
||||
m.ctrl.Call(m, "onHasStreamData", arg0, arg1)
|
||||
}
|
||||
|
||||
// onHasStreamData indicates an expected call of onHasStreamData.
|
||||
func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 any) *MockStreamSenderonHasStreamDataCall {
|
||||
func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0, arg1 any) *MockStreamSenderonHasStreamDataCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0, arg1)
|
||||
return &MockStreamSenderonHasStreamDataCall{Call: call}
|
||||
}
|
||||
|
||||
@@ -65,13 +65,49 @@ func (c *MockStreamSenderonHasStreamDataCall) Return() *MockStreamSenderonHasStr
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamSenderonHasStreamDataCall) Do(f func(protocol.StreamID)) *MockStreamSenderonHasStreamDataCall {
|
||||
func (c *MockStreamSenderonHasStreamDataCall) Do(f func(protocol.StreamID, sendStreamI)) *MockStreamSenderonHasStreamDataCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamSenderonHasStreamDataCall) DoAndReturn(f func(protocol.StreamID)) *MockStreamSenderonHasStreamDataCall {
|
||||
func (c *MockStreamSenderonHasStreamDataCall) DoAndReturn(f func(protocol.StreamID, sendStreamI)) *MockStreamSenderonHasStreamDataCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// onHasStreamWindowUpdate mocks base method.
|
||||
func (m *MockStreamSender) onHasStreamWindowUpdate(arg0 protocol.StreamID, arg1 receiveStreamI) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "onHasStreamWindowUpdate", arg0, arg1)
|
||||
}
|
||||
|
||||
// onHasStreamWindowUpdate indicates an expected call of onHasStreamWindowUpdate.
|
||||
func (mr *MockStreamSenderMockRecorder) onHasStreamWindowUpdate(arg0, arg1 any) *MockStreamSenderonHasStreamWindowUpdateCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamWindowUpdate), arg0, arg1)
|
||||
return &MockStreamSenderonHasStreamWindowUpdateCall{Call: call}
|
||||
}
|
||||
|
||||
// MockStreamSenderonHasStreamWindowUpdateCall wrap *gomock.Call
|
||||
type MockStreamSenderonHasStreamWindowUpdateCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockStreamSenderonHasStreamWindowUpdateCall) Return() *MockStreamSenderonHasStreamWindowUpdateCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockStreamSenderonHasStreamWindowUpdateCall) Do(f func(protocol.StreamID, receiveStreamI)) *MockStreamSenderonHasStreamWindowUpdateCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockStreamSenderonHasStreamWindowUpdateCall) DoAndReturn(f func(protocol.StreamID, receiveStreamI)) *MockStreamSenderonHasStreamWindowUpdateCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -23,9 +23,6 @@ type ReceiveStreamI = receiveStreamI
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI"
|
||||
type SendStreamI = sendStreamI
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter"
|
||||
type StreamGetter = streamGetter
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender"
|
||||
type StreamSender = streamSender
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ var _ = Describe("Packet packer", func() {
|
||||
rand.Seed(uint64(GinkgoRandomSeed()))
|
||||
retransmissionQueue = newRetransmissionQueue()
|
||||
mockSender := NewMockStreamSender(mockCtrl)
|
||||
mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes()
|
||||
mockSender.EXPECT().onHasStreamData(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
initialStream = NewMockCryptoStream(mockCtrl)
|
||||
handshakeStream = NewMockCryptoStream(mockCtrl)
|
||||
framer = NewMockFrameSource(mockCtrl)
|
||||
|
||||
@@ -197,7 +197,9 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
|
||||
// when a RESET_STREAM was received, the flow controller was already
|
||||
// informed about the final byteOffset for this stream
|
||||
if !s.cancelledRemotely {
|
||||
s.flowController.AddBytesRead(protocol.ByteCount(m))
|
||||
if queueWindowUpdate := s.flowController.AddBytesRead(protocol.ByteCount(m)); queueWindowUpdate {
|
||||
s.sender.onHasStreamWindowUpdate(s.streamID, s)
|
||||
}
|
||||
}
|
||||
|
||||
if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
|
||||
|
||||
@@ -414,7 +414,10 @@ var _ = Describe("Receive Stream", func() {
|
||||
It("handles concurrent reads", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any()).AnyTimes()
|
||||
var bytesRead protocol.ByteCount
|
||||
mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) { bytesRead += n }).AnyTimes()
|
||||
mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) bool {
|
||||
bytesRead += n
|
||||
return false
|
||||
}).AnyTimes()
|
||||
|
||||
var numCompleted int32
|
||||
mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) {
|
||||
|
||||
@@ -172,7 +172,7 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
|
||||
|
||||
s.mutex.Unlock()
|
||||
if !notifiedSender {
|
||||
s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
|
||||
s.sender.onHasStreamData(s.streamID, s) // must be called without holding the mutex
|
||||
notifiedSender = true
|
||||
}
|
||||
if copied {
|
||||
@@ -407,7 +407,7 @@ func (s *sendStream) Close() error {
|
||||
if cancelWriteErr != nil {
|
||||
return fmt.Errorf("close called for canceled stream %d", s.streamID)
|
||||
}
|
||||
s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
|
||||
s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex
|
||||
|
||||
s.ctxCancel(nil)
|
||||
return nil
|
||||
@@ -453,7 +453,7 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
|
||||
hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
|
||||
s.mutex.Unlock()
|
||||
if hasStreamData {
|
||||
s.sender.onHasStreamData(s.streamID)
|
||||
s.sender.onHasStreamData(s.streamID, s)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -530,5 +530,5 @@ func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
s.sender.onHasStreamData(s.streamID)
|
||||
s.sender.onHasStreamData(s.streamID, (*sendStream)(s))
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ var _ = Describe("Send Stream", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
n, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
@@ -104,7 +104,7 @@ var _ = Describe("Send Stream", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
n, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
@@ -136,7 +136,7 @@ var _ = Describe("Send Stream", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
mockSender.EXPECT().onHasStreamData(streamID).Times(2)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str).Times(2)
|
||||
n, err := strWithTimeout.Write([]byte("foo"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
@@ -163,7 +163,7 @@ var _ = Describe("Send Stream", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
n, err := strWithTimeout.Write(getData(5000))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(5000))
|
||||
@@ -188,7 +188,7 @@ var _ = Describe("Send Stream", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
_, err := strWithTimeout.Write(getData(protocol.MaxPacketBufferSize + 3))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
@@ -211,14 +211,14 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("only unblocks Write once a previously buffered STREAM frame has been fully dequeued", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
_, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
_, err := str.Write(getData(protocol.MaxPacketBufferSize))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
@@ -251,7 +251,7 @@ var _ = Describe("Send Stream", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(100))
|
||||
@@ -282,7 +282,7 @@ var _ = Describe("Send Stream", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
n, err := strWithTimeout.Write(s)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
@@ -315,7 +315,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("cancels the context when Close is called", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
Expect(str.Context().Done()).ToNot(BeClosed())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
Expect(str.Context().Done()).To(BeClosed())
|
||||
@@ -334,7 +334,7 @@ var _ = Describe("Send Stream", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
_, err := str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
@@ -352,7 +352,7 @@ var _ = Describe("Send Stream", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
_, err := str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
@@ -392,7 +392,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("unblocks after the deadline", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
str.SetWriteDeadline(deadline)
|
||||
n, err := strWithTimeout.Write(getData(5000))
|
||||
@@ -402,7 +402,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("unblocks when the deadline is changed to the past", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
str.SetWriteDeadline(time.Now().Add(time.Hour))
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
@@ -426,7 +426,7 @@ var _ = Describe("Send Stream", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(writeReturned)
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
var err error
|
||||
n, err = strWithTimeout.Write(getData(5000))
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
@@ -450,7 +450,7 @@ var _ = Describe("Send Stream", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(writeReturned)
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
_, err := strWithTimeout.Write(getData(5000))
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
}()
|
||||
@@ -466,7 +466,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("doesn't unblock if the deadline is changed before the first one expires", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond))
|
||||
str.SetWriteDeadline(deadline1)
|
||||
@@ -488,7 +488,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("unblocks earlier, when a new deadline is set", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond))
|
||||
deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
done := make(chan struct{})
|
||||
@@ -509,7 +509,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("doesn't unblock if the deadline is removed", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
str.SetWriteDeadline(deadline)
|
||||
deadlineUnset := make(chan struct{})
|
||||
@@ -539,14 +539,14 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
Context("closing", func() {
|
||||
It("doesn't allow writes after it has been closed", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
str.Close()
|
||||
_, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).To(MatchError("write on closed stream 1337"))
|
||||
})
|
||||
|
||||
It("allows FIN", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
str.Close()
|
||||
frame, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1)
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -560,7 +560,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("doesn't send a FIN when there's still data", func() {
|
||||
const frameHeaderLen protocol.ByteCount = 4
|
||||
mockSender.EXPECT().onHasStreamData(streamID).Times(2)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str).Times(2)
|
||||
_, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
@@ -584,10 +584,10 @@ var _ = Describe("Send Stream", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
_, err := strWithTimeout.Write(getData(5000))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}()
|
||||
waitForWrite()
|
||||
@@ -619,7 +619,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("doesn't allow FIN twice", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
str.Close()
|
||||
frame, ok, _ := str.popStreamFrame(1000, protocol.Version1)
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -646,7 +646,7 @@ var _ = Describe("Send Stream", func() {
|
||||
It("doesn't get data for writing if an error occurred", func() {
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -676,7 +676,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("says when it has data for sending", func() {
|
||||
mockFC.EXPECT().UpdateSendWindow(gomock.Any()).Return(true)
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -685,7 +685,7 @@ var _ = Describe("Send Stream", func() {
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
str.updateSendWindow(42)
|
||||
// make sure the Write go routine returns
|
||||
str.closeForShutdown(nil)
|
||||
@@ -694,7 +694,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("doesn't say it has data for sending if the MAX_STREAM_DATA frame was reordered", func() {
|
||||
mockFC.EXPECT().UpdateSendWindow(gomock.Any()).Return(false) // reordered frame
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -731,7 +731,7 @@ var _ = Describe("Send Stream", func() {
|
||||
// for reliable results it has to be run many times.
|
||||
It("returns a nil error when the whole slice has been sent out", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any()).MaxTimes(1)
|
||||
mockSender.EXPECT().onHasStreamData(streamID).MaxTimes(1)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str).MaxTimes(1)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID).MaxTimes(1)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).MaxTimes(1)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any()).MaxTimes(1)
|
||||
@@ -754,7 +754,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("unblocks Write", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
writeReturned := make(chan struct{})
|
||||
@@ -782,7 +782,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("doesn't pop STREAM frames after being canceled", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
writeReturned := make(chan struct{})
|
||||
@@ -806,7 +806,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("doesn't pop STREAM frames after being canceled, for large writes", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
writeReturned := make(chan struct{})
|
||||
@@ -835,7 +835,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("ignores acknowledgements for STREAM frames after it was cancelled", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
writeReturned := make(chan struct{})
|
||||
@@ -884,7 +884,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("queues a RESET_STREAM frame, even if the stream was already closed", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
|
||||
Expect(f).To(BeAssignableToTypeOf(&wire.ResetStreamFrame{}))
|
||||
})
|
||||
@@ -910,7 +910,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("unblocks Write", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
@@ -958,7 +958,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("handles STOP_SENDING after sending the FIN", func() {
|
||||
mockSender.EXPECT().onHasStreamData(gomock.Any())
|
||||
mockSender.EXPECT().onHasStreamData(gomock.Any(), gomock.Any())
|
||||
str.Close()
|
||||
_, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -973,7 +973,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("handles STOP_SENDING after Close, but before sending the FIN", func() {
|
||||
mockSender.EXPECT().onHasStreamData(gomock.Any())
|
||||
mockSender.EXPECT().onHasStreamData(gomock.Any(), gomock.Any())
|
||||
str.Close()
|
||||
gomock.InOrder(
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any()),
|
||||
@@ -995,7 +995,7 @@ var _ = Describe("Send Stream", func() {
|
||||
Offset: 0x42,
|
||||
DataLenPresent: false,
|
||||
}
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
(*sendStreamAckHandler)(str).OnLost(f)
|
||||
frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -1013,7 +1013,7 @@ var _ = Describe("Send Stream", func() {
|
||||
Offset: 0x42,
|
||||
DataLenPresent: false,
|
||||
}
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
(*sendStreamAckHandler)(str).OnLost(sf)
|
||||
frame, ok, hasMoreData := str.popStreamFrame(sf.Length(protocol.Version1)-3, protocol.Version1)
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -1039,7 +1039,7 @@ var _ = Describe("Send Stream", func() {
|
||||
Offset: 0x42,
|
||||
DataLenPresent: false,
|
||||
}
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
(*sendStreamAckHandler)(str).OnLost(f)
|
||||
_, ok, hasMoreData := str.popStreamFrame(2, protocol.Version1)
|
||||
Expect(ok).To(BeFalse())
|
||||
@@ -1047,7 +1047,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("queues lost STREAM frames", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
|
||||
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
|
||||
done := make(chan struct{})
|
||||
@@ -1065,7 +1065,7 @@ var _ = Describe("Send Stream", func() {
|
||||
Expect(frame.Frame.Data).To(Equal([]byte("foobar")))
|
||||
|
||||
// now lose the frame
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
frame.Handler.OnLost(frame.Frame)
|
||||
newFrame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -1074,7 +1074,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("doesn't queue retransmissions for a stream that was canceled", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
|
||||
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
|
||||
done := make(chan struct{})
|
||||
@@ -1107,7 +1107,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("says when a stream is completed", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -1138,7 +1138,7 @@ var _ = Describe("Send Stream", func() {
|
||||
}
|
||||
|
||||
// Now close the stream and acknowledge the FIN.
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
Expect(str.Close()).To(Succeed())
|
||||
frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -1148,7 +1148,7 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("says when a stream is completed, if Close() is called before popping the frame", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID).Times(2)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str).Times(2)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -1171,13 +1171,13 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
|
||||
It("doesn't say it's completed when there are frames waiting to be retransmitted", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := strWithTimeout.Write(getData(100))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
Expect(str.Close()).To(Succeed())
|
||||
close(done)
|
||||
}()
|
||||
@@ -1201,7 +1201,7 @@ var _ = Describe("Send Stream", func() {
|
||||
for _, f := range frames[1:] {
|
||||
f.Handler.OnAcked(f.Frame)
|
||||
}
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str)
|
||||
frames[0].Handler.OnLost(frames[0].Frame)
|
||||
|
||||
// get the retransmission and acknowledge it
|
||||
@@ -1218,7 +1218,7 @@ var _ = Describe("Send Stream", func() {
|
||||
// and has to be retransmitted.
|
||||
It("retransmits data until everything has been acknowledged", func() {
|
||||
const dataLen = 1 << 22 // 4 MB
|
||||
mockSender.EXPECT().onHasStreamData(streamID).AnyTimes()
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str).AnyTimes()
|
||||
mockFC.EXPECT().SendWindowSize().DoAndReturn(func() protocol.ByteCount {
|
||||
return protocol.ByteCount(mrand.Intn(500)) + 50
|
||||
}).AnyTimes()
|
||||
|
||||
17
stream.go
17
stream.go
@@ -25,7 +25,8 @@ var errDeadline net.Error = &deadlineError{}
|
||||
// The streamSender is notified by the stream about various events.
|
||||
type streamSender interface {
|
||||
queueControlFrame(wire.Frame)
|
||||
onHasStreamData(protocol.StreamID)
|
||||
onHasStreamData(protocol.StreamID, sendStreamI)
|
||||
onHasStreamWindowUpdate(protocol.StreamID, receiveStreamI)
|
||||
// must be called without holding the mutex that is acquired by closeForShutdown
|
||||
onStreamCompleted(protocol.StreamID)
|
||||
}
|
||||
@@ -37,17 +38,11 @@ type uniStreamSender struct {
|
||||
onStreamCompletedImpl func()
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
|
||||
s.streamSender.queueControlFrame(f)
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
|
||||
s.streamSender.onHasStreamData(id)
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) {
|
||||
s.onStreamCompletedImpl()
|
||||
func (s *uniStreamSender) queueControlFrame(f wire.Frame) { s.streamSender.queueControlFrame(f) }
|
||||
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID, str sendStreamI) {
|
||||
s.streamSender.onHasStreamData(id, str)
|
||||
}
|
||||
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { s.onStreamCompletedImpl() }
|
||||
|
||||
var _ streamSender = &uniStreamSender{}
|
||||
|
||||
|
||||
@@ -11,53 +11,44 @@ import (
|
||||
type windowUpdateQueue struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
queue map[protocol.StreamID]struct{} // used as a set
|
||||
queuedConn bool // connection-level window update
|
||||
queue map[protocol.StreamID]receiveStreamI
|
||||
|
||||
streamGetter streamGetter
|
||||
connFlowController flowcontrol.ConnectionFlowController
|
||||
callback func(wire.Frame)
|
||||
}
|
||||
|
||||
func newWindowUpdateQueue(
|
||||
streamGetter streamGetter,
|
||||
connFC flowcontrol.ConnectionFlowController,
|
||||
cb func(wire.Frame),
|
||||
) *windowUpdateQueue {
|
||||
return &windowUpdateQueue{
|
||||
queue: make(map[protocol.StreamID]struct{}),
|
||||
streamGetter: streamGetter,
|
||||
queue: make(map[protocol.StreamID]receiveStreamI),
|
||||
connFlowController: connFC,
|
||||
callback: cb,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *windowUpdateQueue) AddStream(id protocol.StreamID) {
|
||||
func (q *windowUpdateQueue) AddStream(id protocol.StreamID, str receiveStreamI) {
|
||||
q.mutex.Lock()
|
||||
q.queue[id] = struct{}{}
|
||||
q.queue[id] = str
|
||||
q.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (q *windowUpdateQueue) AddConnection() {
|
||||
func (q *windowUpdateQueue) RemoveStream(id protocol.StreamID) {
|
||||
q.mutex.Lock()
|
||||
q.queuedConn = true
|
||||
delete(q.queue, id)
|
||||
q.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (q *windowUpdateQueue) QueueAll() {
|
||||
q.mutex.Lock()
|
||||
// queue a connection-level window update
|
||||
if q.queuedConn {
|
||||
q.callback(&wire.MaxDataFrame{MaximumData: q.connFlowController.GetWindowUpdate()})
|
||||
q.queuedConn = false
|
||||
if offset := q.connFlowController.GetWindowUpdate(); offset > 0 {
|
||||
q.callback(&wire.MaxDataFrame{MaximumData: offset})
|
||||
}
|
||||
// queue all stream-level window updates
|
||||
for id := range q.queue {
|
||||
for id, str := range q.queue {
|
||||
delete(q.queue, id)
|
||||
str, err := q.streamGetter.GetOrOpenReceiveStream(id)
|
||||
if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update
|
||||
continue
|
||||
}
|
||||
offset := str.getWindowUpdate()
|
||||
if offset == 0 { // can happen if we received a final offset, right after queueing the window update
|
||||
continue
|
||||
|
||||
@@ -12,39 +12,36 @@ import (
|
||||
var _ = Describe("Window Update Queue", func() {
|
||||
var (
|
||||
q *windowUpdateQueue
|
||||
streamGetter *MockStreamGetter
|
||||
connFC *mocks.MockConnectionFlowController
|
||||
queuedFrames []wire.Frame
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
streamGetter = NewMockStreamGetter(mockCtrl)
|
||||
connFC = mocks.NewMockConnectionFlowController(mockCtrl)
|
||||
queuedFrames = queuedFrames[:0]
|
||||
q = newWindowUpdateQueue(streamGetter, connFC, func(f wire.Frame) {
|
||||
q = newWindowUpdateQueue(connFC, func(f wire.Frame) {
|
||||
queuedFrames = append(queuedFrames, f)
|
||||
})
|
||||
})
|
||||
|
||||
It("adds stream offsets and gets MAX_STREAM_DATA frames", func() {
|
||||
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes()
|
||||
stream1 := NewMockStreamI(mockCtrl)
|
||||
stream1.EXPECT().getWindowUpdate().Return(protocol.ByteCount(10))
|
||||
stream3 := NewMockStreamI(mockCtrl)
|
||||
stream3.EXPECT().getWindowUpdate().Return(protocol.ByteCount(30))
|
||||
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(stream3, nil)
|
||||
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(1)).Return(stream1, nil)
|
||||
q.AddStream(3)
|
||||
q.AddStream(1)
|
||||
q.AddStream(3, stream3)
|
||||
q.AddStream(1, stream1)
|
||||
q.QueueAll()
|
||||
Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 1, MaximumStreamData: 10}))
|
||||
Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 3, MaximumStreamData: 30}))
|
||||
})
|
||||
|
||||
It("deletes the entry after getting the MAX_STREAM_DATA frame", func() {
|
||||
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes()
|
||||
stream10 := NewMockStreamI(mockCtrl)
|
||||
stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(100))
|
||||
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil)
|
||||
q.AddStream(10)
|
||||
q.AddStream(10, stream10)
|
||||
q.QueueAll()
|
||||
Expect(queuedFrames).To(HaveLen(1))
|
||||
q.QueueAll()
|
||||
@@ -52,36 +49,37 @@ var _ = Describe("Window Update Queue", func() {
|
||||
})
|
||||
|
||||
It("doesn't queue a MAX_STREAM_DATA for a closed stream", func() {
|
||||
q.AddStream(12)
|
||||
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(12)).Return(nil, nil)
|
||||
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes()
|
||||
stream12 := NewMockStreamI(mockCtrl)
|
||||
q.AddStream(12, stream12)
|
||||
q.RemoveStream(12)
|
||||
q.QueueAll()
|
||||
Expect(queuedFrames).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("removes closed streams from the queue", func() {
|
||||
q.AddStream(12)
|
||||
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(12)).Return(nil, nil)
|
||||
q.QueueAll()
|
||||
Expect(queuedFrames).To(BeEmpty())
|
||||
// don't EXPECT any further calls to GetOrOpenReceiveStream
|
||||
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes()
|
||||
stream12 := NewMockStreamI(mockCtrl)
|
||||
q.AddStream(12, stream12)
|
||||
q.RemoveStream(12)
|
||||
q.QueueAll()
|
||||
Expect(queuedFrames).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("doesn't queue a MAX_STREAM_DATA if the flow controller returns an offset of 0", func() {
|
||||
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0))
|
||||
stream5 := NewMockStreamI(mockCtrl)
|
||||
stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0))
|
||||
q.AddStream(5)
|
||||
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(stream5, nil)
|
||||
q.AddStream(5, stream5)
|
||||
q.QueueAll()
|
||||
Expect(queuedFrames).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("removes streams for which the flow controller returns an offset of 0 from the queue", func() {
|
||||
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes()
|
||||
stream5 := NewMockStreamI(mockCtrl)
|
||||
stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0))
|
||||
q.AddStream(5)
|
||||
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(stream5, nil)
|
||||
q.AddStream(5, stream5)
|
||||
q.QueueAll()
|
||||
Expect(queuedFrames).To(BeEmpty())
|
||||
// don't EXPECT any further calls to GetOrOpenReveiveStream and to getWindowUpdate
|
||||
@@ -91,22 +89,17 @@ var _ = Describe("Window Update Queue", func() {
|
||||
|
||||
It("queues MAX_DATA frames", func() {
|
||||
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1337))
|
||||
q.AddConnection()
|
||||
q.QueueAll()
|
||||
Expect(queuedFrames).To(Equal([]wire.Frame{
|
||||
&wire.MaxDataFrame{MaximumData: 0x1337},
|
||||
}))
|
||||
Expect(queuedFrames).To(Equal([]wire.Frame{&wire.MaxDataFrame{MaximumData: 0x1337}}))
|
||||
})
|
||||
|
||||
It("deduplicates", func() {
|
||||
connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0))
|
||||
stream10 := NewMockStreamI(mockCtrl)
|
||||
stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(200))
|
||||
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil)
|
||||
q.AddStream(10)
|
||||
q.AddStream(10)
|
||||
q.AddStream(10, stream10)
|
||||
q.AddStream(10, stream10)
|
||||
q.QueueAll()
|
||||
Expect(queuedFrames).To(Equal([]wire.Frame{
|
||||
&wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 200},
|
||||
}))
|
||||
Expect(queuedFrames).To(Equal([]wire.Frame{&wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 200}}))
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user