
404 lines
16 KiB
Raw Permalink Normal View History

2024-05-26 19:49:15 +02:00
from importlib import import_module
from inspect import signature
from numbers import Integral, Real
import pytest
from sklearn.utils._param_validation import (
def _get_func_info(func_module):
module_name, func_name = func_module.rsplit(".", 1)
module = import_module(module_name)
func = getattr(module, func_name)
func_sig = signature(func)
func_params = [
for p in func_sig.parameters.values()
if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
# The parameters `*args` and `**kwargs` are ignored since we cannot generate
# constraints.
required_params = [
for p in func_sig.parameters.values()
if p.default is p.empty and p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
return func, func_name, func_params, required_params
def _check_function_param_validation(
func, func_name, func_params, required_params, parameter_constraints
"""Check that an informative error is raised when the value of a parameter does not
have an appropriate type or value.
# generate valid values for the required parameters
valid_required_params = {}
for param_name in required_params:
if parameter_constraints[param_name] == "no_validation":
valid_required_params[param_name] = 1
valid_required_params[param_name] = generate_valid_param(
# check that there is a constraint for each parameter
if func_params:
validation_params = parameter_constraints.keys()
unexpected_params = set(validation_params) - set(func_params)
missing_params = set(func_params) - set(validation_params)
err_msg = (
"Mismatch between _parameter_constraints and the parameters of"
f" {func_name}.\nConsider the unexpected parameters {unexpected_params} and"
f" expected but missing parameters {missing_params}\n"
assert set(validation_params) == set(func_params), err_msg
# this object does not have a valid type for sure for all params
param_with_bad_type = type("BadType", (), {})()
for param_name in func_params:
constraints = parameter_constraints[param_name]
if constraints == "no_validation":
# This parameter is not validated
# Mixing an interval of reals and an interval of integers must be avoided.
if any(
isinstance(constraint, Interval) and constraint.type == Integral
for constraint in constraints
) and any(
isinstance(constraint, Interval) and constraint.type == Real
for constraint in constraints
raise ValueError(
f"The constraint for parameter {param_name} of {func_name} can't have a"
" mix of intervals of Integral and Real types. Use the type"
" RealNotInt instead of Real."
match = (
rf"The '{param_name}' parameter of {func_name} must be .* Got .* instead."
err_msg = (
f"{func_name} does not raise an informative error message when the "
f"parameter {param_name} does not have a valid type. If any Python type "
"is valid, the constraint should be 'no_validation'."
# First, check that the error is raised if param doesn't match any valid type.
with pytest.raises(InvalidParameterError, match=match):
func(**{**valid_required_params, param_name: param_with_bad_type})
# Then, for constraints that are more than a type constraint, check that the
# error is raised if param does match a valid type but does not match any valid
# value for this type.
constraints = [make_constraint(constraint) for constraint in constraints]
for constraint in constraints:
bad_value = generate_invalid_param_val(constraint)
except NotImplementedError:
err_msg = (
f"{func_name} does not raise an informative error message when the "
f"parameter {param_name} does not have a valid value.\n"
"Constraints should be disjoint. For instance "
"[StrOptions({'a_string'}), str] is not a acceptable set of "
"constraint because generating an invalid string for the first "
"constraint will always produce a valid string for the second "
with pytest.raises(InvalidParameterError, match=match):
func(**{**valid_required_params, param_name: bad_value})
@pytest.mark.parametrize("func_module", PARAM_VALIDATION_FUNCTION_LIST)
def test_function_param_validation(func_module):
"""Check param validation for public functions that are not wrappers around
func, func_name, func_params, required_params = _get_func_info(func_module)
parameter_constraints = getattr(func, "_skl_parameter_constraints")
func, func_name, func_params, required_params, parameter_constraints
("sklearn.cluster.affinity_propagation", "sklearn.cluster.AffinityPropagation"),
("sklearn.cluster.dbscan", "sklearn.cluster.DBSCAN"),
("sklearn.cluster.k_means", "sklearn.cluster.KMeans"),
("sklearn.cluster.mean_shift", "sklearn.cluster.MeanShift"),
("sklearn.cluster.spectral_clustering", "sklearn.cluster.SpectralClustering"),
("sklearn.covariance.graphical_lasso", "sklearn.covariance.GraphicalLasso"),
("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"),
("sklearn.covariance.oas", "sklearn.covariance.OAS"),
("sklearn.decomposition.dict_learning", "sklearn.decomposition.DictionaryLearning"),
("sklearn.decomposition.fastica", "sklearn.decomposition.FastICA"),
("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"),
("sklearn.preprocessing.maxabs_scale", "sklearn.preprocessing.MaxAbsScaler"),
("sklearn.preprocessing.minmax_scale", "sklearn.preprocessing.MinMaxScaler"),
("sklearn.preprocessing.power_transform", "sklearn.preprocessing.PowerTransformer"),
("sklearn.preprocessing.robust_scale", "sklearn.preprocessing.RobustScaler"),
"func_module, class_module", PARAM_VALIDATION_CLASS_WRAPPER_LIST
def test_class_wrapper_param_validation(func_module, class_module):
"""Check param validation for public functions that are wrappers around
func, func_name, func_params, required_params = _get_func_info(func_module)
module_name, class_name = class_module.rsplit(".", 1)
module = import_module(module_name)
klass = getattr(module, class_name)
parameter_constraints_func = getattr(func, "_skl_parameter_constraints")
parameter_constraints_class = getattr(klass, "_parameter_constraints")
parameter_constraints = {
parameter_constraints = {
k: v for k, v in parameter_constraints.items() if k in func_params
func, func_name, func_params, required_params, parameter_constraints