diff --git a/internal/handshake/params_negotiator_base.go b/internal/handshake/params_negotiator_base.go index 8a65d07f..d99d5a55 100644 --- a/internal/handshake/params_negotiator_base.go +++ b/internal/handshake/params_negotiator_base.go @@ -16,7 +16,6 @@ type ParamsNegotiator interface { GetReceiveStreamFlowControlWindow() protocol.ByteCount GetReceiveConnectionFlowControlWindow() protocol.ByteCount GetMaxOutgoingStreams() uint32 - GetMaxIncomingStreams() uint32 // get the idle timeout that was sent by the peer GetRemoteIdleTimeout() time.Duration // determines if the client requests omission of connection IDs. @@ -40,7 +39,6 @@ type paramsNegotiatorBase struct { omitConnectionID bool requestConnectionIDOmission bool - maxStreamsPerConnection uint32 maxIncomingDynamicStreamsPerConnection uint32 idleTimeout time.Duration remoteIdleTimeout time.Duration @@ -59,18 +57,12 @@ func (h *paramsNegotiatorBase) init(params *TransportParameters) { h.idleTimeout = params.IdleTimeout if h.perspective == protocol.PerspectiveServer { - h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective } else { - h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective } } -func (h *paramsNegotiatorBase) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 { - return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection) -} - func (h *paramsNegotiatorBase) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 { return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection) } @@ -109,14 +101,6 @@ func (h *paramsNegotiatorBase) GetMaxOutgoingStreams() uint32 { return h.maxIncomingDynamicStreamsPerConnection } -func (h *paramsNegotiatorBase) GetMaxIncomingStreams() uint32 { - h.mutex.RLock() - defer h.mutex.RUnlock() - - maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection - return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier)) -} - func (h *paramsNegotiatorBase) setRemoteIdleTimeout(t time.Duration) { h.remoteIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, t) } diff --git a/internal/handshake/params_negotiator_gquic.go b/internal/handshake/params_negotiator_gquic.go index 329ddfa9..89b41538 100644 --- a/internal/handshake/params_negotiator_gquic.go +++ b/internal/handshake/params_negotiator_gquic.go @@ -42,13 +42,6 @@ func (h *paramsNegotiatorGQUIC) SetFromMap(params map[Tag][]byte) error { } h.omitConnectionID = (clientValue == 0) } - if value, ok := params[TagMSPC]; ok { - clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return errMalformedTag - } - h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue) - } if value, ok := params[TagMIDS]; ok { clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) if err != nil { @@ -99,8 +92,6 @@ func (h *paramsNegotiatorGQUIC) GetHelloMap() (map[Tag][]byte, error) { utils.LittleEndian.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow())) cfcw := bytes.NewBuffer([]byte{}) utils.LittleEndian.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow())) - mspc := bytes.NewBuffer([]byte{}) - utils.LittleEndian.WriteUint32(mspc, h.maxStreamsPerConnection) mids := bytes.NewBuffer([]byte{}) utils.LittleEndian.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection) icsl := bytes.NewBuffer([]byte{}) @@ -108,7 +99,6 @@ func (h *paramsNegotiatorGQUIC) GetHelloMap() (map[Tag][]byte, error) { return map[Tag][]byte{ TagICSL: icsl.Bytes(), - TagMSPC: mspc.Bytes(), TagMIDS: mids.Bytes(), TagCFCW: cfcw.Bytes(), TagSFCW: sfcw.Bytes(), diff --git a/internal/handshake/params_negotiator_gquic_test.go b/internal/handshake/params_negotiator_gquic_test.go index 252fb698..6f154d66 100644 --- a/internal/handshake/params_negotiator_gquic_test.go +++ b/internal/handshake/params_negotiator_gquic_test.go @@ -40,7 +40,6 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { entryMap, err := pn.GetHelloMap() Expect(err).ToNot(HaveOccurred()) Expect(entryMap).To(HaveKey(TagICSL)) - Expect(entryMap).To(HaveKey(TagMSPC)) Expect(entryMap).To(HaveKey(TagMIDS)) }) @@ -68,16 +67,6 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { Expect(entryMap[TagICSL]).To(Equal([]byte{0xad, 0xfb, 0xca, 0xde})) }) - It("sets the negotiated value for maximum streams in the SHLO", func() { - val := 50 - Expect(val).To(BeNumerically("<", protocol.MaxStreamsPerConnection)) - err := pn.SetFromMap(map[Tag][]byte{TagMSPC: []byte{byte(val), 0, 0, 0}}) - Expect(err).ToNot(HaveOccurred()) - entryMap, err := pn.GetHelloMap() - Expect(err).ToNot(HaveOccurred()) - Expect(entryMap[TagMSPC]).To(Equal([]byte{byte(val), 0, 0, 0})) - }) - It("always sends its own value for the maximum incoming dynamic streams in the SHLO", func() { err := pn.SetFromMap(map[Tag][]byte{TagMIDS: []byte{5, 0, 0, 0}}) Expect(err).ToNot(HaveOccurred()) @@ -98,8 +87,6 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { Expect(err).ToNot(HaveOccurred()) Expect(entryMap).To(HaveKey(TagICSL)) Expect(binary.LittleEndian.Uint32(entryMap[TagICSL])).To(BeEquivalentTo(idleTimeout / time.Second)) - Expect(entryMap).To(HaveKey(TagMSPC)) - Expect(binary.LittleEndian.Uint32(entryMap[TagMSPC])).To(BeEquivalentTo(protocol.MaxStreamsPerConnection)) Expect(entryMap).To(HaveKey(TagMIDS)) Expect(binary.LittleEndian.Uint32(entryMap[TagMIDS])).To(BeEquivalentTo(protocol.MaxIncomingDynamicStreamsPerConnection)) Expect(entryMap).To(HaveKey(TagSFCW)) @@ -223,12 +210,6 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { }) Context("max streams per connection", func() { - It("errors when given an invalid max streams per connection value", func() { - values := map[Tag][]byte{TagMSPC: {2, 0, 0}} // 1 byte too short - err := pn.SetFromMap(values) - Expect(err).To(MatchError(errMalformedTag)) - }) - It("errors when given an invalid max dynamic incoming streams per connection value", func() { values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short err := pn.SetFromMap(values) @@ -240,28 +221,18 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { // this test only works if the value given here is smaller than protocol.MaxStreamsPerConnection err := pn.SetFromMap(map[Tag][]byte{ TagMIDS: {2, 0, 0, 0}, - TagMSPC: {1, 0, 0, 0}, }) Expect(err).ToNot(HaveOccurred()) Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(2))) }) It("uses the the MSPC value, if no MIDS is given", func() { - err := pn.SetFromMap(map[Tag][]byte{TagMIDS: {3, 0, 0, 0}}) + err := pn.SetFromMap(map[Tag][]byte{ + TagMIDS: {3, 0, 0, 0}, + }) Expect(err).ToNot(HaveOccurred()) Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(3))) }) }) - - Context("incoming connections", func() { - It("always uses the constant value, no matter what the client sent", func() { - err := pn.SetFromMap(map[Tag][]byte{ - TagMSPC: {3, 0, 0, 0}, - TagMIDS: {3, 0, 0, 0}, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.GetMaxIncomingStreams()).To(BeNumerically(">", protocol.MaxStreamsPerConnection)) - }) - }) }) }) diff --git a/internal/mocks/params_negotiator.go b/internal/mocks/params_negotiator.go index e95aff9e..361c1d11 100644 --- a/internal/mocks/params_negotiator.go +++ b/internal/mocks/params_negotiator.go @@ -95,18 +95,6 @@ func (mr *MockParamsNegotiatorMockRecorder) GetMaxOutgoingStreams() *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxOutgoingStreams", reflect.TypeOf((*MockParamsNegotiator)(nil).GetMaxOutgoingStreams)) } -// GetMaxIncomingStreams mocks base method -func (m *MockParamsNegotiator) GetMaxIncomingStreams() uint32 { - ret := m.ctrl.Call(m, "GetMaxIncomingStreams") - ret0, _ := ret[0].(uint32) - return ret0 -} - -// GetMaxIncomingStreams indicates an expected call of GetMaxIncomingStreams -func (mr *MockParamsNegotiatorMockRecorder) GetMaxIncomingStreams() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxIncomingStreams", reflect.TypeOf((*MockParamsNegotiator)(nil).GetMaxIncomingStreams)) -} - // GetRemoteIdleTimeout mocks base method func (m *MockParamsNegotiator) GetRemoteIdleTimeout() time.Duration { ret := m.ctrl.Call(m, "GetRemoteIdleTimeout") diff --git a/session_test.go b/session_test.go index 504e38a0..12165de5 100644 --- a/session_test.go +++ b/session_test.go @@ -159,7 +159,6 @@ func (m *mockParamsNegotiator) GetReceiveConnectionFlowControlWindow() protocol. return protocol.ReceiveConnectionFlowControlWindow } func (m *mockParamsNegotiator) GetMaxOutgoingStreams() uint32 { return 100 } -func (m *mockParamsNegotiator) GetMaxIncomingStreams() uint32 { return 100 } func (m *mockParamsNegotiator) GetRemoteIdleTimeout() time.Duration { return time.Hour } func (m *mockParamsNegotiator) OmitConnectionID() bool { return false } @@ -1564,9 +1563,7 @@ var _ = Describe("Session", func() { Context("counting streams", func() { It("errors when too many streams are opened", func() { - mockPn := mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes() - for i := 0; i < 10; i++ { + for i := 0; i < protocol.MaxStreamsPerConnection; i++ { _, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) } @@ -1575,8 +1572,6 @@ var _ = Describe("Session", func() { }) It("does not error when many streams are opened and closed", func() { - mockPn := mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes() for i := 2; i <= 1000; i++ { s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) diff --git a/streams_map.go b/streams_map.go index 16e865e2..2e7dc9ed 100644 --- a/streams_map.go +++ b/streams_map.go @@ -7,6 +7,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" ) @@ -34,6 +35,7 @@ type streamsMap struct { numOutgoingStreams uint32 numIncomingStreams uint32 + maxIncomingStreams uint32 } type streamLambda func(*stream) (bool, error) @@ -43,6 +45,12 @@ type newStreamLambda func(protocol.StreamID) *stream var errMapAccess = errors.New("streamsMap: Error accessing the streams map") func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamCallback, pers protocol.Perspective, connParams handshake.ParamsNegotiator) *streamsMap { + // add some tolerance to the maximum incoming streams value + maxStreams := uint32(protocol.MaxIncomingDynamicStreamsPerConnection) + maxIncomingStreams := utils.MaxUint32( + maxStreams+protocol.MaxStreamsMinimumIncrement, + uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)), + ) sm := streamsMap{ perspective: pers, streams: make(map[protocol.StreamID]*stream), @@ -50,6 +58,7 @@ func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamC newStream: newStream, removeStreamCallback: removeStreamCallback, connParams: connParams, + maxIncomingStreams: maxIncomingStreams, } sm.nextStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex @@ -126,7 +135,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { } func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { - if m.numIncomingStreams >= m.connParams.GetMaxIncomingStreams() { + if m.numIncomingStreams >= m.maxIncomingStreams { return nil, qerr.TooManyOpenStreams } if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { diff --git a/streams_map_test.go b/streams_map_test.go index 23965fa1..dffaa105 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -11,10 +11,7 @@ import ( ) var _ = Describe("Streams Map", func() { - const ( - maxIncomingStreams = 75 - maxOutgoingStreams = 60 - ) + const maxOutgoingStreams = 60 var ( m *streamsMap @@ -23,9 +20,7 @@ var _ = Describe("Streams Map", func() { setNewStreamsMap := func(p protocol.Perspective) { mockPn = mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().GetMaxOutgoingStreams().AnyTimes().Return(uint32(maxOutgoingStreams)) - mockPn.EXPECT().GetMaxIncomingStreams().AnyTimes().Return(uint32(maxIncomingStreams)) newStream := func(id protocol.StreamID) *stream { return newStream(id, func() {}, nil, nil) @@ -113,21 +108,21 @@ var _ = Describe("Streams Map", func() { Context("counting streams", func() { It("errors when too many streams are opened", func() { - for i := 0; i < maxIncomingStreams; i++ { + for i := uint32(0); i < m.maxIncomingStreams; i++ { _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) } - _, err := m.GetOrOpenStream(protocol.StreamID(2*maxIncomingStreams + 3)) + _, err := m.GetOrOpenStream(protocol.StreamID(2*m.maxIncomingStreams + 3)) Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) It("errors when too many streams are opened implicitely", func() { - _, err := m.GetOrOpenStream(protocol.StreamID(maxIncomingStreams*2 + 1)) + _, err := m.GetOrOpenStream(protocol.StreamID(m.maxIncomingStreams*2 + 1)) Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) It("does not error when many streams are opened and closed", func() { - for i := 2; i < 10*maxIncomingStreams; i++ { + for i := uint32(2); i < 10*m.maxIncomingStreams; i++ { str, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) deleteStream(str.StreamID())