forked from quic-go/quic-go
333 lines
8.9 KiB
Go
333 lines
8.9 KiB
Go
package self_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
mrand "math/rand/v2"
|
|
"net"
|
|
"runtime"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go"
|
|
"github.com/quic-go/quic-go/internal/synctest"
|
|
"github.com/quic-go/quic-go/testutils/simnet"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn {
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
conn, err := quic.Dial(
|
|
ctx,
|
|
clientConn,
|
|
ln.Addr(),
|
|
clientConf,
|
|
getQuicConfig(&quic.Config{
|
|
MaxIdleTimeout: timeout,
|
|
HandshakeIdleTimeout: timeout,
|
|
DisablePathMTUDiscovery: true,
|
|
}),
|
|
)
|
|
require.NoError(t, err)
|
|
defer conn.CloseWithError(0, "")
|
|
|
|
str, err := conn.OpenUniStream()
|
|
require.NoError(t, err)
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
defer str.Close()
|
|
_, err := str.Write(data)
|
|
errChan <- err
|
|
}()
|
|
|
|
serverConn, err := ln.Accept(ctx)
|
|
require.NoError(t, err)
|
|
serverStr, err := serverConn.AcceptUniStream(ctx)
|
|
require.NoError(t, err)
|
|
b, err := io.ReadAll(&readerWithTimeout{Reader: serverStr, Timeout: timeout})
|
|
require.NoError(t, err)
|
|
require.Equal(t, b, data)
|
|
serverConn.CloseWithError(0, "")
|
|
|
|
return conn
|
|
}
|
|
|
|
func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn {
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
conn, err := quic.Dial(
|
|
ctx,
|
|
clientConn,
|
|
ln.Addr(),
|
|
clientConf,
|
|
getQuicConfig(&quic.Config{
|
|
MaxIdleTimeout: timeout,
|
|
HandshakeIdleTimeout: timeout,
|
|
DisablePathMTUDiscovery: true,
|
|
}),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(errChan)
|
|
defer conn.CloseWithError(0, "")
|
|
str, err := conn.AcceptUniStream(ctx)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
b, err := io.ReadAll(&readerWithTimeout{Reader: str, Timeout: timeout})
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
if !bytes.Equal(b, data) {
|
|
errChan <- fmt.Errorf("data mismatch: %x != %x", b, data)
|
|
return
|
|
}
|
|
}()
|
|
|
|
serverConn, err := ln.Accept(ctx)
|
|
require.NoError(t, err)
|
|
serverStr, err := serverConn.OpenUniStream()
|
|
require.NoError(t, err)
|
|
_, err = serverStr.Write(data)
|
|
require.NoError(t, err)
|
|
require.NoError(t, serverStr.Close())
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(timeout):
|
|
t.Fatal("server connection not closed")
|
|
}
|
|
|
|
select {
|
|
case <-conn.Context().Done():
|
|
case <-time.After(timeout):
|
|
t.Fatal("server connection not closed")
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, _ []byte) *quic.Conn {
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
conn, err := quic.Dial(
|
|
ctx,
|
|
clientConn,
|
|
ln.Addr(),
|
|
clientConf,
|
|
getQuicConfig(&quic.Config{
|
|
MaxIdleTimeout: timeout,
|
|
HandshakeIdleTimeout: timeout,
|
|
DisablePathMTUDiscovery: true,
|
|
}),
|
|
)
|
|
require.NoError(t, err)
|
|
defer conn.CloseWithError(0, "")
|
|
|
|
serverConn, err := ln.Accept(ctx)
|
|
require.NoError(t, err)
|
|
serverConn.CloseWithError(0, "")
|
|
|
|
return conn
|
|
}
|
|
|
|
func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.Packet) bool {
|
|
var toClient, toServer atomic.Int32
|
|
return func(d direction, p simnet.Packet) bool {
|
|
switch d {
|
|
case directionToClient:
|
|
c := toClient.Add(1)
|
|
if d == dir || dir == directionBoth {
|
|
return slices.Contains(ns, int(c))
|
|
}
|
|
case directionToServer:
|
|
c := toServer.Add(1)
|
|
if dir == d || dir == directionBoth {
|
|
return slices.Contains(ns, int(c))
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
}
|
|
|
|
func dropCallbackDropOneThird(_ direction) func(direction, simnet.Packet) bool {
|
|
const maxSequentiallyDropped = 10
|
|
var mx sync.Mutex
|
|
var toClient, toServer int
|
|
return func(d direction, p simnet.Packet) bool {
|
|
drop := mrand.IntN(3) == 0
|
|
|
|
mx.Lock()
|
|
defer mx.Unlock()
|
|
// never drop more than 10 consecutive packets
|
|
if d == directionToClient || d == directionBoth {
|
|
if drop {
|
|
toClient++
|
|
if toClient > maxSequentiallyDropped {
|
|
drop = false
|
|
}
|
|
}
|
|
if !drop {
|
|
toClient = 0
|
|
}
|
|
}
|
|
if d == directionToServer || d == directionBoth {
|
|
if drop {
|
|
toServer++
|
|
if toServer > maxSequentiallyDropped {
|
|
drop = false
|
|
}
|
|
}
|
|
if !drop {
|
|
toServer = 0
|
|
}
|
|
}
|
|
return drop
|
|
}
|
|
}
|
|
|
|
func TestHandshakeWithPacketLoss(t *testing.T) {
|
|
data := GeneratePRData(5000)
|
|
const timeout = 2 * time.Minute
|
|
const rtt = 20 * time.Millisecond
|
|
|
|
type dropPattern string
|
|
|
|
const (
|
|
dropPatternDrop1stPacket dropPattern = "drop 1st packet"
|
|
dropPatternDropFirst3Packets dropPattern = "drop first 3 packets"
|
|
dropPatternDropOneThirdOfPackets dropPattern = "drop 1/3 of packets"
|
|
)
|
|
|
|
type testConfig struct {
|
|
postQuantum bool
|
|
longCertChain bool
|
|
doRetry bool
|
|
}
|
|
|
|
for _, dir := range []direction{directionToClient, directionToServer, directionBoth} {
|
|
for _, pattern := range []dropPattern{
|
|
dropPatternDrop1stPacket,
|
|
dropPatternDropFirst3Packets,
|
|
dropPatternDropOneThirdOfPackets,
|
|
} {
|
|
t.Run(fmt.Sprintf("%s in direction %s", pattern, dir), func(t *testing.T) {
|
|
for _, conf := range []testConfig{
|
|
{postQuantum: false, longCertChain: false, doRetry: true},
|
|
{postQuantum: false, longCertChain: false, doRetry: false},
|
|
{postQuantum: false, longCertChain: true, doRetry: false},
|
|
{postQuantum: true, longCertChain: false, doRetry: false},
|
|
{postQuantum: true, longCertChain: true, doRetry: false},
|
|
} {
|
|
for _, test := range []struct {
|
|
name string
|
|
fn func(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn
|
|
}{
|
|
{"client speaks first", dropTestProtocolClientSpeaksFirst},
|
|
{"server speaks first", dropTestProtocolServerSpeaksFirst},
|
|
{"nobody speaks", dropTestProtocolNobodySpeaks},
|
|
} {
|
|
t.Run(fmt.Sprintf("retry: %t/%s", conf.doRetry, test.name), func(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}
|
|
serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}
|
|
var fn func(direction, simnet.Packet) bool
|
|
switch pattern {
|
|
case dropPatternDrop1stPacket:
|
|
fn = dropCallbackDropNthPacket(dir, 1)
|
|
case dropPatternDropFirst3Packets:
|
|
fn = dropCallbackDropNthPacket(dir, 1, 2, 3)
|
|
case dropPatternDropOneThirdOfPackets:
|
|
fn = dropCallbackDropOneThird(dir)
|
|
}
|
|
var numDropped atomic.Int32
|
|
n := &simnet.Simnet{
|
|
Router: &directionAwareDroppingRouter{
|
|
ClientAddr: clientAddr,
|
|
ServerAddr: serverAddr,
|
|
Drop: func(d direction, p simnet.Packet) bool {
|
|
drop := fn(d, p)
|
|
if drop {
|
|
numDropped.Add(1)
|
|
}
|
|
return drop
|
|
},
|
|
},
|
|
}
|
|
settings := simnet.NodeBiDiLinkSettings{
|
|
Downlink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
|
|
Uplink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
|
|
}
|
|
clientConn := n.NewEndpoint(clientAddr, settings)
|
|
defer clientConn.Close()
|
|
serverConn := n.NewEndpoint(serverAddr, settings)
|
|
defer serverConn.Close()
|
|
require.NoError(t, n.Start())
|
|
defer n.Close()
|
|
|
|
var tlsConf *tls.Config
|
|
if conf.longCertChain {
|
|
tlsConf = getTLSConfigWithLongCertChain()
|
|
} else {
|
|
tlsConf = getTLSConfig()
|
|
}
|
|
clientConf := getTLSClientConfig()
|
|
if !conf.postQuantum {
|
|
clientConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
|
|
}
|
|
|
|
tr := &quic.Transport{
|
|
Conn: serverConn,
|
|
VerifySourceAddress: func(net.Addr) bool { return conf.doRetry },
|
|
}
|
|
defer tr.Close()
|
|
|
|
ln, err := tr.Listen(
|
|
tlsConf,
|
|
getQuicConfig(&quic.Config{
|
|
MaxIdleTimeout: timeout,
|
|
HandshakeIdleTimeout: timeout,
|
|
DisablePathMTUDiscovery: true,
|
|
}),
|
|
)
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
conn := test.fn(t, ln, clientConn, clientConf, timeout, data)
|
|
if !strings.HasPrefix(runtime.Version(), "go1.24") {
|
|
curveID := getCurveID(conn.ConnectionState().TLS)
|
|
if conf.postQuantum {
|
|
require.Equal(t, tls.X25519MLKEM768, curveID)
|
|
} else {
|
|
require.Equal(t, tls.CurveP384, curveID)
|
|
}
|
|
}
|
|
|
|
if pattern != dropPatternDropOneThirdOfPackets {
|
|
require.NotZero(t, numDropped.Load())
|
|
}
|
|
t.Logf("dropped %d packets", numDropped.Load())
|
|
})
|
|
})
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|