460 lines
14 KiB
Python
460 lines
14 KiB
Python
import importlib
|
|
from functools import wraps
|
|
from typing import Protocol, runtime_checkable
|
|
|
|
import numpy as np
|
|
from scipy.sparse import issparse
|
|
|
|
from .._config import get_config
|
|
from ._available_if import available_if
|
|
|
|
|
|
def check_library_installed(library):
|
|
"""Check library is installed."""
|
|
try:
|
|
return importlib.import_module(library)
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
f"Setting output container to '{library}' requires {library} to be"
|
|
" installed"
|
|
) from exc
|
|
|
|
|
|
def get_columns(columns):
|
|
if callable(columns):
|
|
try:
|
|
return columns()
|
|
except Exception:
|
|
return None
|
|
return columns
|
|
|
|
|
|
@runtime_checkable
|
|
class ContainerAdapterProtocol(Protocol):
|
|
container_lib: str
|
|
|
|
def create_container(self, X_output, X_original, columns, inplace=False):
|
|
"""Create container from `X_output` with additional metadata.
|
|
|
|
Parameters
|
|
----------
|
|
X_output : {ndarray, dataframe}
|
|
Data to wrap.
|
|
|
|
X_original : {ndarray, dataframe}
|
|
Original input dataframe. This is used to extract the metadata that should
|
|
be passed to `X_output`, e.g. pandas row index.
|
|
|
|
columns : callable, ndarray, or None
|
|
The column names or a callable that returns the column names. The
|
|
callable is useful if the column names require some computation. If `None`,
|
|
then no columns are passed to the container's constructor.
|
|
|
|
inplace : bool, default=False
|
|
Whether or not we intend to modify `X_output` in-place. However, it does
|
|
not guarantee that we return the same object if the in-place operation
|
|
is not possible.
|
|
|
|
Returns
|
|
-------
|
|
wrapped_output : container_type
|
|
`X_output` wrapped into the container type.
|
|
"""
|
|
|
|
def is_supported_container(self, X):
|
|
"""Return True if X is a supported container.
|
|
|
|
Parameters
|
|
----------
|
|
Xs: container
|
|
Containers to be checked.
|
|
|
|
Returns
|
|
-------
|
|
is_supported_container : bool
|
|
True if X is a supported container.
|
|
"""
|
|
|
|
def rename_columns(self, X, columns):
|
|
"""Rename columns in `X`.
|
|
|
|
Parameters
|
|
----------
|
|
X : container
|
|
Container which columns is updated.
|
|
|
|
columns : ndarray of str
|
|
Columns to update the `X`'s columns with.
|
|
|
|
Returns
|
|
-------
|
|
updated_container : container
|
|
Container with new names.
|
|
"""
|
|
|
|
def hstack(self, Xs):
|
|
"""Stack containers horizontally (column-wise).
|
|
|
|
Parameters
|
|
----------
|
|
Xs : list of containers
|
|
List of containers to stack.
|
|
|
|
Returns
|
|
-------
|
|
stacked_Xs : container
|
|
Stacked containers.
|
|
"""
|
|
|
|
|
|
class PandasAdapter:
|
|
container_lib = "pandas"
|
|
|
|
def create_container(self, X_output, X_original, columns, inplace=True):
|
|
pd = check_library_installed("pandas")
|
|
columns = get_columns(columns)
|
|
|
|
if not inplace or not isinstance(X_output, pd.DataFrame):
|
|
# In all these cases, we need to create a new DataFrame
|
|
|
|
# Unfortunately, we cannot use `getattr(container, "index")`
|
|
# because `list` exposes an `index` attribute.
|
|
if isinstance(X_output, pd.DataFrame):
|
|
index = X_output.index
|
|
elif isinstance(X_original, pd.DataFrame):
|
|
index = X_original.index
|
|
else:
|
|
index = None
|
|
|
|
# We don't pass columns here because it would intend columns selection
|
|
# instead of renaming.
|
|
X_output = pd.DataFrame(X_output, index=index, copy=not inplace)
|
|
|
|
if columns is not None:
|
|
return self.rename_columns(X_output, columns)
|
|
return X_output
|
|
|
|
def is_supported_container(self, X):
|
|
pd = check_library_installed("pandas")
|
|
return isinstance(X, pd.DataFrame)
|
|
|
|
def rename_columns(self, X, columns):
|
|
# we cannot use `rename` since it takes a dictionary and at this stage we have
|
|
# potentially duplicate column names in `X`
|
|
X.columns = columns
|
|
return X
|
|
|
|
def hstack(self, Xs):
|
|
pd = check_library_installed("pandas")
|
|
return pd.concat(Xs, axis=1)
|
|
|
|
|
|
class PolarsAdapter:
|
|
container_lib = "polars"
|
|
|
|
def create_container(self, X_output, X_original, columns, inplace=True):
|
|
pl = check_library_installed("polars")
|
|
columns = get_columns(columns)
|
|
columns = columns.tolist() if isinstance(columns, np.ndarray) else columns
|
|
|
|
if not inplace or not isinstance(X_output, pl.DataFrame):
|
|
# In all these cases, we need to create a new DataFrame
|
|
return pl.DataFrame(X_output, schema=columns, orient="row")
|
|
|
|
if columns is not None:
|
|
return self.rename_columns(X_output, columns)
|
|
return X_output
|
|
|
|
def is_supported_container(self, X):
|
|
pl = check_library_installed("polars")
|
|
return isinstance(X, pl.DataFrame)
|
|
|
|
def rename_columns(self, X, columns):
|
|
# we cannot use `rename` since it takes a dictionary and at this stage we have
|
|
# potentially duplicate column names in `X`
|
|
X.columns = columns
|
|
return X
|
|
|
|
def hstack(self, Xs):
|
|
pl = check_library_installed("polars")
|
|
return pl.concat(Xs, how="horizontal")
|
|
|
|
|
|
class ContainerAdaptersManager:
|
|
def __init__(self):
|
|
self.adapters = {}
|
|
|
|
@property
|
|
def supported_outputs(self):
|
|
return {"default"} | set(self.adapters)
|
|
|
|
def register(self, adapter):
|
|
self.adapters[adapter.container_lib] = adapter
|
|
|
|
|
|
ADAPTERS_MANAGER = ContainerAdaptersManager()
|
|
ADAPTERS_MANAGER.register(PandasAdapter())
|
|
ADAPTERS_MANAGER.register(PolarsAdapter())
|
|
|
|
|
|
def _get_adapter_from_container(container):
|
|
"""Get the adapter that knows how to handle such container.
|
|
|
|
See :class:`sklearn.utils._set_output.ContainerAdapterProtocol` for more
|
|
details.
|
|
"""
|
|
module_name = container.__class__.__module__.split(".")[0]
|
|
try:
|
|
return ADAPTERS_MANAGER.adapters[module_name]
|
|
except KeyError as exc:
|
|
available_adapters = list(ADAPTERS_MANAGER.adapters.keys())
|
|
raise ValueError(
|
|
"The container does not have a registered adapter in scikit-learn. "
|
|
f"Available adapters are: {available_adapters} while the container "
|
|
f"provided is: {container!r}."
|
|
) from exc
|
|
|
|
|
|
def _get_container_adapter(method, estimator=None):
|
|
"""Get container adapter."""
|
|
dense_config = _get_output_config(method, estimator)["dense"]
|
|
try:
|
|
return ADAPTERS_MANAGER.adapters[dense_config]
|
|
except KeyError:
|
|
return None
|
|
|
|
|
|
def _get_output_config(method, estimator=None):
|
|
"""Get output config based on estimator and global configuration.
|
|
|
|
Parameters
|
|
----------
|
|
method : {"transform"}
|
|
Estimator's method for which the output container is looked up.
|
|
|
|
estimator : estimator instance or None
|
|
Estimator to get the output configuration from. If `None`, check global
|
|
configuration is used.
|
|
|
|
Returns
|
|
-------
|
|
config : dict
|
|
Dictionary with keys:
|
|
|
|
- "dense": specifies the dense container for `method`. This can be
|
|
`"default"` or `"pandas"`.
|
|
"""
|
|
est_sklearn_output_config = getattr(estimator, "_sklearn_output_config", {})
|
|
if method in est_sklearn_output_config:
|
|
dense_config = est_sklearn_output_config[method]
|
|
else:
|
|
dense_config = get_config()[f"{method}_output"]
|
|
|
|
supported_outputs = ADAPTERS_MANAGER.supported_outputs
|
|
if dense_config not in supported_outputs:
|
|
raise ValueError(
|
|
f"output config must be in {sorted(supported_outputs)}, got {dense_config}"
|
|
)
|
|
|
|
return {"dense": dense_config}
|
|
|
|
|
|
def _wrap_data_with_container(method, data_to_wrap, original_input, estimator):
|
|
"""Wrap output with container based on an estimator's or global config.
|
|
|
|
Parameters
|
|
----------
|
|
method : {"transform"}
|
|
Estimator's method to get container output for.
|
|
|
|
data_to_wrap : {ndarray, dataframe}
|
|
Data to wrap with container.
|
|
|
|
original_input : {ndarray, dataframe}
|
|
Original input of function.
|
|
|
|
estimator : estimator instance
|
|
Estimator with to get the output configuration from.
|
|
|
|
Returns
|
|
-------
|
|
output : {ndarray, dataframe}
|
|
If the output config is "default" or the estimator is not configured
|
|
for wrapping return `data_to_wrap` unchanged.
|
|
If the output config is "pandas", return `data_to_wrap` as a pandas
|
|
DataFrame.
|
|
"""
|
|
output_config = _get_output_config(method, estimator)
|
|
|
|
if output_config["dense"] == "default" or not _auto_wrap_is_configured(estimator):
|
|
return data_to_wrap
|
|
|
|
dense_config = output_config["dense"]
|
|
if issparse(data_to_wrap):
|
|
raise ValueError(
|
|
"The transformer outputs a scipy sparse matrix. "
|
|
"Try to set the transformer output to a dense array or disable "
|
|
f"{dense_config.capitalize()} output with set_output(transform='default')."
|
|
)
|
|
|
|
adapter = ADAPTERS_MANAGER.adapters[dense_config]
|
|
return adapter.create_container(
|
|
data_to_wrap,
|
|
original_input,
|
|
columns=estimator.get_feature_names_out,
|
|
)
|
|
|
|
|
|
def _wrap_method_output(f, method):
|
|
"""Wrapper used by `_SetOutputMixin` to automatically wrap methods."""
|
|
|
|
@wraps(f)
|
|
def wrapped(self, X, *args, **kwargs):
|
|
data_to_wrap = f(self, X, *args, **kwargs)
|
|
if isinstance(data_to_wrap, tuple):
|
|
# only wrap the first output for cross decomposition
|
|
return_tuple = (
|
|
_wrap_data_with_container(method, data_to_wrap[0], X, self),
|
|
*data_to_wrap[1:],
|
|
)
|
|
# Support for namedtuples `_make` is a documented API for namedtuples:
|
|
# https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make
|
|
if hasattr(type(data_to_wrap), "_make"):
|
|
return type(data_to_wrap)._make(return_tuple)
|
|
return return_tuple
|
|
|
|
return _wrap_data_with_container(method, data_to_wrap, X, self)
|
|
|
|
return wrapped
|
|
|
|
|
|
def _auto_wrap_is_configured(estimator):
|
|
"""Return True if estimator is configured for auto-wrapping the transform method.
|
|
|
|
`_SetOutputMixin` sets `_sklearn_auto_wrap_output_keys` to `set()` if auto wrapping
|
|
is manually disabled.
|
|
"""
|
|
auto_wrap_output_keys = getattr(estimator, "_sklearn_auto_wrap_output_keys", set())
|
|
return (
|
|
hasattr(estimator, "get_feature_names_out")
|
|
and "transform" in auto_wrap_output_keys
|
|
)
|
|
|
|
|
|
class _SetOutputMixin:
|
|
"""Mixin that dynamically wraps methods to return container based on config.
|
|
|
|
Currently `_SetOutputMixin` wraps `transform` and `fit_transform` and configures
|
|
it based on `set_output` of the global configuration.
|
|
|
|
`set_output` is only defined if `get_feature_names_out` is defined and
|
|
`auto_wrap_output_keys` is the default value.
|
|
"""
|
|
|
|
def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
# Dynamically wraps `transform` and `fit_transform` and configure it's
|
|
# output based on `set_output`.
|
|
if not (
|
|
isinstance(auto_wrap_output_keys, tuple) or auto_wrap_output_keys is None
|
|
):
|
|
raise ValueError("auto_wrap_output_keys must be None or a tuple of keys.")
|
|
|
|
if auto_wrap_output_keys is None:
|
|
cls._sklearn_auto_wrap_output_keys = set()
|
|
return
|
|
|
|
# Mapping from method to key in configurations
|
|
method_to_key = {
|
|
"transform": "transform",
|
|
"fit_transform": "transform",
|
|
}
|
|
cls._sklearn_auto_wrap_output_keys = set()
|
|
|
|
for method, key in method_to_key.items():
|
|
if not hasattr(cls, method) or key not in auto_wrap_output_keys:
|
|
continue
|
|
cls._sklearn_auto_wrap_output_keys.add(key)
|
|
|
|
# Only wrap methods defined by cls itself
|
|
if method not in cls.__dict__:
|
|
continue
|
|
wrapped_method = _wrap_method_output(getattr(cls, method), key)
|
|
setattr(cls, method, wrapped_method)
|
|
|
|
@available_if(_auto_wrap_is_configured)
|
|
def set_output(self, *, transform=None):
|
|
"""Set output container.
|
|
|
|
See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
|
|
for an example on how to use the API.
|
|
|
|
Parameters
|
|
----------
|
|
transform : {"default", "pandas", "polars"}, default=None
|
|
Configure output of `transform` and `fit_transform`.
|
|
|
|
- `"default"`: Default output format of a transformer
|
|
- `"pandas"`: DataFrame output
|
|
- `"polars"`: Polars output
|
|
- `None`: Transform configuration is unchanged
|
|
|
|
.. versionadded:: 1.4
|
|
`"polars"` option was added.
|
|
|
|
Returns
|
|
-------
|
|
self : estimator instance
|
|
Estimator instance.
|
|
"""
|
|
if transform is None:
|
|
return self
|
|
|
|
if not hasattr(self, "_sklearn_output_config"):
|
|
self._sklearn_output_config = {}
|
|
|
|
self._sklearn_output_config["transform"] = transform
|
|
return self
|
|
|
|
|
|
def _safe_set_output(estimator, *, transform=None):
|
|
"""Safely call estimator.set_output and error if it not available.
|
|
|
|
This is used by meta-estimators to set the output for child estimators.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : estimator instance
|
|
Estimator instance.
|
|
|
|
transform : {"default", "pandas", "polars"}, default=None
|
|
Configure output of the following estimator's methods:
|
|
|
|
- `"transform"`
|
|
- `"fit_transform"`
|
|
|
|
If `None`, this operation is a no-op.
|
|
|
|
Returns
|
|
-------
|
|
estimator : estimator instance
|
|
Estimator instance.
|
|
"""
|
|
set_output_for_transform = (
|
|
hasattr(estimator, "transform")
|
|
or hasattr(estimator, "fit_transform")
|
|
and transform is not None
|
|
)
|
|
if not set_output_for_transform:
|
|
# If estimator can not transform, then `set_output` does not need to be
|
|
# called.
|
|
return
|
|
|
|
if not hasattr(estimator, "set_output"):
|
|
raise ValueError(
|
|
f"Unable to configure output for {estimator} because `set_output` "
|
|
"is not available."
|
|
)
|
|
return estimator.set_output(transform=transform)
|