From a584e580996490b38983c591d6ab78a4b4f0d11f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 22 Dec 2024 11:36:14 +0800 Subject: [PATCH] flowcontrol: reset the connection send window when 0-RTT is rejected (#4764) --- integrationtests/self/zero_rtt_test.go | 124 ++++++++++++++---- .../flowcontrol/connection_flow_controller.go | 1 + .../connection_flow_controller_test.go | 10 +- 3 files changed, 105 insertions(+), 30 deletions(-) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index c9316aefe..ed86537bb 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "os" "sync" "sync/atomic" "testing" @@ -522,17 +523,24 @@ func Test0RTTWithIncreasedStreamLimit(t *testing.T) { require.NotZero(t, counter.Load()) } -func check0RTTRejected(t *testing.T, ln *quic.EarlyListener, addr net.Addr, conf *tls.Config) quic.Connection { +func check0RTTRejected(t *testing.T, + ln *quic.EarlyListener, + addr net.Addr, + conf *tls.Config, + sendData bool, +) (clientConn, serverConn quic.Connection) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.DialEarly(ctx, newUPDConnLocalhost(t), addr, conf, getQuicConfig(nil)) require.NoError(t, err) require.False(t, conn.ConnectionState().TLS.HandshakeComplete) - str, err := conn.OpenUniStream() - require.NoError(t, err) - _, err = str.Write(make([]byte, 3000)) - require.NoError(t, err) - require.NoError(t, str.Close()) + if sendData { + str, err := conn.OpenUniStream() + require.NoError(t, err) + _, err = str.Write(make([]byte, 3000)) + require.NoError(t, err) + require.NoError(t, str.Close()) + } select { case <-conn.HandshakeComplete(): @@ -544,12 +552,13 @@ func check0RTTRejected(t *testing.T, ln *quic.EarlyListener, addr net.Addr, conf // make sure the server doesn't process the data ctx, cancel = context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) defer cancel() - serverConn, err := ln.Accept(ctx) + serverConn, err = ln.Accept(ctx) require.NoError(t, err) require.False(t, serverConn.ConnectionState().Used0RTT) - _, err = serverConn.AcceptUniStream(ctx) - require.Equal(t, context.DeadlineExceeded, err) - require.NoError(t, serverConn.CloseWithError(0, "")) + if sendData { + _, err = serverConn.AcceptUniStream(ctx) + require.Equal(t, context.DeadlineExceeded, err) + } ctx, cancel = context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -557,15 +566,7 @@ func check0RTTRejected(t *testing.T, ln *quic.EarlyListener, addr net.Addr, conf require.NoError(t, err) require.True(t, nextConn.ConnectionState().TLS.HandshakeComplete) require.False(t, nextConn.ConnectionState().Used0RTT) - t.Cleanup(func() { - nextConn.CloseWithError(0, "") - select { - case <-serverConn.Context().Done(): - case <-time.After(time.Second): - t.Fatal("server connection did not close in time") - } - }) - return nextConn + return nextConn, serverConn } func Test0RTTRejectedOnStreamLimitDecrease(t *testing.T) { @@ -605,13 +606,8 @@ func Test0RTTRejectedOnStreamLimitDecrease(t *testing.T) { defer ln.Close() proxy, num0RTT := runCountingProxyAndCount0RTTPackets(t, ln.Addr().(*net.UDPAddr).Port, rtt) - conn := check0RTTRejected(t, ln, proxy.LocalAddr(), clientConf) - - // The client should send 0-RTT packets, but the server doesn't process them. - n := num0RTT.Load() - t.Logf("sent %d 0-RTT packets", n) - require.NotZero(t, n) - require.Empty(t, counter.getRcvd0RTTPacketNumbers()) + conn, serverConn := check0RTTRejected(t, ln, proxy.LocalAddr(), clientConf, true) + defer conn.CloseWithError(0, "") // It should now be possible to open new bidirectional streams up to the new limit... for i := 0; i < newMaxBidiStreams; i++ { @@ -632,6 +628,68 @@ func Test0RTTRejectedOnStreamLimitDecrease(t *testing.T) { _, err = conn.OpenUniStream() require.Error(t, err) require.Contains(t, err.Error(), "too many open streams") + + serverConn.CloseWithError(0, "") + // The client should send 0-RTT packets, but the server doesn't process them. + n := num0RTT.Load() + t.Logf("sent %d 0-RTT packets", n) + require.NotZero(t, n) + require.Empty(t, counter.getRcvd0RTTPacketNumbers()) +} + +func Test0RTTRejectedOnConnectionWindowDecrease(t *testing.T) { + const rtt = 5 * time.Millisecond + + const ( + connFlowControlWindow = 100 + newConnFlowControlWindow = connFlowControlWindow - 1 + ) + + tlsConf := getTLSConfig() + clientConf := dialAndReceiveTicket(t, + rtt, + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + InitialConnectionReceiveWindow: connFlowControlWindow, + }), + nil, + ) + + ln, err := quic.ListenEarly( + newUPDConnLocalhost(t), + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + InitialConnectionReceiveWindow: newConnFlowControlWindow, + }), + ) + require.NoError(t, err) + defer ln.Close() + proxy, _ := runCountingProxyAndCount0RTTPackets(t, ln.Addr().(*net.UDPAddr).Port, rtt) + + conn, serverConn := check0RTTRejected(t, ln, proxy.LocalAddr(), clientConf, false) + defer conn.CloseWithError(0, "") + defer serverConn.CloseWithError(0, "") + + str, err := conn.OpenStream() + require.NoError(t, err) + str.SetWriteDeadline(time.Now().Add(scaleDuration(50 * time.Millisecond))) + n, err := str.Write(make([]byte, 2000)) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + require.Equal(t, newConnFlowControlWindow, n) + + // make sure that only 99 bytes were received + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + serverStr, err := serverConn.AcceptStream(ctx) + require.NoError(t, err) + serverStr.SetReadDeadline(time.Now().Add(scaleDuration(50 * time.Millisecond))) + n, err = io.ReadFull(serverStr, make([]byte, newConnFlowControlWindow)) + require.NoError(t, err) + require.Equal(t, newConnFlowControlWindow, n) + _, err = serverStr.Read([]byte{0}) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) } func Test0RTTRejectedOnALPNChanged(t *testing.T) { @@ -657,9 +715,13 @@ func Test0RTTRejectedOnALPNChanged(t *testing.T) { defer ln.Close() proxy, num0RTTPackets := runCountingProxyAndCount0RTTPackets(t, ln.Addr().(*net.UDPAddr).Port, rtt) - conn := check0RTTRejected(t, ln, proxy.LocalAddr(), clientConf) + conn, serverConn := check0RTTRejected(t, ln, proxy.LocalAddr(), clientConf, true) + defer conn.CloseWithError(0, "") + defer serverConn.CloseWithError(0, "") + require.Equal(t, "new-alpn", conn.ConnectionState().TLS.NegotiatedProtocol) + serverConn.CloseWithError(0, "") // The client should send 0-RTT packets, but the server doesn't process them. num0RTT := num0RTTPackets.Load() t.Logf("Sent %d 0-RTT packets.", num0RTT) @@ -685,8 +747,10 @@ func Test0RTTRejectedWhenDisabled(t *testing.T) { defer ln.Close() proxy, num0RTTPackets := runCountingProxyAndCount0RTTPackets(t, ln.Addr().(*net.UDPAddr).Port, rtt) - check0RTTRejected(t, ln, proxy.LocalAddr(), clientConf) + conn, serverConn := check0RTTRejected(t, ln, proxy.LocalAddr(), clientConf, true) + defer conn.CloseWithError(0, "") + serverConn.CloseWithError(0, "") // The client should send 0-RTT packets, but the server doesn't process them. num0RTT := num0RTTPackets.Load() t.Logf("Sent %d 0-RTT packets.", num0RTT) @@ -713,9 +777,11 @@ func Test0RTTRejectedOnDatagramsDisabled(t *testing.T) { defer ln.Close() proxy, num0RTTPackets := runCountingProxyAndCount0RTTPackets(t, ln.Addr().(*net.UDPAddr).Port, rtt) - conn := check0RTTRejected(t, ln, proxy.LocalAddr(), clientConf) + conn, serverConn := check0RTTRejected(t, ln, proxy.LocalAddr(), clientConf, true) + defer conn.CloseWithError(0, "") require.False(t, conn.ConnectionState().SupportsDatagrams) + serverConn.CloseWithError(0, "") // The client should send 0-RTT packets, but the server doesn't process them. num0RTT := num0RTTPackets.Load() t.Logf("Sent %d 0-RTT packets.", num0RTT) diff --git a/internal/flowcontrol/connection_flow_controller.go b/internal/flowcontrol/connection_flow_controller.go index 016ffaa96..362a9a956 100644 --- a/internal/flowcontrol/connection_flow_controller.go +++ b/internal/flowcontrol/connection_flow_controller.go @@ -106,5 +106,6 @@ func (c *connectionFlowController) Reset() error { } c.bytesSent = 0 c.lastBlockedAt = 0 + c.sendWindow = 0 return nil } diff --git a/internal/flowcontrol/connection_flow_controller_test.go b/internal/flowcontrol/connection_flow_controller_test.go index 0598d1dcf..1e1288b38 100644 --- a/internal/flowcontrol/connection_flow_controller_test.go +++ b/internal/flowcontrol/connection_flow_controller_test.go @@ -59,8 +59,16 @@ func TestConnectionFlowControlViolation(t *testing.T) { require.Equal(t, qerr.FlowControlError, terr.ErrorCode) } -// TODO (#4732): add a test for successfully resetting the flow controller func TestConnectionFlowControllerReset(t *testing.T) { + fc := NewConnectionFlowController(0, 0, nil, &utils.RTTStats{}, utils.DefaultLogger) + fc.UpdateSendWindow(100) + fc.AddBytesSent(10) + require.Equal(t, protocol.ByteCount(90), fc.SendWindowSize()) + require.NoError(t, fc.Reset()) + require.Zero(t, fc.SendWindowSize()) +} + +func TestConnectionFlowControllerResetAfterReading(t *testing.T) { fc := NewConnectionFlowController(0, 0, nil, &utils.RTTStats{}, utils.DefaultLogger) fc.AddBytesRead(1) require.EqualError(t, fc.Reset(), "flow controller reset after reading data")