374 lines
13 KiB
Python
374 lines
13 KiB
Python
|
"""Global configuration state and functions for management"""
|
||
|
|
||
|
import os
|
||
|
import threading
|
||
|
from contextlib import contextmanager as contextmanager
|
||
|
|
||
|
_global_config = {
|
||
|
"assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)),
|
||
|
"working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)),
|
||
|
"print_changed_only": True,
|
||
|
"display": "diagram",
|
||
|
"pairwise_dist_chunk_size": int(
|
||
|
os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256)
|
||
|
),
|
||
|
"enable_cython_pairwise_dist": True,
|
||
|
"array_api_dispatch": False,
|
||
|
"transform_output": "default",
|
||
|
"enable_metadata_routing": False,
|
||
|
"skip_parameter_validation": False,
|
||
|
}
|
||
|
_threadlocal = threading.local()
|
||
|
|
||
|
|
||
|
def _get_threadlocal_config():
|
||
|
"""Get a threadlocal **mutable** configuration. If the configuration
|
||
|
does not exist, copy the default global configuration."""
|
||
|
if not hasattr(_threadlocal, "global_config"):
|
||
|
_threadlocal.global_config = _global_config.copy()
|
||
|
return _threadlocal.global_config
|
||
|
|
||
|
|
||
|
def get_config():
|
||
|
"""Retrieve current values for configuration set by :func:`set_config`.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
config : dict
|
||
|
Keys are parameter names that can be passed to :func:`set_config`.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
config_context : Context manager for global scikit-learn configuration.
|
||
|
set_config : Set global scikit-learn configuration.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> import sklearn
|
||
|
>>> config = sklearn.get_config()
|
||
|
>>> config.keys()
|
||
|
dict_keys([...])
|
||
|
"""
|
||
|
# Return a copy of the threadlocal configuration so that users will
|
||
|
# not be able to modify the configuration with the returned dict.
|
||
|
return _get_threadlocal_config().copy()
|
||
|
|
||
|
|
||
|
def set_config(
|
||
|
assume_finite=None,
|
||
|
working_memory=None,
|
||
|
print_changed_only=None,
|
||
|
display=None,
|
||
|
pairwise_dist_chunk_size=None,
|
||
|
enable_cython_pairwise_dist=None,
|
||
|
array_api_dispatch=None,
|
||
|
transform_output=None,
|
||
|
enable_metadata_routing=None,
|
||
|
skip_parameter_validation=None,
|
||
|
):
|
||
|
"""Set global scikit-learn configuration.
|
||
|
|
||
|
.. versionadded:: 0.19
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
assume_finite : bool, default=None
|
||
|
If True, validation for finiteness will be skipped,
|
||
|
saving time, but leading to potential crashes. If
|
||
|
False, validation for finiteness will be performed,
|
||
|
avoiding error. Global default: False.
|
||
|
|
||
|
.. versionadded:: 0.19
|
||
|
|
||
|
working_memory : int, default=None
|
||
|
If set, scikit-learn will attempt to limit the size of temporary arrays
|
||
|
to this number of MiB (per job when parallelised), often saving both
|
||
|
computation time and memory on expensive operations that can be
|
||
|
performed in chunks. Global default: 1024.
|
||
|
|
||
|
.. versionadded:: 0.20
|
||
|
|
||
|
print_changed_only : bool, default=None
|
||
|
If True, only the parameters that were set to non-default
|
||
|
values will be printed when printing an estimator. For example,
|
||
|
``print(SVC())`` while True will only print 'SVC()' while the default
|
||
|
behaviour would be to print 'SVC(C=1.0, cache_size=200, ...)' with
|
||
|
all the non-changed parameters.
|
||
|
|
||
|
.. versionadded:: 0.21
|
||
|
|
||
|
display : {'text', 'diagram'}, default=None
|
||
|
If 'diagram', estimators will be displayed as a diagram in a Jupyter
|
||
|
lab or notebook context. If 'text', estimators will be displayed as
|
||
|
text. Default is 'diagram'.
|
||
|
|
||
|
.. versionadded:: 0.23
|
||
|
|
||
|
pairwise_dist_chunk_size : int, default=None
|
||
|
The number of row vectors per chunk for the accelerated pairwise-
|
||
|
distances reduction backend. Default is 256 (suitable for most of
|
||
|
modern laptops' caches and architectures).
|
||
|
|
||
|
Intended for easier benchmarking and testing of scikit-learn internals.
|
||
|
End users are not expected to benefit from customizing this configuration
|
||
|
setting.
|
||
|
|
||
|
.. versionadded:: 1.1
|
||
|
|
||
|
enable_cython_pairwise_dist : bool, default=None
|
||
|
Use the accelerated pairwise-distances reduction backend when
|
||
|
possible. Global default: True.
|
||
|
|
||
|
Intended for easier benchmarking and testing of scikit-learn internals.
|
||
|
End users are not expected to benefit from customizing this configuration
|
||
|
setting.
|
||
|
|
||
|
.. versionadded:: 1.1
|
||
|
|
||
|
array_api_dispatch : bool, default=None
|
||
|
Use Array API dispatching when inputs follow the Array API standard.
|
||
|
Default is False.
|
||
|
|
||
|
See the :ref:`User Guide <array_api>` for more details.
|
||
|
|
||
|
.. versionadded:: 1.2
|
||
|
|
||
|
transform_output : str, default=None
|
||
|
Configure output of `transform` and `fit_transform`.
|
||
|
|
||
|
See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
|
||
|
for an example on how to use the API.
|
||
|
|
||
|
- `"default"`: Default output format of a transformer
|
||
|
- `"pandas"`: DataFrame output
|
||
|
- `"polars"`: Polars output
|
||
|
- `None`: Transform configuration is unchanged
|
||
|
|
||
|
.. versionadded:: 1.2
|
||
|
.. versionadded:: 1.4
|
||
|
`"polars"` option was added.
|
||
|
|
||
|
enable_metadata_routing : bool, default=None
|
||
|
Enable metadata routing. By default this feature is disabled.
|
||
|
|
||
|
Refer to :ref:`metadata routing user guide <metadata_routing>` for more
|
||
|
details.
|
||
|
|
||
|
- `True`: Metadata routing is enabled
|
||
|
- `False`: Metadata routing is disabled, use the old syntax.
|
||
|
- `None`: Configuration is unchanged
|
||
|
|
||
|
.. versionadded:: 1.3
|
||
|
|
||
|
skip_parameter_validation : bool, default=None
|
||
|
If `True`, disable the validation of the hyper-parameters' types and values in
|
||
|
the fit method of estimators and for arguments passed to public helper
|
||
|
functions. It can save time in some situations but can lead to low level
|
||
|
crashes and exceptions with confusing error messages.
|
||
|
|
||
|
Note that for data parameters, such as `X` and `y`, only type validation is
|
||
|
skipped but validation with `check_array` will continue to run.
|
||
|
|
||
|
.. versionadded:: 1.3
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
config_context : Context manager for global scikit-learn configuration.
|
||
|
get_config : Retrieve current values of the global configuration.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> from sklearn import set_config
|
||
|
>>> set_config(display='diagram') # doctest: +SKIP
|
||
|
"""
|
||
|
local_config = _get_threadlocal_config()
|
||
|
|
||
|
if assume_finite is not None:
|
||
|
local_config["assume_finite"] = assume_finite
|
||
|
if working_memory is not None:
|
||
|
local_config["working_memory"] = working_memory
|
||
|
if print_changed_only is not None:
|
||
|
local_config["print_changed_only"] = print_changed_only
|
||
|
if display is not None:
|
||
|
local_config["display"] = display
|
||
|
if pairwise_dist_chunk_size is not None:
|
||
|
local_config["pairwise_dist_chunk_size"] = pairwise_dist_chunk_size
|
||
|
if enable_cython_pairwise_dist is not None:
|
||
|
local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist
|
||
|
if array_api_dispatch is not None:
|
||
|
from .utils._array_api import _check_array_api_dispatch
|
||
|
|
||
|
_check_array_api_dispatch(array_api_dispatch)
|
||
|
local_config["array_api_dispatch"] = array_api_dispatch
|
||
|
if transform_output is not None:
|
||
|
local_config["transform_output"] = transform_output
|
||
|
if enable_metadata_routing is not None:
|
||
|
local_config["enable_metadata_routing"] = enable_metadata_routing
|
||
|
if skip_parameter_validation is not None:
|
||
|
local_config["skip_parameter_validation"] = skip_parameter_validation
|
||
|
|
||
|
|
||
|
@contextmanager
|
||
|
def config_context(
|
||
|
*,
|
||
|
assume_finite=None,
|
||
|
working_memory=None,
|
||
|
print_changed_only=None,
|
||
|
display=None,
|
||
|
pairwise_dist_chunk_size=None,
|
||
|
enable_cython_pairwise_dist=None,
|
||
|
array_api_dispatch=None,
|
||
|
transform_output=None,
|
||
|
enable_metadata_routing=None,
|
||
|
skip_parameter_validation=None,
|
||
|
):
|
||
|
"""Context manager for global scikit-learn configuration.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
assume_finite : bool, default=None
|
||
|
If True, validation for finiteness will be skipped,
|
||
|
saving time, but leading to potential crashes. If
|
||
|
False, validation for finiteness will be performed,
|
||
|
avoiding error. If None, the existing value won't change.
|
||
|
The default value is False.
|
||
|
|
||
|
working_memory : int, default=None
|
||
|
If set, scikit-learn will attempt to limit the size of temporary arrays
|
||
|
to this number of MiB (per job when parallelised), often saving both
|
||
|
computation time and memory on expensive operations that can be
|
||
|
performed in chunks. If None, the existing value won't change.
|
||
|
The default value is 1024.
|
||
|
|
||
|
print_changed_only : bool, default=None
|
||
|
If True, only the parameters that were set to non-default
|
||
|
values will be printed when printing an estimator. For example,
|
||
|
``print(SVC())`` while True will only print 'SVC()', but would print
|
||
|
'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters
|
||
|
when False. If None, the existing value won't change.
|
||
|
The default value is True.
|
||
|
|
||
|
.. versionchanged:: 0.23
|
||
|
Default changed from False to True.
|
||
|
|
||
|
display : {'text', 'diagram'}, default=None
|
||
|
If 'diagram', estimators will be displayed as a diagram in a Jupyter
|
||
|
lab or notebook context. If 'text', estimators will be displayed as
|
||
|
text. If None, the existing value won't change.
|
||
|
The default value is 'diagram'.
|
||
|
|
||
|
.. versionadded:: 0.23
|
||
|
|
||
|
pairwise_dist_chunk_size : int, default=None
|
||
|
The number of row vectors per chunk for the accelerated pairwise-
|
||
|
distances reduction backend. Default is 256 (suitable for most of
|
||
|
modern laptops' caches and architectures).
|
||
|
|
||
|
Intended for easier benchmarking and testing of scikit-learn internals.
|
||
|
End users are not expected to benefit from customizing this configuration
|
||
|
setting.
|
||
|
|
||
|
.. versionadded:: 1.1
|
||
|
|
||
|
enable_cython_pairwise_dist : bool, default=None
|
||
|
Use the accelerated pairwise-distances reduction backend when
|
||
|
possible. Global default: True.
|
||
|
|
||
|
Intended for easier benchmarking and testing of scikit-learn internals.
|
||
|
End users are not expected to benefit from customizing this configuration
|
||
|
setting.
|
||
|
|
||
|
.. versionadded:: 1.1
|
||
|
|
||
|
array_api_dispatch : bool, default=None
|
||
|
Use Array API dispatching when inputs follow the Array API standard.
|
||
|
Default is False.
|
||
|
|
||
|
See the :ref:`User Guide <array_api>` for more details.
|
||
|
|
||
|
.. versionadded:: 1.2
|
||
|
|
||
|
transform_output : str, default=None
|
||
|
Configure output of `transform` and `fit_transform`.
|
||
|
|
||
|
See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
|
||
|
for an example on how to use the API.
|
||
|
|
||
|
- `"default"`: Default output format of a transformer
|
||
|
- `"pandas"`: DataFrame output
|
||
|
- `"polars"`: Polars output
|
||
|
- `None`: Transform configuration is unchanged
|
||
|
|
||
|
.. versionadded:: 1.2
|
||
|
.. versionadded:: 1.4
|
||
|
`"polars"` option was added.
|
||
|
|
||
|
enable_metadata_routing : bool, default=None
|
||
|
Enable metadata routing. By default this feature is disabled.
|
||
|
|
||
|
Refer to :ref:`metadata routing user guide <metadata_routing>` for more
|
||
|
details.
|
||
|
|
||
|
- `True`: Metadata routing is enabled
|
||
|
- `False`: Metadata routing is disabled, use the old syntax.
|
||
|
- `None`: Configuration is unchanged
|
||
|
|
||
|
.. versionadded:: 1.3
|
||
|
|
||
|
skip_parameter_validation : bool, default=None
|
||
|
If `True`, disable the validation of the hyper-parameters' types and values in
|
||
|
the fit method of estimators and for arguments passed to public helper
|
||
|
functions. It can save time in some situations but can lead to low level
|
||
|
crashes and exceptions with confusing error messages.
|
||
|
|
||
|
Note that for data parameters, such as `X` and `y`, only type validation is
|
||
|
skipped but validation with `check_array` will continue to run.
|
||
|
|
||
|
.. versionadded:: 1.3
|
||
|
|
||
|
Yields
|
||
|
------
|
||
|
None.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
set_config : Set global scikit-learn configuration.
|
||
|
get_config : Retrieve current values of the global configuration.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
All settings, not just those presently modified, will be returned to
|
||
|
their previous values when the context manager is exited.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> import sklearn
|
||
|
>>> from sklearn.utils.validation import assert_all_finite
|
||
|
>>> with sklearn.config_context(assume_finite=True):
|
||
|
... assert_all_finite([float('nan')])
|
||
|
>>> with sklearn.config_context(assume_finite=True):
|
||
|
... with sklearn.config_context(assume_finite=False):
|
||
|
... assert_all_finite([float('nan')])
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
ValueError: Input contains NaN...
|
||
|
"""
|
||
|
old_config = get_config()
|
||
|
set_config(
|
||
|
assume_finite=assume_finite,
|
||
|
working_memory=working_memory,
|
||
|
print_changed_only=print_changed_only,
|
||
|
display=display,
|
||
|
pairwise_dist_chunk_size=pairwise_dist_chunk_size,
|
||
|
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
|
||
|
array_api_dispatch=array_api_dispatch,
|
||
|
transform_output=transform_output,
|
||
|
enable_metadata_routing=enable_metadata_routing,
|
||
|
skip_parameter_validation=skip_parameter_validation,
|
||
|
)
|
||
|
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
set_config(**old_config)
|