From 9cf118fc82f0f2f93ba6e4585c0433db332be984 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 28 Dec 2024 15:38:54 +0800 Subject: [PATCH] fix flaky MITM packet injection test (#4810) Instead of transferring data in one direction, we should echo the data sent. This makes sure that a roughly equal number of packets is sent in each direction. --- integrationtests/self/mitm_test.go | 73 ++++++++++++++++++------------ 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index a5b6acb08..11cc97cf0 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -63,7 +63,7 @@ func TestMITMDuplicatePackets(t *testing.T) { }) } -func TestMITInjectCorruptedPackets(t *testing.T) { +func TestMITMInjectCorruptedPackets(t *testing.T) { t.Run("towards the server", func(t *testing.T) { testMITMInjectCorruptedPackets(t, quicproxy.DirectionIncoming) }) @@ -142,7 +142,7 @@ func testMITMInjectRandomPackets(t *testing.T, direction quicproxy.Direction) { return false } - runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback) + runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback) } func testMITMDuplicatePackets(t *testing.T, direction quicproxy.Direction) { @@ -162,20 +162,19 @@ func testMITMDuplicatePackets(t *testing.T, direction quicproxy.Direction) { return false } - runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback) + runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback) } func testMITMInjectCorruptedPackets(t *testing.T, direction quicproxy.Direction) { serverTransport, clientTransport := getTransportsForMITMTest(t) - rtt := scaleDuration(10 * time.Millisecond) + rtt := scaleDuration(5 * time.Millisecond) var numCorrupted atomic.Int32 - const interval = 4 dropCallback := func(dir quicproxy.Direction, b []byte) bool { if dir != direction { return false } - if rand.Intn(interval) == 0 { + if rand.Intn(4) == 0 { numCorrupted.Add(1) pos := rand.Intn(len(b)) b[pos] = byte(rand.Intn(256)) @@ -190,17 +189,12 @@ func testMITMInjectCorruptedPackets(t *testing.T, direction quicproxy.Direction) return false } - runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback) + runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback) require.NotZero(t, int(numCorrupted.Load())) } -func runMITMTest(t *testing.T, - serverTransport, clientTransport *quic.Transport, - direction quicproxy.Direction, - rtt time.Duration, - dropCallback quicproxy.DropCallback, -) { - ln, err := serverTransport.Listen(getTLSConfig(), getQuicConfig(nil)) +func runMITMTest(t *testing.T, serverTr, clientTr *quic.Transport, rtt time.Duration, dropCb quicproxy.DropCallback) { + ln, err := serverTr.Listen(getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() @@ -208,41 +202,64 @@ func runMITMTest(t *testing.T, proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), DelayPacket: func(dir quicproxy.Direction, b []byte) time.Duration { return rtt / 2 }, - DropPacket: dropCallback, + DropPacket: dropCb, }) require.NoError(t, err) defer proxy.Close() - conn, err := clientTransport.Dial( - context.Background(), + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + conn, err := clientTr.Dial( + ctx, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()}, getTLSClientConfig(), getQuicConfig(nil), ) require.NoError(t, err) + defer conn.CloseWithError(0, "") - serverConn, err := ln.Accept(context.Background()) - require.NoError(t, err) - serverStr, err := serverConn.OpenUniStream() + serverConn, err := ln.Accept(ctx) require.NoError(t, err) + defer serverConn.CloseWithError(0, "") - errChan := make(chan error, 1) + str, err := conn.OpenStreamSync(ctx) + require.NoError(t, err) + clientErrChan := make(chan error, 1) go func() { - defer close(errChan) - if _, err := serverStr.Write(PRData); err != nil { - errChan <- err + _, err := str.Write(PRData) + clientErrChan <- err + str.Close() + }() + + serverStr, err := serverConn.AcceptStream(ctx) + require.NoError(t, err) + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + if _, err := io.Copy(serverStr, serverStr); err != nil { + serverErrChan <- err return } serverStr.Close() }() - require.NoError(t, <-errChan) + require.NoError(t, <-serverErrChan) - str, err := conn.AcceptUniStream(context.Background()) - require.NoError(t, err) data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, PRData, data) - require.NoError(t, conn.CloseWithError(0, "")) + + select { + case err := <-clientErrChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } + select { + case err := <-serverErrChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } } func TestMITMForgedVersionNegotiationPacket(t *testing.T) {