use synctest for transport tests (#5391)

* use synctest for transport tests

* disable TestTransportReplaceWithClosed on Go 1.24
This commit is contained in:
Marten Seemann
2025-10-19 15:51:01 +08:00
committed by GitHub
parent 7772755df2
commit cb7f6adea0
2 changed files with 369 additions and 322 deletions

View File

@@ -16,5 +16,7 @@ golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -8,6 +8,7 @@ import (
"math"
"net"
"runtime"
"strings"
"sync/atomic"
"syscall"
"testing"
@@ -15,11 +16,13 @@ import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/synctest"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/qlog"
"github.com/quic-go/quic-go/qlogwriter"
"github.com/quic-go/quic-go/testutils/events"
"github.com/quic-go/quic-go/testutils/simnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -62,6 +65,23 @@ func (h *mockPacketHandler) destroy(err error) {
func (h *mockPacketHandler) closeWithTransportError(code qerr.TransportErrorCode) {}
func newSimnetLink(t *testing.T, rtt time.Duration) (client, server net.PacketConn, close func()) {
t.Helper()
n := &simnet.Simnet{Router: &simnet.PerfectRouter{}}
settings := simnet.NodeBiDiLinkSettings{
Downlink: simnet.LinkSettings{BitsPerSecond: math.MaxInt / 1024, Latency: rtt / 4},
Uplink: simnet.LinkSettings{BitsPerSecond: math.MaxInt / 1024, Latency: rtt / 4},
}
client = n.NewEndpoint(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9001}, settings)
server = n.NewEndpoint(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9002}, settings)
require.NoError(t, n.Start())
return client, server, func() {
require.NoError(t, n.Close())
}
}
func TestTransportPacketHandling(t *testing.T) {
tr := &Transport{Conn: newUDPConnLocalhost(t)}
tr.init(true)
@@ -143,38 +163,42 @@ func TestTransportAndDialConcurrentClose(t *testing.T) {
func TestTransportErrFromConn(t *testing.T) {
t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true")
readErrChan := make(chan error, 2)
tr := Transport{
Conn: &mockPacketConn{
readErrs: readErrChan,
localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234},
},
}
defer tr.Close()
tr.init(true)
synctest.Test(t, func(t *testing.T) {
readErrChan := make(chan error, 2)
tr := Transport{
Conn: &mockPacketConn{
readErrs: readErrChan,
localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234},
},
}
defer tr.Close()
tr.init(true)
errChan := make(chan error, 1)
ph := &mockPacketHandler{destruction: errChan}
(*packetHandlerMap)(&tr).Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), ph)
errChan := make(chan error, 1)
ph := &mockPacketHandler{destruction: errChan}
(*packetHandlerMap)(&tr).Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), ph)
// temporary errors don't lead to a shutdown...
var tempErr deadlineError
require.True(t, tempErr.Temporary())
readErrChan <- tempErr
// don't expect any calls to phm.Close
time.Sleep(scaleDuration(10 * time.Millisecond))
// temporary errors don't lead to a shutdown...
var tempErr deadlineError
require.True(t, tempErr.Temporary())
readErrChan <- tempErr
// don't expect any calls to phm.Close
synctest.Wait()
// ...but non-temporary errors do
readErrChan <- errors.New("read failed")
select {
case err := <-errChan:
// ...but non-temporary errors do
readErrChan <- errors.New("read failed")
synctest.Wait()
select {
case err := <-errChan:
require.ErrorIs(t, err, ErrTransportClosed)
case <-time.After(time.Second):
t.Fatal("timeout")
}
_, err := tr.Listen(&tls.Config{}, nil)
require.ErrorIs(t, err, ErrTransportClosed)
case <-time.After(time.Second):
t.Fatal("timeout")
}
_, err := tr.Listen(&tls.Config{}, nil)
require.ErrorIs(t, err, ErrTransportClosed)
})
}
func TestTransportStatelessResetReceiving(t *testing.T) {
@@ -209,195 +233,205 @@ func TestTransportStatelessResetReceiving(t *testing.T) {
}
func TestTransportStatelessResetSending(t *testing.T) {
var eventRecorder events.Recorder
tr := &Transport{
Conn: newUDPConnLocalhost(t),
ConnectionIDLength: 4,
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
Tracer: &eventRecorder,
}
tr.init(true)
defer tr.Close()
synctest.Test(t, func(t *testing.T) {
const rtt = 10 * time.Millisecond
clientConn, serverConn, closeFn := newSimnetLink(t, rtt)
defer closeFn()
connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12})
var eventRecorder events.Recorder
tr := &Transport{
Conn: serverConn,
ConnectionIDLength: 4,
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
Tracer: &eventRecorder,
}
tr.init(true)
defer tr.Close()
// now send a packet with a connection ID that doesn't exist
b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne)
require.NoError(t, err)
connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12})
conn := newUDPConnLocalhost(t)
// now send a packet with a connection ID that doesn't exist
b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne)
require.NoError(t, err)
// no stateless reset sent for packets smaller than MinStatelessResetSize
smallPacket := append(b, make([]byte, protocol.MinStatelessResetSize-len(b))...)
_, err = conn.WriteTo(smallPacket, tr.Conn.LocalAddr())
require.NoError(t, err)
require.Eventually(t,
func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 },
time.Second,
10*time.Millisecond,
)
require.Equal(t,
[]qlogwriter.Event{
qlog.PacketDropped{
Header: qlog.PacketHeader{PacketType: qlog.PacketType1RTT},
Raw: qlog.RawInfo{Length: len(smallPacket)},
Trigger: qlog.PacketDropUnknownConnectionID,
// no stateless reset sent for packets smaller than MinStatelessResetSize
smallPacket := append(b, make([]byte, protocol.MinStatelessResetSize-len(b))...)
_, err = clientConn.WriteTo(smallPacket, tr.Conn.LocalAddr())
require.NoError(t, err)
time.Sleep(rtt) // so that the packet arrives at the server
require.Equal(t,
[]qlogwriter.Event{
qlog.PacketDropped{
Header: qlog.PacketHeader{PacketType: qlog.PacketType1RTT},
Raw: qlog.RawInfo{Length: len(smallPacket)},
Trigger: qlog.PacketDropUnknownConnectionID,
},
},
},
eventRecorder.Events(qlog.PacketDropped{}),
)
eventRecorder.Events(qlog.PacketDropped{}),
)
// but a stateless reset is sent for packets larger than MinStatelessResetSize
_, err = conn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr())
require.NoError(t, err)
conn.SetReadDeadline(time.Now().Add(time.Second))
p := make([]byte, 1024)
n, addr, err := conn.ReadFrom(p)
require.NoError(t, err)
require.Equal(t, addr, tr.Conn.LocalAddr())
srt := newStatelessResetter(tr.StatelessResetKey).GetStatelessResetToken(connID)
require.Contains(t, string(p[:n]), string(srt[:]))
// but a stateless reset is sent for packets larger than MinStatelessResetSize
_, err = clientConn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr())
require.NoError(t, err)
clientConn.SetReadDeadline(time.Now().Add(time.Second))
p := make([]byte, 1024)
n, addr, err := clientConn.ReadFrom(p)
require.NoError(t, err)
require.Equal(t, addr, tr.Conn.LocalAddr())
srt := newStatelessResetter(tr.StatelessResetKey).GetStatelessResetToken(connID)
require.Contains(t, string(p[:n]), string(srt[:]))
})
}
func TestTransportUnparseableQUICPackets(t *testing.T) {
var eventRecorder events.Recorder
tr := &Transport{
Conn: newUDPConnLocalhost(t),
ConnectionIDLength: 10,
Tracer: &eventRecorder,
}
require.NoError(t, tr.init(true))
defer tr.Close()
synctest.Test(t, func(t *testing.T) {
const rtt = 10 * time.Millisecond
clientConn, serverConn, closeFn := newSimnetLink(t, rtt)
defer closeFn()
conn := newUDPConnLocalhost(t)
_, err := conn.WriteTo([]byte{0x40 /* set the QUIC bit */, 1, 2, 3}, tr.Conn.LocalAddr())
require.NoError(t, err)
var eventRecorder events.Recorder
tr := &Transport{
Conn: serverConn,
ConnectionIDLength: 10,
Tracer: &eventRecorder,
}
require.NoError(t, tr.init(true))
defer tr.Close()
require.Eventually(t,
func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 },
time.Second,
10*time.Millisecond,
)
require.Equal(t,
[]qlogwriter.Event{
qlog.PacketDropped{
Raw: qlog.RawInfo{Length: 4},
Trigger: qlog.PacketDropHeaderParseError,
_, err := clientConn.WriteTo([]byte{0x40 /* set the QUIC bit */, 1, 2, 3}, tr.Conn.LocalAddr())
require.NoError(t, err)
time.Sleep(rtt) // so that the packet arrives at the server
require.Equal(t,
[]qlogwriter.Event{
qlog.PacketDropped{
Raw: qlog.RawInfo{Length: 4},
Trigger: qlog.PacketDropHeaderParseError,
},
},
},
eventRecorder.Events(qlog.PacketDropped{}),
)
eventRecorder.Events(qlog.PacketDropped{}),
)
})
}
func TestTransportListening(t *testing.T) {
var eventRecorder events.Recorder
tr := &Transport{
Conn: newUDPConnLocalhost(t),
ConnectionIDLength: 5,
Tracer: &eventRecorder,
}
require.NoError(t, tr.init(true))
defer tr.Close()
synctest.Test(t, func(t *testing.T) {
const rtt = 10 * time.Millisecond
clientConn, serverConn, closeFn := newSimnetLink(t, rtt)
defer closeFn()
conn := newUDPConnLocalhost(t)
data := wire.ComposeVersionNegotiation([]byte{1, 2, 3, 4, 5}, []byte{6, 7, 8, 9, 10}, []protocol.Version{protocol.Version1})
var eventRecorder events.Recorder
tr := &Transport{
Conn: serverConn,
ConnectionIDLength: 5,
Tracer: &eventRecorder,
}
require.NoError(t, tr.init(true))
defer tr.Close()
_, err := conn.WriteTo(data, tr.Conn.LocalAddr())
require.NoError(t, err)
require.Eventually(t,
func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 },
time.Second,
10*time.Millisecond,
)
require.Equal(t,
[]qlogwriter.Event{
qlog.PacketDropped{
Raw: qlog.RawInfo{Length: len(data)},
Trigger: qlog.PacketDropUnknownConnectionID,
data := wire.ComposeVersionNegotiation([]byte{1, 2, 3, 4, 5}, []byte{6, 7, 8, 9, 10}, []protocol.Version{protocol.Version1})
_, err := clientConn.WriteTo(data, tr.Conn.LocalAddr())
require.NoError(t, err)
time.Sleep(rtt) // so that the packet arrives at the server
require.Equal(t,
[]qlogwriter.Event{
qlog.PacketDropped{
Raw: qlog.RawInfo{Length: len(data)},
Trigger: qlog.PacketDropUnknownConnectionID,
},
},
},
eventRecorder.Events(qlog.PacketDropped{}),
)
eventRecorder.Clear()
eventRecorder.Events(qlog.PacketDropped{}),
)
eventRecorder.Clear()
ln, err := tr.Listen(&tls.Config{}, nil)
require.NoError(t, err)
ln, err := tr.Listen(&tls.Config{}, nil)
require.NoError(t, err)
_, err = conn.WriteTo(data, tr.Conn.LocalAddr())
require.NoError(t, err)
require.Eventually(t,
func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 },
time.Second,
10*time.Millisecond,
)
require.Equal(t,
[]qlogwriter.Event{
qlog.PacketDropped{
Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation},
Raw: qlog.RawInfo{Length: len(data)},
Trigger: qlog.PacketDropUnexpectedPacket,
_, err = clientConn.WriteTo(data, tr.Conn.LocalAddr())
require.NoError(t, err)
time.Sleep(rtt) // so that the packet arrives at the server
require.Equal(t,
[]qlogwriter.Event{
qlog.PacketDropped{
Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation},
Raw: qlog.RawInfo{Length: len(data)},
Trigger: qlog.PacketDropUnexpectedPacket,
},
},
},
eventRecorder.Events(qlog.PacketDropped{}),
)
eventRecorder.Events(qlog.PacketDropped{}),
)
// only a single listener can be set
_, err = tr.Listen(&tls.Config{}, nil)
require.Error(t, err)
require.ErrorIs(t, err, errListenerAlreadySet)
// only a single listener can be set
_, err = tr.Listen(&tls.Config{}, nil)
require.Error(t, err)
require.ErrorIs(t, err, errListenerAlreadySet)
require.NoError(t, ln.Close())
// now it's possible to add a new listener
ln, err = tr.Listen(&tls.Config{}, nil)
require.NoError(t, err)
defer ln.Close()
require.NoError(t, ln.Close())
// now it's possible to add a new listener
ln, err = tr.Listen(&tls.Config{}, nil)
require.NoError(t, err)
defer ln.Close()
})
}
func TestTransportNonQUICPackets(t *testing.T) {
tr := &Transport{Conn: newUDPConnLocalhost(t)}
defer tr.Close()
synctest.Test(t, func(t *testing.T) {
const rtt = 10 * time.Millisecond
clientConn, serverConn, closeFn := newSimnetLink(t, rtt)
defer closeFn()
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(5*time.Millisecond))
defer cancel()
_, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 1024))
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
tr := &Transport{Conn: serverConn}
defer tr.Close()
conn := newUDPConnLocalhost(t)
data := []byte{0 /* don't set the QUIC bit */, 1, 2, 3}
_, err = conn.WriteTo(data, tr.Conn.LocalAddr())
require.NoError(t, err)
_, err = conn.WriteTo(data, tr.Conn.LocalAddr())
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), scaleDuration(time.Second))
defer cancel()
b := make([]byte, 1024)
n, addr, err := tr.ReadNonQUICPacket(ctx, b)
require.NoError(t, err)
require.Equal(t, data, b[:n])
require.Equal(t, addr, conn.LocalAddr())
// now send a lot of packets without reading them
for i := range 2 * maxQueuedNonQUICPackets {
data := append([]byte{0 /* don't set the QUIC bit */, uint8(i)}, bytes.Repeat([]byte{uint8(i)}, 1000)...)
_, err = conn.WriteTo(data, tr.Conn.LocalAddr())
require.NoError(t, err)
}
time.Sleep(scaleDuration(10 * time.Millisecond))
var received int
for {
ctx, cancel = context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond))
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
_, _, err := tr.ReadNonQUICPacket(ctx, b)
if errors.Is(err, context.DeadlineExceeded) {
break
}
_, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 1024))
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
data := []byte{0 /* don't set the QUIC bit */, 1, 2, 3}
_, err = clientConn.WriteTo(data, tr.Conn.LocalAddr())
require.NoError(t, err)
received++
}
require.Equal(t, received, maxQueuedNonQUICPackets)
_, err = clientConn.WriteTo(data, tr.Conn.LocalAddr())
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
defer cancel()
b := make([]byte, 1024)
n, addr, err := tr.ReadNonQUICPacket(ctx, b)
require.NoError(t, err)
require.Equal(t, data, b[:n])
require.Equal(t, addr, clientConn.LocalAddr())
// now send a lot of packets without reading them
for i := range 2 * maxQueuedNonQUICPackets {
data := append([]byte{0 /* don't set the QUIC bit */, uint8(i)}, bytes.Repeat([]byte{uint8(i)}, 1000)...)
_, err = clientConn.WriteTo(data, tr.Conn.LocalAddr())
require.NoError(t, err)
}
time.Sleep(rtt) // so that all packets arrive at the server
var received int
for {
ctx, cancel = context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
_, _, err := tr.ReadNonQUICPacket(ctx, b)
if errors.Is(err, context.DeadlineExceeded) {
break
}
require.NoError(t, err)
received++
}
require.Equal(t, received, maxQueuedNonQUICPackets)
})
}
type faultySyscallConn struct{ net.PacketConn }
@@ -466,69 +500,79 @@ func testTransportDial(t *testing.T, early bool) {
originalClientConnConstructor := newClientConnection
t.Cleanup(func() { newClientConnection = originalClientConnConstructor })
var conn *connTestHooks
handshakeChan := make(chan struct{})
blockRun := make(chan struct{})
if early {
conn = &connTestHooks{
earlyConnReady: func() <-chan struct{} { return handshakeChan },
handshakeComplete: func() <-chan struct{} { return make(chan struct{}) },
}
} else {
conn = &connTestHooks{
handshakeComplete: func() <-chan struct{} { return handshakeChan },
}
}
conn.run = func() error { <-blockRun; return errors.New("done") }
defer close(blockRun)
synctest.Test(t, func(t *testing.T) {
_, serverConn, closeFn := newSimnetLink(t, 10*time.Millisecond)
defer closeFn()
newClientConnection = func(
_ context.Context,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ qlogwriter.Trace,
_ utils.Logger,
_ protocol.Version,
) *wrappedConn {
return &wrappedConn{testHooks: conn}
}
tr := &Transport{Conn: newUDPConnLocalhost(t)}
tr.init(true)
defer tr.Close()
errChan := make(chan error, 1)
go func() {
var err error
var conn *connTestHooks
handshakeChan := make(chan struct{})
blockRun := make(chan struct{})
if early {
_, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil)
conn = &connTestHooks{
earlyConnReady: func() <-chan struct{} { return handshakeChan },
handshakeComplete: func() <-chan struct{} { return make(chan struct{}) },
}
} else {
_, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil)
conn = &connTestHooks{
handshakeComplete: func() <-chan struct{} { return handshakeChan },
}
}
errChan <- err
}()
conn.run = func() error { <-blockRun; return errors.New("done") }
defer close(blockRun)
select {
case <-errChan:
t.Fatal("Dial shouldn't have returned")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
newClientConnection = func(
_ context.Context,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ qlogwriter.Trace,
_ utils.Logger,
_ protocol.Version,
) *wrappedConn {
return &wrappedConn{testHooks: conn}
}
close(handshakeChan)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
}
tr := &Transport{Conn: serverConn}
tr.init(true)
defer tr.Close()
errChan := make(chan error, 1)
go func() {
var err error
if early {
_, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil)
} else {
_, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil)
}
errChan <- err
}()
synctest.Wait()
select {
case <-errChan:
t.Fatal("Dial shouldn't have returned")
default:
}
close(handshakeChan)
synctest.Wait()
select {
case err := <-errChan:
require.NoError(t, err)
default:
}
})
}
func TestTransportDialingVersionNegotiation(t *testing.T) {
@@ -603,8 +647,11 @@ func TestTransportDialingVersionNegotiation(t *testing.T) {
}
func TestTransportReplaceWithClosed(t *testing.T) {
t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true")
// synctest works slightly differently on Go 1.24,
// so we skip the test
if strings.HasPrefix(runtime.Version(), "go1.24") {
t.Skip("skipping on Go 1.24 due to synctest issues")
}
t.Run("local", func(t *testing.T) {
testTransportReplaceWithClosed(t, true)
})
@@ -614,91 +661,89 @@ func TestTransportReplaceWithClosed(t *testing.T) {
}
func testTransportReplaceWithClosed(t *testing.T, local bool) {
srk := StatelessResetKey{1, 2, 3, 4}
tr := &Transport{
Conn: newUDPConnLocalhost(t),
ConnectionIDLength: 4,
StatelessResetKey: &srk,
}
tr.init(true)
defer tr.Close()
synctest.Test(t, func(t *testing.T) {
clientConn, serverConn, closeFn := newSimnetLink(t, 10*time.Millisecond)
defer closeFn()
dur := scaleDuration(20 * time.Millisecond)
srk := StatelessResetKey{1, 2, 3, 4}
tr := &Transport{
Conn: serverConn,
ConnectionIDLength: 4,
StatelessResetKey: &srk,
}
tr.init(true)
defer tr.Close()
var closePacket []byte
if local {
closePacket = []byte("foobar")
}
var closePacket []byte
if local {
closePacket = []byte("foobar")
}
handler := &mockPacketHandler{}
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
m := (*packetHandlerMap)(tr)
require.True(t, m.Add(connID, handler))
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, dur)
const expiry = 50 * time.Millisecond
handler := &mockPacketHandler{}
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
m := (*packetHandlerMap)(tr)
require.True(t, m.Add(connID, handler))
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, expiry)
p := make([]byte, 100)
p[0] = 0x40 // QUIC bit
copy(p[1:], connID.Bytes())
p := make([]byte, 100)
p[0] = 0x40 // QUIC bit
copy(p[1:], connID.Bytes())
conn := newUDPConnLocalhost(t)
var sent atomic.Int64
errChan := make(chan error, 1)
stopSending := make(chan struct{})
go func() {
defer close(errChan)
ticker := time.NewTicker(dur / 50)
timeout := time.NewTimer(scaleDuration(time.Second))
var sent atomic.Int64
errChan := make(chan error, 1)
stopSending := make(chan struct{})
go func() {
defer close(errChan)
ticker := time.NewTicker(expiry / 200)
timeout := time.NewTimer(time.Second)
for {
select {
case <-stopSending:
return
case <-timeout.C:
errChan <- errors.New("timeout")
return
case <-ticker.C:
}
if _, err := clientConn.WriteTo(p, tr.Conn.LocalAddr()); err != nil {
errChan <- err
return
}
sent.Add(1)
}
}()
// For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff
var received int
clientConn.SetReadDeadline(time.Now().Add(time.Hour))
for {
select {
case <-stopSending:
return
case <-timeout.C:
errChan <- errors.New("timeout")
return
case <-ticker.C:
b := make([]byte, 100)
n, _, err := clientConn.ReadFrom(b)
require.NoError(t, err)
// at some point, the connection is cleaned up, and we'll receive a stateless reset
if !bytes.Equal(b[:n], []byte("foobar")) {
require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize)
close(stopSending) // stop sending packets
break
}
if _, err := conn.WriteTo(p, tr.Conn.LocalAddr()); err != nil {
errChan <- err
return
}
sent.Add(1)
received++
}
}()
// For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff
var received int
conn.SetReadDeadline(time.Now().Add(scaleDuration(time.Second)))
for {
b := make([]byte, 100)
n, _, err := conn.ReadFrom(b)
require.NoError(t, err)
// at some point, the connection is cleaned up, and we'll receive a stateless reset
if !bytes.Equal(b[:n], []byte("foobar")) {
require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize)
close(stopSending) // stop sending packets
break
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
received++
}
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
numSent := sent.Load()
if !local {
require.Zero(t, received)
t.Logf("sent %d packets", numSent)
return
}
t.Logf("sent %d packets, received %d CONNECTION_CLOSE copies", numSent, received)
// timer resolution on Windows is terrible
if runtime.GOOS != "windows" {
require.GreaterOrEqual(t, numSent, int64(8))
}
require.GreaterOrEqual(t, received, int(math.Floor(math.Log2(float64(numSent)))))
require.LessOrEqual(t, received, int(math.Ceil(math.Log2(float64(numSent)))))
numSent := sent.Load()
if !local {
require.Zero(t, received)
t.Logf("sent %d packets", numSent)
return
}
t.Logf("sent %d packets, received %d CONNECTION_CLOSE copies", numSent, received)
require.Equal(t, int(math.Ceil(math.Log2(float64(numSent)))), received)
})
}