fix flaky MITM packet injection test (#4810)

Instead of transferring data in one direction, we should echo the data
sent. This makes sure that a roughly equal number of packets is sent in
each direction.
This commit is contained in:
Marten Seemann
2024-12-28 15:38:54 +08:00
committed by GitHub
parent 56bebc4e61
commit 9cf118fc82

View File

@@ -63,7 +63,7 @@ func TestMITMDuplicatePackets(t *testing.T) {
})
}
func TestMITInjectCorruptedPackets(t *testing.T) {
func TestMITMInjectCorruptedPackets(t *testing.T) {
t.Run("towards the server", func(t *testing.T) {
testMITMInjectCorruptedPackets(t, quicproxy.DirectionIncoming)
})
@@ -142,7 +142,7 @@ func testMITMInjectRandomPackets(t *testing.T, direction quicproxy.Direction) {
return false
}
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback)
}
func testMITMDuplicatePackets(t *testing.T, direction quicproxy.Direction) {
@@ -162,20 +162,19 @@ func testMITMDuplicatePackets(t *testing.T, direction quicproxy.Direction) {
return false
}
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback)
}
func testMITMInjectCorruptedPackets(t *testing.T, direction quicproxy.Direction) {
serverTransport, clientTransport := getTransportsForMITMTest(t)
rtt := scaleDuration(10 * time.Millisecond)
rtt := scaleDuration(5 * time.Millisecond)
var numCorrupted atomic.Int32
const interval = 4
dropCallback := func(dir quicproxy.Direction, b []byte) bool {
if dir != direction {
return false
}
if rand.Intn(interval) == 0 {
if rand.Intn(4) == 0 {
numCorrupted.Add(1)
pos := rand.Intn(len(b))
b[pos] = byte(rand.Intn(256))
@@ -190,17 +189,12 @@ func testMITMInjectCorruptedPackets(t *testing.T, direction quicproxy.Direction)
return false
}
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback)
require.NotZero(t, int(numCorrupted.Load()))
}
func runMITMTest(t *testing.T,
serverTransport, clientTransport *quic.Transport,
direction quicproxy.Direction,
rtt time.Duration,
dropCallback quicproxy.DropCallback,
) {
ln, err := serverTransport.Listen(getTLSConfig(), getQuicConfig(nil))
func runMITMTest(t *testing.T, serverTr, clientTr *quic.Transport, rtt time.Duration, dropCb quicproxy.DropCallback) {
ln, err := serverTr.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
@@ -208,41 +202,64 @@ func runMITMTest(t *testing.T,
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(dir quicproxy.Direction, b []byte) time.Duration { return rtt / 2 },
DropPacket: dropCallback,
DropPacket: dropCb,
})
require.NoError(t, err)
defer proxy.Close()
conn, err := clientTransport.Dial(
context.Background(),
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn, err := clientTr.Dial(
ctx,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
serverConn, err := ln.Accept(context.Background())
require.NoError(t, err)
serverStr, err := serverConn.OpenUniStream()
serverConn, err := ln.Accept(ctx)
require.NoError(t, err)
defer serverConn.CloseWithError(0, "")
errChan := make(chan error, 1)
str, err := conn.OpenStreamSync(ctx)
require.NoError(t, err)
clientErrChan := make(chan error, 1)
go func() {
defer close(errChan)
if _, err := serverStr.Write(PRData); err != nil {
errChan <- err
_, err := str.Write(PRData)
clientErrChan <- err
str.Close()
}()
serverStr, err := serverConn.AcceptStream(ctx)
require.NoError(t, err)
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
if _, err := io.Copy(serverStr, serverStr); err != nil {
serverErrChan <- err
return
}
serverStr.Close()
}()
require.NoError(t, <-errChan)
require.NoError(t, <-serverErrChan)
str, err := conn.AcceptUniStream(context.Background())
require.NoError(t, err)
data, err := io.ReadAll(str)
require.NoError(t, err)
require.Equal(t, PRData, data)
require.NoError(t, conn.CloseWithError(0, ""))
select {
case err := <-clientErrChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case err := <-serverErrChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestMITMForgedVersionNegotiationPacket(t *testing.T) {