projektAI/venv/Lib/site-packages/mlxtend/externals/estimator_checks.py
2021-06-06 22:13:05 +02:00

74 lines
2.6 KiB
Python

# Source: https://github.com/scikit-learn/scikit-learn
"""Utilities for input validation"""
# Authors: Olivier Grisel
# Gael Varoquaux
# Andreas Mueller
# Lars Buitinck
# Alexandre Gramfort
# Nicolas Tresegnie
# License: BSD 3 clause
class NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting.
This class inherits from both ValueError and AttributeError to help with
exception handling and backward compatibility.
Examples
--------
>>> from sklearn.svm import LinearSVC
>>> from sklearn.exceptions import NotFittedError
>>> try:
... LinearSVC().predict([[1, 2], [2, 3], [3, 4]])
... except NotFittedError as e:
... print(repr(e))
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
NotFittedError('This LinearSVC instance is not fitted yet',)
.. versionchanged:: 0.18
Moved from sklearn.utils.validation.
"""
def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
"""Perform is_fitted validation for estimator.
Checks if the estimator is fitted by verifying the presence of
"all_or_any" of the passed attributes and raises a NotFittedError with the
given message.
Parameters
----------
estimator : estimator instance.
estimator instance for which the check is performed.
attributes : attribute name(s) given as string or a list/tuple of strings
Eg.:
``["coef_", "estimator_", ...], "coef_"``
msg : string
The default error message is, "This %(name)s instance is not fitted
yet. Call 'fit' with appropriate arguments before using this method."
For custom messages if "%(name)s" is present in the message string,
it is substituted for the estimator name.
Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
all_or_any : callable, {all, any}, default all
Specify whether all or any of the given attributes must exist.
Returns
-------
None
Raises
------
NotFittedError
If the attributes are not found.
"""
if msg is None:
msg = ("This %(name)s instance is not fitted yet. Call 'fit' with "
"appropriate arguments before using this method.")
if not hasattr(estimator, 'fit'):
raise TypeError("%s is not an estimator instance." % (estimator))
if not isinstance(attributes, (list, tuple)):
attributes = [attributes]
if not all_or_any([hasattr(estimator, attr) for attr in attributes]):
raise NotFittedError(msg % {'name': type(estimator).__name__})