diff --git a/connection.go b/connection.go index b891b1b1d..ec3a91bfe 100644 --- a/connection.go +++ b/connection.go @@ -1342,6 +1342,7 @@ func (s *connection) handleFrame( s.handleMaxStreamsFrame(frame) case *wire.DataBlockedFrame: case *wire.StreamDataBlockedFrame: + err = s.handleStreamDataBlockedFrame(frame) case *wire.StreamsBlockedFrame: case *wire.StopSendingFrame: err = s.handleStopSendingFrame(frame) @@ -1477,6 +1478,13 @@ func (s *connection) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) er return nil } +func (s *connection) handleStreamDataBlockedFrame(frame *wire.StreamDataBlockedFrame) error { + // We don't need to do anything in response to a STREAM_DATA_BLOCKED frame, + // but we need to make sure that the stream ID is valid. + _, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) + return err +} + func (s *connection) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) { s.streamsMap.HandleMaxStreamsFrame(frame) } diff --git a/connection_test.go b/connection_test.go index 47b9e1cae..b39f382eb 100644 --- a/connection_test.go +++ b/connection_test.go @@ -223,6 +223,7 @@ func TestConnectionHandleReceiveStreamFrames(t *testing.T) { str.EXPECT().handleResetStreamFrame(rsf, now) require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)) // STREAM_DATA_BLOCKED frames are not passed to the stream + streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil) require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)) }) @@ -237,7 +238,7 @@ func TestConnectionHandleReceiveStreamFrames(t *testing.T) { streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil) require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now)) // STREAM_DATA_BLOCKED frames are not passed to the stream - // TODO(#4822): validate stream ID of STREAM_DATA_BLOCKED frames + streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil) require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)) }) @@ -253,8 +254,8 @@ func TestConnectionHandleReceiveStreamFrames(t *testing.T) { streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr) require.ErrorIs(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now), testErr) // STREAM_DATA_BLOCKED frames are not passed to the stream - // TODO(#4822): validate stream ID of STREAM_DATA_BLOCKED frames - require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now)) + streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr) + require.ErrorIs(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now), testErr) }) }