672 lines
21 KiB
Python
672 lines
21 KiB
Python
|
from numbers import Integral, Real
|
||
|
|
||
|
import numpy as np
|
||
|
from scipy.sparse import csr_matrix
|
||
|
import pytest
|
||
|
|
||
|
from sklearn.base import BaseEstimator
|
||
|
from sklearn.model_selection import LeaveOneOut
|
||
|
from sklearn.utils import deprecated
|
||
|
from sklearn.utils._param_validation import Hidden
|
||
|
from sklearn.utils._param_validation import Interval
|
||
|
from sklearn.utils._param_validation import Options
|
||
|
from sklearn.utils._param_validation import StrOptions
|
||
|
from sklearn.utils._param_validation import _ArrayLikes
|
||
|
from sklearn.utils._param_validation import _Booleans
|
||
|
from sklearn.utils._param_validation import _Callables
|
||
|
from sklearn.utils._param_validation import _CVObjects
|
||
|
from sklearn.utils._param_validation import _InstancesOf
|
||
|
from sklearn.utils._param_validation import _MissingValues
|
||
|
from sklearn.utils._param_validation import _PandasNAConstraint
|
||
|
from sklearn.utils._param_validation import _IterablesNotString
|
||
|
from sklearn.utils._param_validation import _NoneConstraint
|
||
|
from sklearn.utils._param_validation import _RandomStates
|
||
|
from sklearn.utils._param_validation import _SparseMatrices
|
||
|
from sklearn.utils._param_validation import _VerboseHelper
|
||
|
from sklearn.utils._param_validation import HasMethods
|
||
|
from sklearn.utils._param_validation import make_constraint
|
||
|
from sklearn.utils._param_validation import generate_invalid_param_val
|
||
|
from sklearn.utils._param_validation import generate_valid_param
|
||
|
from sklearn.utils._param_validation import validate_params
|
||
|
from sklearn.utils._param_validation import InvalidParameterError
|
||
|
|
||
|
|
||
|
# Some helpers for the tests
|
||
|
@validate_params({"a": [Real], "b": [Real], "c": [Real], "d": [Real]})
|
||
|
def _func(a, b=0, *args, c, d=0, **kwargs):
|
||
|
"""A function to test the validation of functions."""
|
||
|
|
||
|
|
||
|
class _Class:
|
||
|
"""A class to test the _InstancesOf constraint and the validation of methods."""
|
||
|
|
||
|
@validate_params({"a": [Real]})
|
||
|
def _method(self, a):
|
||
|
"""A validated method"""
|
||
|
|
||
|
@deprecated()
|
||
|
@validate_params({"a": [Real]})
|
||
|
def _deprecated_method(self, a):
|
||
|
"""A deprecated validated method"""
|
||
|
|
||
|
|
||
|
class _Estimator(BaseEstimator):
|
||
|
"""An estimator to test the validation of estimator parameters."""
|
||
|
|
||
|
_parameter_constraints: dict = {"a": [Real]}
|
||
|
|
||
|
def __init__(self, a):
|
||
|
self.a = a
|
||
|
|
||
|
def fit(self, X=None, y=None):
|
||
|
self._validate_params()
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("interval_type", [Integral, Real])
|
||
|
def test_interval_range(interval_type):
|
||
|
"""Check the range of values depending on closed."""
|
||
|
interval = Interval(interval_type, -2, 2, closed="left")
|
||
|
assert -2 in interval and 2 not in interval
|
||
|
|
||
|
interval = Interval(interval_type, -2, 2, closed="right")
|
||
|
assert -2 not in interval and 2 in interval
|
||
|
|
||
|
interval = Interval(interval_type, -2, 2, closed="both")
|
||
|
assert -2 in interval and 2 in interval
|
||
|
|
||
|
interval = Interval(interval_type, -2, 2, closed="neither")
|
||
|
assert -2 not in interval and 2 not in interval
|
||
|
|
||
|
|
||
|
def test_interval_inf_in_bounds():
|
||
|
"""Check that inf is included iff a bound is closed and set to None.
|
||
|
|
||
|
Only valid for real intervals.
|
||
|
"""
|
||
|
interval = Interval(Real, 0, None, closed="right")
|
||
|
assert np.inf in interval
|
||
|
|
||
|
interval = Interval(Real, None, 0, closed="left")
|
||
|
assert -np.inf in interval
|
||
|
|
||
|
interval = Interval(Real, None, None, closed="neither")
|
||
|
assert np.inf not in interval
|
||
|
assert -np.inf not in interval
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"interval",
|
||
|
[Interval(Real, 0, 1, closed="left"), Interval(Real, None, None, closed="both")],
|
||
|
)
|
||
|
def test_nan_not_in_interval(interval):
|
||
|
"""Check that np.nan is not in any interval."""
|
||
|
assert np.nan not in interval
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"params, error, match",
|
||
|
[
|
||
|
(
|
||
|
{"type": Integral, "left": 1.0, "right": 2, "closed": "both"},
|
||
|
TypeError,
|
||
|
r"Expecting left to be an int for an interval over the integers",
|
||
|
),
|
||
|
(
|
||
|
{"type": Integral, "left": 1, "right": 2.0, "closed": "neither"},
|
||
|
TypeError,
|
||
|
"Expecting right to be an int for an interval over the integers",
|
||
|
),
|
||
|
(
|
||
|
{"type": Integral, "left": None, "right": 0, "closed": "left"},
|
||
|
ValueError,
|
||
|
r"left can't be None when closed == left",
|
||
|
),
|
||
|
(
|
||
|
{"type": Integral, "left": 0, "right": None, "closed": "right"},
|
||
|
ValueError,
|
||
|
r"right can't be None when closed == right",
|
||
|
),
|
||
|
(
|
||
|
{"type": Integral, "left": 1, "right": -1, "closed": "both"},
|
||
|
ValueError,
|
||
|
r"right can't be less than left",
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
def test_interval_errors(params, error, match):
|
||
|
"""Check that informative errors are raised for invalid combination of parameters"""
|
||
|
with pytest.raises(error, match=match):
|
||
|
Interval(**params)
|
||
|
|
||
|
|
||
|
def test_stroptions():
|
||
|
"""Sanity check for the StrOptions constraint"""
|
||
|
options = StrOptions({"a", "b", "c"}, deprecated={"c"})
|
||
|
assert options.is_satisfied_by("a")
|
||
|
assert options.is_satisfied_by("c")
|
||
|
assert not options.is_satisfied_by("d")
|
||
|
|
||
|
assert "'c' (deprecated)" in str(options)
|
||
|
|
||
|
|
||
|
def test_options():
|
||
|
"""Sanity check for the Options constraint"""
|
||
|
options = Options(Real, {-0.5, 0.5, np.inf}, deprecated={-0.5})
|
||
|
assert options.is_satisfied_by(-0.5)
|
||
|
assert options.is_satisfied_by(np.inf)
|
||
|
assert not options.is_satisfied_by(1.23)
|
||
|
|
||
|
assert "-0.5 (deprecated)" in str(options)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"type, expected_type_name",
|
||
|
[
|
||
|
(int, "int"),
|
||
|
(Integral, "int"),
|
||
|
(Real, "float"),
|
||
|
(np.ndarray, "numpy.ndarray"),
|
||
|
],
|
||
|
)
|
||
|
def test_instances_of_type_human_readable(type, expected_type_name):
|
||
|
"""Check the string representation of the _InstancesOf constraint."""
|
||
|
constraint = _InstancesOf(type)
|
||
|
assert str(constraint) == f"an instance of '{expected_type_name}'"
|
||
|
|
||
|
|
||
|
def test_hasmethods():
|
||
|
"""Check the HasMethods constraint."""
|
||
|
constraint = HasMethods(["a", "b"])
|
||
|
|
||
|
class _Good:
|
||
|
def a(self):
|
||
|
pass # pragma: no cover
|
||
|
|
||
|
def b(self):
|
||
|
pass # pragma: no cover
|
||
|
|
||
|
class _Bad:
|
||
|
def a(self):
|
||
|
pass # pragma: no cover
|
||
|
|
||
|
assert constraint.is_satisfied_by(_Good())
|
||
|
assert not constraint.is_satisfied_by(_Bad())
|
||
|
assert str(constraint) == "an object implementing 'a' and 'b'"
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"constraint",
|
||
|
[
|
||
|
Interval(Real, None, 0, closed="left"),
|
||
|
Interval(Real, 0, None, closed="left"),
|
||
|
Interval(Real, None, None, closed="neither"),
|
||
|
StrOptions({"a", "b", "c"}),
|
||
|
_MissingValues(),
|
||
|
_VerboseHelper(),
|
||
|
HasMethods("fit"),
|
||
|
_IterablesNotString(),
|
||
|
_CVObjects(),
|
||
|
],
|
||
|
)
|
||
|
def test_generate_invalid_param_val(constraint):
|
||
|
"""Check that the value generated does not satisfy the constraint"""
|
||
|
bad_value = generate_invalid_param_val(constraint)
|
||
|
assert not constraint.is_satisfied_by(bad_value)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"integer_interval, real_interval",
|
||
|
[
|
||
|
(
|
||
|
Interval(Integral, None, 3, closed="right"),
|
||
|
Interval(Real, -5, 5, closed="both"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, None, 3, closed="right"),
|
||
|
Interval(Real, -5, 5, closed="neither"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, None, 3, closed="right"),
|
||
|
Interval(Real, 4, 5, closed="both"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, None, 3, closed="right"),
|
||
|
Interval(Real, 5, None, closed="left"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, None, 3, closed="right"),
|
||
|
Interval(Real, 4, None, closed="neither"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, 3, None, closed="left"),
|
||
|
Interval(Real, -5, 5, closed="both"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, 3, None, closed="left"),
|
||
|
Interval(Real, -5, 5, closed="neither"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, 3, None, closed="left"),
|
||
|
Interval(Real, 1, 2, closed="both"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, 3, None, closed="left"),
|
||
|
Interval(Real, None, -5, closed="left"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, 3, None, closed="left"),
|
||
|
Interval(Real, None, -4, closed="neither"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, -5, 5, closed="both"),
|
||
|
Interval(Real, None, 1, closed="right"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, -5, 5, closed="both"),
|
||
|
Interval(Real, 1, None, closed="left"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, -5, 5, closed="both"),
|
||
|
Interval(Real, -10, -4, closed="neither"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, -5, 5, closed="both"),
|
||
|
Interval(Real, -10, -4, closed="right"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, -5, 5, closed="neither"),
|
||
|
Interval(Real, 6, 10, closed="neither"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, -5, 5, closed="neither"),
|
||
|
Interval(Real, 6, 10, closed="left"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, 2, None, closed="left"),
|
||
|
Interval(Real, 0, 1, closed="both"),
|
||
|
),
|
||
|
(
|
||
|
Interval(Integral, 1, None, closed="left"),
|
||
|
Interval(Real, 0, 1, closed="both"),
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
def test_generate_invalid_param_val_2_intervals(integer_interval, real_interval):
|
||
|
"""Check that the value generated for an interval constraint does not satisfy any of
|
||
|
the interval constraints.
|
||
|
"""
|
||
|
bad_value = generate_invalid_param_val(
|
||
|
real_interval, constraints=[real_interval, integer_interval]
|
||
|
)
|
||
|
assert not real_interval.is_satisfied_by(bad_value)
|
||
|
assert not integer_interval.is_satisfied_by(bad_value)
|
||
|
|
||
|
bad_value = generate_invalid_param_val(
|
||
|
integer_interval, constraints=[real_interval, integer_interval]
|
||
|
)
|
||
|
assert not real_interval.is_satisfied_by(bad_value)
|
||
|
assert not integer_interval.is_satisfied_by(bad_value)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"constraints",
|
||
|
[
|
||
|
[_ArrayLikes()],
|
||
|
[_InstancesOf(list)],
|
||
|
[_Callables()],
|
||
|
[_NoneConstraint()],
|
||
|
[_RandomStates()],
|
||
|
[_SparseMatrices()],
|
||
|
[_Booleans()],
|
||
|
[Interval(Real, None, None, closed="both")],
|
||
|
[
|
||
|
Interval(Integral, 0, None, closed="left"),
|
||
|
Interval(Real, None, 0, closed="neither"),
|
||
|
],
|
||
|
],
|
||
|
)
|
||
|
def test_generate_invalid_param_val_all_valid(constraints):
|
||
|
"""Check that the function raises NotImplementedError when there's no invalid value
|
||
|
for the constraint.
|
||
|
"""
|
||
|
with pytest.raises(NotImplementedError):
|
||
|
generate_invalid_param_val(constraints[0], constraints=constraints)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"constraint",
|
||
|
[
|
||
|
_ArrayLikes(),
|
||
|
_Callables(),
|
||
|
_InstancesOf(list),
|
||
|
_NoneConstraint(),
|
||
|
_RandomStates(),
|
||
|
_SparseMatrices(),
|
||
|
_Booleans(),
|
||
|
_VerboseHelper(),
|
||
|
_MissingValues(),
|
||
|
StrOptions({"a", "b", "c"}),
|
||
|
Options(Integral, {1, 2, 3}),
|
||
|
Interval(Integral, None, None, closed="neither"),
|
||
|
Interval(Integral, 0, 10, closed="neither"),
|
||
|
Interval(Integral, 0, None, closed="neither"),
|
||
|
Interval(Integral, None, 0, closed="neither"),
|
||
|
Interval(Real, 0, 1, closed="neither"),
|
||
|
Interval(Real, 0, None, closed="both"),
|
||
|
Interval(Real, None, 0, closed="right"),
|
||
|
HasMethods("fit"),
|
||
|
_IterablesNotString(),
|
||
|
_CVObjects(),
|
||
|
],
|
||
|
)
|
||
|
def test_generate_valid_param(constraint):
|
||
|
"""Check that the value generated does satisfy the constraint."""
|
||
|
value = generate_valid_param(constraint)
|
||
|
assert constraint.is_satisfied_by(value)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"constraint_declaration, value",
|
||
|
[
|
||
|
(Interval(Real, 0, 1, closed="both"), 0.42),
|
||
|
(Interval(Integral, 0, None, closed="neither"), 42),
|
||
|
(StrOptions({"a", "b", "c"}), "b"),
|
||
|
(Options(type, {np.float32, np.float64}), np.float64),
|
||
|
(callable, lambda x: x + 1),
|
||
|
(None, None),
|
||
|
("array-like", [[1, 2], [3, 4]]),
|
||
|
("array-like", np.array([[1, 2], [3, 4]])),
|
||
|
("sparse matrix", csr_matrix([[1, 2], [3, 4]])),
|
||
|
("random_state", 0),
|
||
|
("random_state", np.random.RandomState(0)),
|
||
|
("random_state", None),
|
||
|
(_Class, _Class()),
|
||
|
(int, 1),
|
||
|
(Real, 0.5),
|
||
|
("boolean", False),
|
||
|
("verbose", 1),
|
||
|
("missing_values", -1),
|
||
|
("missing_values", -1.0),
|
||
|
("missing_values", None),
|
||
|
("missing_values", float("nan")),
|
||
|
("missing_values", np.nan),
|
||
|
("missing_values", "missing"),
|
||
|
(HasMethods("fit"), _Estimator(a=0)),
|
||
|
("cv_object", 5),
|
||
|
],
|
||
|
)
|
||
|
def test_is_satisfied_by(constraint_declaration, value):
|
||
|
"""Sanity check for the is_satisfied_by method"""
|
||
|
constraint = make_constraint(constraint_declaration)
|
||
|
assert constraint.is_satisfied_by(value)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"constraint_declaration, expected_constraint_class",
|
||
|
[
|
||
|
(Interval(Real, 0, 1, closed="both"), Interval),
|
||
|
(StrOptions({"option1", "option2"}), StrOptions),
|
||
|
(Options(Real, {0.42, 1.23}), Options),
|
||
|
("array-like", _ArrayLikes),
|
||
|
("sparse matrix", _SparseMatrices),
|
||
|
("random_state", _RandomStates),
|
||
|
(None, _NoneConstraint),
|
||
|
(callable, _Callables),
|
||
|
(int, _InstancesOf),
|
||
|
("boolean", _Booleans),
|
||
|
("verbose", _VerboseHelper),
|
||
|
("missing_values", _MissingValues),
|
||
|
(HasMethods("fit"), HasMethods),
|
||
|
("cv_object", _CVObjects),
|
||
|
],
|
||
|
)
|
||
|
def test_make_constraint(constraint_declaration, expected_constraint_class):
|
||
|
"""Check that make_constraint dispaches to the appropriate constraint class"""
|
||
|
constraint = make_constraint(constraint_declaration)
|
||
|
assert constraint.__class__ is expected_constraint_class
|
||
|
|
||
|
|
||
|
def test_make_constraint_unknown():
|
||
|
"""Check that an informative error is raised when an unknown constraint is passed"""
|
||
|
with pytest.raises(ValueError, match="Unknown constraint"):
|
||
|
make_constraint("not a valid constraint")
|
||
|
|
||
|
|
||
|
def test_validate_params():
|
||
|
"""Check that validate_params works no matter how the arguments are passed"""
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError, match="The 'a' parameter of _func must be"
|
||
|
):
|
||
|
_func("wrong", c=1)
|
||
|
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError, match="The 'b' parameter of _func must be"
|
||
|
):
|
||
|
_func(*[1, "wrong"], c=1)
|
||
|
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError, match="The 'c' parameter of _func must be"
|
||
|
):
|
||
|
_func(1, **{"c": "wrong"})
|
||
|
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError, match="The 'd' parameter of _func must be"
|
||
|
):
|
||
|
_func(1, c=1, d="wrong")
|
||
|
|
||
|
# check in the presence of extra positional and keyword args
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError, match="The 'b' parameter of _func must be"
|
||
|
):
|
||
|
_func(0, *["wrong", 2, 3], c=4, **{"e": 5})
|
||
|
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError, match="The 'c' parameter of _func must be"
|
||
|
):
|
||
|
_func(0, *[1, 2, 3], c="four", **{"e": 5})
|
||
|
|
||
|
|
||
|
def test_validate_params_missing_params():
|
||
|
"""Check that no error is raised when there are parameters without
|
||
|
constraints
|
||
|
"""
|
||
|
|
||
|
@validate_params({"a": [int]})
|
||
|
def func(a, b):
|
||
|
pass
|
||
|
|
||
|
func(1, 2)
|
||
|
|
||
|
|
||
|
def test_decorate_validated_function():
|
||
|
"""Check that validate_params functions can be decorated"""
|
||
|
decorated_function = deprecated()(_func)
|
||
|
|
||
|
with pytest.warns(FutureWarning, match="Function _func is deprecated"):
|
||
|
decorated_function(1, 2, c=3)
|
||
|
|
||
|
# outer decorator does not interfer with validation
|
||
|
with pytest.warns(FutureWarning, match="Function _func is deprecated"):
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError, match=r"The 'c' parameter of _func must be"
|
||
|
):
|
||
|
decorated_function(1, 2, c="wrong")
|
||
|
|
||
|
|
||
|
def test_validate_params_method():
|
||
|
"""Check that validate_params works with methods"""
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError, match="The 'a' parameter of _Class._method must be"
|
||
|
):
|
||
|
_Class()._method("wrong")
|
||
|
|
||
|
# validated method can be decorated
|
||
|
with pytest.warns(FutureWarning, match="Function _deprecated_method is deprecated"):
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError,
|
||
|
match="The 'a' parameter of _Class._deprecated_method must be",
|
||
|
):
|
||
|
_Class()._deprecated_method("wrong")
|
||
|
|
||
|
|
||
|
def test_validate_params_estimator():
|
||
|
"""Check that validate_params works with Estimator instances"""
|
||
|
# no validation in init
|
||
|
est = _Estimator("wrong")
|
||
|
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError, match="The 'a' parameter of _Estimator must be"
|
||
|
):
|
||
|
est.fit()
|
||
|
|
||
|
|
||
|
def test_stroptions_deprecated_subset():
|
||
|
"""Check that the deprecated parameter must be a subset of options."""
|
||
|
with pytest.raises(ValueError, match="deprecated options must be a subset"):
|
||
|
StrOptions({"a", "b", "c"}, deprecated={"a", "d"})
|
||
|
|
||
|
|
||
|
def test_hidden_constraint():
|
||
|
"""Check that internal constraints are not exposed in the error message."""
|
||
|
|
||
|
@validate_params({"param": [Hidden(list), dict]})
|
||
|
def f(param):
|
||
|
pass
|
||
|
|
||
|
# list and dict are valid params
|
||
|
f({"a": 1, "b": 2, "c": 3})
|
||
|
f([1, 2, 3])
|
||
|
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError, match="The 'param' parameter"
|
||
|
) as exc_info:
|
||
|
f(param="bad")
|
||
|
|
||
|
# the list option is not exposed in the error message
|
||
|
err_msg = str(exc_info.value)
|
||
|
assert "an instance of 'dict'" in err_msg
|
||
|
assert "an instance of 'list'" not in err_msg
|
||
|
|
||
|
|
||
|
def test_hidden_stroptions():
|
||
|
"""Check that we can have 2 StrOptions constraints, one being hidden."""
|
||
|
|
||
|
@validate_params({"param": [StrOptions({"auto"}), Hidden(StrOptions({"warn"}))]})
|
||
|
def f(param):
|
||
|
pass
|
||
|
|
||
|
# "auto" and "warn" are valid params
|
||
|
f("auto")
|
||
|
f("warn")
|
||
|
|
||
|
with pytest.raises(
|
||
|
InvalidParameterError, match="The 'param' parameter"
|
||
|
) as exc_info:
|
||
|
f(param="bad")
|
||
|
|
||
|
# the "warn" option is not exposed in the error message
|
||
|
err_msg = str(exc_info.value)
|
||
|
assert "auto" in err_msg
|
||
|
assert "warn" not in err_msg
|
||
|
|
||
|
|
||
|
def test_validate_params_set_param_constraints_attribute():
|
||
|
"""Check that the validate_params decorator properly sets the parameter constraints
|
||
|
as attribute of the decorated function/method.
|
||
|
"""
|
||
|
assert hasattr(_func, "_skl_parameter_constraints")
|
||
|
assert hasattr(_Class()._method, "_skl_parameter_constraints")
|
||
|
|
||
|
|
||
|
def test_boolean_constraint_deprecated_int():
|
||
|
"""Check that validate_params raise a deprecation message but still passes
|
||
|
validation when using an int for a parameter accepting a boolean.
|
||
|
"""
|
||
|
|
||
|
@validate_params({"param": ["boolean"]})
|
||
|
def f(param):
|
||
|
pass
|
||
|
|
||
|
# True/False and np.bool_(True/False) are valid params
|
||
|
f(True)
|
||
|
f(np.bool_(False))
|
||
|
|
||
|
# an int is also valid but deprecated
|
||
|
with pytest.warns(
|
||
|
FutureWarning, match="Passing an int for a boolean parameter is deprecated"
|
||
|
):
|
||
|
f(1)
|
||
|
|
||
|
|
||
|
def test_no_validation():
|
||
|
"""Check that validation can be skipped for a parameter."""
|
||
|
|
||
|
@validate_params({"param1": [int, None], "param2": "no_validation"})
|
||
|
def f(param1=None, param2=None):
|
||
|
pass
|
||
|
|
||
|
# param1 is validated
|
||
|
with pytest.raises(InvalidParameterError, match="The 'param1' parameter"):
|
||
|
f(param1="wrong")
|
||
|
|
||
|
# param2 is not validated: any type is valid.
|
||
|
class SomeType:
|
||
|
pass
|
||
|
|
||
|
f(param2=SomeType)
|
||
|
f(param2=SomeType())
|
||
|
|
||
|
|
||
|
def test_pandas_na_constraint_with_pd_na():
|
||
|
"""Add a specific test for checking support for `pandas.NA`."""
|
||
|
pd = pytest.importorskip("pandas")
|
||
|
|
||
|
na_constraint = _PandasNAConstraint()
|
||
|
assert na_constraint.is_satisfied_by(pd.NA)
|
||
|
assert not na_constraint.is_satisfied_by(np.array([1, 2, 3]))
|
||
|
|
||
|
|
||
|
def test_iterable_not_string():
|
||
|
"""Check that a string does not satisfy the _IterableNotString constraint."""
|
||
|
constraint = _IterablesNotString()
|
||
|
assert constraint.is_satisfied_by([1, 2, 3])
|
||
|
assert constraint.is_satisfied_by(range(10))
|
||
|
assert not constraint.is_satisfied_by("some string")
|
||
|
|
||
|
|
||
|
def test_cv_objects():
|
||
|
"""Check that the _CVObjects constraint accepts all current ways
|
||
|
to pass cv objects."""
|
||
|
constraint = _CVObjects()
|
||
|
assert constraint.is_satisfied_by(5)
|
||
|
assert constraint.is_satisfied_by(LeaveOneOut())
|
||
|
assert constraint.is_satisfied_by([([1, 2], [3, 4]), ([3, 4], [1, 2])])
|
||
|
assert constraint.is_satisfied_by(None)
|
||
|
assert not constraint.is_satisfied_by("not a CV object")
|
||
|
|
||
|
|
||
|
def test_third_party_estimator():
|
||
|
"""Check that the validation from a scikit-learn estimator inherited by a third
|
||
|
party estimator does not impose a match between the dict of constraints and the
|
||
|
parameters of the estimator.
|
||
|
"""
|
||
|
|
||
|
class ThirdPartyEstimator(_Estimator):
|
||
|
def __init__(self, b):
|
||
|
self.b = b
|
||
|
super().__init__(a=0)
|
||
|
|
||
|
def fit(self, X=None, y=None):
|
||
|
super().fit(X, y)
|
||
|
|
||
|
# does not raise, even though "b" is not in the constraints dict and "a" is not
|
||
|
# a parameter of the estimator.
|
||
|
ThirdPartyEstimator(b=0).fit()
|
||
|
|
||
|
|
||
|
def test_interval_real_not_int():
|
||
|
"""Check for the type "real_not_int" in the Interval constraint."""
|
||
|
constraint = Interval("real_not_int", 0, 1, closed="both")
|
||
|
assert constraint.is_satisfied_by(1.0)
|
||
|
assert not constraint.is_satisfied_by(1)
|