use synctest for the stateless reset tests (#5415)

This commit is contained in:
Marten Seemann
2025-11-01 17:58:37 +01:00
committed by GitHub
parent d3211a4193
commit a882b9a003

View File

@@ -3,13 +3,14 @@ package self_test
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"net"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/internal/synctest"
"github.com/quic-go/quic-go/testutils/simnet"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -23,104 +24,104 @@ func TestStatelessResets(t *testing.T) {
} }
func testStatelessReset(t *testing.T, connIDLen int) { func testStatelessReset(t *testing.T, connIDLen int) {
var statelessResetKey quic.StatelessResetKey synctest.Test(t, func(t *testing.T) {
rand.Read(statelessResetKey[:]) var drop atomic.Bool
clientPacketConn, serverPacketConn, closeFn := newSimnetLinkWithRouter(t,
c := newUDPConnLocalhost(t) time.Millisecond,
tr := &quic.Transport{ &droppingRouter{Drop: func(p simnet.Packet) bool { return drop.Load() }},
Conn: c,
StatelessResetKey: &statelessResetKey,
}
defer tr.Close()
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
serverErr := make(chan error, 1)
go func() {
conn, err := ln.Accept(context.Background())
if err != nil {
serverErr <- err
return
}
str, err := conn.OpenStream()
if err != nil {
serverErr <- err
return
}
_, err = str.Write([]byte("foobar"))
if err != nil {
serverErr <- err
return
}
close(serverErr)
}()
var drop atomic.Bool
proxy := quicproxy.Proxy{
Conn: newUDPConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DropPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) bool { return drop.Load() },
}
require.NoError(t, proxy.Start())
defer proxy.Close()
var conn *quic.Conn
if connIDLen > 0 {
cl := &quic.Transport{
Conn: newUDPConnLocalhost(t),
ConnectionIDLength: connIDLen,
}
defer cl.Close()
var err error
conn, err = cl.Dial(
context.Background(),
proxy.LocalAddr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
) )
defer closeFn(t)
var statelessResetKey quic.StatelessResetKey
rand.Read(statelessResetKey[:])
tr := &quic.Transport{
Conn: serverPacketConn,
StatelessResetKey: &statelessResetKey,
}
defer tr.Close()
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err) require.NoError(t, err)
} else {
conn, err = quic.Dial( serverErr := make(chan error, 1)
context.Background(), go func() {
newUDPConnLocalhost(t), conn, err := ln.Accept(context.Background())
proxy.LocalAddr(), if err != nil {
getTLSClientConfig(), serverErr <- err
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}), return
) }
str, err := conn.OpenStream()
if err != nil {
serverErr <- err
return
}
_, err = str.Write([]byte("foobar"))
if err != nil {
serverErr <- err
return
}
close(serverErr)
}()
var conn *quic.Conn
if connIDLen > 0 {
cl := &quic.Transport{
Conn: clientPacketConn,
ConnectionIDLength: connIDLen,
}
defer cl.Close()
var err error
conn, err = cl.Dial(
context.Background(),
serverPacketConn.LocalAddr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
)
require.NoError(t, err)
} else {
conn, err = quic.Dial(
context.Background(),
clientPacketConn,
serverPacketConn.LocalAddr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
)
require.NoError(t, err)
}
str, err := conn.AcceptStream(context.Background())
require.NoError(t, err) require.NoError(t, err)
} data := make([]byte, 6)
str, err := conn.AcceptStream(context.Background()) _, err = str.Read(data)
require.NoError(t, err) require.NoError(t, err)
data := make([]byte, 6) require.Equal(t, []byte("foobar"), data)
_, err = str.Read(data)
require.NoError(t, err)
require.Equal(t, []byte("foobar"), data)
// make sure that the CONNECTION_CLOSE is dropped // make sure that the CONNECTION_CLOSE is dropped
drop.Store(true) drop.Store(true)
require.NoError(t, ln.Close()) require.NoError(t, ln.Close())
require.NoError(t, tr.Close()) require.NoError(t, tr.Close())
require.NoError(t, <-serverErr) require.NoError(t, <-serverErr)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
// We need to create a new Transport here, since the old one is still sending out // We need to create a new Transport here, since the old one is still sending out
// CONNECTION_CLOSE packets for (recently) closed connections). // CONNECTION_CLOSE packets for (recently) closed connections).
tr2 := &quic.Transport{ tr2 := &quic.Transport{
Conn: c, Conn: serverPacketConn,
StatelessResetKey: &statelessResetKey, StatelessResetKey: &statelessResetKey,
} }
defer tr2.Close() defer tr2.Close()
ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil)) ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err) require.NoError(t, err)
drop.Store(false) drop.Store(false)
// Trigger something (not too small) to be sent, so that we receive the stateless reset. // Trigger something (not too small) to be sent, so that we receive the stateless reset.
// If the client already sent another packet, it might already have received a packet. // If the client already sent another packet, it might already have received a packet.
_, serr := str.Write([]byte("Lorem ipsum dolor sit amet.")) _, serr := str.Write([]byte("Lorem ipsum dolor sit amet."))
if serr == nil { if serr == nil {
_, serr = str.Read([]byte{0}) _, serr = str.Read([]byte{0})
} }
require.Error(t, serr) require.Error(t, serr)
require.IsType(t, &quic.StatelessResetError{}, serr) require.IsType(t, &quic.StatelessResetError{}, serr)
require.NoError(t, ln2.Close()) require.NoError(t, ln2.Close())
})
} }