3RNN/Lib/site-packages/sklearn/utils/tests/test_set_output.py

465 lines
15 KiB
Python
Raw Permalink Normal View History

2024-05-26 19:49:15 +02:00
import importlib
from collections import namedtuple
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from sklearn._config import config_context, get_config
from sklearn.preprocessing import StandardScaler
from sklearn.utils._set_output import (
ADAPTERS_MANAGER,
ContainerAdapterProtocol,
_get_adapter_from_container,
_get_output_config,
_safe_set_output,
_SetOutputMixin,
_wrap_data_with_container,
check_library_installed,
)
from sklearn.utils.fixes import CSR_CONTAINERS
def test_pandas_adapter():
"""Check pandas adapter has expected behavior."""
pd = pytest.importorskip("pandas")
X_np = np.asarray([[1, 0, 3], [0, 0, 1]])
columns = np.asarray(["f0", "f1", "f2"], dtype=object)
index = np.asarray([0, 1])
X_df_orig = pd.DataFrame([[1, 2], [1, 3]], index=index)
adapter = ADAPTERS_MANAGER.adapters["pandas"]
X_container = adapter.create_container(X_np, X_df_orig, columns=lambda: columns)
assert isinstance(X_container, pd.DataFrame)
assert_array_equal(X_container.columns, columns)
assert_array_equal(X_container.index, index)
# Input dataframe's index does not change
new_columns = np.asarray(["f0", "f1"], dtype=object)
X_df = pd.DataFrame([[1, 2], [1, 3]], index=[10, 12])
new_df = adapter.create_container(X_df, X_df_orig, columns=new_columns)
assert_array_equal(new_df.columns, new_columns)
assert_array_equal(new_df.index, X_df.index)
assert adapter.is_supported_container(X_df)
assert not adapter.is_supported_container(X_np)
# adapter.update_columns updates the columns
new_columns = np.array(["a", "c"], dtype=object)
new_df = adapter.rename_columns(X_df, new_columns)
assert_array_equal(new_df.columns, new_columns)
# adapter.hstack stacks the dataframes horizontally.
X_df_1 = pd.DataFrame([[1, 2, 5], [3, 4, 6]], columns=["a", "b", "e"])
X_df_2 = pd.DataFrame([[4], [5]], columns=["c"])
X_stacked = adapter.hstack([X_df_1, X_df_2])
expected_df = pd.DataFrame(
[[1, 2, 5, 4], [3, 4, 6, 5]], columns=["a", "b", "e", "c"]
)
pd.testing.assert_frame_equal(X_stacked, expected_df)
# check that we update properly the columns even with duplicate column names
# this use-case potentially happen when using ColumnTransformer
# non-regression test for gh-28260
X_df = pd.DataFrame([[1, 2], [1, 3]], columns=["a", "a"])
new_columns = np.array(["x__a", "y__a"], dtype=object)
new_df = adapter.rename_columns(X_df, new_columns)
assert_array_equal(new_df.columns, new_columns)
# check the behavior of the inplace parameter in `create_container`
# we should trigger a copy
X_df = pd.DataFrame([[1, 2], [1, 3]], index=index)
X_output = adapter.create_container(X_df, X_df, columns=["a", "b"], inplace=False)
assert X_output is not X_df
assert list(X_df.columns) == [0, 1]
assert list(X_output.columns) == ["a", "b"]
# the operation is inplace
X_df = pd.DataFrame([[1, 2], [1, 3]], index=index)
X_output = adapter.create_container(X_df, X_df, columns=["a", "b"], inplace=True)
assert X_output is X_df
assert list(X_df.columns) == ["a", "b"]
assert list(X_output.columns) == ["a", "b"]
def test_polars_adapter():
"""Check Polars adapter has expected behavior."""
pl = pytest.importorskip("polars")
X_np = np.array([[1, 0, 3], [0, 0, 1]])
columns = ["f1", "f2", "f3"]
X_df_orig = pl.DataFrame(X_np, schema=columns, orient="row")
adapter = ADAPTERS_MANAGER.adapters["polars"]
X_container = adapter.create_container(X_np, X_df_orig, columns=lambda: columns)
assert isinstance(X_container, pl.DataFrame)
assert_array_equal(X_container.columns, columns)
# Update columns with create_container
new_columns = np.asarray(["a", "b", "c"], dtype=object)
new_df = adapter.create_container(X_df_orig, X_df_orig, columns=new_columns)
assert_array_equal(new_df.columns, new_columns)
assert adapter.is_supported_container(X_df_orig)
assert not adapter.is_supported_container(X_np)
# adapter.update_columns updates the columns
new_columns = np.array(["a", "c", "g"], dtype=object)
new_df = adapter.rename_columns(X_df_orig, new_columns)
assert_array_equal(new_df.columns, new_columns)
# adapter.hstack stacks the dataframes horizontally.
X_df_1 = pl.DataFrame([[1, 2, 5], [3, 4, 6]], schema=["a", "b", "e"], orient="row")
X_df_2 = pl.DataFrame([[4], [5]], schema=["c"], orient="row")
X_stacked = adapter.hstack([X_df_1, X_df_2])
expected_df = pl.DataFrame(
[[1, 2, 5, 4], [3, 4, 6, 5]], schema=["a", "b", "e", "c"], orient="row"
)
from polars.testing import assert_frame_equal
assert_frame_equal(X_stacked, expected_df)
# check the behavior of the inplace parameter in `create_container`
# we should trigger a copy
X_df = pl.DataFrame([[1, 2], [1, 3]], schema=["a", "b"], orient="row")
X_output = adapter.create_container(X_df, X_df, columns=["c", "d"], inplace=False)
assert X_output is not X_df
assert list(X_df.columns) == ["a", "b"]
assert list(X_output.columns) == ["c", "d"]
# the operation is inplace
X_df = pl.DataFrame([[1, 2], [1, 3]], schema=["a", "b"], orient="row")
X_output = adapter.create_container(X_df, X_df, columns=["c", "d"], inplace=True)
assert X_output is X_df
assert list(X_df.columns) == ["c", "d"]
assert list(X_output.columns) == ["c", "d"]
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
def test__container_error_validation(csr_container):
"""Check errors in _wrap_data_with_container."""
X = np.asarray([[1, 0, 3], [0, 0, 1]])
X_csr = csr_container(X)
match = "The transformer outputs a scipy sparse matrix."
with config_context(transform_output="pandas"):
with pytest.raises(ValueError, match=match):
_wrap_data_with_container("transform", X_csr, X, StandardScaler())
class EstimatorWithoutSetOutputAndWithoutTransform:
pass
class EstimatorNoSetOutputWithTransform:
def transform(self, X, y=None):
return X # pragma: no cover
class EstimatorWithSetOutput(_SetOutputMixin):
def fit(self, X, y=None):
self.n_features_in_ = X.shape[1]
return self
def transform(self, X, y=None):
return X
def get_feature_names_out(self, input_features=None):
return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
def test__safe_set_output():
"""Check _safe_set_output works as expected."""
# Estimator without transform will not raise when setting set_output for transform.
est = EstimatorWithoutSetOutputAndWithoutTransform()
_safe_set_output(est, transform="pandas")
# Estimator with transform but without set_output will raise
est = EstimatorNoSetOutputWithTransform()
with pytest.raises(ValueError, match="Unable to configure output"):
_safe_set_output(est, transform="pandas")
est = EstimatorWithSetOutput().fit(np.asarray([[1, 2, 3]]))
_safe_set_output(est, transform="pandas")
config = _get_output_config("transform", est)
assert config["dense"] == "pandas"
_safe_set_output(est, transform="default")
config = _get_output_config("transform", est)
assert config["dense"] == "default"
# transform is None is a no-op, so the config remains "default"
_safe_set_output(est, transform=None)
config = _get_output_config("transform", est)
assert config["dense"] == "default"
class EstimatorNoSetOutputWithTransformNoFeatureNamesOut(_SetOutputMixin):
def transform(self, X, y=None):
return X # pragma: no cover
def test_set_output_mixin():
"""Estimator without get_feature_names_out does not define `set_output`."""
est = EstimatorNoSetOutputWithTransformNoFeatureNamesOut()
assert not hasattr(est, "set_output")
def test__safe_set_output_error():
"""Check transform with invalid config."""
X = np.asarray([[1, 0, 3], [0, 0, 1]])
est = EstimatorWithSetOutput()
_safe_set_output(est, transform="bad")
msg = "output config must be in"
with pytest.raises(ValueError, match=msg):
est.transform(X)
@pytest.mark.parametrize("dataframe_lib", ["pandas", "polars"])
def test_set_output_method(dataframe_lib):
"""Check that the output is a dataframe."""
lib = pytest.importorskip(dataframe_lib)
X = np.asarray([[1, 0, 3], [0, 0, 1]])
est = EstimatorWithSetOutput().fit(X)
# transform=None is a no-op
est2 = est.set_output(transform=None)
assert est2 is est
X_trans_np = est2.transform(X)
assert isinstance(X_trans_np, np.ndarray)
est.set_output(transform=dataframe_lib)
X_trans_pd = est.transform(X)
assert isinstance(X_trans_pd, lib.DataFrame)
def test_set_output_method_error():
"""Check transform fails with invalid transform."""
X = np.asarray([[1, 0, 3], [0, 0, 1]])
est = EstimatorWithSetOutput().fit(X)
est.set_output(transform="bad")
msg = "output config must be in"
with pytest.raises(ValueError, match=msg):
est.transform(X)
@pytest.mark.parametrize("transform_output", ["pandas", "polars"])
def test__get_output_config(transform_output):
"""Check _get_output_config works as expected."""
# Without a configuration set, the global config is used
global_config = get_config()["transform_output"]
config = _get_output_config("transform")
assert config["dense"] == global_config
with config_context(transform_output=transform_output):
# with estimator=None, the global config is used
config = _get_output_config("transform")
assert config["dense"] == transform_output
est = EstimatorNoSetOutputWithTransform()
config = _get_output_config("transform", est)
assert config["dense"] == transform_output
est = EstimatorWithSetOutput()
# If estimator has not config, use global config
config = _get_output_config("transform", est)
assert config["dense"] == transform_output
# If estimator has a config, use local config
est.set_output(transform="default")
config = _get_output_config("transform", est)
assert config["dense"] == "default"
est.set_output(transform=transform_output)
config = _get_output_config("transform", est)
assert config["dense"] == transform_output
class EstimatorWithSetOutputNoAutoWrap(_SetOutputMixin, auto_wrap_output_keys=None):
def transform(self, X, y=None):
return X
def test_get_output_auto_wrap_false():
"""Check that auto_wrap_output_keys=None does not wrap."""
est = EstimatorWithSetOutputNoAutoWrap()
assert not hasattr(est, "set_output")
X = np.asarray([[1, 0, 3], [0, 0, 1]])
assert X is est.transform(X)
def test_auto_wrap_output_keys_errors_with_incorrect_input():
msg = "auto_wrap_output_keys must be None or a tuple of keys."
with pytest.raises(ValueError, match=msg):
class BadEstimator(_SetOutputMixin, auto_wrap_output_keys="bad_parameter"):
pass
class AnotherMixin:
def __init_subclass__(cls, custom_parameter, **kwargs):
super().__init_subclass__(**kwargs)
cls.custom_parameter = custom_parameter
def test_set_output_mixin_custom_mixin():
"""Check that multiple init_subclasses passes parameters up."""
class BothMixinEstimator(_SetOutputMixin, AnotherMixin, custom_parameter=123):
def transform(self, X, y=None):
return X
def get_feature_names_out(self, input_features=None):
return input_features
est = BothMixinEstimator()
assert est.custom_parameter == 123
assert hasattr(est, "set_output")
def test_set_output_mro():
"""Check that multi-inheritance resolves to the correct class method.
Non-regression test gh-25293.
"""
class Base(_SetOutputMixin):
def transform(self, X):
return "Base" # noqa
class A(Base):
pass
class B(Base):
def transform(self, X):
return "B"
class C(A, B):
pass
assert C().transform(None) == "B"
class EstimatorWithSetOutputIndex(_SetOutputMixin):
def fit(self, X, y=None):
self.n_features_in_ = X.shape[1]
return self
def transform(self, X, y=None):
import pandas as pd
# transform by giving output a new index.
return pd.DataFrame(X.to_numpy(), index=[f"s{i}" for i in range(X.shape[0])])
def get_feature_names_out(self, input_features=None):
return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
def test_set_output_pandas_keep_index():
"""Check that set_output does not override index.
Non-regression test for gh-25730.
"""
pd = pytest.importorskip("pandas")
X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], index=[0, 1])
est = EstimatorWithSetOutputIndex().set_output(transform="pandas")
est.fit(X)
X_trans = est.transform(X)
assert_array_equal(X_trans.index, ["s0", "s1"])
class EstimatorReturnTuple(_SetOutputMixin):
def __init__(self, OutputTuple):
self.OutputTuple = OutputTuple
def transform(self, X, y=None):
return self.OutputTuple(X, 2 * X)
def test_set_output_named_tuple_out():
"""Check that namedtuples are kept by default."""
Output = namedtuple("Output", "X, Y")
X = np.asarray([[1, 2, 3]])
est = EstimatorReturnTuple(OutputTuple=Output)
X_trans = est.transform(X)
assert isinstance(X_trans, Output)
assert_array_equal(X_trans.X, X)
assert_array_equal(X_trans.Y, 2 * X)
class EstimatorWithListInput(_SetOutputMixin):
def fit(self, X, y=None):
assert isinstance(X, list)
self.n_features_in_ = len(X[0])
return self
def transform(self, X, y=None):
return X
def get_feature_names_out(self, input_features=None):
return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
@pytest.mark.parametrize("dataframe_lib", ["pandas", "polars"])
def test_set_output_list_input(dataframe_lib):
"""Check set_output for list input.
Non-regression test for #27037.
"""
lib = pytest.importorskip(dataframe_lib)
X = [[0, 1, 2, 3], [4, 5, 6, 7]]
est = EstimatorWithListInput()
est.set_output(transform=dataframe_lib)
X_out = est.fit(X).transform(X)
assert isinstance(X_out, lib.DataFrame)
assert_array_equal(X_out.columns, ["X0", "X1", "X2", "X3"])
@pytest.mark.parametrize("name", sorted(ADAPTERS_MANAGER.adapters))
def test_adapter_class_has_interface(name):
"""Check adapters have the correct interface."""
assert isinstance(ADAPTERS_MANAGER.adapters[name], ContainerAdapterProtocol)
def test_check_library_installed(monkeypatch):
"""Check import error changed."""
orig_import_module = importlib.import_module
def patched_import_module(name):
if name == "pandas":
raise ImportError()
orig_import_module(name, package=None)
monkeypatch.setattr(importlib, "import_module", patched_import_module)
msg = "Setting output container to 'pandas' requires"
with pytest.raises(ImportError, match=msg):
check_library_installed("pandas")
def test_get_adapter_from_container():
"""Check the behavior fo `_get_adapter_from_container`."""
pd = pytest.importorskip("pandas")
X = pd.DataFrame({"a": [1, 2, 3], "b": [10, 20, 100]})
adapter = _get_adapter_from_container(X)
assert adapter.container_lib == "pandas"
err_msg = "The container does not have a registered adapter in scikit-learn."
with pytest.raises(ValueError, match=err_msg):
_get_adapter_from_container(X.to_numpy())