use synctest for the datagram test (#5398)

This commit is contained in:
Marten Seemann
2025-10-23 19:32:01 +02:00
committed by GitHub
parent 49d41dc218
commit 44db8994c9
3 changed files with 142 additions and 128 deletions

View File

@@ -3,6 +3,7 @@ package self_test
import (
"bytes"
"context"
"math"
mrand "math/rand/v2"
"net"
"sync/atomic"
@@ -10,8 +11,9 @@ import (
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/synctest"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/testutils/simnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -124,117 +126,129 @@ func TestDatagramSizeLimit(t *testing.T) {
}
func TestDatagramLoss(t *testing.T) {
const rtt = 10 * time.Millisecond
const numDatagrams = 100
const datagramSize = 500
synctest.Test(t, func(t *testing.T) {
const rtt = 100 * time.Millisecond
const numDatagrams = 100
const datagramSize = 500
server, err := quic.Listen(
newUDPConnLocalhost(t),
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
)
require.NoError(t, err)
defer server.Close()
clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}
serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}
var droppedToClient, droppedToServer, total atomic.Int32
n := &simnet.Simnet{
Router: &directionAwareDroppingRouter{
ClientAddr: clientAddr,
ServerAddr: serverAddr,
Drop: func(d direction, p simnet.Packet) bool {
if wire.IsLongHeaderPacket(p.Data[0]) { // don't drop Long Header packets
return false
}
if len(p.Data) < datagramSize { // don't drop ACK-only packets
return false
}
total.Add(1)
// drop about 20% of Short Header packets with DATAGRAM frames
if mrand.Int()%5 == 0 {
switch d {
case directionToClient:
droppedToClient.Add(1)
case directionToServer:
droppedToServer.Add(1)
}
return true
}
return false
},
},
}
settings := simnet.NodeBiDiLinkSettings{
Downlink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
Uplink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
}
clientPacketConn := n.NewEndpoint(clientAddr, settings)
defer clientPacketConn.Close()
serverPacketConn := n.NewEndpoint(serverAddr, settings)
defer serverPacketConn.Close()
require.NoError(t, n.Start())
defer n.Close()
var droppedIncoming, droppedOutgoing, total atomic.Int32
proxy := &quicproxy.Proxy{
Conn: newUDPConnLocalhost(t),
ServerAddr: server.Addr().(*net.UDPAddr),
DropPacket: func(dir quicproxy.Direction, _, _ net.Addr, packet []byte) bool {
if wire.IsLongHeaderPacket(packet[0]) { // don't drop Long Header packets
return false
}
if len(packet) < datagramSize { // don't drop ACK-only packets
return false
}
total.Add(1)
// drop about 20% of Short Header packets with DATAGRAM frames
if mrand.Int()%5 == 0 {
switch dir {
case quicproxy.DirectionIncoming:
droppedIncoming.Add(1)
case quicproxy.DirectionOutgoing:
droppedOutgoing.Add(1)
server, err := quic.Listen(
serverPacketConn,
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
)
require.NoError(t, err)
defer server.Close()
const sendInterval = time.Second // send a datagram every second
ctx, cancel := context.WithTimeout(context.Background(), (numDatagrams+10)*sendInterval)
defer cancel()
clientConn, err := quic.Dial(
ctx,
clientPacketConn,
serverPacketConn.LocalAddr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
)
require.NoError(t, err)
defer clientConn.CloseWithError(0, "")
serverConn, err := server.Accept(ctx)
require.NoError(t, err)
defer serverConn.CloseWithError(0, "")
var clientDatagrams, serverDatagrams int
clientErrChan := make(chan error, 1)
go func() {
defer close(clientErrChan)
for {
if _, err := clientConn.ReceiveDatagram(ctx); err != nil {
clientErrChan <- err
return
}
return true
clientDatagrams++
}
return false
},
DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 },
}
require.NoError(t, proxy.Start())
defer proxy.Close()
}()
// SendDatagram blocks when the queue is full (maxDatagramSendQueueLen),
// add some extra margin for the handshake, networking and ACKs.
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(4*numDatagrams*time.Millisecond))
defer cancel()
clientConn, err := quic.Dial(
ctx,
newUDPConnLocalhost(t),
proxy.LocalAddr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
)
require.NoError(t, err)
defer clientConn.CloseWithError(0, "")
serverConn, err := server.Accept(ctx)
require.NoError(t, err)
defer serverConn.CloseWithError(0, "")
var clientDatagrams, serverDatagrams int
clientErrChan := make(chan error, 1)
go func() {
defer close(clientErrChan)
for {
if _, err := clientConn.ReceiveDatagram(ctx); err != nil {
clientErrChan <- err
return
}
clientDatagrams++
for i := range numDatagrams {
payload := bytes.Repeat([]byte{uint8(i)}, datagramSize)
require.NoError(t, clientConn.SendDatagram(payload))
require.NoError(t, serverConn.SendDatagram(payload))
time.Sleep(sendInterval)
}
}()
for i := range numDatagrams {
payload := bytes.Repeat([]byte{uint8(i)}, datagramSize)
require.NoError(t, clientConn.SendDatagram(payload))
require.NoError(t, serverConn.SendDatagram(payload))
time.Sleep(scaleDuration(time.Millisecond / 2))
}
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
for {
if _, err := serverConn.ReceiveDatagram(ctx); err != nil {
serverErrChan <- err
return
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
for {
if _, err := serverConn.ReceiveDatagram(ctx); err != nil {
serverErrChan <- err
return
}
serverDatagrams++
}
serverDatagrams++
}()
select {
case err := <-clientErrChan:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(5 * numDatagrams * sendInterval):
t.Fatal("timeout")
}
select {
case err := <-serverErrChan:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(5 * numDatagrams * sendInterval):
t.Fatal("timeout")
}
}()
select {
case err := <-clientErrChan:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)):
t.Fatal("timeout")
}
select {
case err := <-serverErrChan:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)):
t.Fatal("timeout")
}
numDroppedIncoming := droppedIncoming.Load()
numDroppedOutgoing := droppedOutgoing.Load()
t.Logf("dropped %d incoming and %d outgoing out of %d packets", numDroppedIncoming, numDroppedOutgoing, total.Load())
assert.NotZero(t, numDroppedIncoming)
assert.NotZero(t, numDroppedOutgoing)
t.Logf("server received %d out of %d sent datagrams", serverDatagrams, numDatagrams)
assert.EqualValues(t, numDatagrams-numDroppedIncoming, serverDatagrams, "datagrams received by the server")
t.Logf("client received %d out of %d sent datagrams", clientDatagrams, numDatagrams)
assert.EqualValues(t, numDatagrams-numDroppedOutgoing, clientDatagrams, "datagrams received by the client")
numDroppedToClient := droppedToClient.Load()
numDroppedToServer := droppedToServer.Load()
t.Logf("dropped %d to client and %d to server out of %d packets", numDroppedToClient, numDroppedToServer, total.Load())
assert.NotZero(t, numDroppedToClient)
assert.NotZero(t, numDroppedToServer)
t.Logf("server received %d out of %d sent datagrams", serverDatagrams, numDatagrams)
assert.EqualValues(t, numDatagrams-numDroppedToServer, serverDatagrams, "datagrams received by the server")
t.Logf("client received %d out of %d sent datagrams", clientDatagrams, numDatagrams)
assert.EqualValues(t, numDatagrams-numDroppedToClient, clientDatagrams, "datagrams received by the client")
})
}

