From 44db8994c9cdee1de00b45c2faad942ca51e5b26 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 23 Oct 2025 19:32:01 +0200 Subject: [PATCH] use synctest for the datagram test (#5398) --- integrationtests/self/datagram_test.go | 222 ++++++++++--------- integrationtests/self/handshake_drop_test.go | 32 +-- integrationtests/self/simnet_helper_test.go | 16 +- 3 files changed, 142 insertions(+), 128 deletions(-) diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index fdc499f6..fc21e829 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -3,6 +3,7 @@ package self_test import ( "bytes" "context" + "math" mrand "math/rand/v2" "net" "sync/atomic" @@ -10,8 +11,9 @@ import ( "time" "github.com/quic-go/quic-go" - quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" + "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -124,117 +126,129 @@ func TestDatagramSizeLimit(t *testing.T) { } func TestDatagramLoss(t *testing.T) { - const rtt = 10 * time.Millisecond - const numDatagrams = 100 - const datagramSize = 500 + synctest.Test(t, func(t *testing.T) { + const rtt = 100 * time.Millisecond + const numDatagrams = 100 + const datagramSize = 500 - server, err := quic.Listen( - newUDPConnLocalhost(t), - getTLSConfig(), - getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}), - ) - require.NoError(t, err) - defer server.Close() + clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001} + serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002} + var droppedToClient, droppedToServer, total atomic.Int32 + n := &simnet.Simnet{ + Router: &directionAwareDroppingRouter{ + ClientAddr: clientAddr, + ServerAddr: serverAddr, + Drop: func(d direction, p simnet.Packet) bool { + if wire.IsLongHeaderPacket(p.Data[0]) { // don't drop Long Header packets + return false + } + if len(p.Data) < datagramSize { // don't drop ACK-only packets + return false + } + total.Add(1) + // drop about 20% of Short Header packets with DATAGRAM frames + if mrand.Int()%5 == 0 { + switch d { + case directionToClient: + droppedToClient.Add(1) + case directionToServer: + droppedToServer.Add(1) + } + return true + } + return false + }, + }, + } + settings := simnet.NodeBiDiLinkSettings{ + Downlink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4}, + Uplink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4}, + } + clientPacketConn := n.NewEndpoint(clientAddr, settings) + defer clientPacketConn.Close() + serverPacketConn := n.NewEndpoint(serverAddr, settings) + defer serverPacketConn.Close() + require.NoError(t, n.Start()) + defer n.Close() - var droppedIncoming, droppedOutgoing, total atomic.Int32 - proxy := &quicproxy.Proxy{ - Conn: newUDPConnLocalhost(t), - ServerAddr: server.Addr().(*net.UDPAddr), - DropPacket: func(dir quicproxy.Direction, _, _ net.Addr, packet []byte) bool { - if wire.IsLongHeaderPacket(packet[0]) { // don't drop Long Header packets - return false - } - if len(packet) < datagramSize { // don't drop ACK-only packets - return false - } - total.Add(1) - // drop about 20% of Short Header packets with DATAGRAM frames - if mrand.Int()%5 == 0 { - switch dir { - case quicproxy.DirectionIncoming: - droppedIncoming.Add(1) - case quicproxy.DirectionOutgoing: - droppedOutgoing.Add(1) + server, err := quic.Listen( + serverPacketConn, + getTLSConfig(), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}), + ) + require.NoError(t, err) + defer server.Close() + + const sendInterval = time.Second // send a datagram every second + ctx, cancel := context.WithTimeout(context.Background(), (numDatagrams+10)*sendInterval) + defer cancel() + clientConn, err := quic.Dial( + ctx, + clientPacketConn, + serverPacketConn.LocalAddr(), + getTLSClientConfig(), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}), + ) + require.NoError(t, err) + defer clientConn.CloseWithError(0, "") + + serverConn, err := server.Accept(ctx) + require.NoError(t, err) + defer serverConn.CloseWithError(0, "") + + var clientDatagrams, serverDatagrams int + clientErrChan := make(chan error, 1) + go func() { + defer close(clientErrChan) + for { + if _, err := clientConn.ReceiveDatagram(ctx); err != nil { + clientErrChan <- err + return } - return true + clientDatagrams++ } - return false - }, - DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 }, - } - require.NoError(t, proxy.Start()) - defer proxy.Close() + }() - // SendDatagram blocks when the queue is full (maxDatagramSendQueueLen), - // add some extra margin for the handshake, networking and ACKs. - ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(4*numDatagrams*time.Millisecond)) - defer cancel() - clientConn, err := quic.Dial( - ctx, - newUDPConnLocalhost(t), - proxy.LocalAddr(), - getTLSClientConfig(), - getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}), - ) - require.NoError(t, err) - defer clientConn.CloseWithError(0, "") - - serverConn, err := server.Accept(ctx) - require.NoError(t, err) - defer serverConn.CloseWithError(0, "") - - var clientDatagrams, serverDatagrams int - clientErrChan := make(chan error, 1) - go func() { - defer close(clientErrChan) - for { - if _, err := clientConn.ReceiveDatagram(ctx); err != nil { - clientErrChan <- err - return - } - clientDatagrams++ + for i := range numDatagrams { + payload := bytes.Repeat([]byte{uint8(i)}, datagramSize) + require.NoError(t, clientConn.SendDatagram(payload)) + require.NoError(t, serverConn.SendDatagram(payload)) + time.Sleep(sendInterval) } - }() - for i := range numDatagrams { - payload := bytes.Repeat([]byte{uint8(i)}, datagramSize) - require.NoError(t, clientConn.SendDatagram(payload)) - require.NoError(t, serverConn.SendDatagram(payload)) - time.Sleep(scaleDuration(time.Millisecond / 2)) - } - - serverErrChan := make(chan error, 1) - go func() { - defer close(serverErrChan) - for { - if _, err := serverConn.ReceiveDatagram(ctx); err != nil { - serverErrChan <- err - return + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + for { + if _, err := serverConn.ReceiveDatagram(ctx); err != nil { + serverErrChan <- err + return + } + serverDatagrams++ } - serverDatagrams++ + }() + + select { + case err := <-clientErrChan: + require.ErrorIs(t, err, context.DeadlineExceeded) + case <-time.After(5 * numDatagrams * sendInterval): + t.Fatal("timeout") + } + select { + case err := <-serverErrChan: + require.ErrorIs(t, err, context.DeadlineExceeded) + case <-time.After(5 * numDatagrams * sendInterval): + t.Fatal("timeout") } - }() - select { - case err := <-clientErrChan: - require.ErrorIs(t, err, context.DeadlineExceeded) - case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)): - t.Fatal("timeout") - } - select { - case err := <-serverErrChan: - require.ErrorIs(t, err, context.DeadlineExceeded) - case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)): - t.Fatal("timeout") - } - - numDroppedIncoming := droppedIncoming.Load() - numDroppedOutgoing := droppedOutgoing.Load() - t.Logf("dropped %d incoming and %d outgoing out of %d packets", numDroppedIncoming, numDroppedOutgoing, total.Load()) - assert.NotZero(t, numDroppedIncoming) - assert.NotZero(t, numDroppedOutgoing) - t.Logf("server received %d out of %d sent datagrams", serverDatagrams, numDatagrams) - assert.EqualValues(t, numDatagrams-numDroppedIncoming, serverDatagrams, "datagrams received by the server") - t.Logf("client received %d out of %d sent datagrams", clientDatagrams, numDatagrams) - assert.EqualValues(t, numDatagrams-numDroppedOutgoing, clientDatagrams, "datagrams received by the client") + numDroppedToClient := droppedToClient.Load() + numDroppedToServer := droppedToServer.Load() + t.Logf("dropped %d to client and %d to server out of %d packets", numDroppedToClient, numDroppedToServer, total.Load()) + assert.NotZero(t, numDroppedToClient) + assert.NotZero(t, numDroppedToServer) + t.Logf("server received %d out of %d sent datagrams", serverDatagrams, numDatagrams) + assert.EqualValues(t, numDatagrams-numDroppedToServer, serverDatagrams, "datagrams received by the server") + t.Logf("client received %d out of %d sent datagrams", clientDatagrams, numDatagrams) + assert.EqualValues(t, numDatagrams-numDroppedToClient, clientDatagrams, "datagrams received by the client") + }) } diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 63655a0f..09592c0a 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -147,16 +147,16 @@ func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, clientConn ne } func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.Packet) bool { - var incoming, outgoing atomic.Int32 + var toClient, toServer atomic.Int32 return func(d direction, p simnet.Packet) bool { switch d { - case directionIncoming: - c := incoming.Add(1) + case directionToClient: + c := toClient.Add(1) if d == dir || dir == directionBoth { return slices.Contains(ns, int(c)) } - case directionOutgoing: - c := outgoing.Add(1) + case directionToServer: + c := toServer.Add(1) if dir == d || dir == directionBoth { return slices.Contains(ns, int(c)) } @@ -168,33 +168,33 @@ func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet. func dropCallbackDropOneThird(_ direction) func(direction, simnet.Packet) bool { const maxSequentiallyDropped = 10 var mx sync.Mutex - var incoming, outgoing int + var toClient, toServer int return func(d direction, p simnet.Packet) bool { drop := mrand.IntN(3) == 0 mx.Lock() defer mx.Unlock() // never drop more than 10 consecutive packets - if d == directionIncoming || d == directionBoth { + if d == directionToClient || d == directionBoth { if drop { - incoming++ - if incoming > maxSequentiallyDropped { + toClient++ + if toClient > maxSequentiallyDropped { drop = false } } if !drop { - incoming = 0 + toClient = 0 } } - if d == directionOutgoing || d == directionBoth { + if d == directionToServer || d == directionBoth { if drop { - outgoing++ - if outgoing > maxSequentiallyDropped { + toServer++ + if toServer > maxSequentiallyDropped { drop = false } } if !drop { - outgoing = 0 + toServer = 0 } } return drop @@ -220,13 +220,13 @@ func TestHandshakeWithPacketLoss(t *testing.T) { doRetry bool } - for _, dir := range []direction{directionIncoming, directionOutgoing, directionBoth} { + for _, dir := range []direction{directionToClient, directionToServer, directionBoth} { for _, pattern := range []dropPattern{ dropPatternDrop1stPacket, dropPatternDropFirst3Packets, dropPatternDropOneThirdOfPackets, } { - t.Run(fmt.Sprintf("%s in %s direction", pattern, dir), func(t *testing.T) { + t.Run(fmt.Sprintf("%s in direction %s", pattern, dir), func(t *testing.T) { for _, conf := range []testConfig{ {postQuantum: false, longCertChain: false, doRetry: true}, {postQuantum: false, longCertChain: false, doRetry: false}, diff --git a/integrationtests/self/simnet_helper_test.go b/integrationtests/self/simnet_helper_test.go index 35d41f70..10f2afd9 100644 --- a/integrationtests/self/simnet_helper_test.go +++ b/integrationtests/self/simnet_helper_test.go @@ -23,17 +23,17 @@ type direction uint8 const ( directionUnknown = iota - directionIncoming - directionOutgoing + directionToClient + directionToServer directionBoth ) func (d direction) String() string { switch d { - case directionIncoming: - return "incoming" - case directionOutgoing: - return "outgoing" + case directionToClient: + return "to client" + case directionToServer: + return "to server" case directionBoth: return "both" } @@ -54,9 +54,9 @@ func (d *directionAwareDroppingRouter) SendPacket(p simnet.Packet) error { var dir direction switch p.To.String() { case d.ClientAddr.String(): - dir = directionIncoming + dir = directionToClient case d.ServerAddr.String(): - dir = directionOutgoing + dir = directionToServer default: dir = directionUnknown }