import pytest import numpy as np from scipy.sparse import csr_matrix from numpy.testing import assert_array_equal from sklearn._config import config_context, get_config from sklearn.utils._set_output import _wrap_in_pandas_container from sklearn.utils._set_output import _safe_set_output from sklearn.utils._set_output import _SetOutputMixin from sklearn.utils._set_output import _get_output_config def test__wrap_in_pandas_container_dense(): """Check _wrap_in_pandas_container for dense data.""" pd = pytest.importorskip("pandas") X = np.asarray([[1, 0, 3], [0, 0, 1]]) columns = np.asarray(["f0", "f1", "f2"], dtype=object) index = np.asarray([0, 1]) dense_named = _wrap_in_pandas_container(X, columns=lambda: columns, index=index) assert isinstance(dense_named, pd.DataFrame) assert_array_equal(dense_named.columns, columns) assert_array_equal(dense_named.index, index) def test__wrap_in_pandas_container_dense_update_columns_and_index(): """Check that _wrap_in_pandas_container overrides columns and index.""" pd = pytest.importorskip("pandas") X_df = pd.DataFrame([[1, 0, 3], [0, 0, 1]], columns=["a", "b", "c"]) new_columns = np.asarray(["f0", "f1", "f2"], dtype=object) new_index = [10, 12] new_df = _wrap_in_pandas_container(X_df, columns=new_columns, index=new_index) assert_array_equal(new_df.columns, new_columns) # Index does not change when the input is a DataFrame assert_array_equal(new_df.index, X_df.index) def test__wrap_in_pandas_container_error_validation(): """Check errors in _wrap_in_pandas_container.""" X = np.asarray([[1, 0, 3], [0, 0, 1]]) X_csr = csr_matrix(X) match = "Pandas output does not support sparse data" with pytest.raises(ValueError, match=match): _wrap_in_pandas_container(X_csr, columns=["a", "b", "c"]) class EstimatorWithoutSetOutputAndWithoutTransform: pass class EstimatorNoSetOutputWithTransform: def transform(self, X, y=None): return X # pragma: no cover class EstimatorWithSetOutput(_SetOutputMixin): def fit(self, X, y=None): self.n_features_in_ = X.shape[1] return self def transform(self, X, y=None): return X def get_feature_names_out(self, input_features=None): return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object) def test__safe_set_output(): """Check _safe_set_output works as expected.""" # Estimator without transform will not raise when setting set_output for transform. est = EstimatorWithoutSetOutputAndWithoutTransform() _safe_set_output(est, transform="pandas") # Estimator with transform but without set_output will raise est = EstimatorNoSetOutputWithTransform() with pytest.raises(ValueError, match="Unable to configure output"): _safe_set_output(est, transform="pandas") est = EstimatorWithSetOutput().fit(np.asarray([[1, 2, 3]])) _safe_set_output(est, transform="pandas") config = _get_output_config("transform", est) assert config["dense"] == "pandas" _safe_set_output(est, transform="default") config = _get_output_config("transform", est) assert config["dense"] == "default" # transform is None is a no-op, so the config remains "default" _safe_set_output(est, transform=None) config = _get_output_config("transform", est) assert config["dense"] == "default" class EstimatorNoSetOutputWithTransformNoFeatureNamesOut(_SetOutputMixin): def transform(self, X, y=None): return X # pragma: no cover def test_set_output_mixin(): """Estimator without get_feature_names_out does not define `set_output`.""" est = EstimatorNoSetOutputWithTransformNoFeatureNamesOut() assert not hasattr(est, "set_output") def test__safe_set_output_error(): """Check transform with invalid config.""" X = np.asarray([[1, 0, 3], [0, 0, 1]]) est = EstimatorWithSetOutput() _safe_set_output(est, transform="bad") msg = "output config must be 'default'" with pytest.raises(ValueError, match=msg): est.transform(X) def test_set_output_method(): """Check that the output is pandas.""" pd = pytest.importorskip("pandas") X = np.asarray([[1, 0, 3], [0, 0, 1]]) est = EstimatorWithSetOutput().fit(X) # transform=None is a no-op est2 = est.set_output(transform=None) assert est2 is est X_trans_np = est2.transform(X) assert isinstance(X_trans_np, np.ndarray) est.set_output(transform="pandas") X_trans_pd = est.transform(X) assert isinstance(X_trans_pd, pd.DataFrame) def test_set_output_method_error(): """Check transform fails with invalid transform.""" X = np.asarray([[1, 0, 3], [0, 0, 1]]) est = EstimatorWithSetOutput().fit(X) est.set_output(transform="bad") msg = "output config must be 'default'" with pytest.raises(ValueError, match=msg): est.transform(X) def test__get_output_config(): """Check _get_output_config works as expected.""" # Without a configuration set, the global config is used global_config = get_config()["transform_output"] config = _get_output_config("transform") assert config["dense"] == global_config with config_context(transform_output="pandas"): # with estimator=None, the global config is used config = _get_output_config("transform") assert config["dense"] == "pandas" est = EstimatorNoSetOutputWithTransform() config = _get_output_config("transform", est) assert config["dense"] == "pandas" est = EstimatorWithSetOutput() # If estimator has not config, use global config config = _get_output_config("transform", est) assert config["dense"] == "pandas" # If estimator has a config, use local config est.set_output(transform="default") config = _get_output_config("transform", est) assert config["dense"] == "default" est.set_output(transform="pandas") config = _get_output_config("transform", est) assert config["dense"] == "pandas" class EstimatorWithSetOutputNoAutoWrap(_SetOutputMixin, auto_wrap_output_keys=None): def transform(self, X, y=None): return X def test_get_output_auto_wrap_false(): """Check that auto_wrap_output_keys=None does not wrap.""" est = EstimatorWithSetOutputNoAutoWrap() assert not hasattr(est, "set_output") X = np.asarray([[1, 0, 3], [0, 0, 1]]) assert X is est.transform(X) def test_auto_wrap_output_keys_errors_with_incorrect_input(): msg = "auto_wrap_output_keys must be None or a tuple of keys." with pytest.raises(ValueError, match=msg): class BadEstimator(_SetOutputMixin, auto_wrap_output_keys="bad_parameter"): pass class AnotherMixin: def __init_subclass__(cls, custom_parameter, **kwargs): super().__init_subclass__(**kwargs) cls.custom_parameter = custom_parameter def test_set_output_mixin_custom_mixin(): """Check that multiple init_subclasses passes parameters up.""" class BothMixinEstimator(_SetOutputMixin, AnotherMixin, custom_parameter=123): def transform(self, X, y=None): return X def get_feature_names_out(self, input_features=None): return input_features est = BothMixinEstimator() assert est.custom_parameter == 123 assert hasattr(est, "set_output") def test__wrap_in_pandas_container_column_errors(): """If a callable `columns` errors, it has the same semantics as columns=None.""" pd = pytest.importorskip("pandas") def get_columns(): raise ValueError("No feature names defined") X_df = pd.DataFrame({"feat1": [1, 2, 3], "feat2": [3, 4, 5]}) X_wrapped = _wrap_in_pandas_container(X_df, columns=get_columns) assert_array_equal(X_wrapped.columns, X_df.columns) X_np = np.asarray([[1, 3], [2, 4], [3, 5]]) X_wrapped = _wrap_in_pandas_container(X_np, columns=get_columns) assert_array_equal(X_wrapped.columns, range(X_np.shape[1])) def test_set_output_mro(): """Check that multi-inheritance resolves to the correct class method. Non-regression test gh-25293. """ class Base(_SetOutputMixin): def transform(self, X): return "Base" # noqa class A(Base): pass class B(Base): def transform(self, X): return "B" class C(A, B): pass assert C().transform(None) == "B" class EstimatorWithSetOutputIndex(_SetOutputMixin): def fit(self, X, y=None): self.n_features_in_ = X.shape[1] return self def transform(self, X, y=None): import pandas as pd # transform by giving output a new index. return pd.DataFrame(X.to_numpy(), index=[f"s{i}" for i in range(X.shape[0])]) def get_feature_names_out(self, input_features=None): return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object) def test_set_output_pandas_keep_index(): """Check that set_output does not override index. Non-regression test for gh-25730. """ pd = pytest.importorskip("pandas") X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], index=[0, 1]) est = EstimatorWithSetOutputIndex().set_output(transform="pandas") est.fit(X) X_trans = est.transform(X) assert_array_equal(X_trans.index, ["s0", "s1"])