278 lines
8.7 KiB
Python
278 lines
8.7 KiB
Python
from functools import wraps
|
|
|
|
from scipy.sparse import issparse
|
|
|
|
from . import check_pandas_support
|
|
from .._config import get_config
|
|
from ._available_if import available_if
|
|
|
|
|
|
def _wrap_in_pandas_container(
|
|
data_to_wrap,
|
|
*,
|
|
columns,
|
|
index=None,
|
|
):
|
|
"""Create a Pandas DataFrame.
|
|
|
|
If `data_to_wrap` is a DataFrame, then the `columns` and `index` will be changed
|
|
inplace. If `data_to_wrap` is a ndarray, then a new DataFrame is created with
|
|
`columns` and `index`.
|
|
|
|
Parameters
|
|
----------
|
|
data_to_wrap : {ndarray, dataframe}
|
|
Data to be wrapped as pandas dataframe.
|
|
|
|
columns : callable, ndarray, or None
|
|
The column names or a callable that returns the column names. The
|
|
callable is useful if the column names require some computation.
|
|
If `columns` is a callable that raises an error, `columns` will have
|
|
the same semantics as `None`. If `None` and `data_to_wrap` is already a
|
|
dataframe, then the column names are not changed. If `None` and
|
|
`data_to_wrap` is **not** a dataframe, then columns are
|
|
`range(n_features)`.
|
|
|
|
index : array-like, default=None
|
|
Index for data. `index` is ignored if `data_to_wrap` is already a DataFrame.
|
|
|
|
Returns
|
|
-------
|
|
dataframe : DataFrame
|
|
Container with column names or unchanged `output`.
|
|
"""
|
|
if issparse(data_to_wrap):
|
|
raise ValueError("Pandas output does not support sparse data.")
|
|
|
|
if callable(columns):
|
|
try:
|
|
columns = columns()
|
|
except Exception:
|
|
columns = None
|
|
|
|
pd = check_pandas_support("Setting output container to 'pandas'")
|
|
|
|
if isinstance(data_to_wrap, pd.DataFrame):
|
|
if columns is not None:
|
|
data_to_wrap.columns = columns
|
|
return data_to_wrap
|
|
|
|
return pd.DataFrame(data_to_wrap, index=index, columns=columns)
|
|
|
|
|
|
def _get_output_config(method, estimator=None):
|
|
"""Get output config based on estimator and global configuration.
|
|
|
|
Parameters
|
|
----------
|
|
method : {"transform"}
|
|
Estimator's method for which the output container is looked up.
|
|
|
|
estimator : estimator instance or None
|
|
Estimator to get the output configuration from. If `None`, check global
|
|
configuration is used.
|
|
|
|
Returns
|
|
-------
|
|
config : dict
|
|
Dictionary with keys:
|
|
|
|
- "dense": specifies the dense container for `method`. This can be
|
|
`"default"` or `"pandas"`.
|
|
"""
|
|
est_sklearn_output_config = getattr(estimator, "_sklearn_output_config", {})
|
|
if method in est_sklearn_output_config:
|
|
dense_config = est_sklearn_output_config[method]
|
|
else:
|
|
dense_config = get_config()[f"{method}_output"]
|
|
|
|
if dense_config not in {"default", "pandas"}:
|
|
raise ValueError(
|
|
f"output config must be 'default' or 'pandas' got {dense_config}"
|
|
)
|
|
|
|
return {"dense": dense_config}
|
|
|
|
|
|
def _wrap_data_with_container(method, data_to_wrap, original_input, estimator):
|
|
"""Wrap output with container based on an estimator's or global config.
|
|
|
|
Parameters
|
|
----------
|
|
method : {"transform"}
|
|
Estimator's method to get container output for.
|
|
|
|
data_to_wrap : {ndarray, dataframe}
|
|
Data to wrap with container.
|
|
|
|
original_input : {ndarray, dataframe}
|
|
Original input of function.
|
|
|
|
estimator : estimator instance
|
|
Estimator with to get the output configuration from.
|
|
|
|
Returns
|
|
-------
|
|
output : {ndarray, dataframe}
|
|
If the output config is "default" or the estimator is not configured
|
|
for wrapping return `data_to_wrap` unchanged.
|
|
If the output config is "pandas", return `data_to_wrap` as a pandas
|
|
DataFrame.
|
|
"""
|
|
output_config = _get_output_config(method, estimator)
|
|
|
|
if output_config["dense"] == "default" or not _auto_wrap_is_configured(estimator):
|
|
return data_to_wrap
|
|
|
|
# dense_config == "pandas"
|
|
return _wrap_in_pandas_container(
|
|
data_to_wrap=data_to_wrap,
|
|
index=getattr(original_input, "index", None),
|
|
columns=estimator.get_feature_names_out,
|
|
)
|
|
|
|
|
|
def _wrap_method_output(f, method):
|
|
"""Wrapper used by `_SetOutputMixin` to automatically wrap methods."""
|
|
|
|
@wraps(f)
|
|
def wrapped(self, X, *args, **kwargs):
|
|
data_to_wrap = f(self, X, *args, **kwargs)
|
|
if isinstance(data_to_wrap, tuple):
|
|
# only wrap the first output for cross decomposition
|
|
return (
|
|
_wrap_data_with_container(method, data_to_wrap[0], X, self),
|
|
*data_to_wrap[1:],
|
|
)
|
|
|
|
return _wrap_data_with_container(method, data_to_wrap, X, self)
|
|
|
|
return wrapped
|
|
|
|
|
|
def _auto_wrap_is_configured(estimator):
|
|
"""Return True if estimator is configured for auto-wrapping the transform method.
|
|
|
|
`_SetOutputMixin` sets `_sklearn_auto_wrap_output_keys` to `set()` if auto wrapping
|
|
is manually disabled.
|
|
"""
|
|
auto_wrap_output_keys = getattr(estimator, "_sklearn_auto_wrap_output_keys", set())
|
|
return (
|
|
hasattr(estimator, "get_feature_names_out")
|
|
and "transform" in auto_wrap_output_keys
|
|
)
|
|
|
|
|
|
class _SetOutputMixin:
|
|
"""Mixin that dynamically wraps methods to return container based on config.
|
|
|
|
Currently `_SetOutputMixin` wraps `transform` and `fit_transform` and configures
|
|
it based on `set_output` of the global configuration.
|
|
|
|
`set_output` is only defined if `get_feature_names_out` is defined and
|
|
`auto_wrap_output_keys` is the default value.
|
|
"""
|
|
|
|
def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
# Dynamically wraps `transform` and `fit_transform` and configure it's
|
|
# output based on `set_output`.
|
|
if not (
|
|
isinstance(auto_wrap_output_keys, tuple) or auto_wrap_output_keys is None
|
|
):
|
|
raise ValueError("auto_wrap_output_keys must be None or a tuple of keys.")
|
|
|
|
if auto_wrap_output_keys is None:
|
|
cls._sklearn_auto_wrap_output_keys = set()
|
|
return
|
|
|
|
# Mapping from method to key in configurations
|
|
method_to_key = {
|
|
"transform": "transform",
|
|
"fit_transform": "transform",
|
|
}
|
|
cls._sklearn_auto_wrap_output_keys = set()
|
|
|
|
for method, key in method_to_key.items():
|
|
if not hasattr(cls, method) or key not in auto_wrap_output_keys:
|
|
continue
|
|
cls._sklearn_auto_wrap_output_keys.add(key)
|
|
|
|
# Only wrap methods defined by cls itself
|
|
if method not in cls.__dict__:
|
|
continue
|
|
wrapped_method = _wrap_method_output(getattr(cls, method), key)
|
|
setattr(cls, method, wrapped_method)
|
|
|
|
@available_if(_auto_wrap_is_configured)
|
|
def set_output(self, *, transform=None):
|
|
"""Set output container.
|
|
|
|
See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
|
|
for an example on how to use the API.
|
|
|
|
Parameters
|
|
----------
|
|
transform : {"default", "pandas"}, default=None
|
|
Configure output of `transform` and `fit_transform`.
|
|
|
|
- `"default"`: Default output format of a transformer
|
|
- `"pandas"`: DataFrame output
|
|
- `None`: Transform configuration is unchanged
|
|
|
|
Returns
|
|
-------
|
|
self : estimator instance
|
|
Estimator instance.
|
|
"""
|
|
if transform is None:
|
|
return self
|
|
|
|
if not hasattr(self, "_sklearn_output_config"):
|
|
self._sklearn_output_config = {}
|
|
|
|
self._sklearn_output_config["transform"] = transform
|
|
return self
|
|
|
|
|
|
def _safe_set_output(estimator, *, transform=None):
|
|
"""Safely call estimator.set_output and error if it not available.
|
|
|
|
This is used by meta-estimators to set the output for child estimators.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : estimator instance
|
|
Estimator instance.
|
|
|
|
transform : {"default", "pandas"}, default=None
|
|
Configure output of the following estimator's methods:
|
|
|
|
- `"transform"`
|
|
- `"fit_transform"`
|
|
|
|
If `None`, this operation is a no-op.
|
|
|
|
Returns
|
|
-------
|
|
estimator : estimator instance
|
|
Estimator instance.
|
|
"""
|
|
set_output_for_transform = (
|
|
hasattr(estimator, "transform")
|
|
or hasattr(estimator, "fit_transform")
|
|
and transform is not None
|
|
)
|
|
if not set_output_for_transform:
|
|
# If estimator can not transform, then `set_output` does not need to be
|
|
# called.
|
|
return
|
|
|
|
if not hasattr(estimator, "set_output"):
|
|
raise ValueError(
|
|
f"Unable to configure output for {estimator} because `set_output` "
|
|
"is not available."
|
|
)
|
|
return estimator.set_output(transform=transform)
|