From b2f24318afa16a366bd2ad414e5ae06c34aca098 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 26 Jun 2025 14:42:08 +0800 Subject: [PATCH] implement receiver side behavior for RESET_STREAM_AT (#5235) * implement receiver side behavior for RESET_STREAM_AT * simplify reliable offset tracking --- receive_stream.go | 40 ++++++-- receive_stream_test.go | 211 ++++++++++++++++++++++++++++++++++------- 2 files changed, 205 insertions(+), 46 deletions(-) diff --git a/receive_stream.go b/receive_stream.go index 3b2d618c..61f4cdf4 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -42,6 +42,9 @@ type ReceiveStream struct { cancelErr *StreamError closeForShutdownErr error + readPos protocol.ByteCount + reliableSize protocol.ByteCount + readChan chan struct{} readOnce chan struct{} // cap: 1, to protect against concurrent use of Read deadline time.Time @@ -128,7 +131,7 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW s.errorRead = true return false, false, 0, io.EOF } - if s.cancelledRemotely || s.cancelledLocally { + if s.cancelledLocally || (s.cancelledRemotely && s.readPos >= s.reliableSize) { s.errorRead = true return false, false, 0, s.cancelErr } @@ -151,9 +154,9 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW if s.closeForShutdownErr != nil { return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr } - if s.cancelledRemotely || s.cancelledLocally { + if s.cancelledLocally || (s.cancelledRemotely && s.readPos >= s.reliableSize) { s.errorRead = true - return hasStreamWindowUpdate, hasConnWindowUpdate, 0, s.cancelErr + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr } deadline := s.deadline @@ -194,14 +197,11 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW if s.readPosInFrame > len(s.currentFrame) { return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) } - m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:]) - s.readPosInFrame += m - bytesRead += m // when a RESET_STREAM was received, the flow controller was already - // informed about the final byteOffset for this stream - if !s.cancelledRemotely { + // informed about the final offset for this stream + if !s.cancelledRemotely || s.readPos < s.reliableSize { hasStream, hasConn := s.flowController.AddBytesRead(protocol.ByteCount(m)) if hasStream { s.queuedMaxStreamData = true @@ -212,6 +212,14 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW } } + s.readPosInFrame += m + s.readPos += protocol.ByteCount(m) + bytesRead += m + + if s.cancelledRemotely && s.readPos >= s.reliableSize { + s.flowController.Abandon() + } + if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast { s.currentFrame = nil if s.currentFrameDone != nil { @@ -221,6 +229,10 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, io.EOF } } + if s.cancelledRemotely && s.readPos >= s.reliableSize { + s.errorRead = true + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr + } return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, nil } @@ -231,7 +243,7 @@ func (s *ReceiveStream) dequeueNextFrame() { s.currentFrameDone() } offset, s.currentFrame, s.currentFrameDone = s.frameQueue.Pop() - s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset + s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset && !s.cancelledRemotely s.readPosInFrame = 0 } @@ -323,11 +335,19 @@ func (s *ReceiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame, } s.finalOffset = frame.FinalSize + // senders are allowed to reduce the reliable size, but frames might have been reordered + if (!s.cancelledRemotely && s.reliableSize == 0) || frame.ReliableSize < s.reliableSize { + s.reliableSize = frame.ReliableSize + } + if s.readPos >= s.reliableSize { + // calling Abandon multiple times is a no-op + s.flowController.Abandon() + } // ignore duplicate RESET_STREAM frames for this stream (after checking their final offset) if s.cancelledRemotely { return nil } - s.flowController.Abandon() + // don't save the error if the RESET_STREAM frames was received after CancelRead was called if s.cancelledLocally { return nil diff --git a/receive_stream_test.go b/receive_stream_test.go index 2b9f38d2..3619fa3c 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -447,7 +447,16 @@ func TestReceiveStreamCancellation(t *testing.T) { require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}) } -func TestReceiveStreamCancelReadAfterFINReceived(t *testing.T) { +func TestReceiveStreamCancelReadAfterFIN(t *testing.T) { + t.Run("FIN not read", func(t *testing.T) { + testReceiveStreamCancelReadAfterFIN(t, false) + }) + t.Run("FIN read", func(t *testing.T) { + testReceiveStreamCancelReadAfterFIN(t, true) + }) +} + +func testReceiveStreamCancelReadAfterFIN(t *testing.T, finRead bool) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) @@ -456,46 +465,38 @@ func TestReceiveStreamCancelReadAfterFINReceived(t *testing.T) { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) + if finRead { + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + n, err := str.Read(make([]byte, 10)) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 6, n) + } // if the FIN was received, but not read yet, a STOP_SENDING frame is queued - mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), str) - mockFC.EXPECT().Abandon() + if !finRead { + mockFC.EXPECT().Abandon() + mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), str) + } str.CancelRead(1337) f, ok, hasMore := str.getControlFrame(time.Now()) - require.True(t, ok) - require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1337}, f.Frame) - require.False(t, hasMore) + // if the EOF was already read, no STOP_SENDING frame is queued + if finRead { + require.False(t, ok) + require.False(t, hasMore) + } else { + require.True(t, ok) + require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1337}, f.Frame) + require.False(t, hasMore) + } // Read returns the error n, err := str.Read([]byte{0}) require.Zero(t, n) - require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: false}) -} - -func TestReceiveStreamCancelReadAfterFINRead(t *testing.T) { - mockCtrl := gomock.NewController(t) - mockFC := mocks.NewMockStreamFlowController(mockCtrl) - mockSender := NewMockStreamSender(mockCtrl) - str := newReceiveStream(42, mockSender, mockFC) - - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) - mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) - require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) - n, err := str.Read(make([]byte, 10)) - require.ErrorIs(t, err, io.EOF) - require.Equal(t, 6, n) - - // if the EOF was already read, no STOP_SENDING frame is queued - str.CancelRead(1234) - _, ok, hasMore := str.getControlFrame(time.Now()) - require.False(t, ok) - require.False(t, hasMore) - - // Read returns the error - n, err = str.Read([]byte{0}) - require.Zero(t, n) - require.ErrorIs(t, err, io.EOF) + if finRead { + require.ErrorIs(t, err, io.EOF) + } else { + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: false}) + } } func TestReceiveStreamReset(t *testing.T) { @@ -520,7 +521,7 @@ func TestReceiveStreamReset(t *testing.T) { mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) gomock.InOrder( mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()), - mockFC.EXPECT().Abandon(), + mockFC.EXPECT().Abandon().MinTimes(1), ) require.NoError(t, str.handleResetStreamFrame( &wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1234, FinalSize: 42}, @@ -616,14 +617,14 @@ func TestReceiveStreamConcurrentReads(t *testing.T) { const num = 3 errChan := make(chan error, num) - for i := 0; i < num; i++ { + for range num { go func() { _, err := str.Read(make([]byte, 8)) errChan <- err }() } require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) - for i := 0; i < num; i++ { + for range num { select { case err := <-errChan: require.ErrorIs(t, err, io.EOF) @@ -634,3 +635,141 @@ func TestReceiveStreamConcurrentReads(t *testing.T) { require.Equal(t, protocol.ByteCount(6), bytesRead) require.Equal(t, int32(1), numCompleted.Load()) } + +func TestReceiveStreamResetStreamAtBeforeReadOffset(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)) + b := make([]byte, 3) + n, err := str.Read(b) + require.NoError(t, err) + require.Equal(t, 3, n) + require.Equal(t, []byte("foo"), b) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + mockFC.EXPECT().Abandon() + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 3}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // Read returns the error + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + n, err = str.Read([]byte{0}) + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) + require.Zero(t, n) +} + +func TestReceiveStreamResetStreamAtAfterReadOffset(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + b := make([]byte, 2) + n, err := str.Read(b) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte("fo"), b) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 6}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // Read returns the error + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + n, err = str.Read(b) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte("ob"), b) + require.True(t, mockCtrl.Satisfied()) + + gomock.InOrder( + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)), + mockFC.EXPECT().Abandon(), + ) + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + n, err = str.Read(b) + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) + require.Equal(t, 2, n) + require.Equal(t, []byte("ar"), b) +} + +func TestReceiveStreamMultipleResetStreamAt(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())) + + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)) + b := make([]byte, 3) + n, err := str.Read(b) + require.NoError(t, err) + require.Equal(t, 3, n) + require.Equal(t, []byte("foo"), b) + require.True(t, mockCtrl.Satisfied()) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 6}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // receiving a reordered RESET_STREAM_AT frame has no effect + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 8}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // receiving a RESET_STREAM_AT frame with a smaller reliable size is valid + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + mockFC.EXPECT().Abandon() + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 3}, time.Now()) + + // Read returns the error + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + n, err = str.Read(b) + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) + require.Zero(t, n) +} + +func TestReceiveStreamResetStreamAtAfterResetStream(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())) + + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)) + b := make([]byte, 3) + n, err := str.Read(b) + require.NoError(t, err) + require.Equal(t, 3, n) + require.Equal(t, []byte("foo"), b) + require.True(t, mockCtrl.Satisfied()) + + mockFC.EXPECT().Abandon() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // receiving a reordered RESET_STREAM_AT frame has no effect + mockFC.EXPECT().Abandon() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 8}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // Read returns the error + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + n, err = str.Read(b) + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) + require.Zero(t, n) +}