use synctest for the handshake drop test (#5397)

This test used to take 15-30s locally, and even more on CI. It now runs in less than 1ms.
This commit is contained in:
Marten Seemann
2025-10-23 16:56:35 +02:00
committed by GitHub
parent 756bdc0104
commit 49d41dc218
4 changed files with 218 additions and 106 deletions

View File

@@ -3,65 +3,35 @@ package self_test
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"fmt"
"io"
"math"
mrand "math/rand/v2"
"net"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/internal/synctest"
"github.com/quic-go/quic-go/testutils/simnet"
"github.com/stretchr/testify/require"
)
func startDropTestListenerAndProxy(t *testing.T, rtt, timeout time.Duration, dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (_ *quic.Listener, proxyAddr net.Addr) {
t.Helper()
conf := getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
})
var tlsConf *tls.Config
if longCertChain {
tlsConf = getTLSConfigWithLongCertChain()
} else {
tlsConf = getTLSConfig()
}
tr := &quic.Transport{
Conn: newUDPConnLocalhost(t),
VerifySourceAddress: func(net.Addr) bool { return doRetry },
}
t.Cleanup(func() { tr.Close() })
ln, err := tr.Listen(tlsConf, conf)
require.NoError(t, err)
t.Cleanup(func() { ln.Close() })
proxy := quicproxy.Proxy{
Conn: newUDPConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DropPacket: dropCallback,
DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 },
}
require.NoError(t, proxy.Start())
t.Cleanup(func() { proxy.Close() })
return ln, proxy.LocalAddr()
}
func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, addr net.Addr, timeout time.Duration, data []byte) {
func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := quic.Dial(
ctx,
newUDPConnLocalhost(t),
addr,
getTLSClientConfig(),
clientConn,
ln.Addr(),
clientConf,
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
@@ -88,16 +58,18 @@ func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, addr net
require.NoError(t, err)
require.Equal(t, b, data)
serverConn.CloseWithError(0, "")
return conn
}
func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, addr net.Addr, timeout time.Duration, data []byte) {
func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := quic.Dial(
ctx,
newUDPConnLocalhost(t),
addr,
getTLSClientConfig(),
clientConn,
ln.Addr(),
clientConf,
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
@@ -146,16 +118,18 @@ func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, addr net
case <-time.After(timeout):
t.Fatal("server connection not closed")
}
return conn
}
func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, addr net.Addr, timeout time.Duration) {
func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, _ []byte) *quic.Conn {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := quic.Dial(
ctx,
newUDPConnLocalhost(t),
addr,
getTLSClientConfig(),
clientConn,
ln.Addr(),
clientConf,
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
@@ -168,33 +142,40 @@ func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, addr net.Addr
serverConn, err := ln.Accept(ctx)
require.NoError(t, err)
serverConn.CloseWithError(0, "")
return conn
}
func dropCallbackDropNthPacket(direction quicproxy.Direction, n int) quicproxy.DropCallback {
func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.Packet) bool {
var incoming, outgoing atomic.Int32
return func(d quicproxy.Direction, _, _ net.Addr, packet []byte) bool {
var p int32
return func(d direction, p simnet.Packet) bool {
switch d {
case quicproxy.DirectionIncoming:
p = incoming.Add(1)
case quicproxy.DirectionOutgoing:
p = outgoing.Add(1)
case directionIncoming:
c := incoming.Add(1)
if d == dir || dir == directionBoth {
return slices.Contains(ns, int(c))
}
case directionOutgoing:
c := outgoing.Add(1)
if dir == d || dir == directionBoth {
return slices.Contains(ns, int(c))
}
}
return p == int32(n) && d.Is(direction)
return false
}
}
func dropCallbackDropOneThird(direction quicproxy.Direction) quicproxy.DropCallback {
func dropCallbackDropOneThird(_ direction) func(direction, simnet.Packet) bool {
const maxSequentiallyDropped = 10
var mx sync.Mutex
var incoming, outgoing int
return func(d quicproxy.Direction, _, _ net.Addr, _ []byte) bool {
return func(d direction, p simnet.Packet) bool {
drop := mrand.IntN(3) == 0
mx.Lock()
defer mx.Unlock()
// never drop more than 10 consecutive packets
if d.Is(quicproxy.DirectionIncoming) {
if d == directionIncoming || d == directionBoth {
if drop {
incoming++
if incoming > maxSequentiallyDropped {
@@ -205,7 +186,7 @@ func dropCallbackDropOneThird(direction quicproxy.Direction) quicproxy.DropCallb
incoming = 0
}
}
if d.Is(quicproxy.DirectionOutgoing) {
if d == directionOutgoing || d == directionBoth {
if drop {
outgoing++
if outgoing > maxSequentiallyDropped {
@@ -225,63 +206,127 @@ func TestHandshakeWithPacketLoss(t *testing.T) {
const timeout = 2 * time.Minute
const rtt = 20 * time.Millisecond
type dropPattern struct {
name string
fn quicproxy.DropCallback
}
type dropPattern string
type serverConfig struct {
const (
dropPatternDrop1stPacket dropPattern = "drop 1st packet"
dropPatternDropFirst3Packets dropPattern = "drop first 3 packets"
dropPatternDropOneThirdOfPackets dropPattern = "drop 1/3 of packets"
)
type testConfig struct {
postQuantum bool
longCertChain bool
doRetry bool
}
for _, direction := range []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth} {
for _, dropPattern := range []dropPattern{
{name: "drop 1st packet", fn: dropCallbackDropNthPacket(direction, 1)},
{name: "drop 2nd packet", fn: dropCallbackDropNthPacket(direction, 2)},
{name: "drop 1/3 of packets", fn: dropCallbackDropOneThird(direction)},
for _, dir := range []direction{directionIncoming, directionOutgoing, directionBoth} {
for _, pattern := range []dropPattern{
dropPatternDrop1stPacket,
dropPatternDropFirst3Packets,
dropPatternDropOneThirdOfPackets,
} {
t.Run(fmt.Sprintf("%s in %s direction", dropPattern.name, direction), func(t *testing.T) {
for _, conf := range []serverConfig{
{longCertChain: false, doRetry: true},
{longCertChain: false, doRetry: false},
{longCertChain: true, doRetry: false},
t.Run(fmt.Sprintf("%s in %s direction", pattern, dir), func(t *testing.T) {
for _, conf := range []testConfig{
{postQuantum: false, longCertChain: false, doRetry: true},
{postQuantum: false, longCertChain: false, doRetry: false},
{postQuantum: false, longCertChain: true, doRetry: false},
{postQuantum: true, longCertChain: false, doRetry: false},
{postQuantum: true, longCertChain: true, doRetry: false},
} {
t.Run(fmt.Sprintf("retry: %t", conf.doRetry), func(t *testing.T) {
t.Run("client speaks first", func(t *testing.T) {
ln, proxyAddr := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
dropTestProtocolClientSpeaksFirst(t, ln, proxyAddr, timeout, data)
})
for _, test := range []struct {
name string
fn func(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn
}{
{"client speaks first", dropTestProtocolClientSpeaksFirst},
{"server speaks first", dropTestProtocolServerSpeaksFirst},
{"nobody speaks", dropTestProtocolNobodySpeaks},
} {
t.Run(fmt.Sprintf("retry: %t/%s", conf.doRetry, test.name), func(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}
serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}
var fn func(direction, simnet.Packet) bool
switch pattern {
case dropPatternDrop1stPacket:
fn = dropCallbackDropNthPacket(dir, 1)
case dropPatternDropFirst3Packets:
fn = dropCallbackDropNthPacket(dir, 1, 2, 3)
case dropPatternDropOneThirdOfPackets:
fn = dropCallbackDropOneThird(dir)
}
var numDropped atomic.Int32
n := &simnet.Simnet{
Router: &directionAwareDroppingRouter{
ClientAddr: clientAddr,
ServerAddr: serverAddr,
Drop: func(d direction, p simnet.Packet) bool {
drop := fn(d, p)
if drop {
numDropped.Add(1)
}
return drop
},
},
}
settings := simnet.NodeBiDiLinkSettings{
Downlink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
Uplink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
}
clientConn := n.NewEndpoint(clientAddr, settings)
defer clientConn.Close()
serverConn := n.NewEndpoint(serverAddr, settings)
defer serverConn.Close()
require.NoError(t, n.Start())
defer n.Close()
t.Run("server speaks first", func(t *testing.T) {
ln, proxyAddr := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
dropTestProtocolServerSpeaksFirst(t, ln, proxyAddr, timeout, data)
})
var tlsConf *tls.Config
if conf.longCertChain {
tlsConf = getTLSConfigWithLongCertChain()
} else {
tlsConf = getTLSConfig()
}
clientConf := getTLSClientConfig()
if !conf.postQuantum {
clientConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
}
t.Run("nobody speaks", func(t *testing.T) {
ln, proxyAddr := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
dropTestProtocolNobodySpeaks(t, ln, proxyAddr, timeout)
tr := &quic.Transport{
Conn: serverConn,
VerifySourceAddress: func(net.Addr) bool { return conf.doRetry },
}
defer tr.Close()
ln, err := tr.Listen(
tlsConf,
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
defer ln.Close()
conn := test.fn(t, ln, clientConn, clientConf, timeout, data)
if !strings.HasPrefix(runtime.Version(), "go1.24") {
curveID := getCurveID(conn.ConnectionState().TLS)
if conf.postQuantum {
require.Equal(t, tls.X25519MLKEM768, curveID)
} else {
require.Equal(t, tls.CurveP384, curveID)
}
}
if pattern != dropPatternDropOneThirdOfPackets {
require.NotZero(t, numDropped.Load())
}
t.Logf("dropped %d packets", numDropped.Load())
})
})
})
}
}
})
}
}
}
func TestPostQuantumClientHello(t *testing.T) {
origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient
t.Cleanup(func() { wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient })
b := make([]byte, 2500) // the ClientHello will now span across 3 packets
rand.Read(b)
wire.AdditionalTransportParametersClient = map[uint64][]byte{
// We don't use a greased transport parameter here, since the transport parameter serialization function
// will add a greased transport parameter, and therefore there's a risk of a collision.
// Instead, we just use pseudorandom constant value.
1234567: b,
}
ln, proxyPort := startDropTestListenerAndProxy(t, 10*time.Millisecond, 20*time.Second, dropCallbackDropOneThird(quicproxy.DirectionIncoming), false, false)
dropTestProtocolClientSpeaksFirst(t, ln, proxyPort, time.Minute, GeneratePRData(5000))
}

View File

@@ -0,0 +1,9 @@
//go:build !go1.25
package self_test
import "crypto/tls"
func getCurveID(connState tls.ConnectionState) tls.CurveID {
return 0
}

View File

@@ -0,0 +1,9 @@
//go:build go1.25
package self_test
import "crypto/tls"
func getCurveID(connState tls.ConnectionState) tls.CurveID {
return connState.CurveID
}

View File

@@ -1,6 +1,10 @@
package self_test
import "github.com/quic-go/quic-go/testutils/simnet"
import (
"net"
"github.com/quic-go/quic-go/testutils/simnet"
)
type droppingRouter struct {
simnet.PerfectRouter
@@ -15,4 +19,49 @@ func (d *droppingRouter) SendPacket(p simnet.Packet) error {
return d.PerfectRouter.SendPacket(p)
}
type direction uint8
const (
directionUnknown = iota
directionIncoming
directionOutgoing
directionBoth
)
func (d direction) String() string {
switch d {
case directionIncoming:
return "incoming"
case directionOutgoing:
return "outgoing"
case directionBoth:
return "both"
}
return "unknown"
}
var _ simnet.Router = &droppingRouter{}
type directionAwareDroppingRouter struct {
simnet.PerfectRouter
ClientAddr, ServerAddr *net.UDPAddr
Drop func(direction direction, p simnet.Packet) bool
}
func (d *directionAwareDroppingRouter) SendPacket(p simnet.Packet) error {
var dir direction
switch p.To.String() {
case d.ClientAddr.String():
dir = directionIncoming
case d.ServerAddr.String():
dir = directionOutgoing
default:
dir = directionUnknown
}
if d.Drop(dir, p) {
return nil
}
return d.PerfectRouter.SendPacket(p)
}