From daff6256b948798e6bc9adf798a42fd137dfef5d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 15 Oct 2017 09:23:42 +0800 Subject: [PATCH] don't negotiate the maximum incoming streams value The maximum number of incoming streams doesn't need to be negotiated. It is a value that is only announced by the peer, and has to be respected. Furthermore, Chrome doesn't seem to care about the MSPC value anymore (since MIDS was introduced), so there's no need to send this in the handshake any more. --- internal/handshake/params_negotiator_base.go | 16 --------- internal/handshake/params_negotiator_gquic.go | 10 ------ .../handshake/params_negotiator_gquic_test.go | 35 ++----------------- internal/mocks/params_negotiator.go | 12 ------- session_test.go | 7 +--- streams_map.go | 11 +++++- streams_map_test.go | 15 +++----- 7 files changed, 19 insertions(+), 87 deletions(-) diff --git a/internal/handshake/params_negotiator_base.go b/internal/handshake/params_negotiator_base.go index 8a65d07f..d99d5a55 100644 --- a/internal/handshake/params_negotiator_base.go +++ b/internal/handshake/params_negotiator_base.go @@ -16,7 +16,6 @@ type ParamsNegotiator interface { GetReceiveStreamFlowControlWindow() protocol.ByteCount GetReceiveConnectionFlowControlWindow() protocol.ByteCount GetMaxOutgoingStreams() uint32 - GetMaxIncomingStreams() uint32 // get the idle timeout that was sent by the peer GetRemoteIdleTimeout() time.Duration // determines if the client requests omission of connection IDs. @@ -40,7 +39,6 @@ type paramsNegotiatorBase struct { omitConnectionID bool requestConnectionIDOmission bool - maxStreamsPerConnection uint32 maxIncomingDynamicStreamsPerConnection uint32 idleTimeout time.Duration remoteIdleTimeout time.Duration @@ -59,18 +57,12 @@ func (h *paramsNegotiatorBase) init(params *TransportParameters) { h.idleTimeout = params.IdleTimeout if h.perspective == protocol.PerspectiveServer { - h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective } else { - h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective } } -func (h *paramsNegotiatorBase) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 { - return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection) -} - func (h *paramsNegotiatorBase) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 { return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection) } @@ -109,14 +101,6 @@ func (h *paramsNegotiatorBase) GetMaxOutgoingStreams() uint32 { return h.maxIncomingDynamicStreamsPerConnection } -func (h *paramsNegotiatorBase) GetMaxIncomingStreams() uint32 { - h.mutex.RLock() - defer h.mutex.RUnlock() - - maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection - return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier)) -} - func (h *paramsNegotiatorBase) setRemoteIdleTimeout(t time.Duration) { h.remoteIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, t) } diff --git a/internal/handshake/params_negotiator_gquic.go b/internal/handshake/params_negotiator_gquic.go index 329ddfa9..89b41538 100644 --- a/internal/handshake/params_negotiator_gquic.go +++ b/internal/handshake/params_negotiator_gquic.go @@ -42,13 +42,6 @@ func (h *paramsNegotiatorGQUIC) SetFromMap(params map[Tag][]byte) error { } h.omitConnectionID = (clientValue == 0) } - if value, ok := params[TagMSPC]; ok { - clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return errMalformedTag - } - h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue) - } if value, ok := params[TagMIDS]; ok { clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) if err != nil { @@ -99,8 +92,6 @@ func (h *paramsNegotiatorGQUIC) GetHelloMap() (map[Tag][]byte, error) { utils.LittleEndian.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow())) cfcw := bytes.NewBuffer([]byte{}) utils.LittleEndian.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow())) - mspc := bytes.NewBuffer([]byte{}) - utils.LittleEndian.WriteUint32(mspc, h.maxStreamsPerConnection) mids := bytes.NewBuffer([]byte{}) utils.LittleEndian.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection) icsl := bytes.NewBuffer([]byte{}) @@ -108,7 +99,6 @@ func (h *paramsNegotiatorGQUIC) GetHelloMap() (map[Tag][]byte, error) { return map[Tag][]byte{ TagICSL: icsl.Bytes(), - TagMSPC: mspc.Bytes(), TagMIDS: mids.Bytes(), TagCFCW: cfcw.Bytes(), TagSFCW: sfcw.Bytes(), diff --git a/internal/handshake/params_negotiator_gquic_test.go b/internal/handshake/params_negotiator_gquic_test.go index 252fb698..6f154d66 100644 --- a/internal/handshake/params_negotiator_gquic_test.go +++ b/internal/handshake/params_negotiator_gquic_test.go @@ -40,7 +40,6 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { entryMap, err := pn.GetHelloMap() Expect(err).ToNot(HaveOccurred()) Expect(entryMap).To(HaveKey(TagICSL)) - Expect(entryMap).To(HaveKey(TagMSPC)) Expect(entryMap).To(HaveKey(TagMIDS)) }) @@ -68,16 +67,6 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { Expect(entryMap[TagICSL]).To(Equal([]byte{0xad, 0xfb, 0xca, 0xde})) }) - It("sets the negotiated value for maximum streams in the SHLO", func() { - val := 50 - Expect(val).To(BeNumerically("<", protocol.MaxStreamsPerConnection)) - err := pn.SetFromMap(map[Tag][]byte{TagMSPC: []byte{byte(val), 0, 0, 0}}) - Expect(err).ToNot(HaveOccurred()) - entryMap, err := pn.GetHelloMap() - Expect(err).ToNot(HaveOccurred()) - Expect(entryMap[TagMSPC]).To(Equal([]byte{byte(val), 0, 0, 0})) - }) - It("always sends its own value for the maximum incoming dynamic streams in the SHLO", func() { err := pn.SetFromMap(map[Tag][]byte{TagMIDS: []byte{5, 0, 0, 0}}) Expect(err).ToNot(HaveOccurred()) @@ -98,8 +87,6 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { Expect(err).ToNot(HaveOccurred()) Expect(entryMap).To(HaveKey(TagICSL)) Expect(binary.LittleEndian.Uint32(entryMap[TagICSL])).To(BeEquivalentTo(idleTimeout / time.Second)) - Expect(entryMap).To(HaveKey(TagMSPC)) - Expect(binary.LittleEndian.Uint32(entryMap[TagMSPC])).To(BeEquivalentTo(protocol.MaxStreamsPerConnection)) Expect(entryMap).To(HaveKey(TagMIDS)) Expect(binary.LittleEndian.Uint32(entryMap[TagMIDS])).To(BeEquivalentTo(protocol.MaxIncomingDynamicStreamsPerConnection)) Expect(entryMap).To(HaveKey(TagSFCW)) @@ -223,12 +210,6 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { }) Context("max streams per connection", func() { - It("errors when given an invalid max streams per connection value", func() { - values := map[Tag][]byte{TagMSPC: {2, 0, 0}} // 1 byte too short - err := pn.SetFromMap(values) - Expect(err).To(MatchError(errMalformedTag)) - }) - It("errors when given an invalid max dynamic incoming streams per connection value", func() { values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short err := pn.SetFromMap(values) @@ -240,28 +221,18 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { // this test only works if the value given here is smaller than protocol.MaxStreamsPerConnection err := pn.SetFromMap(map[Tag][]byte{ TagMIDS: {2, 0, 0, 0}, - TagMSPC: {1, 0, 0, 0}, }) Expect(err).ToNot(HaveOccurred()) Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(2))) }) It("uses the the MSPC value, if no MIDS is given", func() { - err := pn.SetFromMap(map[Tag][]byte{TagMIDS: {3, 0, 0, 0}}) + err := pn.SetFromMap(map[Tag][]byte{ + TagMIDS: {3, 0, 0, 0}, + }) Expect(err).ToNot(HaveOccurred()) Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(3))) }) }) - - Context("incoming connections", func() { - It("always uses the constant value, no matter what the client sent", func() { - err := pn.SetFromMap(map[Tag][]byte{ - TagMSPC: {3, 0, 0, 0}, - TagMIDS: {3, 0, 0, 0}, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.GetMaxIncomingStreams()).To(BeNumerically(">", protocol.MaxStreamsPerConnection)) - }) - }) }) }) diff --git a/internal/mocks/params_negotiator.go b/internal/mocks/params_negotiator.go index e95aff9e..361c1d11 100644 --- a/internal/mocks/params_negotiator.go +++ b/internal/mocks/params_negotiator.go @@ -95,18 +95,6 @@ func (mr *MockParamsNegotiatorMockRecorder) GetMaxOutgoingStreams() *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxOutgoingStreams", reflect.TypeOf((*MockParamsNegotiator)(nil).GetMaxOutgoingStreams)) } -// GetMaxIncomingStreams mocks base method -func (m *MockParamsNegotiator) GetMaxIncomingStreams() uint32 { - ret := m.ctrl.Call(m, "GetMaxIncomingStreams") - ret0, _ := ret[0].(uint32) - return ret0 -} - -// GetMaxIncomingStreams indicates an expected call of GetMaxIncomingStreams -func (mr *MockParamsNegotiatorMockRecorder) GetMaxIncomingStreams() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxIncomingStreams", reflect.TypeOf((*MockParamsNegotiator)(nil).GetMaxIncomingStreams)) -} - // GetRemoteIdleTimeout mocks base method func (m *MockParamsNegotiator) GetRemoteIdleTimeout() time.Duration { ret := m.ctrl.Call(m, "GetRemoteIdleTimeout") diff --git a/session_test.go b/session_test.go index 504e38a0..12165de5 100644 --- a/session_test.go +++ b/session_test.go @@ -159,7 +159,6 @@ func (m *mockParamsNegotiator) GetReceiveConnectionFlowControlWindow() protocol. return protocol.ReceiveConnectionFlowControlWindow } func (m *mockParamsNegotiator) GetMaxOutgoingStreams() uint32 { return 100 } -func (m *mockParamsNegotiator) GetMaxIncomingStreams() uint32 { return 100 } func (m *mockParamsNegotiator) GetRemoteIdleTimeout() time.Duration { return time.Hour } func (m *mockParamsNegotiator) OmitConnectionID() bool { return false } @@ -1564,9 +1563,7 @@ var _ = Describe("Session", func() { Context("counting streams", func() { It("errors when too many streams are opened", func() { - mockPn := mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes() - for i := 0; i < 10; i++ { + for i := 0; i < protocol.MaxStreamsPerConnection; i++ { _, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) } @@ -1575,8 +1572,6 @@ var _ = Describe("Session", func() { }) It("does not error when many streams are opened and closed", func() { - mockPn := mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes() for i := 2; i <= 1000; i++ { s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) diff --git a/streams_map.go b/streams_map.go index 16e865e2..2e7dc9ed 100644 --- a/streams_map.go +++ b/streams_map.go @@ -7,6 +7,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" ) @@ -34,6 +35,7 @@ type streamsMap struct { numOutgoingStreams uint32 numIncomingStreams uint32 + maxIncomingStreams uint32 } type streamLambda func(*stream) (bool, error) @@ -43,6 +45,12 @@ type newStreamLambda func(protocol.StreamID) *stream var errMapAccess = errors.New("streamsMap: Error accessing the streams map") func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamCallback, pers protocol.Perspective, connParams handshake.ParamsNegotiator) *streamsMap { + // add some tolerance to the maximum incoming streams value + maxStreams := uint32(protocol.MaxIncomingDynamicStreamsPerConnection) + maxIncomingStreams := utils.MaxUint32( + maxStreams+protocol.MaxStreamsMinimumIncrement, + uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)), + ) sm := streamsMap{ perspective: pers, streams: make(map[protocol.StreamID]*stream), @@ -50,6 +58,7 @@ func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamC newStream: newStream, removeStreamCallback: removeStreamCallback, connParams: connParams, + maxIncomingStreams: maxIncomingStreams, } sm.nextStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex @@ -126,7 +135,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { } func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { - if m.numIncomingStreams >= m.connParams.GetMaxIncomingStreams() { + if m.numIncomingStreams >= m.maxIncomingStreams { return nil, qerr.TooManyOpenStreams } if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { diff --git a/streams_map_test.go b/streams_map_test.go index 23965fa1..dffaa105 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -11,10 +11,7 @@ import ( ) var _ = Describe("Streams Map", func() { - const ( - maxIncomingStreams = 75 - maxOutgoingStreams = 60 - ) + const maxOutgoingStreams = 60 var ( m *streamsMap @@ -23,9 +20,7 @@ var _ = Describe("Streams Map", func() { setNewStreamsMap := func(p protocol.Perspective) { mockPn = mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().GetMaxOutgoingStreams().AnyTimes().Return(uint32(maxOutgoingStreams)) - mockPn.EXPECT().GetMaxIncomingStreams().AnyTimes().Return(uint32(maxIncomingStreams)) newStream := func(id protocol.StreamID) *stream { return newStream(id, func() {}, nil, nil) @@ -113,21 +108,21 @@ var _ = Describe("Streams Map", func() { Context("counting streams", func() { It("errors when too many streams are opened", func() { - for i := 0; i < maxIncomingStreams; i++ { + for i := uint32(0); i < m.maxIncomingStreams; i++ { _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) } - _, err := m.GetOrOpenStream(protocol.StreamID(2*maxIncomingStreams + 3)) + _, err := m.GetOrOpenStream(protocol.StreamID(2*m.maxIncomingStreams + 3)) Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) It("errors when too many streams are opened implicitely", func() { - _, err := m.GetOrOpenStream(protocol.StreamID(maxIncomingStreams*2 + 1)) + _, err := m.GetOrOpenStream(protocol.StreamID(m.maxIncomingStreams*2 + 1)) Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) It("does not error when many streams are opened and closed", func() { - for i := 2; i < 10*maxIncomingStreams; i++ { + for i := uint32(2); i < 10*m.maxIncomingStreams; i++ { str, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) deleteStream(str.StreamID())