diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index 2de49ce3..e23f8e33 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -16,5 +16,7 @@ golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/transport_test.go b/transport_test.go index 0938b742..cf085251 100644 --- a/transport_test.go +++ b/transport_test.go @@ -8,6 +8,7 @@ import ( "math" "net" "runtime" + "strings" "sync/atomic" "syscall" "testing" @@ -15,11 +16,13 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" + "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -62,6 +65,23 @@ func (h *mockPacketHandler) destroy(err error) { func (h *mockPacketHandler) closeWithTransportError(code qerr.TransportErrorCode) {} +func newSimnetLink(t *testing.T, rtt time.Duration) (client, server net.PacketConn, close func()) { + t.Helper() + + n := &simnet.Simnet{Router: &simnet.PerfectRouter{}} + settings := simnet.NodeBiDiLinkSettings{ + Downlink: simnet.LinkSettings{BitsPerSecond: math.MaxInt / 1024, Latency: rtt / 4}, + Uplink: simnet.LinkSettings{BitsPerSecond: math.MaxInt / 1024, Latency: rtt / 4}, + } + + client = n.NewEndpoint(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9001}, settings) + server = n.NewEndpoint(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9002}, settings) + require.NoError(t, n.Start()) + return client, server, func() { + require.NoError(t, n.Close()) + } +} + func TestTransportPacketHandling(t *testing.T) { tr := &Transport{Conn: newUDPConnLocalhost(t)} tr.init(true) @@ -143,38 +163,42 @@ func TestTransportAndDialConcurrentClose(t *testing.T) { func TestTransportErrFromConn(t *testing.T) { t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") - readErrChan := make(chan error, 2) - tr := Transport{ - Conn: &mockPacketConn{ - readErrs: readErrChan, - localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}, - }, - } - defer tr.Close() - tr.init(true) + synctest.Test(t, func(t *testing.T) { + readErrChan := make(chan error, 2) + tr := Transport{ + Conn: &mockPacketConn{ + readErrs: readErrChan, + localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}, + }, + } + defer tr.Close() + tr.init(true) - errChan := make(chan error, 1) - ph := &mockPacketHandler{destruction: errChan} - (*packetHandlerMap)(&tr).Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), ph) + errChan := make(chan error, 1) + ph := &mockPacketHandler{destruction: errChan} + (*packetHandlerMap)(&tr).Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), ph) - // temporary errors don't lead to a shutdown... - var tempErr deadlineError - require.True(t, tempErr.Temporary()) - readErrChan <- tempErr - // don't expect any calls to phm.Close - time.Sleep(scaleDuration(10 * time.Millisecond)) + // temporary errors don't lead to a shutdown... + var tempErr deadlineError + require.True(t, tempErr.Temporary()) + readErrChan <- tempErr + // don't expect any calls to phm.Close + synctest.Wait() - // ...but non-temporary errors do - readErrChan <- errors.New("read failed") - select { - case err := <-errChan: + // ...but non-temporary errors do + readErrChan <- errors.New("read failed") + synctest.Wait() + + select { + case err := <-errChan: + require.ErrorIs(t, err, ErrTransportClosed) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + _, err := tr.Listen(&tls.Config{}, nil) require.ErrorIs(t, err, ErrTransportClosed) - case <-time.After(time.Second): - t.Fatal("timeout") - } - - _, err := tr.Listen(&tls.Config{}, nil) - require.ErrorIs(t, err, ErrTransportClosed) + }) } func TestTransportStatelessResetReceiving(t *testing.T) { @@ -209,195 +233,205 @@ func TestTransportStatelessResetReceiving(t *testing.T) { } func TestTransportStatelessResetSending(t *testing.T) { - var eventRecorder events.Recorder - tr := &Transport{ - Conn: newUDPConnLocalhost(t), - ConnectionIDLength: 4, - StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, - Tracer: &eventRecorder, - } - tr.init(true) - defer tr.Close() + synctest.Test(t, func(t *testing.T) { + const rtt = 10 * time.Millisecond + clientConn, serverConn, closeFn := newSimnetLink(t, rtt) + defer closeFn() - connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12}) + var eventRecorder events.Recorder + tr := &Transport{ + Conn: serverConn, + ConnectionIDLength: 4, + StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, + Tracer: &eventRecorder, + } + tr.init(true) + defer tr.Close() - // now send a packet with a connection ID that doesn't exist - b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne) - require.NoError(t, err) + connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12}) - conn := newUDPConnLocalhost(t) + // now send a packet with a connection ID that doesn't exist + b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne) + require.NoError(t, err) - // no stateless reset sent for packets smaller than MinStatelessResetSize - smallPacket := append(b, make([]byte, protocol.MinStatelessResetSize-len(b))...) - _, err = conn.WriteTo(smallPacket, tr.Conn.LocalAddr()) - require.NoError(t, err) - require.Eventually(t, - func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 }, - time.Second, - 10*time.Millisecond, - ) - require.Equal(t, - []qlogwriter.Event{ - qlog.PacketDropped{ - Header: qlog.PacketHeader{PacketType: qlog.PacketType1RTT}, - Raw: qlog.RawInfo{Length: len(smallPacket)}, - Trigger: qlog.PacketDropUnknownConnectionID, + // no stateless reset sent for packets smaller than MinStatelessResetSize + smallPacket := append(b, make([]byte, protocol.MinStatelessResetSize-len(b))...) + _, err = clientConn.WriteTo(smallPacket, tr.Conn.LocalAddr()) + require.NoError(t, err) + + time.Sleep(rtt) // so that the packet arrives at the server + + require.Equal(t, + []qlogwriter.Event{ + qlog.PacketDropped{ + Header: qlog.PacketHeader{PacketType: qlog.PacketType1RTT}, + Raw: qlog.RawInfo{Length: len(smallPacket)}, + Trigger: qlog.PacketDropUnknownConnectionID, + }, }, - }, - eventRecorder.Events(qlog.PacketDropped{}), - ) + eventRecorder.Events(qlog.PacketDropped{}), + ) - // but a stateless reset is sent for packets larger than MinStatelessResetSize - _, err = conn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr()) - require.NoError(t, err) - conn.SetReadDeadline(time.Now().Add(time.Second)) - p := make([]byte, 1024) - n, addr, err := conn.ReadFrom(p) - require.NoError(t, err) - require.Equal(t, addr, tr.Conn.LocalAddr()) - srt := newStatelessResetter(tr.StatelessResetKey).GetStatelessResetToken(connID) - require.Contains(t, string(p[:n]), string(srt[:])) + // but a stateless reset is sent for packets larger than MinStatelessResetSize + _, err = clientConn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr()) + require.NoError(t, err) + clientConn.SetReadDeadline(time.Now().Add(time.Second)) + p := make([]byte, 1024) + n, addr, err := clientConn.ReadFrom(p) + require.NoError(t, err) + require.Equal(t, addr, tr.Conn.LocalAddr()) + srt := newStatelessResetter(tr.StatelessResetKey).GetStatelessResetToken(connID) + require.Contains(t, string(p[:n]), string(srt[:])) + }) } func TestTransportUnparseableQUICPackets(t *testing.T) { - var eventRecorder events.Recorder - tr := &Transport{ - Conn: newUDPConnLocalhost(t), - ConnectionIDLength: 10, - Tracer: &eventRecorder, - } - require.NoError(t, tr.init(true)) - defer tr.Close() + synctest.Test(t, func(t *testing.T) { + const rtt = 10 * time.Millisecond + clientConn, serverConn, closeFn := newSimnetLink(t, rtt) + defer closeFn() - conn := newUDPConnLocalhost(t) - _, err := conn.WriteTo([]byte{0x40 /* set the QUIC bit */, 1, 2, 3}, tr.Conn.LocalAddr()) - require.NoError(t, err) + var eventRecorder events.Recorder + tr := &Transport{ + Conn: serverConn, + ConnectionIDLength: 10, + Tracer: &eventRecorder, + } + require.NoError(t, tr.init(true)) + defer tr.Close() - require.Eventually(t, - func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 }, - time.Second, - 10*time.Millisecond, - ) - require.Equal(t, - []qlogwriter.Event{ - qlog.PacketDropped{ - Raw: qlog.RawInfo{Length: 4}, - Trigger: qlog.PacketDropHeaderParseError, + _, err := clientConn.WriteTo([]byte{0x40 /* set the QUIC bit */, 1, 2, 3}, tr.Conn.LocalAddr()) + require.NoError(t, err) + + time.Sleep(rtt) // so that the packet arrives at the server + + require.Equal(t, + []qlogwriter.Event{ + qlog.PacketDropped{ + Raw: qlog.RawInfo{Length: 4}, + Trigger: qlog.PacketDropHeaderParseError, + }, }, - }, - eventRecorder.Events(qlog.PacketDropped{}), - ) + eventRecorder.Events(qlog.PacketDropped{}), + ) + }) } func TestTransportListening(t *testing.T) { - var eventRecorder events.Recorder - tr := &Transport{ - Conn: newUDPConnLocalhost(t), - ConnectionIDLength: 5, - Tracer: &eventRecorder, - } - require.NoError(t, tr.init(true)) - defer tr.Close() + synctest.Test(t, func(t *testing.T) { + const rtt = 10 * time.Millisecond + clientConn, serverConn, closeFn := newSimnetLink(t, rtt) + defer closeFn() - conn := newUDPConnLocalhost(t) - data := wire.ComposeVersionNegotiation([]byte{1, 2, 3, 4, 5}, []byte{6, 7, 8, 9, 10}, []protocol.Version{protocol.Version1}) + var eventRecorder events.Recorder + tr := &Transport{ + Conn: serverConn, + ConnectionIDLength: 5, + Tracer: &eventRecorder, + } + require.NoError(t, tr.init(true)) + defer tr.Close() - _, err := conn.WriteTo(data, tr.Conn.LocalAddr()) - require.NoError(t, err) - require.Eventually(t, - func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 }, - time.Second, - 10*time.Millisecond, - ) - require.Equal(t, - []qlogwriter.Event{ - qlog.PacketDropped{ - Raw: qlog.RawInfo{Length: len(data)}, - Trigger: qlog.PacketDropUnknownConnectionID, + data := wire.ComposeVersionNegotiation([]byte{1, 2, 3, 4, 5}, []byte{6, 7, 8, 9, 10}, []protocol.Version{protocol.Version1}) + + _, err := clientConn.WriteTo(data, tr.Conn.LocalAddr()) + require.NoError(t, err) + + time.Sleep(rtt) // so that the packet arrives at the server + + require.Equal(t, + []qlogwriter.Event{ + qlog.PacketDropped{ + Raw: qlog.RawInfo{Length: len(data)}, + Trigger: qlog.PacketDropUnknownConnectionID, + }, }, - }, - eventRecorder.Events(qlog.PacketDropped{}), - ) - eventRecorder.Clear() + eventRecorder.Events(qlog.PacketDropped{}), + ) + eventRecorder.Clear() - ln, err := tr.Listen(&tls.Config{}, nil) - require.NoError(t, err) + ln, err := tr.Listen(&tls.Config{}, nil) + require.NoError(t, err) - _, err = conn.WriteTo(data, tr.Conn.LocalAddr()) - require.NoError(t, err) - require.Eventually(t, - func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 }, - time.Second, - 10*time.Millisecond, - ) - require.Equal(t, - []qlogwriter.Event{ - qlog.PacketDropped{ - Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation}, - Raw: qlog.RawInfo{Length: len(data)}, - Trigger: qlog.PacketDropUnexpectedPacket, + _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr()) + require.NoError(t, err) + time.Sleep(rtt) // so that the packet arrives at the server + + require.Equal(t, + []qlogwriter.Event{ + qlog.PacketDropped{ + Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation}, + Raw: qlog.RawInfo{Length: len(data)}, + Trigger: qlog.PacketDropUnexpectedPacket, + }, }, - }, - eventRecorder.Events(qlog.PacketDropped{}), - ) + eventRecorder.Events(qlog.PacketDropped{}), + ) - // only a single listener can be set - _, err = tr.Listen(&tls.Config{}, nil) - require.Error(t, err) - require.ErrorIs(t, err, errListenerAlreadySet) + // only a single listener can be set + _, err = tr.Listen(&tls.Config{}, nil) + require.Error(t, err) + require.ErrorIs(t, err, errListenerAlreadySet) - require.NoError(t, ln.Close()) - // now it's possible to add a new listener - ln, err = tr.Listen(&tls.Config{}, nil) - require.NoError(t, err) - defer ln.Close() + require.NoError(t, ln.Close()) + // now it's possible to add a new listener + ln, err = tr.Listen(&tls.Config{}, nil) + require.NoError(t, err) + defer ln.Close() + }) } func TestTransportNonQUICPackets(t *testing.T) { - tr := &Transport{Conn: newUDPConnLocalhost(t)} - defer tr.Close() + synctest.Test(t, func(t *testing.T) { + const rtt = 10 * time.Millisecond + clientConn, serverConn, closeFn := newSimnetLink(t, rtt) + defer closeFn() - ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(5*time.Millisecond)) - defer cancel() - _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 1024)) - require.Error(t, err) - require.ErrorIs(t, err, context.DeadlineExceeded) + tr := &Transport{Conn: serverConn} + defer tr.Close() - conn := newUDPConnLocalhost(t) - data := []byte{0 /* don't set the QUIC bit */, 1, 2, 3} - _, err = conn.WriteTo(data, tr.Conn.LocalAddr()) - require.NoError(t, err) - _, err = conn.WriteTo(data, tr.Conn.LocalAddr()) - require.NoError(t, err) - - ctx, cancel = context.WithTimeout(context.Background(), scaleDuration(time.Second)) - defer cancel() - b := make([]byte, 1024) - n, addr, err := tr.ReadNonQUICPacket(ctx, b) - require.NoError(t, err) - require.Equal(t, data, b[:n]) - require.Equal(t, addr, conn.LocalAddr()) - - // now send a lot of packets without reading them - for i := range 2 * maxQueuedNonQUICPackets { - data := append([]byte{0 /* don't set the QUIC bit */, uint8(i)}, bytes.Repeat([]byte{uint8(i)}, 1000)...) - _, err = conn.WriteTo(data, tr.Conn.LocalAddr()) - require.NoError(t, err) - } - time.Sleep(scaleDuration(10 * time.Millisecond)) - - var received int - for { - ctx, cancel = context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond)) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() - _, _, err := tr.ReadNonQUICPacket(ctx, b) - if errors.Is(err, context.DeadlineExceeded) { - break - } + _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 1024)) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + + data := []byte{0 /* don't set the QUIC bit */, 1, 2, 3} + _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr()) require.NoError(t, err) - received++ - } - require.Equal(t, received, maxQueuedNonQUICPackets) + _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr()) + require.NoError(t, err) + + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + b := make([]byte, 1024) + n, addr, err := tr.ReadNonQUICPacket(ctx, b) + require.NoError(t, err) + require.Equal(t, data, b[:n]) + require.Equal(t, addr, clientConn.LocalAddr()) + + // now send a lot of packets without reading them + for i := range 2 * maxQueuedNonQUICPackets { + data := append([]byte{0 /* don't set the QUIC bit */, uint8(i)}, bytes.Repeat([]byte{uint8(i)}, 1000)...) + _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr()) + require.NoError(t, err) + } + + time.Sleep(rtt) // so that all packets arrive at the server + + var received int + for { + ctx, cancel = context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + _, _, err := tr.ReadNonQUICPacket(ctx, b) + if errors.Is(err, context.DeadlineExceeded) { + break + } + require.NoError(t, err) + received++ + } + require.Equal(t, received, maxQueuedNonQUICPackets) + }) } type faultySyscallConn struct{ net.PacketConn } @@ -466,69 +500,79 @@ func testTransportDial(t *testing.T, early bool) { originalClientConnConstructor := newClientConnection t.Cleanup(func() { newClientConnection = originalClientConnConstructor }) - var conn *connTestHooks - handshakeChan := make(chan struct{}) - blockRun := make(chan struct{}) - if early { - conn = &connTestHooks{ - earlyConnReady: func() <-chan struct{} { return handshakeChan }, - handshakeComplete: func() <-chan struct{} { return make(chan struct{}) }, - } - } else { - conn = &connTestHooks{ - handshakeComplete: func() <-chan struct{} { return handshakeChan }, - } - } - conn.run = func() error { <-blockRun; return errors.New("done") } - defer close(blockRun) + synctest.Test(t, func(t *testing.T) { + _, serverConn, closeFn := newSimnetLink(t, 10*time.Millisecond) + defer closeFn() - newClientConnection = func( - _ context.Context, - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ qlogwriter.Trace, - _ utils.Logger, - _ protocol.Version, - ) *wrappedConn { - return &wrappedConn{testHooks: conn} - } - - tr := &Transport{Conn: newUDPConnLocalhost(t)} - tr.init(true) - defer tr.Close() - - errChan := make(chan error, 1) - go func() { - var err error + var conn *connTestHooks + handshakeChan := make(chan struct{}) + blockRun := make(chan struct{}) if early { - _, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil) + conn = &connTestHooks{ + earlyConnReady: func() <-chan struct{} { return handshakeChan }, + handshakeComplete: func() <-chan struct{} { return make(chan struct{}) }, + } } else { - _, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil) + conn = &connTestHooks{ + handshakeComplete: func() <-chan struct{} { return handshakeChan }, + } } - errChan <- err - }() + conn.run = func() error { <-blockRun; return errors.New("done") } + defer close(blockRun) - select { - case <-errChan: - t.Fatal("Dial shouldn't have returned") - case <-time.After(scaleDuration(10 * time.Millisecond)): - } + newClientConnection = func( + _ context.Context, + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ ConnectionIDGenerator, + _ *statelessResetter, + _ *Config, + _ *tls.Config, + _ protocol.PacketNumber, + _ bool, + _ bool, + _ qlogwriter.Trace, + _ utils.Logger, + _ protocol.Version, + ) *wrappedConn { + return &wrappedConn{testHooks: conn} + } - close(handshakeChan) - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(time.Second): - } + tr := &Transport{Conn: serverConn} + tr.init(true) + defer tr.Close() + + errChan := make(chan error, 1) + go func() { + var err error + if early { + _, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil) + } else { + _, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil) + } + errChan <- err + }() + + synctest.Wait() + + select { + case <-errChan: + t.Fatal("Dial shouldn't have returned") + default: + } + + close(handshakeChan) + + synctest.Wait() + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } + }) } func TestTransportDialingVersionNegotiation(t *testing.T) { @@ -603,8 +647,11 @@ func TestTransportDialingVersionNegotiation(t *testing.T) { } func TestTransportReplaceWithClosed(t *testing.T) { - t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") - + // synctest works slightly differently on Go 1.24, + // so we skip the test + if strings.HasPrefix(runtime.Version(), "go1.24") { + t.Skip("skipping on Go 1.24 due to synctest issues") + } t.Run("local", func(t *testing.T) { testTransportReplaceWithClosed(t, true) }) @@ -614,91 +661,89 @@ func TestTransportReplaceWithClosed(t *testing.T) { } func testTransportReplaceWithClosed(t *testing.T, local bool) { - srk := StatelessResetKey{1, 2, 3, 4} - tr := &Transport{ - Conn: newUDPConnLocalhost(t), - ConnectionIDLength: 4, - StatelessResetKey: &srk, - } - tr.init(true) - defer tr.Close() + synctest.Test(t, func(t *testing.T) { + clientConn, serverConn, closeFn := newSimnetLink(t, 10*time.Millisecond) + defer closeFn() - dur := scaleDuration(20 * time.Millisecond) + srk := StatelessResetKey{1, 2, 3, 4} + tr := &Transport{ + Conn: serverConn, + ConnectionIDLength: 4, + StatelessResetKey: &srk, + } + tr.init(true) + defer tr.Close() - var closePacket []byte - if local { - closePacket = []byte("foobar") - } + var closePacket []byte + if local { + closePacket = []byte("foobar") + } - handler := &mockPacketHandler{} - connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) - m := (*packetHandlerMap)(tr) - require.True(t, m.Add(connID, handler)) - m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, dur) + const expiry = 50 * time.Millisecond + handler := &mockPacketHandler{} + connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + m := (*packetHandlerMap)(tr) + require.True(t, m.Add(connID, handler)) + m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, expiry) - p := make([]byte, 100) - p[0] = 0x40 // QUIC bit - copy(p[1:], connID.Bytes()) + p := make([]byte, 100) + p[0] = 0x40 // QUIC bit + copy(p[1:], connID.Bytes()) - conn := newUDPConnLocalhost(t) - var sent atomic.Int64 - errChan := make(chan error, 1) - stopSending := make(chan struct{}) - go func() { - defer close(errChan) - ticker := time.NewTicker(dur / 50) - timeout := time.NewTimer(scaleDuration(time.Second)) + var sent atomic.Int64 + errChan := make(chan error, 1) + stopSending := make(chan struct{}) + go func() { + defer close(errChan) + ticker := time.NewTicker(expiry / 200) + timeout := time.NewTimer(time.Second) + for { + select { + case <-stopSending: + return + case <-timeout.C: + errChan <- errors.New("timeout") + return + case <-ticker.C: + } + if _, err := clientConn.WriteTo(p, tr.Conn.LocalAddr()); err != nil { + errChan <- err + return + } + sent.Add(1) + } + }() + + // For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff + var received int + clientConn.SetReadDeadline(time.Now().Add(time.Hour)) for { - select { - case <-stopSending: - return - case <-timeout.C: - errChan <- errors.New("timeout") - return - case <-ticker.C: + b := make([]byte, 100) + n, _, err := clientConn.ReadFrom(b) + require.NoError(t, err) + // at some point, the connection is cleaned up, and we'll receive a stateless reset + if !bytes.Equal(b[:n], []byte("foobar")) { + require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize) + close(stopSending) // stop sending packets + break } - if _, err := conn.WriteTo(p, tr.Conn.LocalAddr()); err != nil { - errChan <- err - return - } - sent.Add(1) + received++ } - }() - // For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff - var received int - conn.SetReadDeadline(time.Now().Add(scaleDuration(time.Second))) - for { - b := make([]byte, 100) - n, _, err := conn.ReadFrom(b) - require.NoError(t, err) - // at some point, the connection is cleaned up, and we'll receive a stateless reset - if !bytes.Equal(b[:n], []byte("foobar")) { - require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize) - close(stopSending) // stop sending packets - break + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") } - received++ - } - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(time.Second): - t.Fatal("timeout") - } - - numSent := sent.Load() - if !local { - require.Zero(t, received) - t.Logf("sent %d packets", numSent) - return - } - t.Logf("sent %d packets, received %d CONNECTION_CLOSE copies", numSent, received) - // timer resolution on Windows is terrible - if runtime.GOOS != "windows" { - require.GreaterOrEqual(t, numSent, int64(8)) - } - require.GreaterOrEqual(t, received, int(math.Floor(math.Log2(float64(numSent))))) - require.LessOrEqual(t, received, int(math.Ceil(math.Log2(float64(numSent))))) + numSent := sent.Load() + if !local { + require.Zero(t, received) + t.Logf("sent %d packets", numSent) + return + } + t.Logf("sent %d packets, received %d CONNECTION_CLOSE copies", numSent, received) + require.Equal(t, int(math.Ceil(math.Log2(float64(numSent)))), received) + }) }