View File

@@ -147,16 +147,16 @@ func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, clientConn ne
}
func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.Packet) bool {
var incoming, outgoing atomic.Int32
var toClient, toServer atomic.Int32
return func(d direction, p simnet.Packet) bool {
switch d {
case directionIncoming:
c := incoming.Add(1)
case directionToClient:
c := toClient.Add(1)
if d == dir || dir == directionBoth {
return slices.Contains(ns, int(c))
}
case directionOutgoing:
c := outgoing.Add(1)
case directionToServer:
c := toServer.Add(1)
if dir == d || dir == directionBoth {
return slices.Contains(ns, int(c))
}
@@ -168,33 +168,33 @@ func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.
func dropCallbackDropOneThird(_ direction) func(direction, simnet.Packet) bool {
const maxSequentiallyDropped = 10
var mx sync.Mutex
var incoming, outgoing int
var toClient, toServer int
return func(d direction, p simnet.Packet) bool {
drop := mrand.IntN(3) == 0
mx.Lock()
defer mx.Unlock()
// never drop more than 10 consecutive packets
if d == directionIncoming || d == directionBoth {
if d == directionToClient || d == directionBoth {
if drop {
incoming++
if incoming > maxSequentiallyDropped {
toClient++
if toClient > maxSequentiallyDropped {
drop = false
}
}
if !drop {
incoming = 0
toClient = 0
}
}
if d == directionOutgoing || d == directionBoth {
if d == directionToServer || d == directionBoth {
if drop {
outgoing++
if outgoing > maxSequentiallyDropped {
toServer++
if toServer > maxSequentiallyDropped {
drop = false
}
}
if !drop {
outgoing = 0
toServer = 0
}
}
return drop
@@ -220,13 +220,13 @@ func TestHandshakeWithPacketLoss(t *testing.T) {
doRetry bool
}
for _, dir := range []direction{directionIncoming, directionOutgoing, directionBoth} {
for _, dir := range []direction{directionToClient, directionToServer, directionBoth} {
for _, pattern := range []dropPattern{
dropPatternDrop1stPacket,
dropPatternDropFirst3Packets,
dropPatternDropOneThirdOfPackets,
} {
t.Run(fmt.Sprintf("%s in %s direction", pattern, dir), func(t *testing.T) {
t.Run(fmt.Sprintf("%s in direction %s", pattern, dir), func(t *testing.T) {
for _, conf := range []testConfig{
{postQuantum: false, longCertChain: false, doRetry: true},
{postQuantum: false, longCertChain: false, doRetry: false},

View File

@@ -23,17 +23,17 @@ type direction uint8
const (
directionUnknown = iota
directionIncoming
directionOutgoing
directionToClient
directionToServer
directionBoth
)
func (d direction) String() string {
switch d {
case directionIncoming:
return "incoming"
case directionOutgoing:
return "outgoing"
case directionToClient:
return "to client"
case directionToServer:
return "to server"
case directionBoth:
return "both"
}
@@ -54,9 +54,9 @@ func (d *directionAwareDroppingRouter) SendPacket(p simnet.Packet) error {
var dir direction
switch p.To.String() {
case d.ClientAddr.String():
dir = directionIncoming
dir = directionToClient
case d.ServerAddr.String():
dir = directionOutgoing
dir = directionToServer
default:
dir = directionUnknown
}