465 lines
15 KiB
Python
465 lines
15 KiB
Python
|
import importlib
|
||
|
from collections import namedtuple
|
||
|
|
||
|
import numpy as np
|
||
|
import pytest
|
||
|
from numpy.testing import assert_array_equal
|
||
|
|
||
|
from sklearn._config import config_context, get_config
|
||
|
from sklearn.preprocessing import StandardScaler
|
||
|
from sklearn.utils._set_output import (
|
||
|
ADAPTERS_MANAGER,
|
||
|
ContainerAdapterProtocol,
|
||
|
_get_adapter_from_container,
|
||
|
_get_output_config,
|
||
|
_safe_set_output,
|
||
|
_SetOutputMixin,
|
||
|
_wrap_data_with_container,
|
||
|
check_library_installed,
|
||
|
)
|
||
|
from sklearn.utils.fixes import CSR_CONTAINERS
|
||
|
|
||
|
|
||
|
def test_pandas_adapter():
|
||
|
"""Check pandas adapter has expected behavior."""
|
||
|
pd = pytest.importorskip("pandas")
|
||
|
X_np = np.asarray([[1, 0, 3], [0, 0, 1]])
|
||
|
columns = np.asarray(["f0", "f1", "f2"], dtype=object)
|
||
|
index = np.asarray([0, 1])
|
||
|
X_df_orig = pd.DataFrame([[1, 2], [1, 3]], index=index)
|
||
|
|
||
|
adapter = ADAPTERS_MANAGER.adapters["pandas"]
|
||
|
X_container = adapter.create_container(X_np, X_df_orig, columns=lambda: columns)
|
||
|
assert isinstance(X_container, pd.DataFrame)
|
||
|
assert_array_equal(X_container.columns, columns)
|
||
|
assert_array_equal(X_container.index, index)
|
||
|
|
||
|
# Input dataframe's index does not change
|
||
|
new_columns = np.asarray(["f0", "f1"], dtype=object)
|
||
|
X_df = pd.DataFrame([[1, 2], [1, 3]], index=[10, 12])
|
||
|
new_df = adapter.create_container(X_df, X_df_orig, columns=new_columns)
|
||
|
assert_array_equal(new_df.columns, new_columns)
|
||
|
assert_array_equal(new_df.index, X_df.index)
|
||
|
|
||
|
assert adapter.is_supported_container(X_df)
|
||
|
assert not adapter.is_supported_container(X_np)
|
||
|
|
||
|
# adapter.update_columns updates the columns
|
||
|
new_columns = np.array(["a", "c"], dtype=object)
|
||
|
new_df = adapter.rename_columns(X_df, new_columns)
|
||
|
assert_array_equal(new_df.columns, new_columns)
|
||
|
|
||
|
# adapter.hstack stacks the dataframes horizontally.
|
||
|
X_df_1 = pd.DataFrame([[1, 2, 5], [3, 4, 6]], columns=["a", "b", "e"])
|
||
|
X_df_2 = pd.DataFrame([[4], [5]], columns=["c"])
|
||
|
X_stacked = adapter.hstack([X_df_1, X_df_2])
|
||
|
|
||
|
expected_df = pd.DataFrame(
|
||
|
[[1, 2, 5, 4], [3, 4, 6, 5]], columns=["a", "b", "e", "c"]
|
||
|
)
|
||
|
pd.testing.assert_frame_equal(X_stacked, expected_df)
|
||
|
|
||
|
# check that we update properly the columns even with duplicate column names
|
||
|
# this use-case potentially happen when using ColumnTransformer
|
||
|
# non-regression test for gh-28260
|
||
|
X_df = pd.DataFrame([[1, 2], [1, 3]], columns=["a", "a"])
|
||
|
new_columns = np.array(["x__a", "y__a"], dtype=object)
|
||
|
new_df = adapter.rename_columns(X_df, new_columns)
|
||
|
assert_array_equal(new_df.columns, new_columns)
|
||
|
|
||
|
# check the behavior of the inplace parameter in `create_container`
|
||
|
# we should trigger a copy
|
||
|
X_df = pd.DataFrame([[1, 2], [1, 3]], index=index)
|
||
|
X_output = adapter.create_container(X_df, X_df, columns=["a", "b"], inplace=False)
|
||
|
assert X_output is not X_df
|
||
|
assert list(X_df.columns) == [0, 1]
|
||
|
assert list(X_output.columns) == ["a", "b"]
|
||
|
|
||
|
# the operation is inplace
|
||
|
X_df = pd.DataFrame([[1, 2], [1, 3]], index=index)
|
||
|
X_output = adapter.create_container(X_df, X_df, columns=["a", "b"], inplace=True)
|
||
|
assert X_output is X_df
|
||
|
assert list(X_df.columns) == ["a", "b"]
|
||
|
assert list(X_output.columns) == ["a", "b"]
|
||
|
|
||
|
|
||
|
def test_polars_adapter():
|
||
|
"""Check Polars adapter has expected behavior."""
|
||
|
pl = pytest.importorskip("polars")
|
||
|
X_np = np.array([[1, 0, 3], [0, 0, 1]])
|
||
|
columns = ["f1", "f2", "f3"]
|
||
|
X_df_orig = pl.DataFrame(X_np, schema=columns, orient="row")
|
||
|
|
||
|
adapter = ADAPTERS_MANAGER.adapters["polars"]
|
||
|
X_container = adapter.create_container(X_np, X_df_orig, columns=lambda: columns)
|
||
|
|
||
|
assert isinstance(X_container, pl.DataFrame)
|
||
|
assert_array_equal(X_container.columns, columns)
|
||
|
|
||
|
# Update columns with create_container
|
||
|
new_columns = np.asarray(["a", "b", "c"], dtype=object)
|
||
|
new_df = adapter.create_container(X_df_orig, X_df_orig, columns=new_columns)
|
||
|
assert_array_equal(new_df.columns, new_columns)
|
||
|
|
||
|
assert adapter.is_supported_container(X_df_orig)
|
||
|
assert not adapter.is_supported_container(X_np)
|
||
|
|
||
|
# adapter.update_columns updates the columns
|
||
|
new_columns = np.array(["a", "c", "g"], dtype=object)
|
||
|
new_df = adapter.rename_columns(X_df_orig, new_columns)
|
||
|
assert_array_equal(new_df.columns, new_columns)
|
||
|
|
||
|
# adapter.hstack stacks the dataframes horizontally.
|
||
|
X_df_1 = pl.DataFrame([[1, 2, 5], [3, 4, 6]], schema=["a", "b", "e"], orient="row")
|
||
|
X_df_2 = pl.DataFrame([[4], [5]], schema=["c"], orient="row")
|
||
|
X_stacked = adapter.hstack([X_df_1, X_df_2])
|
||
|
|
||
|
expected_df = pl.DataFrame(
|
||
|
[[1, 2, 5, 4], [3, 4, 6, 5]], schema=["a", "b", "e", "c"], orient="row"
|
||
|
)
|
||
|
from polars.testing import assert_frame_equal
|
||
|
|
||
|
assert_frame_equal(X_stacked, expected_df)
|
||
|
|
||
|
# check the behavior of the inplace parameter in `create_container`
|
||
|
# we should trigger a copy
|
||
|
X_df = pl.DataFrame([[1, 2], [1, 3]], schema=["a", "b"], orient="row")
|
||
|
X_output = adapter.create_container(X_df, X_df, columns=["c", "d"], inplace=False)
|
||
|
assert X_output is not X_df
|
||
|
assert list(X_df.columns) == ["a", "b"]
|
||
|
assert list(X_output.columns) == ["c", "d"]
|
||
|
|
||
|
# the operation is inplace
|
||
|
X_df = pl.DataFrame([[1, 2], [1, 3]], schema=["a", "b"], orient="row")
|
||
|
X_output = adapter.create_container(X_df, X_df, columns=["c", "d"], inplace=True)
|
||
|
assert X_output is X_df
|
||
|
assert list(X_df.columns) == ["c", "d"]
|
||
|
assert list(X_output.columns) == ["c", "d"]
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
|
||
|
def test__container_error_validation(csr_container):
|
||
|
"""Check errors in _wrap_data_with_container."""
|
||
|
X = np.asarray([[1, 0, 3], [0, 0, 1]])
|
||
|
X_csr = csr_container(X)
|
||
|
match = "The transformer outputs a scipy sparse matrix."
|
||
|
with config_context(transform_output="pandas"):
|
||
|
with pytest.raises(ValueError, match=match):
|
||
|
_wrap_data_with_container("transform", X_csr, X, StandardScaler())
|
||
|
|
||
|
|
||
|
class EstimatorWithoutSetOutputAndWithoutTransform:
|
||
|
pass
|
||
|
|
||
|
|
||
|
class EstimatorNoSetOutputWithTransform:
|
||
|
def transform(self, X, y=None):
|
||
|
return X # pragma: no cover
|
||
|
|
||
|
|
||
|
class EstimatorWithSetOutput(_SetOutputMixin):
|
||
|
def fit(self, X, y=None):
|
||
|
self.n_features_in_ = X.shape[1]
|
||
|
return self
|
||
|
|
||
|
def transform(self, X, y=None):
|
||
|
return X
|
||
|
|
||
|
def get_feature_names_out(self, input_features=None):
|
||
|
return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
|
||
|
|
||
|
|
||
|
def test__safe_set_output():
|
||
|
"""Check _safe_set_output works as expected."""
|
||
|
|
||
|
# Estimator without transform will not raise when setting set_output for transform.
|
||
|
est = EstimatorWithoutSetOutputAndWithoutTransform()
|
||
|
_safe_set_output(est, transform="pandas")
|
||
|
|
||
|
# Estimator with transform but without set_output will raise
|
||
|
est = EstimatorNoSetOutputWithTransform()
|
||
|
with pytest.raises(ValueError, match="Unable to configure output"):
|
||
|
_safe_set_output(est, transform="pandas")
|
||
|
|
||
|
est = EstimatorWithSetOutput().fit(np.asarray([[1, 2, 3]]))
|
||
|
_safe_set_output(est, transform="pandas")
|
||
|
config = _get_output_config("transform", est)
|
||
|
assert config["dense"] == "pandas"
|
||
|
|
||
|
_safe_set_output(est, transform="default")
|
||
|
config = _get_output_config("transform", est)
|
||
|
assert config["dense"] == "default"
|
||
|
|
||
|
# transform is None is a no-op, so the config remains "default"
|
||
|
_safe_set_output(est, transform=None)
|
||
|
config = _get_output_config("transform", est)
|
||
|
assert config["dense"] == "default"
|
||
|
|
||
|
|
||
|
class EstimatorNoSetOutputWithTransformNoFeatureNamesOut(_SetOutputMixin):
|
||
|
def transform(self, X, y=None):
|
||
|
return X # pragma: no cover
|
||
|
|
||
|
|
||
|
def test_set_output_mixin():
|
||
|
"""Estimator without get_feature_names_out does not define `set_output`."""
|
||
|
est = EstimatorNoSetOutputWithTransformNoFeatureNamesOut()
|
||
|
assert not hasattr(est, "set_output")
|
||
|
|
||
|
|
||
|
def test__safe_set_output_error():
|
||
|
"""Check transform with invalid config."""
|
||
|
X = np.asarray([[1, 0, 3], [0, 0, 1]])
|
||
|
|
||
|
est = EstimatorWithSetOutput()
|
||
|
_safe_set_output(est, transform="bad")
|
||
|
|
||
|
msg = "output config must be in"
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
est.transform(X)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("dataframe_lib", ["pandas", "polars"])
|
||
|
def test_set_output_method(dataframe_lib):
|
||
|
"""Check that the output is a dataframe."""
|
||
|
lib = pytest.importorskip(dataframe_lib)
|
||
|
|
||
|
X = np.asarray([[1, 0, 3], [0, 0, 1]])
|
||
|
est = EstimatorWithSetOutput().fit(X)
|
||
|
|
||
|
# transform=None is a no-op
|
||
|
est2 = est.set_output(transform=None)
|
||
|
assert est2 is est
|
||
|
X_trans_np = est2.transform(X)
|
||
|
assert isinstance(X_trans_np, np.ndarray)
|
||
|
|
||
|
est.set_output(transform=dataframe_lib)
|
||
|
|
||
|
X_trans_pd = est.transform(X)
|
||
|
|
||
|
assert isinstance(X_trans_pd, lib.DataFrame)
|
||
|
|
||
|
|
||
|
def test_set_output_method_error():
|
||
|
"""Check transform fails with invalid transform."""
|
||
|
|
||
|
X = np.asarray([[1, 0, 3], [0, 0, 1]])
|
||
|
est = EstimatorWithSetOutput().fit(X)
|
||
|
est.set_output(transform="bad")
|
||
|
|
||
|
msg = "output config must be in"
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
est.transform(X)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("transform_output", ["pandas", "polars"])
|
||
|
def test__get_output_config(transform_output):
|
||
|
"""Check _get_output_config works as expected."""
|
||
|
|
||
|
# Without a configuration set, the global config is used
|
||
|
global_config = get_config()["transform_output"]
|
||
|
config = _get_output_config("transform")
|
||
|
assert config["dense"] == global_config
|
||
|
|
||
|
with config_context(transform_output=transform_output):
|
||
|
# with estimator=None, the global config is used
|
||
|
config = _get_output_config("transform")
|
||
|
assert config["dense"] == transform_output
|
||
|
|
||
|
est = EstimatorNoSetOutputWithTransform()
|
||
|
config = _get_output_config("transform", est)
|
||
|
assert config["dense"] == transform_output
|
||
|
|
||
|
est = EstimatorWithSetOutput()
|
||
|
# If estimator has not config, use global config
|
||
|
config = _get_output_config("transform", est)
|
||
|
assert config["dense"] == transform_output
|
||
|
|
||
|
# If estimator has a config, use local config
|
||
|
est.set_output(transform="default")
|
||
|
config = _get_output_config("transform", est)
|
||
|
assert config["dense"] == "default"
|
||
|
|
||
|
est.set_output(transform=transform_output)
|
||
|
config = _get_output_config("transform", est)
|
||
|
assert config["dense"] == transform_output
|
||
|
|
||
|
|
||
|
class EstimatorWithSetOutputNoAutoWrap(_SetOutputMixin, auto_wrap_output_keys=None):
|
||
|
def transform(self, X, y=None):
|
||
|
return X
|
||
|
|
||
|
|
||
|
def test_get_output_auto_wrap_false():
|
||
|
"""Check that auto_wrap_output_keys=None does not wrap."""
|
||
|
est = EstimatorWithSetOutputNoAutoWrap()
|
||
|
assert not hasattr(est, "set_output")
|
||
|
|
||
|
X = np.asarray([[1, 0, 3], [0, 0, 1]])
|
||
|
assert X is est.transform(X)
|
||
|
|
||
|
|
||
|
def test_auto_wrap_output_keys_errors_with_incorrect_input():
|
||
|
msg = "auto_wrap_output_keys must be None or a tuple of keys."
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
|
||
|
class BadEstimator(_SetOutputMixin, auto_wrap_output_keys="bad_parameter"):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class AnotherMixin:
|
||
|
def __init_subclass__(cls, custom_parameter, **kwargs):
|
||
|
super().__init_subclass__(**kwargs)
|
||
|
cls.custom_parameter = custom_parameter
|
||
|
|
||
|
|
||
|
def test_set_output_mixin_custom_mixin():
|
||
|
"""Check that multiple init_subclasses passes parameters up."""
|
||
|
|
||
|
class BothMixinEstimator(_SetOutputMixin, AnotherMixin, custom_parameter=123):
|
||
|
def transform(self, X, y=None):
|
||
|
return X
|
||
|
|
||
|
def get_feature_names_out(self, input_features=None):
|
||
|
return input_features
|
||
|
|
||
|
est = BothMixinEstimator()
|
||
|
assert est.custom_parameter == 123
|
||
|
assert hasattr(est, "set_output")
|
||
|
|
||
|
|
||
|
def test_set_output_mro():
|
||
|
"""Check that multi-inheritance resolves to the correct class method.
|
||
|
|
||
|
Non-regression test gh-25293.
|
||
|
"""
|
||
|
|
||
|
class Base(_SetOutputMixin):
|
||
|
def transform(self, X):
|
||
|
return "Base" # noqa
|
||
|
|
||
|
class A(Base):
|
||
|
pass
|
||
|
|
||
|
class B(Base):
|
||
|
def transform(self, X):
|
||
|
return "B"
|
||
|
|
||
|
class C(A, B):
|
||
|
pass
|
||
|
|
||
|
assert C().transform(None) == "B"
|
||
|
|
||
|
|
||
|
class EstimatorWithSetOutputIndex(_SetOutputMixin):
|
||
|
def fit(self, X, y=None):
|
||
|
self.n_features_in_ = X.shape[1]
|
||
|
return self
|
||
|
|
||
|
def transform(self, X, y=None):
|
||
|
import pandas as pd
|
||
|
|
||
|
# transform by giving output a new index.
|
||
|
return pd.DataFrame(X.to_numpy(), index=[f"s{i}" for i in range(X.shape[0])])
|
||
|
|
||
|
def get_feature_names_out(self, input_features=None):
|
||
|
return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
|
||
|
|
||
|
|
||
|
def test_set_output_pandas_keep_index():
|
||
|
"""Check that set_output does not override index.
|
||
|
|
||
|
Non-regression test for gh-25730.
|
||
|
"""
|
||
|
pd = pytest.importorskip("pandas")
|
||
|
|
||
|
X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], index=[0, 1])
|
||
|
est = EstimatorWithSetOutputIndex().set_output(transform="pandas")
|
||
|
est.fit(X)
|
||
|
|
||
|
X_trans = est.transform(X)
|
||
|
assert_array_equal(X_trans.index, ["s0", "s1"])
|
||
|
|
||
|
|
||
|
class EstimatorReturnTuple(_SetOutputMixin):
|
||
|
def __init__(self, OutputTuple):
|
||
|
self.OutputTuple = OutputTuple
|
||
|
|
||
|
def transform(self, X, y=None):
|
||
|
return self.OutputTuple(X, 2 * X)
|
||
|
|
||
|
|
||
|
def test_set_output_named_tuple_out():
|
||
|
"""Check that namedtuples are kept by default."""
|
||
|
Output = namedtuple("Output", "X, Y")
|
||
|
X = np.asarray([[1, 2, 3]])
|
||
|
est = EstimatorReturnTuple(OutputTuple=Output)
|
||
|
X_trans = est.transform(X)
|
||
|
|
||
|
assert isinstance(X_trans, Output)
|
||
|
assert_array_equal(X_trans.X, X)
|
||
|
assert_array_equal(X_trans.Y, 2 * X)
|
||
|
|
||
|
|
||
|
class EstimatorWithListInput(_SetOutputMixin):
|
||
|
def fit(self, X, y=None):
|
||
|
assert isinstance(X, list)
|
||
|
self.n_features_in_ = len(X[0])
|
||
|
return self
|
||
|
|
||
|
def transform(self, X, y=None):
|
||
|
return X
|
||
|
|
||
|
def get_feature_names_out(self, input_features=None):
|
||
|
return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("dataframe_lib", ["pandas", "polars"])
|
||
|
def test_set_output_list_input(dataframe_lib):
|
||
|
"""Check set_output for list input.
|
||
|
|
||
|
Non-regression test for #27037.
|
||
|
"""
|
||
|
lib = pytest.importorskip(dataframe_lib)
|
||
|
|
||
|
X = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
||
|
est = EstimatorWithListInput()
|
||
|
est.set_output(transform=dataframe_lib)
|
||
|
|
||
|
X_out = est.fit(X).transform(X)
|
||
|
assert isinstance(X_out, lib.DataFrame)
|
||
|
assert_array_equal(X_out.columns, ["X0", "X1", "X2", "X3"])
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("name", sorted(ADAPTERS_MANAGER.adapters))
|
||
|
def test_adapter_class_has_interface(name):
|
||
|
"""Check adapters have the correct interface."""
|
||
|
assert isinstance(ADAPTERS_MANAGER.adapters[name], ContainerAdapterProtocol)
|
||
|
|
||
|
|
||
|
def test_check_library_installed(monkeypatch):
|
||
|
"""Check import error changed."""
|
||
|
orig_import_module = importlib.import_module
|
||
|
|
||
|
def patched_import_module(name):
|
||
|
if name == "pandas":
|
||
|
raise ImportError()
|
||
|
orig_import_module(name, package=None)
|
||
|
|
||
|
monkeypatch.setattr(importlib, "import_module", patched_import_module)
|
||
|
|
||
|
msg = "Setting output container to 'pandas' requires"
|
||
|
with pytest.raises(ImportError, match=msg):
|
||
|
check_library_installed("pandas")
|
||
|
|
||
|
|
||
|
def test_get_adapter_from_container():
|
||
|
"""Check the behavior fo `_get_adapter_from_container`."""
|
||
|
pd = pytest.importorskip("pandas")
|
||
|
X = pd.DataFrame({"a": [1, 2, 3], "b": [10, 20, 100]})
|
||
|
adapter = _get_adapter_from_container(X)
|
||
|
assert adapter.container_lib == "pandas"
|
||
|
err_msg = "The container does not have a registered adapter in scikit-learn."
|
||
|
with pytest.raises(ValueError, match=err_msg):
|
||
|
_get_adapter_from_container(X.to_numpy())
|