""" The :mod:`sklearn.utils.discovery` module includes utilities to discover objects (i.e. estimators, displays, functions) from the `sklearn` package. """ import inspect import pkgutil from importlib import import_module from operator import itemgetter from pathlib import Path _MODULE_TO_IGNORE = { "tests", "externals", "setup", "conftest", "experimental", "estimator_checks", } def all_estimators(type_filter=None): """Get a list of all estimators from `sklearn`. This function crawls the module and gets all classes that inherit from BaseEstimator. Classes that are defined in test-modules are not included. Parameters ---------- type_filter : {"classifier", "regressor", "cluster", "transformer"} \ or list of such str, default=None Which kind of estimators should be returned. If None, no filter is applied and all estimators are returned. Possible values are 'classifier', 'regressor', 'cluster' and 'transformer' to get estimators only of these specific types, or a list of these to get the estimators that fit at least one of the types. Returns ------- estimators : list of tuples List of (name, class), where ``name`` is the class name as string and ``class`` is the actual type of the class. Examples -------- >>> from sklearn.utils.discovery import all_estimators >>> estimators = all_estimators() >>> type(estimators) >>> type(estimators[0]) >>> estimators[:2] [('ARDRegression', ), ('AdaBoostClassifier', )] >>> classifiers = all_estimators(type_filter="classifier") >>> classifiers[:2] [('AdaBoostClassifier', ), ('BaggingClassifier', )] >>> regressors = all_estimators(type_filter="regressor") >>> regressors[:2] [('ARDRegression', ), ('AdaBoostRegressor', )] >>> both = all_estimators(type_filter=["classifier", "regressor"]) >>> both[:2] [('ARDRegression', ), ('AdaBoostClassifier', )] """ # lazy import to avoid circular imports from sklearn.base from ..base import ( BaseEstimator, ClassifierMixin, ClusterMixin, RegressorMixin, TransformerMixin, ) from ._testing import ignore_warnings from .fixes import _IS_PYPY def is_abstract(c): if not (hasattr(c, "__abstractmethods__")): return False if not len(c.__abstractmethods__): return False return True all_classes = [] root = str(Path(__file__).parent.parent) # sklearn package # Ignore deprecation warnings triggered at import time and from walking # packages with ignore_warnings(category=FutureWarning): for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."): module_parts = module_name.split(".") if ( any(part in _MODULE_TO_IGNORE for part in module_parts) or "._" in module_name ): continue module = import_module(module_name) classes = inspect.getmembers(module, inspect.isclass) classes = [ (name, est_cls) for name, est_cls in classes if not name.startswith("_") ] # TODO: Remove when FeatureHasher is implemented in PYPY # Skips FeatureHasher for PYPY if _IS_PYPY and "feature_extraction" in module_name: classes = [ (name, est_cls) for name, est_cls in classes if name == "FeatureHasher" ] all_classes.extend(classes) all_classes = set(all_classes) estimators = [ c for c in all_classes if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator") ] # get rid of abstract base classes estimators = [c for c in estimators if not is_abstract(c[1])] if type_filter is not None: if not isinstance(type_filter, list): type_filter = [type_filter] else: type_filter = list(type_filter) # copy filtered_estimators = [] filters = { "classifier": ClassifierMixin, "regressor": RegressorMixin, "transformer": TransformerMixin, "cluster": ClusterMixin, } for name, mixin in filters.items(): if name in type_filter: type_filter.remove(name) filtered_estimators.extend( [est for est in estimators if issubclass(est[1], mixin)] ) estimators = filtered_estimators if type_filter: raise ValueError( "Parameter type_filter must be 'classifier', " "'regressor', 'transformer', 'cluster' or " "None, got" f" {repr(type_filter)}." ) # drop duplicates, sort for reproducibility # itemgetter is used to ensure the sort does not extend to the 2nd item of # the tuple return sorted(set(estimators), key=itemgetter(0)) def all_displays(): """Get a list of all displays from `sklearn`. Returns ------- displays : list of tuples List of (name, class), where ``name`` is the display class name as string and ``class`` is the actual type of the class. Examples -------- >>> from sklearn.utils.discovery import all_displays >>> displays = all_displays() >>> displays[0] ('CalibrationDisplay', ) """ # lazy import to avoid circular imports from sklearn.base from ._testing import ignore_warnings all_classes = [] root = str(Path(__file__).parent.parent) # sklearn package # Ignore deprecation warnings triggered at import time and from walking # packages with ignore_warnings(category=FutureWarning): for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."): module_parts = module_name.split(".") if ( any(part in _MODULE_TO_IGNORE for part in module_parts) or "._" in module_name ): continue module = import_module(module_name) classes = inspect.getmembers(module, inspect.isclass) classes = [ (name, display_class) for name, display_class in classes if not name.startswith("_") and name.endswith("Display") ] all_classes.extend(classes) return sorted(set(all_classes), key=itemgetter(0)) def _is_checked_function(item): if not inspect.isfunction(item): return False if item.__name__.startswith("_"): return False mod = item.__module__ if not mod.startswith("sklearn.") or mod.endswith("estimator_checks"): return False return True def all_functions(): """Get a list of all functions from `sklearn`. Returns ------- functions : list of tuples List of (name, function), where ``name`` is the function name as string and ``function`` is the actual function. Examples -------- >>> from sklearn.utils.discovery import all_functions >>> functions = all_functions() >>> name, function = functions[0] >>> name 'accuracy_score' """ # lazy import to avoid circular imports from sklearn.base from ._testing import ignore_warnings all_functions = [] root = str(Path(__file__).parent.parent) # sklearn package # Ignore deprecation warnings triggered at import time and from walking # packages with ignore_warnings(category=FutureWarning): for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."): module_parts = module_name.split(".") if ( any(part in _MODULE_TO_IGNORE for part in module_parts) or "._" in module_name ): continue module = import_module(module_name) functions = inspect.getmembers(module, _is_checked_function) functions = [ (func.__name__, func) for name, func in functions if not name.startswith("_") ] all_functions.extend(functions) # drop duplicates, sort for reproducibility # itemgetter is used to ensure the sort does not extend to the 2nd item of # the tuple return sorted(set(all_functions), key=itemgetter(0))