Files
quic-go/integrationtests/self/handshake_drop_test.go
2025-10-23 19:32:01 +02:00

333 lines
8.9 KiB
Go

package self_test
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"math"
mrand "math/rand/v2"
"net"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/synctest"
"github.com/quic-go/quic-go/testutils/simnet"
"github.com/stretchr/testify/require"
)
func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := quic.Dial(
ctx,
clientConn,
ln.Addr(),
clientConf,
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
str, err := conn.OpenUniStream()
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
defer str.Close()
_, err := str.Write(data)
errChan <- err
}()
serverConn, err := ln.Accept(ctx)
require.NoError(t, err)
serverStr, err := serverConn.AcceptUniStream(ctx)
require.NoError(t, err)
b, err := io.ReadAll(&readerWithTimeout{Reader: serverStr, Timeout: timeout})
require.NoError(t, err)
require.Equal(t, b, data)
serverConn.CloseWithError(0, "")
return conn
}
func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := quic.Dial(
ctx,
clientConn,
ln.Addr(),
clientConf,
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
defer close(errChan)
defer conn.CloseWithError(0, "")
str, err := conn.AcceptUniStream(ctx)
if err != nil {
errChan <- err
return
}
b, err := io.ReadAll(&readerWithTimeout{Reader: str, Timeout: timeout})
if err != nil {
errChan <- err
return
}
if !bytes.Equal(b, data) {
errChan <- fmt.Errorf("data mismatch: %x != %x", b, data)
return
}
}()
serverConn, err := ln.Accept(ctx)
require.NoError(t, err)
serverStr, err := serverConn.OpenUniStream()
require.NoError(t, err)
_, err = serverStr.Write(data)
require.NoError(t, err)
require.NoError(t, serverStr.Close())
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(timeout):
t.Fatal("server connection not closed")
}
select {
case <-conn.Context().Done():
case <-time.After(timeout):
t.Fatal("server connection not closed")
}
return conn
}
func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, _ []byte) *quic.Conn {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := quic.Dial(
ctx,
clientConn,
ln.Addr(),
clientConf,
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
serverConn, err := ln.Accept(ctx)
require.NoError(t, err)
serverConn.CloseWithError(0, "")
return conn
}
func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.Packet) bool {
var toClient, toServer atomic.Int32
return func(d direction, p simnet.Packet) bool {
switch d {
case directionToClient:
c := toClient.Add(1)
if d == dir || dir == directionBoth {
return slices.Contains(ns, int(c))
}
case directionToServer:
c := toServer.Add(1)
if dir == d || dir == directionBoth {
return slices.Contains(ns, int(c))
}
}
return false
}
}
func dropCallbackDropOneThird(_ direction) func(direction, simnet.Packet) bool {
const maxSequentiallyDropped = 10
var mx sync.Mutex
var toClient, toServer int
return func(d direction, p simnet.Packet) bool {
drop := mrand.IntN(3) == 0
mx.Lock()
defer mx.Unlock()
// never drop more than 10 consecutive packets
if d == directionToClient || d == directionBoth {
if drop {
toClient++
if toClient > maxSequentiallyDropped {
drop = false
}
}
if !drop {
toClient = 0
}
}
if d == directionToServer || d == directionBoth {
if drop {
toServer++
if toServer > maxSequentiallyDropped {
drop = false
}
}
if !drop {
toServer = 0
}
}
return drop
}
}
func TestHandshakeWithPacketLoss(t *testing.T) {
data := GeneratePRData(5000)
const timeout = 2 * time.Minute
const rtt = 20 * time.Millisecond
type dropPattern string
const (
dropPatternDrop1stPacket dropPattern = "drop 1st packet"
dropPatternDropFirst3Packets dropPattern = "drop first 3 packets"
dropPatternDropOneThirdOfPackets dropPattern = "drop 1/3 of packets"
)
type testConfig struct {
postQuantum bool
longCertChain bool
doRetry bool
}
for _, dir := range []direction{directionToClient, directionToServer, directionBoth} {
for _, pattern := range []dropPattern{
dropPatternDrop1stPacket,
dropPatternDropFirst3Packets,
dropPatternDropOneThirdOfPackets,
} {
t.Run(fmt.Sprintf("%s in direction %s", pattern, dir), func(t *testing.T) {
for _, conf := range []testConfig{
{postQuantum: false, longCertChain: false, doRetry: true},
{postQuantum: false, longCertChain: false, doRetry: false},
{postQuantum: false, longCertChain: true, doRetry: false},
{postQuantum: true, longCertChain: false, doRetry: false},
{postQuantum: true, longCertChain: true, doRetry: false},
} {
for _, test := range []struct {
name string
fn func(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn
}{
{"client speaks first", dropTestProtocolClientSpeaksFirst},
{"server speaks first", dropTestProtocolServerSpeaksFirst},
{"nobody speaks", dropTestProtocolNobodySpeaks},
} {
t.Run(fmt.Sprintf("retry: %t/%s", conf.doRetry, test.name), func(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}
serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}
var fn func(direction, simnet.Packet) bool
switch pattern {
case dropPatternDrop1stPacket:
fn = dropCallbackDropNthPacket(dir, 1)
case dropPatternDropFirst3Packets:
fn = dropCallbackDropNthPacket(dir, 1, 2, 3)
case dropPatternDropOneThirdOfPackets:
fn = dropCallbackDropOneThird(dir)
}
var numDropped atomic.Int32
n := &simnet.Simnet{
Router: &directionAwareDroppingRouter{
ClientAddr: clientAddr,
ServerAddr: serverAddr,
Drop: func(d direction, p simnet.Packet) bool {
drop := fn(d, p)
if drop {
numDropped.Add(1)
}
return drop
},
},
}
settings := simnet.NodeBiDiLinkSettings{
Downlink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
Uplink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
}
clientConn := n.NewEndpoint(clientAddr, settings)
defer clientConn.Close()
serverConn := n.NewEndpoint(serverAddr, settings)
defer serverConn.Close()
require.NoError(t, n.Start())
defer n.Close()
var tlsConf *tls.Config
if conf.longCertChain {
tlsConf = getTLSConfigWithLongCertChain()
} else {
tlsConf = getTLSConfig()
}
clientConf := getTLSClientConfig()
if !conf.postQuantum {
clientConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
}
tr := &quic.Transport{
Conn: serverConn,
VerifySourceAddress: func(net.Addr) bool { return conf.doRetry },
}
defer tr.Close()
ln, err := tr.Listen(
tlsConf,
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
defer ln.Close()
conn := test.fn(t, ln, clientConn, clientConf, timeout, data)
if !strings.HasPrefix(runtime.Version(), "go1.24") {
curveID := getCurveID(conn.ConnectionState().TLS)
if conf.postQuantum {
require.Equal(t, tls.X25519MLKEM768, curveID)
} else {
require.Equal(t, tls.CurveP384, curveID)
}
}
if pattern != dropPatternDropOneThirdOfPackets {
require.NotZero(t, numDropped.Load())
}
t.Logf("dropped %d packets", numDropped.Load())
})
})
}
}
})
}
}
}