# Source: https://github.com/scikit-learn/scikit-learn """Utilities for input validation""" # Authors: Olivier Grisel # Gael Varoquaux # Andreas Mueller # Lars Buitinck # Alexandre Gramfort # Nicolas Tresegnie # License: BSD 3 clause class NotFittedError(ValueError, AttributeError): """Exception class to raise if estimator is used before fitting. This class inherits from both ValueError and AttributeError to help with exception handling and backward compatibility. Examples -------- >>> from sklearn.svm import LinearSVC >>> from sklearn.exceptions import NotFittedError >>> try: ... LinearSVC().predict([[1, 2], [2, 3], [3, 4]]) ... except NotFittedError as e: ... print(repr(e)) ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS NotFittedError('This LinearSVC instance is not fitted yet',) .. versionchanged:: 0.18 Moved from sklearn.utils.validation. """ def check_is_fitted(estimator, attributes, msg=None, all_or_any=all): """Perform is_fitted validation for estimator. Checks if the estimator is fitted by verifying the presence of "all_or_any" of the passed attributes and raises a NotFittedError with the given message. Parameters ---------- estimator : estimator instance. estimator instance for which the check is performed. attributes : attribute name(s) given as string or a list/tuple of strings Eg.: ``["coef_", "estimator_", ...], "coef_"`` msg : string The default error message is, "This %(name)s instance is not fitted yet. Call 'fit' with appropriate arguments before using this method." For custom messages if "%(name)s" is present in the message string, it is substituted for the estimator name. Eg. : "Estimator, %(name)s, must be fitted before sparsifying". all_or_any : callable, {all, any}, default all Specify whether all or any of the given attributes must exist. Returns ------- None Raises ------ NotFittedError If the attributes are not found. """ if msg is None: msg = ("This %(name)s instance is not fitted yet. Call 'fit' with " "appropriate arguments before using this method.") if not hasattr(estimator, 'fit'): raise TypeError("%s is not an estimator instance." % (estimator)) if not isinstance(attributes, (list, tuple)): attributes = [attributes] if not all_or_any([hasattr(estimator, attr) for attr in attributes]): raise NotFittedError(msg % {'name': type(estimator).__name__})