diff --git a/transport_test.go b/transport_test.go index 8c918250..37ec0438 100644 --- a/transport_test.go +++ b/transport_test.go @@ -5,8 +5,10 @@ import ( "context" "crypto/tls" "errors" + "math" "net" - "os" + "runtime" + "sync/atomic" "syscall" "testing" "time" @@ -598,6 +600,8 @@ func TestTransportDialingVersionNegotiation(t *testing.T) { } func TestTransportReplaceWithClosed(t *testing.T) { + t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") + t.Run("local", func(t *testing.T) { testTransportReplaceWithClosed(t, true) }) @@ -616,7 +620,7 @@ func testTransportReplaceWithClosed(t *testing.T, local bool) { tr.init(true) defer tr.Close() - dur := scaleDuration(10 * time.Millisecond) + dur := scaleDuration(20 * time.Millisecond) var closePacket []byte if local { @@ -627,7 +631,6 @@ func testTransportReplaceWithClosed(t *testing.T, local bool) { connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) m := (*packetHandlerMap)(tr) require.True(t, m.Add(connID, handler)) - start := time.Now() m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, dur) p := make([]byte, 100) @@ -635,38 +638,64 @@ func testTransportReplaceWithClosed(t *testing.T, local bool) { copy(p[1:], connID.Bytes()) conn := newUDPConnLocalhost(t) - var sent int - for now := range time.NewTicker(dur / 20).C { - _, err := conn.WriteTo(p, tr.Conn.LocalAddr()) - require.NoError(t, err) - sent++ - if now.After(start.Add(dur / 2)) { - break + var sent atomic.Int64 + errChan := make(chan error, 1) + stopSending := make(chan struct{}) + go func() { + defer close(errChan) + ticker := time.NewTicker(dur / 50) + timeout := time.NewTimer(scaleDuration(time.Second)) + for { + select { + case <-stopSending: + return + case <-timeout.C: + errChan <- errors.New("timeout") + return + case <-ticker.C: + } + if _, err := conn.WriteTo(p, tr.Conn.LocalAddr()); err != nil { + errChan <- err + return + } + sent.Add(1) } - } + }() + // For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff - for i := 0; i*i < sent; i++ { - conn.SetReadDeadline(time.Now().Add(time.Second)) - b := make([]byte, 100) - if local { - n, _, err := conn.ReadFrom(b) - require.NoError(t, err) - require.Equal(t, []byte("foobar"), b[:n]) - } - } - // Afterwards, we receive a stateless reset, not a copy of the CONNECTION_CLOSE packet. - // Retry a few times, since the connection is deleted from the map on a timer. - require.Eventually(t, func() bool { - _, err := conn.WriteTo(p, tr.Conn.LocalAddr()) - require.NoError(t, err) - conn.SetReadDeadline(time.Now().Add(dur / 4)) + var received int + conn.SetReadDeadline(time.Now().Add(scaleDuration(time.Second))) + for { b := make([]byte, 100) n, _, err := conn.ReadFrom(b) - if errors.Is(err, os.ErrDeadlineExceeded) || bytes.Equal(b[:n], []byte("foobar")) { - return false - } require.NoError(t, err) - require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize) - return true - }, scaleDuration(200*time.Millisecond), scaleDuration(10*time.Millisecond)) + // at some point, the connection is cleaned up, and we'll receive a stateless reset + if !bytes.Equal(b[:n], []byte("foobar")) { + require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize) + close(stopSending) // stop sending packets + break + } + received++ + } + + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + numSent := sent.Load() + if !local { + require.Zero(t, received) + t.Logf("sent %d packets", numSent) + return + } + t.Logf("sent %d packets, received %d CONNECTION_CLOSE copies", numSent, received) + // timer resolution on Windows is terrible + if runtime.GOOS != "windows" { + require.GreaterOrEqual(t, numSent, int64(8)) + } + require.GreaterOrEqual(t, received, int(math.Floor(math.Log2(float64(numSent))))) + require.LessOrEqual(t, received, int(math.Ceil(math.Log2(float64(numSent))))) }