implement receiver side behavior for RESET_STREAM_AT (#5235)

* implement receiver side behavior for RESET_STREAM_AT

* simplify reliable offset tracking
This commit is contained in:
Marten Seemann
2025-06-26 14:42:08 +08:00
committed by GitHub
parent 1b9add1bec
commit b2f24318af
2 changed files with 205 additions and 46 deletions

View File

@@ -42,6 +42,9 @@ type ReceiveStream struct {
cancelErr *StreamError
closeForShutdownErr error
readPos protocol.ByteCount
reliableSize protocol.ByteCount
readChan chan struct{}
readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
deadline time.Time
@@ -128,7 +131,7 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW
s.errorRead = true
return false, false, 0, io.EOF
}
if s.cancelledRemotely || s.cancelledLocally {
if s.cancelledLocally || (s.cancelledRemotely && s.readPos >= s.reliableSize) {
s.errorRead = true
return false, false, 0, s.cancelErr
}
@@ -151,9 +154,9 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW
if s.closeForShutdownErr != nil {
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr
}
if s.cancelledRemotely || s.cancelledLocally {
if s.cancelledLocally || (s.cancelledRemotely && s.readPos >= s.reliableSize) {
s.errorRead = true
return hasStreamWindowUpdate, hasConnWindowUpdate, 0, s.cancelErr
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr
}
deadline := s.deadline
@@ -194,14 +197,11 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW
if s.readPosInFrame > len(s.currentFrame) {
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
}
m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:])
s.readPosInFrame += m
bytesRead += m
// when a RESET_STREAM was received, the flow controller was already
// informed about the final byteOffset for this stream
if !s.cancelledRemotely {
// informed about the final offset for this stream
if !s.cancelledRemotely || s.readPos < s.reliableSize {
hasStream, hasConn := s.flowController.AddBytesRead(protocol.ByteCount(m))
if hasStream {
s.queuedMaxStreamData = true
@@ -212,6 +212,14 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW
}
}
s.readPosInFrame += m
s.readPos += protocol.ByteCount(m)
bytesRead += m
if s.cancelledRemotely && s.readPos >= s.reliableSize {
s.flowController.Abandon()
}
if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
s.currentFrame = nil
if s.currentFrameDone != nil {
@@ -221,6 +229,10 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, io.EOF
}
}
if s.cancelledRemotely && s.readPos >= s.reliableSize {
s.errorRead = true
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr
}
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, nil
}
@@ -231,7 +243,7 @@ func (s *ReceiveStream) dequeueNextFrame() {
s.currentFrameDone()
}
offset, s.currentFrame, s.currentFrameDone = s.frameQueue.Pop()
s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset
s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset && !s.cancelledRemotely
s.readPosInFrame = 0
}
@@ -323,11 +335,19 @@ func (s *ReceiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame,
}
s.finalOffset = frame.FinalSize
// senders are allowed to reduce the reliable size, but frames might have been reordered
if (!s.cancelledRemotely && s.reliableSize == 0) || frame.ReliableSize < s.reliableSize {
s.reliableSize = frame.ReliableSize
}
if s.readPos >= s.reliableSize {
// calling Abandon multiple times is a no-op
s.flowController.Abandon()
}
// ignore duplicate RESET_STREAM frames for this stream (after checking their final offset)
if s.cancelledRemotely {
return nil
}
s.flowController.Abandon()
// don't save the error if the RESET_STREAM frames was received after CancelRead was called
if s.cancelledLocally {
return nil

View File

@@ -447,7 +447,16 @@ func TestReceiveStreamCancellation(t *testing.T) {
require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false})
}
func TestReceiveStreamCancelReadAfterFINReceived(t *testing.T) {
func TestReceiveStreamCancelReadAfterFIN(t *testing.T) {
t.Run("FIN not read", func(t *testing.T) {
testReceiveStreamCancelReadAfterFIN(t, false)
})
t.Run("FIN read", func(t *testing.T) {
testReceiveStreamCancelReadAfterFIN(t, true)
})
}
func testReceiveStreamCancelReadAfterFIN(t *testing.T, finRead bool) {
mockCtrl := gomock.NewController(t)
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
mockSender := NewMockStreamSender(mockCtrl)
@@ -456,46 +465,38 @@ func TestReceiveStreamCancelReadAfterFINReceived(t *testing.T) {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any())
mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now()))
if finRead {
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
n, err := str.Read(make([]byte, 10))
require.ErrorIs(t, err, io.EOF)
require.Equal(t, 6, n)
}
// if the FIN was received, but not read yet, a STOP_SENDING frame is queued
mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), str)
mockFC.EXPECT().Abandon()
if !finRead {
mockFC.EXPECT().Abandon()
mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), str)
}
str.CancelRead(1337)
f, ok, hasMore := str.getControlFrame(time.Now())
require.True(t, ok)
require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1337}, f.Frame)
require.False(t, hasMore)
// if the EOF was already read, no STOP_SENDING frame is queued
if finRead {
require.False(t, ok)
require.False(t, hasMore)
} else {
require.True(t, ok)
require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1337}, f.Frame)
require.False(t, hasMore)
}
// Read returns the error
n, err := str.Read([]byte{0})
require.Zero(t, n)
require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: false})
}
func TestReceiveStreamCancelReadAfterFINRead(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
mockSender := NewMockStreamSender(mockCtrl)
str := newReceiveStream(42, mockSender, mockFC)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any())
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now()))
n, err := str.Read(make([]byte, 10))
require.ErrorIs(t, err, io.EOF)
require.Equal(t, 6, n)
// if the EOF was already read, no STOP_SENDING frame is queued
str.CancelRead(1234)
_, ok, hasMore := str.getControlFrame(time.Now())
require.False(t, ok)
require.False(t, hasMore)
// Read returns the error
n, err = str.Read([]byte{0})
require.Zero(t, n)
require.ErrorIs(t, err, io.EOF)
if finRead {
require.ErrorIs(t, err, io.EOF)
} else {
require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: false})
}
}
func TestReceiveStreamReset(t *testing.T) {
@@ -520,7 +521,7 @@ func TestReceiveStreamReset(t *testing.T) {
mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
gomock.InOrder(
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()),
mockFC.EXPECT().Abandon(),
mockFC.EXPECT().Abandon().MinTimes(1),
)
require.NoError(t, str.handleResetStreamFrame(
&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1234, FinalSize: 42},
@@ -616,14 +617,14 @@ func TestReceiveStreamConcurrentReads(t *testing.T) {
const num = 3
errChan := make(chan error, num)
for i := 0; i < num; i++ {
for range num {
go func() {
_, err := str.Read(make([]byte, 8))
errChan <- err
}()
}
require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now()))
for i := 0; i < num; i++ {
for range num {
select {
case err := <-errChan:
require.ErrorIs(t, err, io.EOF)
@@ -634,3 +635,141 @@ func TestReceiveStreamConcurrentReads(t *testing.T) {
require.Equal(t, protocol.ByteCount(6), bytesRead)
require.Equal(t, int32(1), numCompleted.Load())
}
func TestReceiveStreamResetStreamAtBeforeReadOffset(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
mockSender := NewMockStreamSender(mockCtrl)
str := newReceiveStream(42, mockSender, mockFC)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any())
require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now()))
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3))
b := make([]byte, 3)
n, err := str.Read(b)
require.NoError(t, err)
require.Equal(t, 3, n)
require.Equal(t, []byte("foo"), b)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
mockFC.EXPECT().Abandon()
str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 3}, time.Now())
require.True(t, mockCtrl.Satisfied())
// Read returns the error
mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
n, err = str.Read([]byte{0})
require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true})
require.Zero(t, n)
}
func TestReceiveStreamResetStreamAtAfterReadOffset(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
mockSender := NewMockStreamSender(mockCtrl)
str := newReceiveStream(42, mockSender, mockFC)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any())
require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now()))
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
b := make([]byte, 2)
n, err := str.Read(b)
require.NoError(t, err)
require.Equal(t, 2, n)
require.Equal(t, []byte("fo"), b)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 6}, time.Now())
require.True(t, mockCtrl.Satisfied())
// Read returns the error
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
n, err = str.Read(b)
require.NoError(t, err)
require.Equal(t, 2, n)
require.Equal(t, []byte("ob"), b)
require.True(t, mockCtrl.Satisfied())
gomock.InOrder(
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)),
mockFC.EXPECT().Abandon(),
)
mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
n, err = str.Read(b)
require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true})
require.Equal(t, 2, n)
require.Equal(t, []byte("ar"), b)
}
func TestReceiveStreamMultipleResetStreamAt(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
mockSender := NewMockStreamSender(mockCtrl)
str := newReceiveStream(42, mockSender, mockFC)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any())
require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now()))
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3))
b := make([]byte, 3)
n, err := str.Read(b)
require.NoError(t, err)
require.Equal(t, 3, n)
require.Equal(t, []byte("foo"), b)
require.True(t, mockCtrl.Satisfied())
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 6}, time.Now())
require.True(t, mockCtrl.Satisfied())
// receiving a reordered RESET_STREAM_AT frame has no effect
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 8}, time.Now())
require.True(t, mockCtrl.Satisfied())
// receiving a RESET_STREAM_AT frame with a smaller reliable size is valid
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
mockFC.EXPECT().Abandon()
str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 3}, time.Now())
// Read returns the error
mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
n, err = str.Read(b)
require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true})
require.Zero(t, n)
}
func TestReceiveStreamResetStreamAtAfterResetStream(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
mockSender := NewMockStreamSender(mockCtrl)
str := newReceiveStream(42, mockSender, mockFC)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any())
require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now()))
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3))
b := make([]byte, 3)
n, err := str.Read(b)
require.NoError(t, err)
require.Equal(t, 3, n)
require.Equal(t, []byte("foo"), b)
require.True(t, mockCtrl.Satisfied())
mockFC.EXPECT().Abandon()
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10}, time.Now())
require.True(t, mockCtrl.Satisfied())
// receiving a reordered RESET_STREAM_AT frame has no effect
mockFC.EXPECT().Abandon()
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 8}, time.Now())
require.True(t, mockCtrl.Satisfied())
// Read returns the error
mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
n, err = str.Read(b)
require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true})
require.Zero(t, n)
}