diff --git a/http3/conn.go b/http3/conn.go index 746029e17..a2c8ba61e 100644 --- a/http3/conn.go +++ b/http3/conn.go @@ -120,12 +120,21 @@ func (c *connection) openRequestStream( disableCompression bool, maxHeaderBytes uint64, ) (*requestStream, error) { - c.streamMx.Lock() - maxStreamID := c.maxStreamID - lastStreamID := c.lastStreamID - c.streamMx.Unlock() - if maxStreamID != protocol.InvalidStreamID && lastStreamID >= maxStreamID { - return nil, errGoAway + if c.perspective == protocol.PerspectiveClient { + c.streamMx.Lock() + maxStreamID := c.maxStreamID + var nextStreamID quic.StreamID + if c.lastStreamID == protocol.InvalidStreamID { + nextStreamID = 0 + } else { + nextStreamID = c.lastStreamID + 4 + } + c.streamMx.Unlock() + // Streams with stream ID equal to or greater than the stream ID carried in the GOAWAY frame + // will be rejected, see section 5.2 of RFC 9114. + if maxStreamID != protocol.InvalidStreamID && nextStreamID >= maxStreamID { + return nil, errGoAway + } } str, err := c.OpenStreamSync(ctx) diff --git a/http3/conn_test.go b/http3/conn_test.go index 13e23b342..5096c3a9c 100644 --- a/http3/conn_test.go +++ b/http3/conn_test.go @@ -283,13 +283,13 @@ func testConnGoAway(t *testing.T, withStream bool) { ) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) - b = (&goAwayFrame{StreamID: 4}).Append(b) + b = (&goAwayFrame{StreamID: 8}).Append(b) var mockStr *mockquic.MockStream var str quic.Stream if withStream { mockStr = mockquic.NewMockStream(mockCtrl) - mockStr.EXPECT().StreamID().Return(4).AnyTimes() + mockStr.EXPECT().StreamID().Return(0).AnyTimes() mockStr.EXPECT().Context().Return(context.Background()).AnyTimes() qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(mockStr, nil) s, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000) @@ -327,7 +327,20 @@ func testConnGoAway(t *testing.T, withStream bool) { case <-time.After(scaleDuration(10 * time.Millisecond)): } - _, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000) + // The stream ID in the GOAWAY frame is 8, so it's possible to open stream 4. + mockStr2 := mockquic.NewMockStream(mockCtrl) + mockStr2.EXPECT().StreamID().Return(4).AnyTimes() + mockStr2.EXPECT().Context().Return(context.Background()).AnyTimes() + qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(mockStr2, nil) + str2, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000) + require.NoError(t, err) + mockStr2.EXPECT().Close() + str2.Close() + mockStr2.EXPECT().CancelRead(gomock.Any()) + str2.CancelRead(1337) + + // It's not possible to open stream 8. + _, err = conn.openRequestStream(context.Background(), nil, nil, true, 1000) require.ErrorIs(t, err, errGoAway) mockStr.EXPECT().Close() diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index b04aa47c0..83535d351 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -1188,14 +1188,12 @@ func testHTTPRequestAfterGracefulShutdown(t *testing.T, setGetBody bool) { require.Nil(t, req.GetBody) } + // By increasing the RTT, we make sure that the request is sent before the client receives the GOAWAY frame. inShutdown.Store(true) go server1.Shutdown(context.Background()) go server2.ServeListener(ln) defer server2.Close() - // so that graceful shutdown can actually start - time.Sleep(scaleDuration(10 * time.Millisecond)) - resp, err = cl.Do(req) if !setGetBody { require.ErrorContains(t, err, "after Request.Body was written; define Request.GetBody to avoid this error")