74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
|
|
||
|
# Source: https://github.com/scikit-learn/scikit-learn
|
||
|
|
||
|
"""Utilities for input validation"""
|
||
|
|
||
|
# Authors: Olivier Grisel
|
||
|
# Gael Varoquaux
|
||
|
# Andreas Mueller
|
||
|
# Lars Buitinck
|
||
|
# Alexandre Gramfort
|
||
|
# Nicolas Tresegnie
|
||
|
# License: BSD 3 clause
|
||
|
|
||
|
|
||
|
class NotFittedError(ValueError, AttributeError):
|
||
|
"""Exception class to raise if estimator is used before fitting.
|
||
|
This class inherits from both ValueError and AttributeError to help with
|
||
|
exception handling and backward compatibility.
|
||
|
Examples
|
||
|
--------
|
||
|
>>> from sklearn.svm import LinearSVC
|
||
|
>>> from sklearn.exceptions import NotFittedError
|
||
|
>>> try:
|
||
|
... LinearSVC().predict([[1, 2], [2, 3], [3, 4]])
|
||
|
... except NotFittedError as e:
|
||
|
... print(repr(e))
|
||
|
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
|
||
|
NotFittedError('This LinearSVC instance is not fitted yet',)
|
||
|
.. versionchanged:: 0.18
|
||
|
Moved from sklearn.utils.validation.
|
||
|
"""
|
||
|
|
||
|
|
||
|
def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
|
||
|
"""Perform is_fitted validation for estimator.
|
||
|
Checks if the estimator is fitted by verifying the presence of
|
||
|
"all_or_any" of the passed attributes and raises a NotFittedError with the
|
||
|
given message.
|
||
|
Parameters
|
||
|
----------
|
||
|
estimator : estimator instance.
|
||
|
estimator instance for which the check is performed.
|
||
|
attributes : attribute name(s) given as string or a list/tuple of strings
|
||
|
Eg.:
|
||
|
``["coef_", "estimator_", ...], "coef_"``
|
||
|
msg : string
|
||
|
The default error message is, "This %(name)s instance is not fitted
|
||
|
yet. Call 'fit' with appropriate arguments before using this method."
|
||
|
For custom messages if "%(name)s" is present in the message string,
|
||
|
it is substituted for the estimator name.
|
||
|
Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
|
||
|
all_or_any : callable, {all, any}, default all
|
||
|
Specify whether all or any of the given attributes must exist.
|
||
|
Returns
|
||
|
-------
|
||
|
None
|
||
|
Raises
|
||
|
------
|
||
|
NotFittedError
|
||
|
If the attributes are not found.
|
||
|
"""
|
||
|
if msg is None:
|
||
|
msg = ("This %(name)s instance is not fitted yet. Call 'fit' with "
|
||
|
"appropriate arguments before using this method.")
|
||
|
|
||
|
if not hasattr(estimator, 'fit'):
|
||
|
raise TypeError("%s is not an estimator instance." % (estimator))
|
||
|
|
||
|
if not isinstance(attributes, (list, tuple)):
|
||
|
attributes = [attributes]
|
||
|
|
||
|
if not all_or_any([hasattr(estimator, attr) for attr in attributes]):
|
||
|
raise NotFittedError(msg % {'name': type(estimator).__name__})
|