362 lines
12 KiB
Python
362 lines
12 KiB
Python
|
# Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
||
|
# Raghav RV <rvraghav93@gmail.com>
|
||
|
# License: BSD 3 clause
|
||
|
|
||
|
import inspect
|
||
|
import warnings
|
||
|
import importlib
|
||
|
|
||
|
from pkgutil import walk_packages
|
||
|
from inspect import signature
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
# make it possible to discover experimental estimators when calling `all_estimators`
|
||
|
from sklearn.experimental import enable_iterative_imputer # noqa
|
||
|
from sklearn.experimental import enable_halving_search_cv # noqa
|
||
|
|
||
|
import sklearn
|
||
|
from sklearn.utils import IS_PYPY
|
||
|
from sklearn.utils._testing import check_docstring_parameters
|
||
|
from sklearn.utils._testing import _get_func_name
|
||
|
from sklearn.utils._testing import ignore_warnings
|
||
|
from sklearn.utils import all_estimators
|
||
|
from sklearn.utils.estimator_checks import _enforce_estimator_tags_y
|
||
|
from sklearn.utils.estimator_checks import _enforce_estimator_tags_X
|
||
|
from sklearn.utils.estimator_checks import _construct_instance
|
||
|
from sklearn.utils.fixes import sp_version, parse_version
|
||
|
from sklearn.utils.deprecation import _is_deprecated
|
||
|
from sklearn.datasets import make_classification
|
||
|
from sklearn.linear_model import LogisticRegression
|
||
|
from sklearn.preprocessing import FunctionTransformer
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
|
||
|
# walk_packages() ignores DeprecationWarnings, now we need to ignore
|
||
|
# FutureWarnings
|
||
|
with warnings.catch_warnings():
|
||
|
warnings.simplefilter("ignore", FutureWarning)
|
||
|
# mypy error: Module has no attribute "__path__"
|
||
|
sklearn_path = sklearn.__path__ # type: ignore # mypy issue #1422
|
||
|
PUBLIC_MODULES = set(
|
||
|
[
|
||
|
pckg[1]
|
||
|
for pckg in walk_packages(prefix="sklearn.", path=sklearn_path)
|
||
|
if not ("._" in pckg[1] or ".tests." in pckg[1])
|
||
|
]
|
||
|
)
|
||
|
|
||
|
# functions to ignore args / docstring of
|
||
|
_DOCSTRING_IGNORES = [
|
||
|
"sklearn.utils.deprecation.load_mlcomp",
|
||
|
"sklearn.pipeline.make_pipeline",
|
||
|
"sklearn.pipeline.make_union",
|
||
|
"sklearn.utils.extmath.safe_sparse_dot",
|
||
|
"sklearn.utils._joblib",
|
||
|
]
|
||
|
|
||
|
# Methods where y param should be ignored if y=None by default
|
||
|
_METHODS_IGNORE_NONE_Y = [
|
||
|
"fit",
|
||
|
"score",
|
||
|
"fit_predict",
|
||
|
"fit_transform",
|
||
|
"partial_fit",
|
||
|
"predict",
|
||
|
]
|
||
|
|
||
|
|
||
|
# numpydoc 0.8.0's docscrape tool raises because of collections.abc under
|
||
|
# Python 3.7
|
||
|
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||
|
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
||
|
@pytest.mark.skipif(IS_PYPY, reason="test segfaults on PyPy")
|
||
|
def test_docstring_parameters():
|
||
|
# Test module docstring formatting
|
||
|
|
||
|
# Skip test if numpydoc is not found
|
||
|
pytest.importorskip(
|
||
|
"numpydoc", reason="numpydoc is required to test the docstrings"
|
||
|
)
|
||
|
|
||
|
# XXX unreached code as of v0.22
|
||
|
from numpydoc import docscrape
|
||
|
|
||
|
incorrect = []
|
||
|
for name in PUBLIC_MODULES:
|
||
|
if name.endswith(".conftest"):
|
||
|
# pytest tooling, not part of the scikit-learn API
|
||
|
continue
|
||
|
if name == "sklearn.utils.fixes":
|
||
|
# We cannot always control these docstrings
|
||
|
continue
|
||
|
with warnings.catch_warnings(record=True):
|
||
|
module = importlib.import_module(name)
|
||
|
classes = inspect.getmembers(module, inspect.isclass)
|
||
|
# Exclude non-scikit-learn classes
|
||
|
classes = [cls for cls in classes if cls[1].__module__.startswith("sklearn")]
|
||
|
for cname, cls in classes:
|
||
|
this_incorrect = []
|
||
|
if cname in _DOCSTRING_IGNORES or cname.startswith("_"):
|
||
|
continue
|
||
|
if inspect.isabstract(cls):
|
||
|
continue
|
||
|
with warnings.catch_warnings(record=True) as w:
|
||
|
cdoc = docscrape.ClassDoc(cls)
|
||
|
if len(w):
|
||
|
raise RuntimeError(
|
||
|
"Error for __init__ of %s in %s:\n%s" % (cls, name, w[0])
|
||
|
)
|
||
|
|
||
|
cls_init = getattr(cls, "__init__", None)
|
||
|
|
||
|
if _is_deprecated(cls_init):
|
||
|
continue
|
||
|
elif cls_init is not None:
|
||
|
this_incorrect += check_docstring_parameters(cls.__init__, cdoc)
|
||
|
|
||
|
for method_name in cdoc.methods:
|
||
|
method = getattr(cls, method_name)
|
||
|
if _is_deprecated(method):
|
||
|
continue
|
||
|
param_ignore = None
|
||
|
# Now skip docstring test for y when y is None
|
||
|
# by default for API reason
|
||
|
if method_name in _METHODS_IGNORE_NONE_Y:
|
||
|
sig = signature(method)
|
||
|
if "y" in sig.parameters and sig.parameters["y"].default is None:
|
||
|
param_ignore = ["y"] # ignore y for fit and score
|
||
|
result = check_docstring_parameters(method, ignore=param_ignore)
|
||
|
this_incorrect += result
|
||
|
|
||
|
incorrect += this_incorrect
|
||
|
|
||
|
functions = inspect.getmembers(module, inspect.isfunction)
|
||
|
# Exclude imported functions
|
||
|
functions = [fn for fn in functions if fn[1].__module__ == name]
|
||
|
for fname, func in functions:
|
||
|
# Don't test private methods / functions
|
||
|
if fname.startswith("_"):
|
||
|
continue
|
||
|
if fname == "configuration" and name.endswith("setup"):
|
||
|
continue
|
||
|
name_ = _get_func_name(func)
|
||
|
if not any(d in name_ for d in _DOCSTRING_IGNORES) and not _is_deprecated(
|
||
|
func
|
||
|
):
|
||
|
incorrect += check_docstring_parameters(func)
|
||
|
|
||
|
msg = "\n".join(incorrect)
|
||
|
if len(incorrect) > 0:
|
||
|
raise AssertionError("Docstring Error:\n" + msg)
|
||
|
|
||
|
|
||
|
@ignore_warnings(category=FutureWarning)
|
||
|
def test_tabs():
|
||
|
# Test that there are no tabs in our source files
|
||
|
for importer, modname, ispkg in walk_packages(sklearn.__path__, prefix="sklearn."):
|
||
|
|
||
|
if IS_PYPY and (
|
||
|
"_svmlight_format_io" in modname
|
||
|
or "feature_extraction._hashing_fast" in modname
|
||
|
):
|
||
|
continue
|
||
|
|
||
|
# because we don't import
|
||
|
mod = importlib.import_module(modname)
|
||
|
|
||
|
try:
|
||
|
source = inspect.getsource(mod)
|
||
|
except IOError: # user probably should have run "make clean"
|
||
|
continue
|
||
|
assert "\t" not in source, (
|
||
|
'"%s" has tabs, please remove them ',
|
||
|
"or add it to the ignore list" % modname,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _construct_searchcv_instance(SearchCV):
|
||
|
return SearchCV(LogisticRegression(), {"C": [0.1, 1]})
|
||
|
|
||
|
|
||
|
def _construct_compose_pipeline_instance(Estimator):
|
||
|
# Minimal / degenerate instances: only useful to test the docstrings.
|
||
|
if Estimator.__name__ == "ColumnTransformer":
|
||
|
return Estimator(transformers=[("transformer", "passthrough", [0, 1])])
|
||
|
elif Estimator.__name__ == "Pipeline":
|
||
|
return Estimator(steps=[("clf", LogisticRegression())])
|
||
|
elif Estimator.__name__ == "FeatureUnion":
|
||
|
return Estimator(transformer_list=[("transformer", FunctionTransformer())])
|
||
|
|
||
|
|
||
|
def _construct_sparse_coder(Estimator):
|
||
|
# XXX: hard-coded assumption that n_features=3
|
||
|
dictionary = np.array(
|
||
|
[[0, 1, 0], [-1, -1, 2], [1, 1, 1], [0, 1, 1], [0, 2, 1]],
|
||
|
dtype=np.float64,
|
||
|
)
|
||
|
return Estimator(dictionary=dictionary)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("name, Estimator", all_estimators())
|
||
|
def test_fit_docstring_attributes(name, Estimator):
|
||
|
pytest.importorskip("numpydoc")
|
||
|
from numpydoc import docscrape
|
||
|
|
||
|
doc = docscrape.ClassDoc(Estimator)
|
||
|
attributes = doc["Attributes"]
|
||
|
|
||
|
if Estimator.__name__ in (
|
||
|
"HalvingRandomSearchCV",
|
||
|
"RandomizedSearchCV",
|
||
|
"HalvingGridSearchCV",
|
||
|
"GridSearchCV",
|
||
|
):
|
||
|
est = _construct_searchcv_instance(Estimator)
|
||
|
elif Estimator.__name__ in (
|
||
|
"ColumnTransformer",
|
||
|
"Pipeline",
|
||
|
"FeatureUnion",
|
||
|
):
|
||
|
est = _construct_compose_pipeline_instance(Estimator)
|
||
|
elif Estimator.__name__ == "SparseCoder":
|
||
|
est = _construct_sparse_coder(Estimator)
|
||
|
else:
|
||
|
est = _construct_instance(Estimator)
|
||
|
|
||
|
if Estimator.__name__ == "SelectKBest":
|
||
|
est.set_params(k=2)
|
||
|
elif Estimator.__name__ == "DummyClassifier":
|
||
|
est.set_params(strategy="stratified")
|
||
|
elif Estimator.__name__ == "CCA" or Estimator.__name__.startswith("PLS"):
|
||
|
# default = 2 is invalid for single target
|
||
|
est.set_params(n_components=1)
|
||
|
elif Estimator.__name__ in (
|
||
|
"GaussianRandomProjection",
|
||
|
"SparseRandomProjection",
|
||
|
):
|
||
|
# default="auto" raises an error with the shape of `X`
|
||
|
est.set_params(n_components=2)
|
||
|
elif Estimator.__name__ == "TSNE":
|
||
|
# default raises an error, perplexity must be less than n_samples
|
||
|
est.set_params(perplexity=2)
|
||
|
|
||
|
# FIXME: TO BE REMOVED for 1.3 (avoid FutureWarning)
|
||
|
if Estimator.__name__ == "SequentialFeatureSelector":
|
||
|
est.set_params(n_features_to_select="auto")
|
||
|
|
||
|
# FIXME: TO BE REMOVED for 1.3 (avoid FutureWarning)
|
||
|
if Estimator.__name__ == "FastICA":
|
||
|
est.set_params(whiten="unit-variance")
|
||
|
|
||
|
# FIXME: TO BE REMOVED for 1.3 (avoid FutureWarning)
|
||
|
if Estimator.__name__ == "MiniBatchDictionaryLearning":
|
||
|
est.set_params(batch_size=5)
|
||
|
|
||
|
# TODO(1.4): TO BE REMOVED for 1.4 (avoid FutureWarning)
|
||
|
if Estimator.__name__ in ("KMeans", "MiniBatchKMeans"):
|
||
|
est.set_params(n_init="auto")
|
||
|
|
||
|
# TODO(1.4): TO BE REMOVED for 1.4 (avoid FutureWarning)
|
||
|
if Estimator.__name__ in (
|
||
|
"MultinomialNB",
|
||
|
"ComplementNB",
|
||
|
"BernoulliNB",
|
||
|
"CategoricalNB",
|
||
|
):
|
||
|
est.set_params(force_alpha=True)
|
||
|
|
||
|
if Estimator.__name__ == "QuantileRegressor":
|
||
|
solver = "highs" if sp_version >= parse_version("1.6.0") else "interior-point"
|
||
|
est.set_params(solver=solver)
|
||
|
|
||
|
# TODO(1.4): TO BE REMOVED for 1.4 (avoid FutureWarning)
|
||
|
if Estimator.__name__ == "MDS":
|
||
|
est.set_params(normalized_stress="auto")
|
||
|
|
||
|
# In case we want to deprecate some attributes in the future
|
||
|
skipped_attributes = {}
|
||
|
|
||
|
if Estimator.__name__.endswith("Vectorizer"):
|
||
|
# Vectorizer require some specific input data
|
||
|
if Estimator.__name__ in (
|
||
|
"CountVectorizer",
|
||
|
"HashingVectorizer",
|
||
|
"TfidfVectorizer",
|
||
|
):
|
||
|
X = [
|
||
|
"This is the first document.",
|
||
|
"This document is the second document.",
|
||
|
"And this is the third one.",
|
||
|
"Is this the first document?",
|
||
|
]
|
||
|
elif Estimator.__name__ == "DictVectorizer":
|
||
|
X = [{"foo": 1, "bar": 2}, {"foo": 3, "baz": 1}]
|
||
|
y = None
|
||
|
else:
|
||
|
X, y = make_classification(
|
||
|
n_samples=20,
|
||
|
n_features=3,
|
||
|
n_redundant=0,
|
||
|
n_classes=2,
|
||
|
random_state=2,
|
||
|
)
|
||
|
|
||
|
y = _enforce_estimator_tags_y(est, y)
|
||
|
X = _enforce_estimator_tags_X(est, X)
|
||
|
|
||
|
if "1dlabels" in est._get_tags()["X_types"]:
|
||
|
est.fit(y)
|
||
|
elif "2dlabels" in est._get_tags()["X_types"]:
|
||
|
est.fit(np.c_[y, y])
|
||
|
else:
|
||
|
est.fit(X, y)
|
||
|
|
||
|
for attr in attributes:
|
||
|
if attr.name in skipped_attributes:
|
||
|
continue
|
||
|
desc = " ".join(attr.desc).lower()
|
||
|
# As certain attributes are present "only" if a certain parameter is
|
||
|
# provided, this checks if the word "only" is present in the attribute
|
||
|
# description, and if not the attribute is required to be present.
|
||
|
if "only " in desc:
|
||
|
continue
|
||
|
# ignore deprecation warnings
|
||
|
with ignore_warnings(category=FutureWarning):
|
||
|
assert hasattr(est, attr.name)
|
||
|
|
||
|
fit_attr = _get_all_fitted_attributes(est)
|
||
|
fit_attr_names = [attr.name for attr in attributes]
|
||
|
undocumented_attrs = set(fit_attr).difference(fit_attr_names)
|
||
|
undocumented_attrs = set(undocumented_attrs).difference(skipped_attributes)
|
||
|
if undocumented_attrs:
|
||
|
raise AssertionError(
|
||
|
f"Undocumented attributes for {Estimator.__name__}: {undocumented_attrs}"
|
||
|
)
|
||
|
|
||
|
|
||
|
def _get_all_fitted_attributes(estimator):
|
||
|
"Get all the fitted attributes of an estimator including properties"
|
||
|
# attributes
|
||
|
fit_attr = list(estimator.__dict__.keys())
|
||
|
|
||
|
# properties
|
||
|
with warnings.catch_warnings():
|
||
|
warnings.filterwarnings("error", category=FutureWarning)
|
||
|
|
||
|
for name in dir(estimator.__class__):
|
||
|
obj = getattr(estimator.__class__, name)
|
||
|
if not isinstance(obj, property):
|
||
|
continue
|
||
|
|
||
|
# ignore properties that raises an AttributeError and deprecated
|
||
|
# properties
|
||
|
try:
|
||
|
getattr(estimator, name)
|
||
|
except (AttributeError, FutureWarning):
|
||
|
continue
|
||
|
fit_attr.append(name)
|
||
|
|
||
|
return [k for k in fit_attr if k.endswith("_") and not k.startswith("_")]
|