diff --git a/internal/mocks/cpm.go b/internal/mocks/cpm.go new file mode 100644 index 00000000..686928f6 --- /dev/null +++ b/internal/mocks/cpm.go @@ -0,0 +1,153 @@ +// Automatically generated by MockGen. DO NOT EDIT! +// Source: github.com/lucas-clemente/quic-go/handshake (interfaces: ConnectionParametersManager) + +package mocks + +import ( + gomock "github.com/golang/mock/gomock" + handshake "github.com/lucas-clemente/quic-go/handshake" + protocol "github.com/lucas-clemente/quic-go/protocol" + time "time" +) + +// Mock of ConnectionParametersManager interface +type MockConnectionParametersManager struct { + ctrl *gomock.Controller + recorder *_MockConnectionParametersManagerRecorder +} + +// Recorder for MockConnectionParametersManager (not exported) +type _MockConnectionParametersManagerRecorder struct { + mock *MockConnectionParametersManager +} + +func NewMockConnectionParametersManager(ctrl *gomock.Controller) *MockConnectionParametersManager { + mock := &MockConnectionParametersManager{ctrl: ctrl} + mock.recorder = &_MockConnectionParametersManagerRecorder{mock} + return mock +} + +func (_m *MockConnectionParametersManager) EXPECT() *_MockConnectionParametersManagerRecorder { + return _m.recorder +} + +func (_m *MockConnectionParametersManager) GetHelloMap() (map[handshake.Tag][]byte, error) { + ret := _m.ctrl.Call(_m, "GetHelloMap") + ret0, _ := ret[0].(map[handshake.Tag][]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockConnectionParametersManagerRecorder) GetHelloMap() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetHelloMap") +} + +func (_m *MockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { + ret := _m.ctrl.Call(_m, "GetIdleConnectionStateLifetime") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +func (_mr *_MockConnectionParametersManagerRecorder) GetIdleConnectionStateLifetime() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetIdleConnectionStateLifetime") +} + +func (_m *MockConnectionParametersManager) GetMaxIncomingStreams() uint32 { + ret := _m.ctrl.Call(_m, "GetMaxIncomingStreams") + ret0, _ := ret[0].(uint32) + return ret0 +} + +func (_mr *_MockConnectionParametersManagerRecorder) GetMaxIncomingStreams() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxIncomingStreams") +} + +func (_m *MockConnectionParametersManager) GetMaxOutgoingStreams() uint32 { + ret := _m.ctrl.Call(_m, "GetMaxOutgoingStreams") + ret0, _ := ret[0].(uint32) + return ret0 +} + +func (_mr *_MockConnectionParametersManagerRecorder) GetMaxOutgoingStreams() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxOutgoingStreams") +} + +func (_m *MockConnectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetMaxReceiveConnectionFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +func (_mr *_MockConnectionParametersManagerRecorder) GetMaxReceiveConnectionFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxReceiveConnectionFlowControlWindow") +} + +func (_m *MockConnectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetMaxReceiveStreamFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +func (_mr *_MockConnectionParametersManagerRecorder) GetMaxReceiveStreamFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxReceiveStreamFlowControlWindow") +} + +func (_m *MockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetReceiveConnectionFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +func (_mr *_MockConnectionParametersManagerRecorder) GetReceiveConnectionFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveConnectionFlowControlWindow") +} + +func (_m *MockConnectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetReceiveStreamFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +func (_mr *_MockConnectionParametersManagerRecorder) GetReceiveStreamFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveStreamFlowControlWindow") +} + +func (_m *MockConnectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetSendConnectionFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +func (_mr *_MockConnectionParametersManagerRecorder) GetSendConnectionFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetSendConnectionFlowControlWindow") +} + +func (_m *MockConnectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetSendStreamFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +func (_mr *_MockConnectionParametersManagerRecorder) GetSendStreamFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetSendStreamFlowControlWindow") +} + +func (_m *MockConnectionParametersManager) SetFromMap(_param0 map[handshake.Tag][]byte) error { + ret := _m.ctrl.Call(_m, "SetFromMap", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockConnectionParametersManagerRecorder) SetFromMap(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "SetFromMap", arg0) +} + +func (_m *MockConnectionParametersManager) TruncateConnectionID() bool { + ret := _m.ctrl.Call(_m, "TruncateConnectionID") + ret0, _ := ret[0].(bool) + return ret0 +} + +func (_mr *_MockConnectionParametersManagerRecorder) TruncateConnectionID() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "TruncateConnectionID") +} diff --git a/internal/mocks/gen.go b/internal/mocks/gen.go index f87196d0..b0f34d9b 100644 --- a/internal/mocks/gen.go +++ b/internal/mocks/gen.go @@ -1,3 +1,4 @@ package mocks //go:generate mockgen -destination flow_control_manager.go -package mocks github.com/lucas-clemente/quic-go/flowcontrol FlowControlManager +//go:generate mockgen -destination cpm.go -package mocks github.com/lucas-clemente/quic-go/handshake ConnectionParametersManager diff --git a/streams_map_test.go b/streams_map_test.go index 31a89386..b1697143 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/lucas-clemente/quic-go/handshake" + "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" @@ -52,25 +53,28 @@ func (m *mockConnectionParametersManager) TruncateConnectionID() bool { return f var _ handshake.ConnectionParametersManager = &mockConnectionParametersManager{} var _ = Describe("Streams Map", func() { + const ( + maxIncomingStreams = 75 + maxOutgoingStreams = 60 + ) + var ( - cpm handshake.ConnectionParametersManager - m *streamsMap + m *streamsMap + mockCpm *mocks.MockConnectionParametersManager ) setNewStreamsMap := func(p protocol.Perspective) { - m = newStreamsMap(nil, p, cpm) + mockCpm = mocks.NewMockConnectionParametersManager(mockCtrl) + + mockCpm.EXPECT().GetMaxOutgoingStreams().AnyTimes().Return(uint32(maxOutgoingStreams)) + mockCpm.EXPECT().GetMaxIncomingStreams().AnyTimes().Return(uint32(maxIncomingStreams)) + + m = newStreamsMap(nil, p, mockCpm) m.newStream = func(id protocol.StreamID) *stream { return &stream{streamID: id} } } - BeforeEach(func() { - cpm = &mockConnectionParametersManager{ - maxIncomingStreams: 75, - maxOutgoingStreams: 60, - } - }) - AfterEach(func() { Expect(m.openStreams).To(HaveLen(len(m.streams))) }) @@ -106,11 +110,11 @@ var _ = Describe("Streams Map", func() { }) It("returns nil for closed streams", func() { - s, err := m.GetOrOpenStream(5) + _, err := m.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) err = m.RemoveStream(5) Expect(err).NotTo(HaveOccurred()) - s, err = m.GetOrOpenStream(5) + s, err := m.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(s).To(BeNil()) }) @@ -134,28 +138,22 @@ var _ = Describe("Streams Map", func() { }) Context("counting streams", func() { - var maxNumStreams int - - BeforeEach(func() { - maxNumStreams = int(cpm.GetMaxIncomingStreams()) - }) - It("errors when too many streams are opened", func() { - for i := 0; i < maxNumStreams; i++ { + for i := 0; i < maxIncomingStreams; i++ { _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) } - _, err := m.GetOrOpenStream(protocol.StreamID(2*maxNumStreams + 3)) + _, err := m.GetOrOpenStream(protocol.StreamID(2*maxIncomingStreams + 3)) Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) It("errors when too many streams are opened implicitely", func() { - _, err := m.GetOrOpenStream(protocol.StreamID(maxNumStreams*2 + 1)) + _, err := m.GetOrOpenStream(protocol.StreamID(maxIncomingStreams*2 + 1)) Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) It("does not error when many streams are opened and closed", func() { - for i := 2; i < 10*maxNumStreams; i++ { + for i := 2; i < 10*maxIncomingStreams; i++ { _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) m.RemoveStream(protocol.StreamID(i*2 + 1)) @@ -182,14 +180,8 @@ var _ = Describe("Streams Map", func() { }) Context("counting streams", func() { - var maxNumStreams int - - BeforeEach(func() { - maxNumStreams = int(cpm.GetMaxOutgoingStreams()) - }) - It("errors when too many streams are opened", func() { - for i := 1; i <= maxNumStreams; i++ { + for i := 1; i <= maxOutgoingStreams; i++ { _, err := m.OpenStream() Expect(err).NotTo(HaveOccurred()) } @@ -198,7 +190,7 @@ var _ = Describe("Streams Map", func() { }) It("does not error when many streams are opened and closed", func() { - for i := 2; i < 10*maxNumStreams; i++ { + for i := 2; i < 10*maxOutgoingStreams; i++ { str, err := m.OpenStream() Expect(err).NotTo(HaveOccurred()) m.RemoveStream(str.StreamID()) @@ -206,11 +198,11 @@ var _ = Describe("Streams Map", func() { }) It("allows many server- and client-side streams at the same time", func() { - for i := 1; i < int(cpm.GetMaxOutgoingStreams()); i++ { + for i := 1; i < maxOutgoingStreams; i++ { _, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) } - for i := 0; i < int(cpm.GetMaxIncomingStreams()); i++ { + for i := 0; i < maxOutgoingStreams; i++ { _, err := m.GetOrOpenStream(protocol.StreamID(2*i + 1)) Expect(err).ToNot(HaveOccurred()) } @@ -218,14 +210,8 @@ var _ = Describe("Streams Map", func() { }) Context("opening streams synchronously", func() { - var maxNumStreams int - - BeforeEach(func() { - maxNumStreams = int(cpm.GetMaxOutgoingStreams()) - }) - openMaxNumStreams := func() { - for i := 1; i <= maxNumStreams; i++ { + for i := 1; i <= maxOutgoingStreams; i++ { _, err := m.OpenStream() Expect(err).NotTo(HaveOccurred()) } @@ -249,7 +235,7 @@ var _ = Describe("Streams Map", func() { err := m.RemoveStream(6) Expect(err).ToNot(HaveOccurred()) Eventually(func() bool { return returned }).Should(BeTrue()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxNumStreams + 2))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxOutgoingStreams + 2))) }) It("stops waiting when an error is registered", func() {