forked from quic-go/quic-go
fix flaky MITM packet injection test (#4810)
Instead of transferring data in one direction, we should echo the data sent. This makes sure that a roughly equal number of packets is sent in each direction.
This commit is contained in:
@@ -63,7 +63,7 @@ func TestMITMDuplicatePackets(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestMITInjectCorruptedPackets(t *testing.T) {
|
||||
func TestMITMInjectCorruptedPackets(t *testing.T) {
|
||||
t.Run("towards the server", func(t *testing.T) {
|
||||
testMITMInjectCorruptedPackets(t, quicproxy.DirectionIncoming)
|
||||
})
|
||||
@@ -142,7 +142,7 @@ func testMITMInjectRandomPackets(t *testing.T, direction quicproxy.Direction) {
|
||||
return false
|
||||
}
|
||||
|
||||
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
|
||||
runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback)
|
||||
}
|
||||
|
||||
func testMITMDuplicatePackets(t *testing.T, direction quicproxy.Direction) {
|
||||
@@ -162,20 +162,19 @@ func testMITMDuplicatePackets(t *testing.T, direction quicproxy.Direction) {
|
||||
return false
|
||||
}
|
||||
|
||||
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
|
||||
runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback)
|
||||
}
|
||||
|
||||
func testMITMInjectCorruptedPackets(t *testing.T, direction quicproxy.Direction) {
|
||||
serverTransport, clientTransport := getTransportsForMITMTest(t)
|
||||
rtt := scaleDuration(10 * time.Millisecond)
|
||||
rtt := scaleDuration(5 * time.Millisecond)
|
||||
|
||||
var numCorrupted atomic.Int32
|
||||
const interval = 4
|
||||
dropCallback := func(dir quicproxy.Direction, b []byte) bool {
|
||||
if dir != direction {
|
||||
return false
|
||||
}
|
||||
if rand.Intn(interval) == 0 {
|
||||
if rand.Intn(4) == 0 {
|
||||
numCorrupted.Add(1)
|
||||
pos := rand.Intn(len(b))
|
||||
b[pos] = byte(rand.Intn(256))
|
||||
@@ -190,17 +189,12 @@ func testMITMInjectCorruptedPackets(t *testing.T, direction quicproxy.Direction)
|
||||
return false
|
||||
}
|
||||
|
||||
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
|
||||
runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback)
|
||||
require.NotZero(t, int(numCorrupted.Load()))
|
||||
}
|
||||
|
||||
func runMITMTest(t *testing.T,
|
||||
serverTransport, clientTransport *quic.Transport,
|
||||
direction quicproxy.Direction,
|
||||
rtt time.Duration,
|
||||
dropCallback quicproxy.DropCallback,
|
||||
) {
|
||||
ln, err := serverTransport.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
func runMITMTest(t *testing.T, serverTr, clientTr *quic.Transport, rtt time.Duration, dropCb quicproxy.DropCallback) {
|
||||
ln, err := serverTr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
@@ -208,41 +202,64 @@ func runMITMTest(t *testing.T,
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(dir quicproxy.Direction, b []byte) time.Duration { return rtt / 2 },
|
||||
DropPacket: dropCallback,
|
||||
DropPacket: dropCb,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := clientTransport.Dial(
|
||||
context.Background(),
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
conn, err := clientTr.Dial(
|
||||
ctx,
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
serverConn, err := ln.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
serverStr, err := serverConn.OpenUniStream()
|
||||
serverConn, err := ln.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
defer serverConn.CloseWithError(0, "")
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
str, err := conn.OpenStreamSync(ctx)
|
||||
require.NoError(t, err)
|
||||
clientErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
if _, err := serverStr.Write(PRData); err != nil {
|
||||
errChan <- err
|
||||
_, err := str.Write(PRData)
|
||||
clientErrChan <- err
|
||||
str.Close()
|
||||
}()
|
||||
|
||||
serverStr, err := serverConn.AcceptStream(ctx)
|
||||
require.NoError(t, err)
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrChan)
|
||||
if _, err := io.Copy(serverStr, serverStr); err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
serverStr.Close()
|
||||
}()
|
||||
require.NoError(t, <-errChan)
|
||||
require.NoError(t, <-serverErrChan)
|
||||
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(str)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, PRData, data)
|
||||
require.NoError(t, conn.CloseWithError(0, ""))
|
||||
|
||||
select {
|
||||
case err := <-clientErrChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
select {
|
||||
case err := <-serverErrChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMITMForgedVersionNegotiationPacket(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user