diff --git a/connection.go b/connection.go index 81d602c23..9007f7270 100644 --- a/connection.go +++ b/connection.go @@ -140,7 +140,6 @@ type connection struct { receivedPacketHandler ackhandler.ReceivedPacketHandler retransmissionQueue *retransmissionQueue framer *framer - windowUpdateQueue *windowUpdateQueue connFlowController flowcontrol.ConnectionFlowController tokenStoreKey string // only set for the client tokenGenerator *handshake.TokenGenerator // only set for the server @@ -472,6 +471,7 @@ func (s *connection) preSetup() { s.streamsMap = newStreamsMap( s.ctx, s, + s.queueControlFrame, s.newFlowController, uint64(s.config.MaxIncomingStreams), uint64(s.config.MaxIncomingUniStreams), @@ -487,7 +487,6 @@ func (s *connection) preSetup() { s.lastPacketReceivedTime = now s.creationTime = now - s.windowUpdateQueue = newWindowUpdateQueue(s.connFlowController, s.framer.QueueControlFrame) s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) s.connState.Version = s.version } @@ -1874,7 +1873,9 @@ func (s *connection) sendPackets(now time.Time) error { if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset}) } - s.windowUpdateQueue.QueueAll() + if offset := s.connFlowController.GetWindowUpdate(); offset > 0 { + s.framer.QueueControlFrame(&wire.MaxDataFrame{MaximumData: offset}) + } if cf := s.cryptoStreamManager.GetPostHandshakeData(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil { s.queueControlFrame(cf) } @@ -2244,13 +2245,13 @@ func (s *connection) queueControlFrame(f wire.Frame) { s.scheduleSending() } -func (s *connection) onHasStreamWindowUpdate(id protocol.StreamID, str receiveStreamI) { - s.windowUpdateQueue.AddStream(id, str) +func (s *connection) onHasStreamData(id protocol.StreamID, str sendStreamI) { + s.framer.AddActiveStream(id, str) s.scheduleSending() } -func (s *connection) onHasStreamData(id protocol.StreamID, str sendStreamI) { - s.framer.AddActiveStream(id, str) +func (s *connection) onHasStreamControlFrame(id protocol.StreamID, str streamControlFrameGetter) { + s.framer.AddStreamWithControlFrames(id, str) s.scheduleSending() } @@ -2259,7 +2260,6 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) { s.closeLocal(err) } s.framer.RemoveActiveStream(id) - s.windowUpdateQueue.RemoveStream(id) } func (s *connection) onMTUIncreased(mtu protocol.ByteCount) { diff --git a/connection_test.go b/connection_test.go index 73008173c..561b9a8fb 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1284,6 +1284,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) + fc.EXPECT().GetWindowUpdate() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 13}, []byte("foobar")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() conn.connFlowController = fc diff --git a/framer.go b/framer.go index 76ecd399e..e162f6b8f 100644 --- a/framer.go +++ b/framer.go @@ -16,11 +16,20 @@ const ( maxControlFrames = 16 << 10 ) +// This is the largest possible size of a stream-related control frame +// (which is the RESET_STREAM frame). +const maxStreamControlFrameSize = 25 + +type streamControlFrameGetter interface { + getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) +} + type framer struct { mutex sync.Mutex - activeStreams map[protocol.StreamID]sendStreamI - streamQueue ringbuffer.RingBuffer[protocol.StreamID] + activeStreams map[protocol.StreamID]sendStreamI + streamQueue ringbuffer.RingBuffer[protocol.StreamID] + streamsWithControlFrames map[protocol.StreamID]streamControlFrameGetter controlFrameMutex sync.Mutex controlFrames []wire.Frame @@ -29,7 +38,10 @@ type framer struct { } func newFramer() *framer { - return &framer{activeStreams: make(map[protocol.StreamID]sendStreamI)} + return &framer{ + activeStreams: make(map[protocol.StreamID]sendStreamI), + streamsWithControlFrames: make(map[protocol.StreamID]streamControlFrameGetter), + } } func (f *framer) HasData() bool { @@ -41,7 +53,7 @@ func (f *framer) HasData() bool { } f.controlFrameMutex.Lock() defer f.controlFrameMutex.Unlock() - return len(f.controlFrames) > 0 || len(f.pathResponses) > 0 + return len(f.streamsWithControlFrames) > 0 || len(f.controlFrames) > 0 || len(f.pathResponses) > 0 } func (f *framer) QueueControlFrame(frame wire.Frame) { @@ -82,6 +94,29 @@ func (f *framer) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol. } } + // add stream-related control frames + for id, str := range f.streamsWithControlFrames { + start: + remainingLen := maxLen - length + if remainingLen <= maxStreamControlFrameSize { + break + } + fr, ok, hasMore := str.getControlFrame() + if !hasMore { + delete(f.streamsWithControlFrames, id) + } + if !ok { + continue + } + frames = append(frames, fr) + length += fr.Frame.Length(v) + if hasMore { + // It is rare that a stream has more than one control frame to queue. + // We don't want to spawn another loop for just to cover that case. + goto start + } + } + for len(f.controlFrames) > 0 { frame := f.controlFrames[len(f.controlFrames)-1] frameLen := frame.Length(v) @@ -92,6 +127,7 @@ func (f *framer) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol. length += frameLen f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] } + return frames, length } @@ -113,6 +149,14 @@ func (f *framer) AddActiveStream(id protocol.StreamID, str sendStreamI) { f.mutex.Unlock() } +func (f *framer) AddStreamWithControlFrames(id protocol.StreamID, str streamControlFrameGetter) { + f.controlFrameMutex.Lock() + if _, ok := f.streamsWithControlFrames[id]; !ok { + f.streamsWithControlFrames[id] = str + } + f.controlFrameMutex.Unlock() +} + // RemoveActiveStream is called when a stream completes. func (f *framer) RemoveActiveStream(id protocol.StreamID) { f.mutex.Lock() @@ -127,7 +171,7 @@ func (f *framer) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen prot startLen := len(frames) var length protocol.ByteCount f.mutex.Lock() - // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet + // pop STREAM frames, until less than 128 bytes are left in the packet numActiveStreams := f.streamQueue.Len() for i := 0; i < numActiveStreams; i++ { if protocol.MinStreamFrameSize+length > maxLen { @@ -153,7 +197,7 @@ func (f *framer) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen prot delete(f.activeStreams, id) } // The frame can be "nil" - // * if the receiveStream was canceled after it said it had data + // * if the stream was canceled after it said it had data // * the remaining size doesn't allow us to add another STREAM frame if !ok { continue diff --git a/framer_test.go b/framer_test.go index b51ef1692..a8cf83287 100644 --- a/framer_test.go +++ b/framer_test.go @@ -36,16 +36,16 @@ var _ = Describe("Framer", func() { Context("handling control frames", func() { It("adds control frames", func() { - mdf := &wire.MaxDataFrame{MaximumData: 0x42} + pc := &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 6, 7, 8}} msf := &wire.MaxStreamsFrame{MaxStreamNum: 0x1337} - framer.QueueControlFrame(mdf) + framer.QueueControlFrame(pc) framer.QueueControlFrame(msf) frames, length := framer.AppendControlFrames(nil, 1000, protocol.Version1) Expect(frames).To(HaveLen(2)) fs := []wire.Frame{frames[0].Frame, frames[1].Frame} - Expect(fs).To(ContainElement(mdf)) + Expect(fs).To(ContainElement(pc)) Expect(fs).To(ContainElement(msf)) - Expect(length).To(Equal(mdf.Length(version) + msf.Length(version))) + Expect(length).To(Equal(pc.Length(version) + msf.Length(version))) }) It("says if it has data", func() { @@ -60,13 +60,43 @@ var _ = Describe("Framer", func() { It("appends to the slice given", func() { ping := &wire.PingFrame{} - mdf := &wire.MaxDataFrame{MaximumData: 0x42} - framer.QueueControlFrame(mdf) + pc := &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 6, 7, 8}} + framer.QueueControlFrame(pc) frames, length := framer.AppendControlFrames([]ackhandler.Frame{{Frame: ping}}, 1000, protocol.Version1) Expect(frames).To(HaveLen(2)) Expect(frames[0].Frame).To(Equal(ping)) - Expect(frames[1].Frame).To(Equal(mdf)) - Expect(length).To(Equal(mdf.Length(version))) + Expect(frames[1].Frame).To(Equal(pc)) + Expect(length).To(Equal(pc.Length(version))) + }) + + It("adds stream-related control frames", func() { + ping := &wire.PingFrame{} + framer.QueueControlFrame(ping) + str := NewMockStreamControlFrameGetter(mockCtrl) + framer.AddStreamWithControlFrames(10, str) + mdf1 := &wire.MaxStreamDataFrame{MaximumStreamData: 1337} + mdf2 := &wire.MaxStreamDataFrame{MaximumStreamData: 1338} + str.EXPECT().getControlFrame().Return(ackhandler.Frame{Frame: mdf1}, true, true) + str.EXPECT().getControlFrame().Return(ackhandler.Frame{Frame: mdf2}, true, false) + frames, l := framer.AppendControlFrames(nil, protocol.MaxByteCount, protocol.Version1) + Expect(frames).To(HaveLen(3)) + Expect(frames[0].Frame).To(Equal(mdf1)) + Expect(frames[1].Frame).To(Equal(mdf2)) + Expect(frames[2].Frame).To(Equal(ping)) + Expect(l).To(Equal(ping.Length(protocol.Version1) + mdf1.Length(protocol.Version1) + mdf2.Length(protocol.Version1))) + }) + + It("doesn't enqueue more stream-related control frames if there are less than 25 bytes left", func() { + str := NewMockStreamControlFrameGetter(mockCtrl) + framer.AddStreamWithControlFrames(10, str) + mdf1 := &wire.MaxStreamDataFrame{MaximumStreamData: 1337} + str.EXPECT().getControlFrame().Return(ackhandler.Frame{Frame: mdf1}, true, true).AnyTimes() + frames, l := framer.AppendControlFrames(nil, 100, protocol.Version1) + Expect(l).To(Equal(protocol.ByteCount(len(frames)) * mdf1.Length(protocol.Version1))) + Expect(l).To(And( + BeNumerically(">", 100-maxStreamControlFrameSize), + BeNumerically("<=", 100), + )) }) It("adds the right number of frames", func() { @@ -211,6 +241,8 @@ var _ = Describe("Framer", func() { Expect(frames).To(HaveLen(1)) Expect(frames[0].Frame).To(Equal(f2)) Expect(framer.HasData()).To(BeFalse()) + framer.AddStreamWithControlFrames(id1, nil) + Expect(framer.HasData()).To(BeTrue()) }) It("appends to a frame slice", func() { diff --git a/mock_receive_stream_internal_test.go b/mock_receive_stream_internal_test.go index c2faf3800..793cdfe0d 100644 --- a/mock_receive_stream_internal_test.go +++ b/mock_receive_stream_internal_test.go @@ -229,44 +229,6 @@ func (c *MockReceiveStreamIcloseForShutdownCall) DoAndReturn(f func(error)) *Moc return c } -// getWindowUpdate mocks base method. -func (m *MockReceiveStreamI) getWindowUpdate() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getWindowUpdate") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// getWindowUpdate indicates an expected call of getWindowUpdate. -func (mr *MockReceiveStreamIMockRecorder) getWindowUpdate() *MockReceiveStreamIgetWindowUpdateCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockReceiveStreamI)(nil).getWindowUpdate)) - return &MockReceiveStreamIgetWindowUpdateCall{Call: call} -} - -// MockReceiveStreamIgetWindowUpdateCall wrap *gomock.Call -type MockReceiveStreamIgetWindowUpdateCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockReceiveStreamIgetWindowUpdateCall) Return(arg0 protocol.ByteCount) *MockReceiveStreamIgetWindowUpdateCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockReceiveStreamIgetWindowUpdateCall) Do(f func() protocol.ByteCount) *MockReceiveStreamIgetWindowUpdateCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockReceiveStreamIgetWindowUpdateCall) DoAndReturn(f func() protocol.ByteCount) *MockReceiveStreamIgetWindowUpdateCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // handleResetStreamFrame mocks base method. func (m *MockReceiveStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) error { m.ctrl.T.Helper() diff --git a/mock_stream_control_frame_getter_test.go b/mock_stream_control_frame_getter_test.go new file mode 100644 index 000000000..2fad6e7c0 --- /dev/null +++ b/mock_stream_control_frame_getter_test.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go (interfaces: StreamControlFrameGetter) +// +// Generated by this command: +// +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_control_frame_getter_test.go github.com/quic-go/quic-go StreamControlFrameGetter +// + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + ackhandler "github.com/quic-go/quic-go/internal/ackhandler" + gomock "go.uber.org/mock/gomock" +) + +// MockStreamControlFrameGetter is a mock of StreamControlFrameGetter interface. +type MockStreamControlFrameGetter struct { + ctrl *gomock.Controller + recorder *MockStreamControlFrameGetterMockRecorder +} + +// MockStreamControlFrameGetterMockRecorder is the mock recorder for MockStreamControlFrameGetter. +type MockStreamControlFrameGetterMockRecorder struct { + mock *MockStreamControlFrameGetter +} + +// NewMockStreamControlFrameGetter creates a new mock instance. +func NewMockStreamControlFrameGetter(ctrl *gomock.Controller) *MockStreamControlFrameGetter { + mock := &MockStreamControlFrameGetter{ctrl: ctrl} + mock.recorder = &MockStreamControlFrameGetterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStreamControlFrameGetter) EXPECT() *MockStreamControlFrameGetterMockRecorder { + return m.recorder +} + +// getControlFrame mocks base method. +func (m *MockStreamControlFrameGetter) getControlFrame() (ackhandler.Frame, bool, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "getControlFrame") + ret0, _ := ret[0].(ackhandler.Frame) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(bool) + return ret0, ret1, ret2 +} + +// getControlFrame indicates an expected call of getControlFrame. +func (mr *MockStreamControlFrameGetterMockRecorder) getControlFrame() *MockStreamControlFrameGettergetControlFrameCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getControlFrame", reflect.TypeOf((*MockStreamControlFrameGetter)(nil).getControlFrame)) + return &MockStreamControlFrameGettergetControlFrameCall{Call: call} +} + +// MockStreamControlFrameGettergetControlFrameCall wrap *gomock.Call +type MockStreamControlFrameGettergetControlFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamControlFrameGettergetControlFrameCall) Return(arg0 ackhandler.Frame, arg1, arg2 bool) *MockStreamControlFrameGettergetControlFrameCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamControlFrameGettergetControlFrameCall) Do(f func() (ackhandler.Frame, bool, bool)) *MockStreamControlFrameGettergetControlFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamControlFrameGettergetControlFrameCall) DoAndReturn(f func() (ackhandler.Frame, bool, bool)) *MockStreamControlFrameGettergetControlFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index 212256d87..a27fbed81 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -458,44 +458,6 @@ func (c *MockStreamIcloseForShutdownCall) DoAndReturn(f func(error)) *MockStream return c } -// getWindowUpdate mocks base method. -func (m *MockStreamI) getWindowUpdate() protocol.ByteCount { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getWindowUpdate") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// getWindowUpdate indicates an expected call of getWindowUpdate. -func (mr *MockStreamIMockRecorder) getWindowUpdate() *MockStreamIgetWindowUpdateCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).getWindowUpdate)) - return &MockStreamIgetWindowUpdateCall{Call: call} -} - -// MockStreamIgetWindowUpdateCall wrap *gomock.Call -type MockStreamIgetWindowUpdateCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamIgetWindowUpdateCall) Return(arg0 protocol.ByteCount) *MockStreamIgetWindowUpdateCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamIgetWindowUpdateCall) Do(f func() protocol.ByteCount) *MockStreamIgetWindowUpdateCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamIgetWindowUpdateCall) DoAndReturn(f func() protocol.ByteCount) *MockStreamIgetWindowUpdateCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // handleResetStreamFrame mocks base method. func (m *MockStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) error { m.ctrl.T.Helper() diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index 2b6083452..089061519 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -13,7 +13,6 @@ import ( reflect "reflect" protocol "github.com/quic-go/quic-go/internal/protocol" - wire "github.com/quic-go/quic-go/internal/wire" gomock "go.uber.org/mock/gomock" ) @@ -40,6 +39,42 @@ func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder { return m.recorder } +// onHasStreamControlFrame mocks base method. +func (m *MockStreamSender) onHasStreamControlFrame(arg0 protocol.StreamID, arg1 streamControlFrameGetter) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "onHasStreamControlFrame", arg0, arg1) +} + +// onHasStreamControlFrame indicates an expected call of onHasStreamControlFrame. +func (mr *MockStreamSenderMockRecorder) onHasStreamControlFrame(arg0, arg1 any) *MockStreamSenderonHasStreamControlFrameCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamControlFrame", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamControlFrame), arg0, arg1) + return &MockStreamSenderonHasStreamControlFrameCall{Call: call} +} + +// MockStreamSenderonHasStreamControlFrameCall wrap *gomock.Call +type MockStreamSenderonHasStreamControlFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamSenderonHasStreamControlFrameCall) Return() *MockStreamSenderonHasStreamControlFrameCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamSenderonHasStreamControlFrameCall) Do(f func(protocol.StreamID, streamControlFrameGetter)) *MockStreamSenderonHasStreamControlFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamSenderonHasStreamControlFrameCall) DoAndReturn(f func(protocol.StreamID, streamControlFrameGetter)) *MockStreamSenderonHasStreamControlFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // onHasStreamData mocks base method. func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID, arg1 sendStreamI) { m.ctrl.T.Helper() @@ -76,42 +111,6 @@ func (c *MockStreamSenderonHasStreamDataCall) DoAndReturn(f func(protocol.Stream return c } -// onHasStreamWindowUpdate mocks base method. -func (m *MockStreamSender) onHasStreamWindowUpdate(arg0 protocol.StreamID, arg1 receiveStreamI) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "onHasStreamWindowUpdate", arg0, arg1) -} - -// onHasStreamWindowUpdate indicates an expected call of onHasStreamWindowUpdate. -func (mr *MockStreamSenderMockRecorder) onHasStreamWindowUpdate(arg0, arg1 any) *MockStreamSenderonHasStreamWindowUpdateCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamWindowUpdate), arg0, arg1) - return &MockStreamSenderonHasStreamWindowUpdateCall{Call: call} -} - -// MockStreamSenderonHasStreamWindowUpdateCall wrap *gomock.Call -type MockStreamSenderonHasStreamWindowUpdateCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamSenderonHasStreamWindowUpdateCall) Return() *MockStreamSenderonHasStreamWindowUpdateCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamSenderonHasStreamWindowUpdateCall) Do(f func(protocol.StreamID, receiveStreamI)) *MockStreamSenderonHasStreamWindowUpdateCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamSenderonHasStreamWindowUpdateCall) DoAndReturn(f func(protocol.StreamID, receiveStreamI)) *MockStreamSenderonHasStreamWindowUpdateCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // onStreamCompleted mocks base method. func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) { m.ctrl.T.Helper() @@ -147,39 +146,3 @@ func (c *MockStreamSenderonStreamCompletedCall) DoAndReturn(f func(protocol.Stre c.Call = c.Call.DoAndReturn(f) return c } - -// queueControlFrame mocks base method. -func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "queueControlFrame", arg0) -} - -// queueControlFrame indicates an expected call of queueControlFrame. -func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 any) *MockStreamSenderqueueControlFrameCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0) - return &MockStreamSenderqueueControlFrameCall{Call: call} -} - -// MockStreamSenderqueueControlFrameCall wrap *gomock.Call -type MockStreamSenderqueueControlFrameCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStreamSenderqueueControlFrameCall) Return() *MockStreamSenderqueueControlFrameCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStreamSenderqueueControlFrameCall) Do(f func(wire.Frame)) *MockStreamSenderqueueControlFrameCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStreamSenderqueueControlFrameCall) DoAndReturn(f func(wire.Frame)) *MockStreamSenderqueueControlFrameCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/mockgen.go b/mockgen.go index cf0d6ad21..65ec465aa 100644 --- a/mockgen.go +++ b/mockgen.go @@ -23,6 +23,9 @@ type SendStreamI = sendStreamI //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender" type StreamSender = streamSender +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_control_frame_getter_test.go github.com/quic-go/quic-go StreamControlFrameGetter" +type StreamControlFrameGetter = streamControlFrameGetter + //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource" type FrameSource = frameSource diff --git a/receive_stream.go b/receive_stream.go index 8f5764b01..803409235 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" @@ -19,7 +20,6 @@ type receiveStreamI interface { handleStreamFrame(*wire.StreamFrame) error handleResetStreamFrame(*wire.ResetStreamFrame) error closeForShutdown(error) - getWindowUpdate() protocol.ByteCount } type receiveStream struct { @@ -37,6 +37,9 @@ type receiveStream struct { readPosInFrame int currentFrameIsLast bool // is the currentFrame the last frame on this stream + queuedStopSending bool + queuedMaxStreamData bool + // Set once we read the io.EOF or the cancellation error. // Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet. errorRead bool @@ -54,8 +57,9 @@ type receiveStream struct { } var ( - _ ReceiveStream = &receiveStream{} - _ receiveStreamI = &receiveStream{} + _ ReceiveStream = &receiveStream{} + _ receiveStreamI = &receiveStream{} + _ streamControlFrameGetter = &receiveStream{} ) func newReceiveStream( @@ -87,13 +91,16 @@ func (s *receiveStream) Read(p []byte) (int, error) { defer func() { <-s.readOnce }() s.mutex.Lock() - n, err := s.readImpl(p) + queuedNewControlFrame, n, err := s.readImpl(p) completed := s.isNewlyCompleted() s.mutex.Unlock() if completed { s.sender.onStreamCompleted(s.streamID) } + if queuedNewControlFrame { + s.sender.onHasStreamControlFrame(s.streamID, s) + } return n, err } @@ -118,19 +125,20 @@ func (s *receiveStream) isNewlyCompleted() bool { return false } -func (s *receiveStream) readImpl(p []byte) (int, error) { +func (s *receiveStream) readImpl(p []byte) (bool, int, error) { if s.currentFrameIsLast && s.currentFrame == nil { s.errorRead = true - return 0, io.EOF + return false, 0, io.EOF } if s.cancelledRemotely || s.cancelledLocally { s.errorRead = true - return 0, s.cancelErr + return false, 0, s.cancelErr } if s.closeForShutdownErr != nil { - return 0, s.closeForShutdownErr + return false, 0, s.closeForShutdownErr } + var queuedNewControlFrame bool var bytesRead int var deadlineTimer *utils.Timer for bytesRead < len(p) { @@ -138,23 +146,23 @@ func (s *receiveStream) readImpl(p []byte) (int, error) { s.dequeueNextFrame() } if s.currentFrame == nil && bytesRead > 0 { - return bytesRead, s.closeForShutdownErr + return queuedNewControlFrame, bytesRead, s.closeForShutdownErr } for { // Stop waiting on errors if s.closeForShutdownErr != nil { - return bytesRead, s.closeForShutdownErr + return queuedNewControlFrame, bytesRead, s.closeForShutdownErr } if s.cancelledRemotely || s.cancelledLocally { s.errorRead = true - return 0, s.cancelErr + return queuedNewControlFrame, 0, s.cancelErr } deadline := s.deadline if !deadline.IsZero() { if !time.Now().Before(deadline) { - return bytesRead, errDeadline + return queuedNewControlFrame, bytesRead, errDeadline } if deadlineTimer == nil { deadlineTimer = utils.NewTimer() @@ -184,10 +192,10 @@ func (s *receiveStream) readImpl(p []byte) (int, error) { } if bytesRead > len(p) { - return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) + return queuedNewControlFrame, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) } if s.readPosInFrame > len(s.currentFrame) { - return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) + return queuedNewControlFrame, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) } m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:]) @@ -197,8 +205,9 @@ func (s *receiveStream) readImpl(p []byte) (int, error) { // when a RESET_STREAM was received, the flow controller was already // informed about the final byteOffset for this stream if !s.cancelledRemotely { - if queueWindowUpdate := s.flowController.AddBytesRead(protocol.ByteCount(m)); queueWindowUpdate { - s.sender.onHasStreamWindowUpdate(s.streamID, s) + if queueMaxStreamData := s.flowController.AddBytesRead(protocol.ByteCount(m)); queueMaxStreamData { + s.queuedMaxStreamData = true + queuedNewControlFrame = true } } @@ -208,10 +217,10 @@ func (s *receiveStream) readImpl(p []byte) (int, error) { s.currentFrameDone() } s.errorRead = true - return bytesRead, io.EOF + return queuedNewControlFrame, bytesRead, io.EOF } } - return bytesRead, nil + return queuedNewControlFrame, bytesRead, nil } func (s *receiveStream) dequeueNextFrame() { @@ -227,30 +236,31 @@ func (s *receiveStream) dequeueNextFrame() { func (s *receiveStream) CancelRead(errorCode StreamErrorCode) { s.mutex.Lock() - s.cancelReadImpl(errorCode) + queuedNewControlFrame := s.cancelReadImpl(errorCode) completed := s.isNewlyCompleted() s.mutex.Unlock() + if queuedNewControlFrame { + s.sender.onHasStreamControlFrame(s.streamID, s) + } if completed { s.flowController.Abandon() s.sender.onStreamCompleted(s.streamID) } } -func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) { +func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) (queuedNewControlFrame bool) { if s.cancelledLocally { // duplicate call to CancelRead - return + return false } s.cancelledLocally = true if s.errorRead || s.cancelledRemotely { - return + return false } + s.queuedStopSending = true s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false} s.signalRead() - s.sender.queueControlFrame(&wire.StopSendingFrame{ - StreamID: s.streamID, - ErrorCode: errorCode, - }) + return true } func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { @@ -320,6 +330,26 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) return nil } +func (s *receiveStream) getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if !s.queuedStopSending && !s.queuedMaxStreamData { + return ackhandler.Frame{}, false, false + } + if s.queuedStopSending { + s.queuedStopSending = false + return ackhandler.Frame{ + Frame: &wire.StopSendingFrame{StreamID: s.streamID, ErrorCode: s.cancelErr.ErrorCode}, + }, true, s.queuedMaxStreamData + } + + s.queuedMaxStreamData = false + return ackhandler.Frame{ + Frame: &wire.MaxStreamDataFrame{StreamID: s.streamID, MaximumStreamData: s.flowController.GetWindowUpdate()}, + }, true, false +} + func (s *receiveStream) SetReadDeadline(t time.Time) error { s.mutex.Lock() s.deadline = t @@ -338,10 +368,6 @@ func (s *receiveStream) closeForShutdown(err error) { s.signalRead() } -func (s *receiveStream) getWindowUpdate() protocol.ByteCount { - return s.flowController.GetWindowUpdate() -} - // signalRead performs a non-blocking send on the readChan func (s *receiveStream) signalRead() { select { diff --git a/receive_stream_test.go b/receive_stream_test.go index 605ec1476..a01931aad 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -49,8 +49,7 @@ var _ = Describe("Receive Stream", func() { Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } - err := str.handleStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) + Expect(str.handleStreamFrame(&frame)).To(Succeed()) b := make([]byte, 4) n, err := strWithTimeout.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -66,8 +65,7 @@ var _ = Describe("Receive Stream", func() { Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } - err := str.handleStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) + Expect(str.handleStreamFrame(&frame)).To(Succeed()) b := make([]byte, 2) n, err := strWithTimeout.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -79,6 +77,25 @@ var _ = Describe("Receive Stream", func() { Expect(b).To(Equal([]byte{0xBE, 0xEF})) }) + It("queues a flow control update", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)).Return(true) + frame := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xde, 0xad, 0xbe, 0xef}, + } + Expect(str.handleStreamFrame(&frame)).To(Succeed()) + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + n, err := strWithTimeout.Read(make([]byte, 3)) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(1337)) + f, ok, hasMore := str.getControlFrame() + Expect(ok).To(BeTrue()) + Expect(f.Frame).To(Equal(&wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337})) + Expect(hasMore).To(BeFalse()) + }) + It("reads all data available", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) @@ -477,7 +494,7 @@ var _ = Describe("Receive Stream", func() { Context("stream cancellations", func() { Context("canceling read", func() { It("unblocks Read", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -495,7 +512,7 @@ var _ = Describe("Receive Stream", func() { }) It("doesn't allow further calls to Read", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) str.CancelRead(1234) _, err := strWithTimeout.Read([]byte{0}) Expect(err).To(Equal(&StreamError{ @@ -506,7 +523,7 @@ var _ = Describe("Receive Stream", func() { }) It("does nothing when CancelRead is called twice", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) str.CancelRead(1234) str.CancelRead(1234) _, err := strWithTimeout.Read([]byte{0}) @@ -518,11 +535,15 @@ var _ = Describe("Receive Stream", func() { }) It("queues a STOP_SENDING frame", func() { - mockSender.EXPECT().queueControlFrame(&wire.StopSendingFrame{ + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) + str.CancelRead(1234) + f, ok, hasMore := str.getControlFrame() + Expect(ok).To(BeTrue()) + Expect(f.Frame).To(Equal(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 1234, - }) - str.CancelRead(1234) + })) + Expect(hasMore).To(BeFalse()) }) It("doesn't send a STOP_SENDING frame, if the FIN was already read", func() { @@ -568,7 +589,7 @@ var _ = Describe("Receive Stream", func() { Fin: true, })).To(Succeed()) mockFC.EXPECT().Abandon() - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) mockSender.EXPECT().onStreamCompleted(streamID) str.CancelRead(1234) // read the error @@ -578,7 +599,7 @@ var _ = Describe("Receive Stream", func() { }) It("completes the stream when receiving the Fin after the stream was canceled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) str.CancelRead(1234) gomock.InOrder( mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true), @@ -592,7 +613,7 @@ var _ = Describe("Receive Stream", func() { }) It("handles duplicate FinBits after the stream was canceled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) str.CancelRead(1234) gomock.InOrder( mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true), @@ -692,7 +713,7 @@ var _ = Describe("Receive Stream", func() { It("handles RESET_STREAM after CancelRead", func() { mockFC.EXPECT().Abandon() - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) str.CancelRead(1234) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true) mockSender.EXPECT().onStreamCompleted(streamID) @@ -708,21 +729,14 @@ var _ = Describe("Receive Stream", func() { }) }) - Context("flow control", func() { - It("errors when a STREAM frame causes a flow control violation", func() { - testErr := errors.New("flow control violation") - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false).Return(testErr) - frame := wire.StreamFrame{ - Offset: 2, - Data: []byte("foobar"), - } - err := str.handleStreamFrame(&frame) - Expect(err).To(MatchError(testErr)) - }) - - It("gets a window update", func() { - mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x100)) - Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100))) - }) + It("errors when a STREAM frame causes a flow control violation", func() { + testErr := errors.New("flow control violation") + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false).Return(testErr) + frame := wire.StreamFrame{ + Offset: 2, + Data: []byte("foobar"), + } + err := str.handleStreamFrame(&frame) + Expect(err).To(MatchError(testErr)) }) }) diff --git a/send_stream.go b/send_stream.go index c89424e41..a7bb2cd6a 100644 --- a/send_stream.go +++ b/send_stream.go @@ -37,9 +37,12 @@ type sendStream struct { writeOffset protocol.ByteCount - cancelWriteErr error + cancelWriteErr *StreamError closeForShutdownErr error + queuedResetStreamFrame bool + queuedBlockedFrame bool + finishedWriting bool // set once Close() is called finSent bool // set when a STREAM_FRAME with FIN bit has been sent // Set when the application knows about the cancellation. @@ -59,8 +62,9 @@ type sendStream struct { } var ( - _ SendStream = &sendStream{} - _ sendStreamI = &sendStream{} + _ SendStream = &sendStream{} + _ sendStreamI = &sendStream{} + _ streamControlFrameGetter = &sendStream{} ) func newSendStream( @@ -215,12 +219,15 @@ func (s *sendStream) canBufferStreamFrame() bool { // maxBytes is the maximum length this frame (including frame header) will have. func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) { s.mutex.Lock() - f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v) + f, hasMoreData, queuedControlFrame := s.popNewOrRetransmittedStreamFrame(maxBytes, v) if f != nil { s.numOutstandingFrames++ } s.mutex.Unlock() + if queuedControlFrame { + s.sender.onHasStreamControlFrame(s.streamID, s) + } if f == nil { return ackhandler.StreamFrame{}, false, hasMoreData } @@ -230,20 +237,20 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers }, true, hasMoreData } -func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) { +func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, hasMoreData, queuedControlFrame bool) { if s.cancelWriteErr != nil || s.closeForShutdownErr != nil { - return nil, false + return nil, false, false } if len(s.retransmissionQueue) > 0 { f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v) if f != nil || hasMoreRetransmissions { if f == nil { - return nil, true + return nil, true, false } // We always claim that we have more data to send. // This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future. - return f, true + return f, true, false } } @@ -255,21 +262,18 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun Offset: s.writeOffset, DataLenPresent: true, Fin: true, - }, false + }, false, false } - return nil, false + return nil, false, false } sendWindow := s.flowController.SendWindowSize() if sendWindow == 0 { if s.flowController.IsNewlyBlocked() { - s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{ - StreamID: s.streamID, - MaximumStreamData: s.writeOffset, - }) - return nil, false + s.queuedBlockedFrame = true + return nil, false, true } - return nil, true + return nil, true, false } f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v) @@ -281,7 +285,7 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun if f.Fin { s.finSent = true } - return f, hasMoreData + return f, hasMoreData, false } func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) { @@ -442,14 +446,11 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool s.numOutstandingFrames = 0 s.retransmissionQueue = nil newlyCompleted := s.isNewlyCompleted() + s.queuedResetStreamFrame = true s.mutex.Unlock() s.signalWrite() - s.sender.queueControlFrame(&wire.ResetStreamFrame{ - StreamID: s.streamID, - FinalSize: s.writeOffset, - ErrorCode: errorCode, - }) + s.sender.onHasStreamControlFrame(s.streamID, s) if newlyCompleted { s.sender.onStreamCompleted(s.streamID) } @@ -472,6 +473,26 @@ func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { s.cancelWriteImpl(frame.ErrorCode, true) } +func (s *sendStream) getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if !s.queuedBlockedFrame && !s.queuedResetStreamFrame { + return ackhandler.Frame{}, false, false + } + if s.queuedBlockedFrame { + s.queuedBlockedFrame = false + return ackhandler.Frame{ + Frame: &wire.StreamDataBlockedFrame{StreamID: s.streamID, MaximumStreamData: s.writeOffset}, + }, true, s.queuedResetStreamFrame + } + // RESET_STREAM frame + s.queuedResetStreamFrame = false + return ackhandler.Frame{ + Frame: &wire.ResetStreamFrame{StreamID: s.streamID, FinalSize: s.writeOffset, ErrorCode: s.cancelWriteErr.ErrorCode}, + }, true, false +} + func (s *sendStream) Context() context.Context { return s.ctx } diff --git a/send_stream_test.go b/send_stream_test.go index 4c8df1c4e..6279e9ffa 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -328,10 +328,6 @@ var _ = Describe("Send Stream", func() { mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(0)) mockFC.EXPECT().IsNewlyBlocked().Return(true) - mockSender.EXPECT().queueControlFrame(&wire.StreamDataBlockedFrame{ - StreamID: streamID, - MaximumStreamData: 3, - }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -345,9 +341,17 @@ var _ = Describe("Send Stream", func() { Expect(ok).To(BeTrue()) Expect(hasMoreData).To(BeTrue()) Expect(f.Frame.Data).To(HaveLen(3)) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) _, ok, hasMoreData = str.popStreamFrame(1000, protocol.Version1) Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) + cf, ok, hasMore := str.getControlFrame() + Expect(ok).To(BeTrue()) + Expect(cf.Frame).To(Equal(&wire.StreamDataBlockedFrame{ + StreamID: streamID, + MaximumStreamData: 3, + })) + Expect(hasMore).To(BeFalse()) // make the Write go routine return str.closeForShutdown(nil) Eventually(done).Should(BeClosed()) @@ -686,22 +690,26 @@ var _ = Describe("Send Stream", func() { Context("canceling writing", func() { It("queues a RESET_STREAM frame", func() { gomock.InOrder( - mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{ - StreamID: streamID, - FinalSize: 1234, - ErrorCode: 9876, - }), + mockSender.EXPECT().onHasStreamControlFrame(streamID, gomock.Any()), mockSender.EXPECT().onStreamCompleted(streamID), ) str.writeOffset = 1234 str.CancelWrite(9876) + cf, ok, hasMore := str.getControlFrame() + Expect(ok).To(BeTrue()) + Expect(cf.Frame).To(Equal(&wire.ResetStreamFrame{ + StreamID: streamID, + FinalSize: 1234, + ErrorCode: 9876, + })) + Expect(hasMore).To(BeFalse()) }) // This test is inherently racy, as it tests a concurrent call to Write() and CancelRead(). // A single successful run of this test therefore doesn't mean a lot, // for reliable results it has to be run many times. It("returns a nil error when the whole slice has been sent out", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).MaxTimes(1) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()).MaxTimes(1) mockSender.EXPECT().onHasStreamData(streamID, str).MaxTimes(1) mockSender.EXPECT().onStreamCompleted(streamID).MaxTimes(1) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).MaxTimes(1) @@ -724,7 +732,7 @@ var _ = Describe("Send Stream", func() { }) It("unblocks Write", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) mockSender.EXPECT().onHasStreamData(streamID, str) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) @@ -752,7 +760,7 @@ var _ = Describe("Send Stream", func() { }) It("doesn't pop STREAM frames after being canceled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) mockSender.EXPECT().onHasStreamData(streamID, str) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) @@ -776,7 +784,7 @@ var _ = Describe("Send Stream", func() { }) It("doesn't pop STREAM frames after being canceled, for large writes", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) mockSender.EXPECT().onHasStreamData(streamID, str) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) @@ -805,7 +813,7 @@ var _ = Describe("Send Stream", func() { }) It("ignores acknowledgements for STREAM frames after it was cancelled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) mockSender.EXPECT().onHasStreamData(streamID, str) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) @@ -826,7 +834,7 @@ var _ = Describe("Send Stream", func() { }) It("cancels the context", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) mockSender.EXPECT().onStreamCompleted(gomock.Any()) Expect(str.Context().Done()).ToNot(BeClosed()) str.CancelWrite(1234) @@ -836,7 +844,7 @@ var _ = Describe("Send Stream", func() { }) It("doesn't allow further calls to Write", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) mockSender.EXPECT().onStreamCompleted(gomock.Any()) str.CancelWrite(1234) _, err := strWithTimeout.Write([]byte("foobar")) @@ -848,43 +856,54 @@ var _ = Describe("Send Stream", func() { }) It("only cancels once", func() { - mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 1234}) + mockSender.EXPECT().onHasStreamControlFrame(streamID, gomock.Any()) mockSender.EXPECT().onStreamCompleted(gomock.Any()) str.CancelWrite(1234) str.CancelWrite(4321) + cf, ok, hasMore := str.getControlFrame() + Expect(ok).To(BeTrue()) + Expect(cf.Frame).To(Equal(&wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 1234})) + Expect(hasMore).To(BeFalse()) + cf, ok, hasMore = str.getControlFrame() + Expect(ok).To(BeFalse()) + Expect(cf.Frame).To(BeNil()) + Expect(hasMore).To(BeFalse()) }) It("queues a RESET_STREAM frame, even if the stream was already closed", func() { mockSender.EXPECT().onHasStreamData(streamID, str) - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f).To(BeAssignableToTypeOf(&wire.ResetStreamFrame{})) - }) + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) mockSender.EXPECT().onStreamCompleted(gomock.Any()) Expect(str.Close()).To(Succeed()) // don't EXPECT any calls to queueControlFrame str.CancelWrite(123) + f, ok, hasMore := str.getControlFrame() + Expect(ok).To(BeTrue()) + Expect(f.Frame).To(BeAssignableToTypeOf(&wire.ResetStreamFrame{})) + Expect(hasMore).To(BeFalse()) }) }) Context("receiving STOP_SENDING frames", func() { It("queues a RESET_STREAM frames, and copies the error code from the STOP_SENDING frame", func() { - mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{ - StreamID: streamID, - ErrorCode: 101, - }) + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) // Don't EXPECT calls to onStreamCompleted. // The application needs to learn about the cancellation first. str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 101, }) + f, ok, hasMore := str.getControlFrame() + Expect(ok).To(BeTrue()) + Expect(f.Frame).To(Equal(&wire.ResetStreamFrame{ + StreamID: streamID, + ErrorCode: 101, + })) + Expect(hasMore).To(BeFalse()) }) It("discards the stream when CancelWrite is called after receiving STOP_SENDING", func() { - mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{ - StreamID: streamID, - ErrorCode: 101, - }) + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 101, @@ -896,7 +915,7 @@ var _ = Describe("Send Stream", func() { It("unblocks Write", func() { mockSender.EXPECT().onHasStreamData(streamID, str) - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -918,7 +937,7 @@ var _ = Describe("Send Stream", func() { }) It("doesn't allow further calls to Write", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 123, @@ -933,7 +952,7 @@ var _ = Describe("Send Stream", func() { }) It("handles Close after STOP_SENDING", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()) str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 123, @@ -948,7 +967,7 @@ var _ = Describe("Send Stream", func() { _, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) Expect(ok).To(BeTrue()) gomock.InOrder( - mockSender.EXPECT().queueControlFrame(gomock.Any()), + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()), mockSender.EXPECT().onStreamCompleted(gomock.Any()), ) str.handleStopSendingFrame(&wire.StopSendingFrame{ @@ -961,7 +980,7 @@ var _ = Describe("Send Stream", func() { mockSender.EXPECT().onHasStreamData(gomock.Any(), gomock.Any()) str.Close() gomock.InOrder( - mockSender.EXPECT().queueControlFrame(gomock.Any()), + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()), mockSender.EXPECT().onStreamCompleted(gomock.Any()), ) str.handleStopSendingFrame(&wire.StopSendingFrame{ @@ -1075,7 +1094,7 @@ var _ = Describe("Send Stream", func() { Eventually(done).Should(BeClosed()) Expect(f).ToNot(BeNil()) gomock.InOrder( - mockSender.EXPECT().queueControlFrame(gomock.Any()), + mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()), mockSender.EXPECT().onStreamCompleted(streamID), ) str.CancelWrite(9876) diff --git a/stream.go b/stream.go index 673956ed8..1ed263233 100644 --- a/stream.go +++ b/stream.go @@ -24,9 +24,8 @@ var errDeadline net.Error = &deadlineError{} // The streamSender is notified by the stream about various events. type streamSender interface { - queueControlFrame(wire.Frame) onHasStreamData(protocol.StreamID, sendStreamI) - onHasStreamWindowUpdate(protocol.StreamID, receiveStreamI) + onHasStreamControlFrame(protocol.StreamID, streamControlFrameGetter) // must be called without holding the mutex that is acquired by closeForShutdown onStreamCompleted(protocol.StreamID) } @@ -35,14 +34,17 @@ type streamSender interface { // This is necessary in order to keep track when both halves have been completed. type uniStreamSender struct { streamSender - onStreamCompletedImpl func() + onStreamCompletedImpl func() + onHasStreamControlFrameImpl func(protocol.StreamID, streamControlFrameGetter) } -func (s *uniStreamSender) queueControlFrame(f wire.Frame) { s.streamSender.queueControlFrame(f) } func (s *uniStreamSender) onHasStreamData(id protocol.StreamID, str sendStreamI) { s.streamSender.onHasStreamData(id, str) } func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { s.onStreamCompletedImpl() } +func (s *uniStreamSender) onHasStreamControlFrame(id protocol.StreamID, str streamControlFrameGetter) { + s.onHasStreamControlFrameImpl(id, str) +} var _ streamSender = &uniStreamSender{} @@ -52,7 +54,6 @@ type streamI interface { // for receiving handleStreamFrame(*wire.StreamFrame) error handleResetStreamFrame(*wire.ResetStreamFrame) error - getWindowUpdate() protocol.ByteCount // for sending hasData() bool handleStopSendingFrame(*wire.StopSendingFrame) @@ -78,7 +79,10 @@ type stream struct { sendStreamCompleted bool } -var _ Stream = &stream{} +var ( + _ Stream = &stream{} + _ streamControlFrameGetter = &receiveStream{} +) // newStream creates a new Stream func newStream( @@ -96,6 +100,9 @@ func newStream( s.checkIfCompleted() s.completedMutex.Unlock() }, + onHasStreamControlFrameImpl: func(id protocol.StreamID, str streamControlFrameGetter) { + sender.onHasStreamControlFrame(streamID, s) + }, } s.sendStream = *newSendStream(ctx, streamID, senderForSendStream, flowController) senderForReceiveStream := &uniStreamSender{ @@ -106,6 +113,9 @@ func newStream( s.checkIfCompleted() s.completedMutex.Unlock() }, + onHasStreamControlFrameImpl: func(id protocol.StreamID, str streamControlFrameGetter) { + sender.onHasStreamControlFrame(streamID, s) + }, } s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController) return s @@ -121,6 +131,14 @@ func (s *stream) Close() error { return s.sendStream.Close() } +func (s *stream) getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) { + f, ok, _ := s.sendStream.getControlFrame() + if ok { + return f, true, true + } + return s.receiveStream.getControlFrame() +} + func (s *stream) SetDeadline(t time.Time) error { _ = s.SetReadDeadline(t) // SetReadDeadline never errors _ = s.SetWriteDeadline(t) // SetWriteDeadline never errors diff --git a/streams_map.go b/streams_map.go index 041636c34..0ce91287b 100644 --- a/streams_map.go +++ b/streams_map.go @@ -62,6 +62,7 @@ type streamsMap struct { maxIncomingUniStreams uint64 sender streamSender + queueControlFrame func(wire.Frame) newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController mutex sync.Mutex @@ -77,14 +78,16 @@ var _ streamManager = &streamsMap{} func newStreamsMap( ctx context.Context, sender streamSender, + queueControlFrame func(wire.Frame), newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, maxIncomingBidiStreams uint64, maxIncomingUniStreams uint64, perspective protocol.Perspective, -) streamManager { +) *streamsMap { m := &streamsMap{ ctx: ctx, perspective: perspective, + queueControlFrame: queueControlFrame, newFlowController: newFlowController, maxIncomingBidiStreams: maxIncomingBidiStreams, maxIncomingUniStreams: maxIncomingUniStreams, @@ -101,7 +104,7 @@ func (m *streamsMap) initMaps() { id := num.StreamID(protocol.StreamTypeBidi, m.perspective) return newStream(m.ctx, id, m.sender, m.newFlowController(id)) }, - m.sender.queueControlFrame, + m.queueControlFrame, ) m.incomingBidiStreams = newIncomingStreamsMap( protocol.StreamTypeBidi, @@ -110,7 +113,7 @@ func (m *streamsMap) initMaps() { return newStream(m.ctx, id, m.sender, m.newFlowController(id)) }, m.maxIncomingBidiStreams, - m.sender.queueControlFrame, + m.queueControlFrame, ) m.outgoingUniStreams = newOutgoingStreamsMap( protocol.StreamTypeUni, @@ -118,7 +121,7 @@ func (m *streamsMap) initMaps() { id := num.StreamID(protocol.StreamTypeUni, m.perspective) return newSendStream(m.ctx, id, m.sender, m.newFlowController(id)) }, - m.sender.queueControlFrame, + m.queueControlFrame, ) m.incomingUniStreams = newIncomingStreamsMap( protocol.StreamTypeUni, @@ -127,7 +130,7 @@ func (m *streamsMap) initMaps() { return newReceiveStream(id, m.sender, m.newFlowController(id)) }, m.maxIncomingUniStreams, - m.sender.queueControlFrame, + m.queueControlFrame, ) } diff --git a/streams_map_incoming_test.go b/streams_map_incoming_test.go index b5abba518..04fca4f52 100644 --- a/streams_map_incoming_test.go +++ b/streams_map_incoming_test.go @@ -12,7 +12,6 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "go.uber.org/mock/gomock" ) type mockGenericStream struct { @@ -34,10 +33,10 @@ func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) { var _ = Describe("Streams Map (incoming)", func() { var ( - m *incomingStreamsMap[*mockGenericStream] - newItemCounter int - mockSender *MockStreamSender - maxNumStreams uint64 + m *incomingStreamsMap[*mockGenericStream] + newItemCounter int + maxNumStreams uint64 + queuedControlFrames []wire.Frame ) streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeUni}[rand.Intn(2)] @@ -53,8 +52,8 @@ var _ = Describe("Streams Map (incoming)", func() { BeforeEach(func() { maxNumStreams = 5 }) JustBeforeEach(func() { + queuedControlFrames = []wire.Frame{} newItemCounter = 0 - mockSender = NewMockStreamSender(mockCtrl) m = newIncomingStreamsMap( streamType, func(num protocol.StreamNum) *mockGenericStream { @@ -62,7 +61,7 @@ var _ = Describe("Streams Map (incoming)", func() { return &mockGenericStream{num: num} }, maxNumStreams, - mockSender.queueControlFrame, + func(f wire.Frame) { queuedControlFrames = append(queuedControlFrames, f) }, ) }) @@ -171,7 +170,6 @@ var _ = Describe("Streams Map (incoming)", func() { }) It("deletes streams", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) _, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) str, err := m.AcceptStream(context.Background()) @@ -191,7 +189,6 @@ var _ = Describe("Streams Map (incoming)", func() { Expect(err).ToNot(HaveOccurred()) Expect(str.num).To(Equal(protocol.StreamNum(1))) // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued - mockSender.EXPECT().queueControlFrame(gomock.Any()) str, err = m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.num).To(Equal(protocol.StreamNum(2))) @@ -206,7 +203,6 @@ var _ = Describe("Streams Map (incoming)", func() { Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued - mockSender.EXPECT().queueControlFrame(gomock.Any()) str, err = m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) @@ -227,18 +223,17 @@ var _ = Describe("Streams Map (incoming)", func() { _, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) } - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - msf := f.(*wire.MaxStreamsFrame) - Expect(msf.Type).To(BeEquivalentTo(streamType)) - Expect(msf.MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1))) - checkFrameSerialization(f) - }) + Expect(queuedControlFrames).To(BeEmpty()) Expect(m.DeleteStream(3)).To(Succeed()) - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2))) - checkFrameSerialization(f) - }) + Expect(queuedControlFrames).To(HaveLen(1)) + msf := queuedControlFrames[0].(*wire.MaxStreamsFrame) + Expect(msf.Type).To(BeEquivalentTo(streamType)) + Expect(msf.MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1))) + checkFrameSerialization(msf) Expect(m.DeleteStream(4)).To(Succeed()) + Expect(queuedControlFrames).To(HaveLen(2)) + Expect(queuedControlFrames[1].(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2))) + checkFrameSerialization(queuedControlFrames[1]) }) Context("using high stream limits", func() { @@ -253,19 +248,19 @@ var _ = Describe("Streams Map (incoming)", func() { _, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) } - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1)) - checkFrameSerialization(f) - }) + Expect(queuedControlFrames).To(BeEmpty()) Expect(m.DeleteStream(4)).To(Succeed()) - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount)) - checkFrameSerialization(f) - }) + Expect(queuedControlFrames).To(HaveLen(1)) + Expect(queuedControlFrames[0].(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1)) + checkFrameSerialization(queuedControlFrames[0]) Expect(m.DeleteStream(3)).To(Succeed()) + Expect(queuedControlFrames).To(HaveLen(2)) + Expect(queuedControlFrames[1].(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount)) + checkFrameSerialization(queuedControlFrames[1]) // at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent Expect(m.DeleteStream(2)).To(Succeed()) Expect(m.DeleteStream(1)).To(Succeed()) + Expect(queuedControlFrames).To(HaveLen(2)) }) }) diff --git a/streams_map_outgoing_test.go b/streams_map_outgoing_test.go index 3ae337177..7c5c4af12 100644 --- a/streams_map_outgoing_test.go +++ b/streams_map_outgoing_test.go @@ -15,14 +15,13 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "go.uber.org/mock/gomock" ) var _ = Describe("Streams Map (outgoing)", func() { var ( - m *outgoingStreamsMap[*mockGenericStream] - newStr func(num protocol.StreamNum) *mockGenericStream - mockSender *MockStreamSender + m *outgoingStreamsMap[*mockGenericStream] + newStr func(num protocol.StreamNum) *mockGenericStream + queuedControlFrames []wire.Frame ) const streamType = 42 @@ -37,11 +36,15 @@ var _ = Describe("Streams Map (outgoing)", func() { } BeforeEach(func() { + queuedControlFrames = []wire.Frame{} newStr = func(num protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: num} } - mockSender = NewMockStreamSender(mockCtrl) - m = newOutgoingStreamsMap[*mockGenericStream](streamType, newStr, mockSender.queueControlFrame) + m = newOutgoingStreamsMap[*mockGenericStream]( + streamType, + newStr, + func(f wire.Frame) { queuedControlFrames = append(queuedControlFrames, f) }, + ) }) Context("no stream ID limit", func() { @@ -130,7 +133,6 @@ var _ = Describe("Streams Map (outgoing)", func() { Context("with stream ID limits", func() { It("errors when no stream can be opened immediately", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) _, err := m.OpenStream() expectTooManyStreamsError(err) }) @@ -143,7 +145,6 @@ var _ = Describe("Streams Map (outgoing)", func() { }) It("blocks until a stream can be opened synchronously", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -159,7 +160,6 @@ var _ = Describe("Streams Map (outgoing)", func() { }) It("unblocks when the context is canceled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) ctx, cancel := context.WithCancel(context.Background()) done := make(chan struct{}) go func() { @@ -181,7 +181,6 @@ var _ = Describe("Streams Map (outgoing)", func() { }) It("opens streams in the right order", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() done1 := make(chan struct{}) go func() { defer GinkgoRecover() @@ -210,7 +209,6 @@ var _ = Describe("Streams Map (outgoing)", func() { }) It("opens streams in the right order, when one of the contexts is canceled", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() done1 := make(chan struct{}) go func() { defer GinkgoRecover() @@ -249,7 +247,6 @@ var _ = Describe("Streams Map (outgoing)", func() { }) It("unblocks multiple OpenStreamSync calls at the same time", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -282,7 +279,6 @@ var _ = Describe("Streams Map (outgoing)", func() { }) It("returns an error for OpenStream while an OpenStreamSync call is blocking", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).MaxTimes(2) openedSync := make(chan struct{}) go func() { defer GinkgoRecover() @@ -322,7 +318,6 @@ var _ = Describe("Streams Map (outgoing)", func() { }) It("stops opening synchronously when it is closed", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) testErr := errors.New("test error") done := make(chan struct{}) go func() { @@ -355,33 +350,31 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(err).ToNot(HaveOccurred()) } - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - bf := f.(*wire.StreamsBlockedFrame) - Expect(bf.Type).To(BeEquivalentTo(streamType)) - Expect(bf.StreamLimit).To(BeEquivalentTo(6)) - }) + Expect(queuedControlFrames).To(BeEmpty()) _, err := m.OpenStream() Expect(err).To(MatchError(&StreamLimitReachedError{})) + Expect(queuedControlFrames).To(HaveLen(1)) + bf := queuedControlFrames[0].(*wire.StreamsBlockedFrame) + Expect(bf.Type).To(BeEquivalentTo(streamType)) + Expect(bf.StreamLimit).To(BeEquivalentTo(6)) }) It("only sends one STREAMS_BLOCKED frame for one stream ID", func() { m.SetMaxStream(1) - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1)) - }) _, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) + Expect(queuedControlFrames).To(BeEmpty()) // try to open a stream twice, but expect only one STREAMS_BLOCKED to be sent _, err = m.OpenStream() expectTooManyStreamsError(err) + Expect(queuedControlFrames).To(HaveLen(1)) + Expect(queuedControlFrames[0].(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1)) _, err = m.OpenStream() expectTooManyStreamsError(err) + Expect(queuedControlFrames).To(HaveLen(1)) }) It("queues a STREAMS_BLOCKED frame when there more streams waiting for OpenStreamSync than MAX_STREAMS allows", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(0)) - }) done := make(chan struct{}, 2) go func() { defer GinkgoRecover() @@ -396,13 +389,14 @@ var _ = Describe("Streams Map (outgoing)", func() { done <- struct{}{} }() waitForEnqueued(2) + Expect(queuedControlFrames).To(HaveLen(1)) + Expect(queuedControlFrames[0].(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(0)) - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1)) - }) m.SetMaxStream(1) Eventually(done).Should(Receive()) Consistently(done).ShouldNot(Receive()) + Expect(queuedControlFrames).To(HaveLen(2)) + Expect(queuedControlFrames[1].(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1)) m.SetMaxStream(2) Eventually(done).Should(Receive()) }) @@ -414,10 +408,6 @@ var _ = Describe("Streams Map (outgoing)", func() { const n = 100 fmt.Fprintf(GinkgoWriter, "Opening %d streams concurrently.\n", n) - var blockedAt []protocol.StreamNum - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit) - }).AnyTimes() done := make(map[int]chan struct{}) for i := 1; i <= n; i++ { c := make(chan struct{}) @@ -456,6 +446,10 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(str.num).To(Equal(protocol.StreamNum(n + 1))) } } + var blockedAt []protocol.StreamNum + for _, f := range queuedControlFrames { + blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit) + } Expect(blockedAt).To(Equal(limits)) }) @@ -464,11 +458,6 @@ var _ = Describe("Streams Map (outgoing)", func() { const n = 100 fmt.Fprintf(GinkgoWriter, "Opening %d streams concurrently.\n", n) - var blockedAt []protocol.StreamNum - mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit) - }).AnyTimes() - ctx, cancel := context.WithCancel(context.Background()) streamsToCancel := make(map[protocol.StreamNum]struct{}) // used as a set for i := 0; i < 10; i++ { @@ -537,6 +526,10 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(streamIDs[i]).To(Equal(i + 1)) } } + var blockedAt []protocol.StreamNum + for _, f := range queuedControlFrames { + blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit) + } Expect(blockedAt).To(Equal(limits)) }) }) diff --git a/streams_map_test.go b/streams_map_test.go index 6e1f1c4e0..bbd29fabb 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -14,7 +14,6 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "go.uber.org/mock/gomock" ) func (e streamError) TestError() error { @@ -72,8 +71,9 @@ var _ = Describe("Streams Map", func() { Context(perspective.String(), func() { var ( - m *streamsMap - mockSender *MockStreamSender + m *streamsMap + mockSender *MockStreamSender + queuedControlFrames []wire.Frame ) const ( @@ -89,8 +89,17 @@ var _ = Describe("Streams Map", func() { } BeforeEach(func() { + queuedControlFrames = []wire.Frame{} mockSender = NewMockStreamSender(mockCtrl) - m = newStreamsMap(context.Background(), mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective).(*streamsMap) + m = newStreamsMap( + context.Background(), + mockSender, + func(f wire.Frame) { queuedControlFrames = append(queuedControlFrames, f) }, + newFlowController, + MaxBidiStreamNum, + MaxUniStreamNum, + perspective, + ) }) Context("opening", func() { @@ -140,10 +149,7 @@ var _ = Describe("Streams Map", func() { }) Context("deleting", func() { - BeforeEach(func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() - allowUnlimitedStreams() - }) + BeforeEach(func() { allowUnlimitedStreams() }) It("deletes outgoing bidirectional streams", func() { id := ids.firstOutgoingBidiStream @@ -339,7 +345,6 @@ var _ = Describe("Streams Map", func() { }) It("processes the parameter for outgoing streams", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) _, err := m.OpenStream() expectTooManyStreamsError(err) m.UpdateLimits(&wire.TransportParameters{ @@ -347,7 +352,6 @@ var _ = Describe("Streams Map", func() { MaxUniStreamNum: 8, }) - mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2) // test we can only 5 bidirectional streams for i := 0; i < 5; i++ { str, err := m.OpenStream() @@ -364,6 +368,7 @@ var _ = Describe("Streams Map", func() { } _, err = m.OpenUniStream() expectTooManyStreamsError(err) + Expect(queuedControlFrames).To(HaveLen(3)) }) if perspective == protocol.PerspectiveClient { @@ -399,10 +404,6 @@ var _ = Describe("Streams Map", func() { } Context("handling MAX_STREAMS frames", func() { - BeforeEach(func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() - }) - It("processes IDs for outgoing bidirectional streams", func() { _, err := m.OpenStream() expectTooManyStreamsError(err) @@ -438,11 +439,13 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) _, err = m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: MaxBidiStreamNum + 1, - }) Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed()) + Expect(queuedControlFrames).To(Equal([]wire.Frame{ + &wire.MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: MaxBidiStreamNum + 1, + }, + })) }) It("sends a MAX_STREAMS frame for unidirectional streams", func() { @@ -450,11 +453,13 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) _, err = m.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeUni, - MaxStreamNum: MaxUniStreamNum + 1, - }) Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed()) + Expect(queuedControlFrames).To(Equal([]wire.Frame{ + &wire.MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreamNum: MaxUniStreamNum + 1, + }, + })) }) }) @@ -477,7 +482,6 @@ var _ = Describe("Streams Map", func() { if perspective == protocol.PerspectiveClient { It("resets for 0-RTT", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() m.ResetFor0RTT() // make sure that calls to open / accept streams fail _, err := m.OpenStream() diff --git a/window_update_queue.go b/window_update_queue.go deleted file mode 100644 index 4ae79f90e..000000000 --- a/window_update_queue.go +++ /dev/null @@ -1,62 +0,0 @@ -package quic - -import ( - "sync" - - "github.com/quic-go/quic-go/internal/flowcontrol" - "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/wire" -) - -type windowUpdateQueue struct { - mutex sync.Mutex - - queue map[protocol.StreamID]receiveStreamI - - connFlowController flowcontrol.ConnectionFlowController - callback func(wire.Frame) -} - -func newWindowUpdateQueue( - connFC flowcontrol.ConnectionFlowController, - cb func(wire.Frame), -) *windowUpdateQueue { - return &windowUpdateQueue{ - queue: make(map[protocol.StreamID]receiveStreamI), - connFlowController: connFC, - callback: cb, - } -} - -func (q *windowUpdateQueue) AddStream(id protocol.StreamID, str receiveStreamI) { - q.mutex.Lock() - q.queue[id] = str - q.mutex.Unlock() -} - -func (q *windowUpdateQueue) RemoveStream(id protocol.StreamID) { - q.mutex.Lock() - delete(q.queue, id) - q.mutex.Unlock() -} - -func (q *windowUpdateQueue) QueueAll() { - q.mutex.Lock() - // queue a connection-level window update - if offset := q.connFlowController.GetWindowUpdate(); offset > 0 { - q.callback(&wire.MaxDataFrame{MaximumData: offset}) - } - // queue all stream-level window updates - for id, str := range q.queue { - delete(q.queue, id) - offset := str.getWindowUpdate() - if offset == 0 { // can happen if we received a final offset, right after queueing the window update - continue - } - q.callback(&wire.MaxStreamDataFrame{ - StreamID: id, - MaximumStreamData: offset, - }) - } - q.mutex.Unlock() -} diff --git a/window_update_queue_test.go b/window_update_queue_test.go deleted file mode 100644 index 7f9400e96..000000000 --- a/window_update_queue_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package quic - -import ( - "github.com/quic-go/quic-go/internal/mocks" - "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/wire" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("Window Update Queue", func() { - var ( - q *windowUpdateQueue - connFC *mocks.MockConnectionFlowController - queuedFrames []wire.Frame - ) - - BeforeEach(func() { - connFC = mocks.NewMockConnectionFlowController(mockCtrl) - queuedFrames = queuedFrames[:0] - q = newWindowUpdateQueue(connFC, func(f wire.Frame) { - queuedFrames = append(queuedFrames, f) - }) - }) - - It("adds stream offsets and gets MAX_STREAM_DATA frames", func() { - connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes() - stream1 := NewMockStreamI(mockCtrl) - stream1.EXPECT().getWindowUpdate().Return(protocol.ByteCount(10)) - stream3 := NewMockStreamI(mockCtrl) - stream3.EXPECT().getWindowUpdate().Return(protocol.ByteCount(30)) - q.AddStream(3, stream3) - q.AddStream(1, stream1) - q.QueueAll() - Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 1, MaximumStreamData: 10})) - Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 3, MaximumStreamData: 30})) - }) - - It("deletes the entry after getting the MAX_STREAM_DATA frame", func() { - connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes() - stream10 := NewMockStreamI(mockCtrl) - stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(100)) - q.AddStream(10, stream10) - q.QueueAll() - Expect(queuedFrames).To(HaveLen(1)) - q.QueueAll() - Expect(queuedFrames).To(HaveLen(1)) - }) - - It("doesn't queue a MAX_STREAM_DATA for a closed stream", func() { - connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes() - stream12 := NewMockStreamI(mockCtrl) - q.AddStream(12, stream12) - q.RemoveStream(12) - q.QueueAll() - Expect(queuedFrames).To(BeEmpty()) - }) - - It("removes closed streams from the queue", func() { - connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes() - stream12 := NewMockStreamI(mockCtrl) - q.AddStream(12, stream12) - q.RemoveStream(12) - q.QueueAll() - Expect(queuedFrames).To(BeEmpty()) - }) - - It("doesn't queue a MAX_STREAM_DATA if the flow controller returns an offset of 0", func() { - connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)) - stream5 := NewMockStreamI(mockCtrl) - stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0)) - q.AddStream(5, stream5) - q.QueueAll() - Expect(queuedFrames).To(BeEmpty()) - }) - - It("removes streams for which the flow controller returns an offset of 0 from the queue", func() { - connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).AnyTimes() - stream5 := NewMockStreamI(mockCtrl) - stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0)) - q.AddStream(5, stream5) - q.QueueAll() - Expect(queuedFrames).To(BeEmpty()) - // don't EXPECT any further calls to GetOrOpenReveiveStream and to getWindowUpdate - q.QueueAll() - Expect(queuedFrames).To(BeEmpty()) - }) - - It("queues MAX_DATA frames", func() { - connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1337)) - q.QueueAll() - Expect(queuedFrames).To(Equal([]wire.Frame{&wire.MaxDataFrame{MaximumData: 0x1337}})) - }) - - It("deduplicates", func() { - connFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)) - stream10 := NewMockStreamI(mockCtrl) - stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(200)) - q.AddStream(10, stream10) - q.AddStream(10, stream10) - q.QueueAll() - Expect(queuedFrames).To(Equal([]wire.Frame{&wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 200}})) - }) -})