146 lines
5.1 KiB
Python
146 lines
5.1 KiB
Python
|
from importlib import import_module
|
||
|
from inspect import signature
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from sklearn.utils._param_validation import generate_invalid_param_val
|
||
|
from sklearn.utils._param_validation import generate_valid_param
|
||
|
from sklearn.utils._param_validation import make_constraint
|
||
|
from sklearn.utils._param_validation import InvalidParameterError
|
||
|
|
||
|
|
||
|
def _get_func_info(func_module):
|
||
|
module_name, func_name = func_module.rsplit(".", 1)
|
||
|
module = import_module(module_name)
|
||
|
func = getattr(module, func_name)
|
||
|
|
||
|
func_sig = signature(func)
|
||
|
func_params = [
|
||
|
p.name
|
||
|
for p in func_sig.parameters.values()
|
||
|
if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
|
||
|
]
|
||
|
|
||
|
# The parameters `*args` and `**kwargs` are ignored since we cannot generate
|
||
|
# constraints.
|
||
|
required_params = [
|
||
|
p.name for p in func_sig.parameters.values() if p.default is p.empty
|
||
|
]
|
||
|
|
||
|
return func, func_name, func_params, required_params
|
||
|
|
||
|
|
||
|
def _check_function_param_validation(
|
||
|
func, func_name, func_params, required_params, parameter_constraints
|
||
|
):
|
||
|
"""Check that an informative error is raised when the value of a parameter does not
|
||
|
have an appropriate type or value.
|
||
|
"""
|
||
|
# generate valid values for the required parameters
|
||
|
valid_required_params = {}
|
||
|
for param_name in required_params:
|
||
|
if parameter_constraints[param_name] == "no_validation":
|
||
|
valid_required_params[param_name] = 1
|
||
|
else:
|
||
|
valid_required_params[param_name] = generate_valid_param(
|
||
|
make_constraint(parameter_constraints[param_name][0])
|
||
|
)
|
||
|
|
||
|
# check that there is a constraint for each parameter
|
||
|
if func_params:
|
||
|
validation_params = parameter_constraints.keys()
|
||
|
unexpected_params = set(validation_params) - set(func_params)
|
||
|
missing_params = set(func_params) - set(validation_params)
|
||
|
err_msg = (
|
||
|
"Mismatch between _parameter_constraints and the parameters of"
|
||
|
f" {func_name}.\nConsider the unexpected parameters {unexpected_params} and"
|
||
|
f" expected but missing parameters {missing_params}\n"
|
||
|
)
|
||
|
assert set(validation_params) == set(func_params), err_msg
|
||
|
|
||
|
# this object does not have a valid type for sure for all params
|
||
|
param_with_bad_type = type("BadType", (), {})()
|
||
|
|
||
|
for param_name in func_params:
|
||
|
constraints = parameter_constraints[param_name]
|
||
|
|
||
|
if constraints == "no_validation":
|
||
|
# This parameter is not validated
|
||
|
continue
|
||
|
|
||
|
match = (
|
||
|
rf"The '{param_name}' parameter of {func_name} must be .* Got .* instead."
|
||
|
)
|
||
|
|
||
|
# First, check that the error is raised if param doesn't match any valid type.
|
||
|
with pytest.raises(InvalidParameterError, match=match):
|
||
|
func(**{**valid_required_params, param_name: param_with_bad_type})
|
||
|
|
||
|
# Then, for constraints that are more than a type constraint, check that the
|
||
|
# error is raised if param does match a valid type but does not match any valid
|
||
|
# value for this type.
|
||
|
constraints = [make_constraint(constraint) for constraint in constraints]
|
||
|
|
||
|
for constraint in constraints:
|
||
|
try:
|
||
|
bad_value = generate_invalid_param_val(constraint)
|
||
|
except NotImplementedError:
|
||
|
continue
|
||
|
|
||
|
with pytest.raises(InvalidParameterError, match=match):
|
||
|
func(**{**valid_required_params, param_name: bad_value})
|
||
|
|
||
|
|
||
|
PARAM_VALIDATION_FUNCTION_LIST = [
|
||
|
"sklearn.cluster.kmeans_plusplus",
|
||
|
"sklearn.metrics.accuracy_score",
|
||
|
"sklearn.svm.l1_min_c",
|
||
|
]
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("func_module", PARAM_VALIDATION_FUNCTION_LIST)
|
||
|
def test_function_param_validation(func_module):
|
||
|
"""Check param validation for public functions that are not wrappers around
|
||
|
estimators.
|
||
|
"""
|
||
|
func, func_name, func_params, required_params = _get_func_info(func_module)
|
||
|
|
||
|
parameter_constraints = getattr(func, "_skl_parameter_constraints")
|
||
|
|
||
|
_check_function_param_validation(
|
||
|
func, func_name, func_params, required_params, parameter_constraints
|
||
|
)
|
||
|
|
||
|
|
||
|
PARAM_VALIDATION_CLASS_WRAPPER_LIST = [
|
||
|
("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"),
|
||
|
]
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"func_module, class_module", PARAM_VALIDATION_CLASS_WRAPPER_LIST
|
||
|
)
|
||
|
def test_class_wrapper_param_validation(func_module, class_module):
|
||
|
"""Check param validation for public functions that are wrappers around
|
||
|
estimators.
|
||
|
"""
|
||
|
func, func_name, func_params, required_params = _get_func_info(func_module)
|
||
|
|
||
|
module_name, class_name = class_module.rsplit(".", 1)
|
||
|
module = import_module(module_name)
|
||
|
klass = getattr(module, class_name)
|
||
|
|
||
|
parameter_constraints_func = getattr(func, "_skl_parameter_constraints")
|
||
|
parameter_constraints_class = getattr(klass, "_parameter_constraints")
|
||
|
parameter_constraints = {
|
||
|
**parameter_constraints_class,
|
||
|
**parameter_constraints_func,
|
||
|
}
|
||
|
parameter_constraints = {
|
||
|
k: v for k, v in parameter_constraints.items() if k in func_params
|
||
|
}
|
||
|
|
||
|
_check_function_param_validation(
|
||
|
func, func_name, func_params, required_params, parameter_constraints
|
||
|
)
|