Inzynierka/Lib/site-packages/sklearn/utils/tests/test_metaestimators.py
2023-06-02 12:51:02 +02:00

173 lines
5.2 KiB
Python

import numpy as np
import pytest
import warnings
import pickle
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils.metaestimators import available_if
class Prefix:
def func(self):
pass
class MockMetaEstimator:
"""This is a mock meta estimator"""
a_prefix = Prefix()
@if_delegate_has_method(delegate="a_prefix")
def func(self):
"""This is a mock delegated function"""
pass
@pytest.mark.filterwarnings("ignore:if_delegate_has_method was deprecated")
def test_delegated_docstring():
assert "This is a mock delegated function" in str(
MockMetaEstimator.__dict__["func"].__doc__
)
assert "This is a mock delegated function" in str(MockMetaEstimator.func.__doc__)
assert "This is a mock delegated function" in str(MockMetaEstimator().func.__doc__)
class MetaEst:
"""A mock meta estimator"""
def __init__(self, sub_est, better_sub_est=None):
self.sub_est = sub_est
self.better_sub_est = better_sub_est
@if_delegate_has_method(delegate="sub_est")
def predict(self):
pass
class MetaEstTestTuple(MetaEst):
"""A mock meta estimator to test passing a tuple of delegates"""
@if_delegate_has_method(delegate=("sub_est", "better_sub_est"))
def predict(self):
pass
class MetaEstTestList(MetaEst):
"""A mock meta estimator to test passing a list of delegates"""
@if_delegate_has_method(delegate=["sub_est", "better_sub_est"])
def predict(self):
pass
class HasPredict:
"""A mock sub-estimator with predict method"""
def predict(self):
pass
class HasNoPredict:
"""A mock sub-estimator with no predict method"""
pass
class HasPredictAsNDArray:
"""A mock sub-estimator where predict is a NumPy array"""
predict = np.ones((10, 2), dtype=np.int64)
@pytest.mark.filterwarnings("ignore:if_delegate_has_method was deprecated")
def test_if_delegate_has_method():
assert hasattr(MetaEst(HasPredict()), "predict")
assert not hasattr(MetaEst(HasNoPredict()), "predict")
assert not hasattr(MetaEstTestTuple(HasNoPredict(), HasNoPredict()), "predict")
assert hasattr(MetaEstTestTuple(HasPredict(), HasNoPredict()), "predict")
assert not hasattr(MetaEstTestTuple(HasNoPredict(), HasPredict()), "predict")
assert not hasattr(MetaEstTestList(HasNoPredict(), HasPredict()), "predict")
assert hasattr(MetaEstTestList(HasPredict(), HasPredict()), "predict")
class AvailableParameterEstimator:
"""This estimator's `available` parameter toggles the presence of a method"""
def __init__(self, available=True, return_value=1):
self.available = available
self.return_value = return_value
@available_if(lambda est: est.available)
def available_func(self):
"""This is a mock available_if function"""
return self.return_value
def test_available_if_docstring():
assert "This is a mock available_if function" in str(
AvailableParameterEstimator.__dict__["available_func"].__doc__
)
assert "This is a mock available_if function" in str(
AvailableParameterEstimator.available_func.__doc__
)
assert "This is a mock available_if function" in str(
AvailableParameterEstimator().available_func.__doc__
)
def test_available_if():
assert hasattr(AvailableParameterEstimator(), "available_func")
assert not hasattr(AvailableParameterEstimator(available=False), "available_func")
def test_available_if_unbound_method():
# This is a non regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/20614
# to make sure that decorated functions can be used as an unbound method,
# for instance when monkeypatching.
est = AvailableParameterEstimator()
AvailableParameterEstimator.available_func(est)
est = AvailableParameterEstimator(available=False)
with pytest.raises(
AttributeError,
match="This 'AvailableParameterEstimator' has no attribute 'available_func'",
):
AvailableParameterEstimator.available_func(est)
@pytest.mark.filterwarnings("ignore:if_delegate_has_method was deprecated")
def test_if_delegate_has_method_numpy_array():
"""Check that we can check for an attribute that is a NumPy array.
This is a non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/21144
"""
estimator = MetaEst(HasPredictAsNDArray())
assert hasattr(estimator, "predict")
def test_if_delegate_has_method_deprecated():
"""Check the deprecation warning of if_delegate_has_method"""
# don't warn when creating the decorator
with warnings.catch_warnings():
warnings.simplefilter("error", FutureWarning)
_ = if_delegate_has_method(delegate="predict")
# Only when calling it
with pytest.warns(FutureWarning, match="if_delegate_has_method was deprecated"):
hasattr(MetaEst(HasPredict()), "predict")
def test_available_if_methods_can_be_pickled():
"""Check that available_if methods can be pickled.
Non-regression test for #21344.
"""
return_value = 10
est = AvailableParameterEstimator(available=True, return_value=return_value)
pickled_bytes = pickle.dumps(est.available_func)
unpickled_func = pickle.loads(pickled_bytes)
assert unpickled_func() == return_value