diff --git a/streams_map.go b/streams_map.go index 61a5398e..4ad53a13 100644 --- a/streams_map.go +++ b/streams_map.go @@ -278,7 +278,9 @@ func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) { } func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) { + m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote) m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum) + m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni) m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum) } diff --git a/streams_map_generic_helper.go b/streams_map_generic_helper.go index 692f093e..26b56233 100644 --- a/streams_map_generic_helper.go +++ b/streams_map_generic_helper.go @@ -11,6 +11,7 @@ import ( // This definition must be in a file that Genny doesn't process. type item interface { generic.Type + updateSendWindow(protocol.ByteCount) closeForShutdown(error) } diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index d59045d0..6d83a6b2 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -18,8 +18,9 @@ import ( type mockGenericStream struct { num protocol.StreamNum - closed bool - closeErr error + closed bool + closeErr error + sendWindow protocol.ByteCount } func (s *mockGenericStream) closeForShutdown(err error) { @@ -27,6 +28,10 @@ func (s *mockGenericStream) closeForShutdown(err error) { s.closeErr = err } +func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) { + s.sendWindow = limit +} + var _ = Describe("Streams Map (incoming)", func() { var ( m *incomingItemsMap diff --git a/streams_map_outgoing_bidi.go b/streams_map_outgoing_bidi.go index 72d94211..c14ab981 100644 --- a/streams_map_outgoing_bidi.go +++ b/streams_map_outgoing_bidi.go @@ -180,6 +180,17 @@ func (m *outgoingBidiStreamsMap) SetMaxStream(num protocol.StreamNum) { m.unblockOpenSync() } +// UpdateSendWindow is called when the peer's transport parameters are received. +// Only in the case of a 0-RTT handshake will we have open streams at this point. +// We might need to update the send window, in case the server increased it. +func (m *outgoingBidiStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { + m.mutex.Lock() + for _, str := range m.streams { + str.updateSendWindow(limit) + } + m.mutex.Unlock() +} + // unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream func (m *outgoingBidiStreamsMap) unblockOpenSync() { if len(m.openQueue) == 0 { diff --git a/streams_map_outgoing_generic.go b/streams_map_outgoing_generic.go index 01edd790..2fe89936 100644 --- a/streams_map_outgoing_generic.go +++ b/streams_map_outgoing_generic.go @@ -178,6 +178,17 @@ func (m *outgoingItemsMap) SetMaxStream(num protocol.StreamNum) { m.unblockOpenSync() } +// UpdateSendWindow is called when the peer's transport parameters are received. +// Only in the case of a 0-RTT handshake will we have open streams at this point. +// We might need to update the send window, in case the server increased it. +func (m *outgoingItemsMap) UpdateSendWindow(limit protocol.ByteCount) { + m.mutex.Lock() + for _, str := range m.streams { + str.updateSendWindow(limit) + } + m.mutex.Unlock() +} + // unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream func (m *outgoingItemsMap) unblockOpenSync() { if len(m.openQueue) == 0 { diff --git a/streams_map_outgoing_generic_test.go b/streams_map_outgoing_generic_test.go index a9189718..fb07dd99 100644 --- a/streams_map_outgoing_generic_test.go +++ b/streams_map_outgoing_generic_test.go @@ -112,6 +112,16 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(str2.(*mockGenericStream).closed).To(BeTrue()) Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr)) }) + + It("updates the send window", func() { + str1, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + str2, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + m.UpdateSendWindow(1337) + Expect(str1.(*mockGenericStream).sendWindow).To(BeEquivalentTo(1337)) + Expect(str2.(*mockGenericStream).sendWindow).To(BeEquivalentTo(1337)) + }) }) Context("with stream ID limits", func() { diff --git a/streams_map_outgoing_uni.go b/streams_map_outgoing_uni.go index 3fd1f177..9177c143 100644 --- a/streams_map_outgoing_uni.go +++ b/streams_map_outgoing_uni.go @@ -180,6 +180,17 @@ func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) { m.unblockOpenSync() } +// UpdateSendWindow is called when the peer's transport parameters are received. +// Only in the case of a 0-RTT handshake will we have open streams at this point. +// We might need to update the send window, in case the server increased it. +func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { + m.mutex.Lock() + for _, str := range m.streams { + str.updateSendWindow(limit) + } + m.mutex.Unlock() +} + // unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream func (m *outgoingUniStreamsMap) unblockOpenSync() { if len(m.openQueue) == 0 { diff --git a/streams_map_test.go b/streams_map_test.go index 81e4b128..8ef91fc2 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -348,6 +348,38 @@ var _ = Describe("Streams Map", func() { expectTooManyStreamsError(err) }) + if perspective == protocol.PerspectiveClient { + It("applies parameters to existing streams (needed for 0-RTT)", func() { + m.UpdateLimits(&wire.TransportParameters{ + MaxBidiStreamNum: 1000, + MaxUniStreamNum: 1000, + }) + flowControllers := make(map[protocol.StreamID]*mocks.MockStreamFlowController) + m.newFlowController = func(id protocol.StreamID) flowcontrol.StreamFlowController { + fc := mocks.NewMockStreamFlowController(mockCtrl) + flowControllers[id] = fc + return fc + } + + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + unistr, err := m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + + Expect(flowControllers).To(HaveKey(str.StreamID())) + flowControllers[str.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(4321)) + Expect(flowControllers).To(HaveKey(unistr.StreamID())) + flowControllers[unistr.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(1234)) + + m.UpdateLimits(&wire.TransportParameters{ + MaxBidiStreamNum: 1000, + InitialMaxStreamDataUni: 1234, + MaxUniStreamNum: 1000, + InitialMaxStreamDataBidiRemote: 4321, + }) + }) + } + Context("handling MAX_STREAMS frames", func() { BeforeEach(func() { mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()