124 lines
4.0 KiB
Python
124 lines
4.0 KiB
Python
|
"""Module that customize joblib tools for scikit-learn usage."""
|
||
|
|
||
|
import functools
|
||
|
import warnings
|
||
|
from functools import update_wrapper
|
||
|
|
||
|
import joblib
|
||
|
|
||
|
from .._config import config_context, get_config
|
||
|
|
||
|
|
||
|
def _with_config(delayed_func, config):
|
||
|
"""Helper function that intends to attach a config to a delayed function."""
|
||
|
if hasattr(delayed_func, "with_config"):
|
||
|
return delayed_func.with_config(config)
|
||
|
else:
|
||
|
warnings.warn(
|
||
|
"`sklearn.utils.parallel.Parallel` needs to be used in "
|
||
|
"conjunction with `sklearn.utils.parallel.delayed` instead of "
|
||
|
"`joblib.delayed` to correctly propagate the scikit-learn "
|
||
|
"configuration to the joblib workers.",
|
||
|
UserWarning,
|
||
|
)
|
||
|
return delayed_func
|
||
|
|
||
|
|
||
|
class Parallel(joblib.Parallel):
|
||
|
"""Tweak of :class:`joblib.Parallel` that propagates the scikit-learn configuration.
|
||
|
|
||
|
This subclass of :class:`joblib.Parallel` ensures that the active configuration
|
||
|
(thread-local) of scikit-learn is propagated to the parallel workers for the
|
||
|
duration of the execution of the parallel tasks.
|
||
|
|
||
|
The API does not change and you can refer to :class:`joblib.Parallel`
|
||
|
documentation for more details.
|
||
|
|
||
|
.. versionadded:: 1.3
|
||
|
"""
|
||
|
|
||
|
def __call__(self, iterable):
|
||
|
"""Dispatch the tasks and return the results.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
iterable : iterable
|
||
|
Iterable containing tuples of (delayed_function, args, kwargs) that should
|
||
|
be consumed.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
results : list
|
||
|
List of results of the tasks.
|
||
|
"""
|
||
|
# Capture the thread-local scikit-learn configuration at the time
|
||
|
# Parallel.__call__ is issued since the tasks can be dispatched
|
||
|
# in a different thread depending on the backend and on the value of
|
||
|
# pre_dispatch and n_jobs.
|
||
|
config = get_config()
|
||
|
iterable_with_config = (
|
||
|
(_with_config(delayed_func, config), args, kwargs)
|
||
|
for delayed_func, args, kwargs in iterable
|
||
|
)
|
||
|
return super().__call__(iterable_with_config)
|
||
|
|
||
|
|
||
|
# remove when https://github.com/joblib/joblib/issues/1071 is fixed
|
||
|
def delayed(function):
|
||
|
"""Decorator used to capture the arguments of a function.
|
||
|
|
||
|
This alternative to `joblib.delayed` is meant to be used in conjunction
|
||
|
with `sklearn.utils.parallel.Parallel`. The latter captures the the scikit-
|
||
|
learn configuration by calling `sklearn.get_config()` in the current
|
||
|
thread, prior to dispatching the first task. The captured configuration is
|
||
|
then propagated and enabled for the duration of the execution of the
|
||
|
delayed function in the joblib workers.
|
||
|
|
||
|
.. versionchanged:: 1.3
|
||
|
`delayed` was moved from `sklearn.utils.fixes` to `sklearn.utils.parallel`
|
||
|
in scikit-learn 1.3.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
function : callable
|
||
|
The function to be delayed.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
output: tuple
|
||
|
Tuple containing the delayed function, the positional arguments, and the
|
||
|
keyword arguments.
|
||
|
"""
|
||
|
|
||
|
@functools.wraps(function)
|
||
|
def delayed_function(*args, **kwargs):
|
||
|
return _FuncWrapper(function), args, kwargs
|
||
|
|
||
|
return delayed_function
|
||
|
|
||
|
|
||
|
class _FuncWrapper:
|
||
|
"""Load the global configuration before calling the function."""
|
||
|
|
||
|
def __init__(self, function):
|
||
|
self.function = function
|
||
|
update_wrapper(self, self.function)
|
||
|
|
||
|
def with_config(self, config):
|
||
|
self.config = config
|
||
|
return self
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
config = getattr(self, "config", None)
|
||
|
if config is None:
|
||
|
warnings.warn(
|
||
|
"`sklearn.utils.parallel.delayed` should be used with "
|
||
|
"`sklearn.utils.parallel.Parallel` to make it possible to propagate "
|
||
|
"the scikit-learn configuration of the current thread to the "
|
||
|
"joblib workers.",
|
||
|
UserWarning,
|
||
|
)
|
||
|
config = {}
|
||
|
with config_context(**config):
|
||
|
return self.function(*args, **kwargs)
|