726 lines
24 KiB
Python
726 lines
24 KiB
Python
![]() |
import unittest
|
||
|
import sys
|
||
|
|
||
|
import numpy as np
|
||
|
import scipy.sparse as sp
|
||
|
import joblib
|
||
|
|
||
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
||
|
from sklearn.utils import deprecated
|
||
|
from sklearn.utils._testing import (
|
||
|
assert_raises,
|
||
|
assert_raises_regex,
|
||
|
assert_warns,
|
||
|
ignore_warnings,
|
||
|
MinimalClassifier,
|
||
|
MinimalRegressor,
|
||
|
MinimalTransformer,
|
||
|
SkipTest,
|
||
|
)
|
||
|
from sklearn.utils.estimator_checks import check_estimator, _NotAnArray
|
||
|
from sklearn.utils.estimator_checks \
|
||
|
import check_class_weight_balanced_linear_classifier
|
||
|
from sklearn.utils.estimator_checks import set_random_state
|
||
|
from sklearn.utils.estimator_checks import _set_checking_parameters
|
||
|
from sklearn.utils.estimator_checks import check_estimators_unfitted
|
||
|
from sklearn.utils.estimator_checks import check_fit_score_takes_y
|
||
|
from sklearn.utils.estimator_checks import check_no_attributes_set_in_init
|
||
|
from sklearn.utils.estimator_checks import check_classifier_data_not_an_array
|
||
|
from sklearn.utils.estimator_checks import check_regressor_data_not_an_array
|
||
|
from sklearn.utils.estimator_checks import \
|
||
|
check_estimator_get_tags_default_keys
|
||
|
from sklearn.utils.validation import check_is_fitted
|
||
|
from sklearn.utils.estimator_checks import check_outlier_corruption
|
||
|
from sklearn.utils.fixes import np_version, parse_version
|
||
|
from sklearn.ensemble import RandomForestClassifier
|
||
|
from sklearn.linear_model import LinearRegression, SGDClassifier
|
||
|
from sklearn.mixture import GaussianMixture
|
||
|
from sklearn.cluster import MiniBatchKMeans
|
||
|
from sklearn.decomposition import NMF
|
||
|
from sklearn.linear_model import MultiTaskElasticNet, LogisticRegression
|
||
|
from sklearn.svm import SVC, NuSVC
|
||
|
from sklearn.neighbors import KNeighborsRegressor
|
||
|
from sklearn.utils.validation import check_array
|
||
|
from sklearn.utils import all_estimators
|
||
|
from sklearn.exceptions import SkipTestWarning
|
||
|
|
||
|
|
||
|
class CorrectNotFittedError(ValueError):
|
||
|
"""Exception class to raise if estimator is used before fitting.
|
||
|
|
||
|
Like NotFittedError, it inherits from ValueError, but not from
|
||
|
AttributeError. Used for testing only.
|
||
|
"""
|
||
|
|
||
|
|
||
|
class BaseBadClassifier(ClassifierMixin, BaseEstimator):
|
||
|
def fit(self, X, y):
|
||
|
return self
|
||
|
|
||
|
def predict(self, X):
|
||
|
return np.ones(X.shape[0])
|
||
|
|
||
|
|
||
|
class ChangesDict(BaseEstimator):
|
||
|
def __init__(self, key=0):
|
||
|
self.key = key
|
||
|
|
||
|
def fit(self, X, y=None):
|
||
|
X, y = self._validate_data(X, y)
|
||
|
return self
|
||
|
|
||
|
def predict(self, X):
|
||
|
X = check_array(X)
|
||
|
self.key = 1000
|
||
|
return np.ones(X.shape[0])
|
||
|
|
||
|
|
||
|
class SetsWrongAttribute(BaseEstimator):
|
||
|
def __init__(self, acceptable_key=0):
|
||
|
self.acceptable_key = acceptable_key
|
||
|
|
||
|
def fit(self, X, y=None):
|
||
|
self.wrong_attribute = 0
|
||
|
X, y = self._validate_data(X, y)
|
||
|
return self
|
||
|
|
||
|
|
||
|
class ChangesWrongAttribute(BaseEstimator):
|
||
|
def __init__(self, wrong_attribute=0):
|
||
|
self.wrong_attribute = wrong_attribute
|
||
|
|
||
|
def fit(self, X, y=None):
|
||
|
self.wrong_attribute = 1
|
||
|
X, y = self._validate_data(X, y)
|
||
|
return self
|
||
|
|
||
|
|
||
|
class ChangesUnderscoreAttribute(BaseEstimator):
|
||
|
def fit(self, X, y=None):
|
||
|
self._good_attribute = 1
|
||
|
X, y = self._validate_data(X, y)
|
||
|
return self
|
||
|
|
||
|
|
||
|
class RaisesErrorInSetParams(BaseEstimator):
|
||
|
def __init__(self, p=0):
|
||
|
self.p = p
|
||
|
|
||
|
def set_params(self, **kwargs):
|
||
|
if 'p' in kwargs:
|
||
|
p = kwargs.pop('p')
|
||
|
if p < 0:
|
||
|
raise ValueError("p can't be less than 0")
|
||
|
self.p = p
|
||
|
return super().set_params(**kwargs)
|
||
|
|
||
|
def fit(self, X, y=None):
|
||
|
X, y = self._validate_data(X, y)
|
||
|
return self
|
||
|
|
||
|
|
||
|
class HasMutableParameters(BaseEstimator):
|
||
|
def __init__(self, p=object()):
|
||
|
self.p = p
|
||
|
|
||
|
def fit(self, X, y=None):
|
||
|
X, y = self._validate_data(X, y)
|
||
|
return self
|
||
|
|
||
|
|
||
|
class HasImmutableParameters(BaseEstimator):
|
||
|
# Note that object is an uninitialized class, thus immutable.
|
||
|
def __init__(self, p=42, q=np.int32(42), r=object):
|
||
|
self.p = p
|
||
|
self.q = q
|
||
|
self.r = r
|
||
|
|
||
|
def fit(self, X, y=None):
|
||
|
X, y = self._validate_data(X, y)
|
||
|
return self
|
||
|
|
||
|
class ModifiesValueInsteadOfRaisingError(BaseEstimator):
|
||
|
def __init__(self, p=0):
|
||
|
self.p = p
|
||
|
|
||
|
def set_params(self, **kwargs):
|
||
|
if 'p' in kwargs:
|
||
|
p = kwargs.pop('p')
|
||
|
if p < 0:
|
||
|
p = 0
|
||
|
self.p = p
|
||
|
return super().set_params(**kwargs)
|
||
|
|
||
|
def fit(self, X, y=None):
|
||
|
X, y = self._validate_data(X, y)
|
||
|
return self
|
||
|
|
||
|
|
||
|
class ModifiesAnotherValue(BaseEstimator):
|
||
|
def __init__(self, a=0, b='method1'):
|
||
|
self.a = a
|
||
|
self.b = b
|
||
|
|
||
|
def set_params(self, **kwargs):
|
||
|
if 'a' in kwargs:
|
||
|
a = kwargs.pop('a')
|
||
|
self.a = a
|
||
|
if a is None:
|
||
|
kwargs.pop('b')
|
||
|
self.b = 'method2'
|
||
|
return super().set_params(**kwargs)
|
||
|
|
||
|
def fit(self, X, y=None):
|
||
|
X, y = self._validate_data(X, y)
|
||
|
return self
|
||
|
|
||
|
|
||
|
class NoCheckinPredict(BaseBadClassifier):
|
||
|
def fit(self, X, y):
|
||
|
X, y = self._validate_data(X, y)
|
||
|
return self
|
||
|
|
||
|
|
||
|
class NoSparseClassifier(BaseBadClassifier):
|
||
|
def fit(self, X, y):
|
||
|
X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'])
|
||
|
if sp.issparse(X):
|
||
|
raise ValueError("Nonsensical Error")
|
||
|
return self
|
||
|
|
||
|
def predict(self, X):
|
||
|
X = check_array(X)
|
||
|
return np.ones(X.shape[0])
|
||
|
|
||
|
|
||
|
class CorrectNotFittedErrorClassifier(BaseBadClassifier):
|
||
|
def fit(self, X, y):
|
||
|
X, y = self._validate_data(X, y)
|
||
|
self.coef_ = np.ones(X.shape[1])
|
||
|
return self
|
||
|
|
||
|
def predict(self, X):
|
||
|
check_is_fitted(self)
|
||
|
X = check_array(X)
|
||
|
return np.ones(X.shape[0])
|
||
|
|
||
|
|
||
|
class NoSampleWeightPandasSeriesType(BaseEstimator):
|
||
|
def fit(self, X, y, sample_weight=None):
|
||
|
# Convert data
|
||
|
X, y = self._validate_data(
|
||
|
X, y,
|
||
|
accept_sparse=("csr", "csc"),
|
||
|
multi_output=True,
|
||
|
y_numeric=True)
|
||
|
# Function is only called after we verify that pandas is installed
|
||
|
from pandas import Series
|
||
|
if isinstance(sample_weight, Series):
|
||
|
raise ValueError("Estimator does not accept 'sample_weight'"
|
||
|
"of type pandas.Series")
|
||
|
return self
|
||
|
|
||
|
def predict(self, X):
|
||
|
X = check_array(X)
|
||
|
return np.ones(X.shape[0])
|
||
|
|
||
|
|
||
|
class BadBalancedWeightsClassifier(BaseBadClassifier):
|
||
|
def __init__(self, class_weight=None):
|
||
|
self.class_weight = class_weight
|
||
|
|
||
|
def fit(self, X, y):
|
||
|
from sklearn.preprocessing import LabelEncoder
|
||
|
from sklearn.utils import compute_class_weight
|
||
|
|
||
|
label_encoder = LabelEncoder().fit(y)
|
||
|
classes = label_encoder.classes_
|
||
|
class_weight = compute_class_weight(self.class_weight, classes=classes,
|
||
|
y=y)
|
||
|
|
||
|
# Intentionally modify the balanced class_weight
|
||
|
# to simulate a bug and raise an exception
|
||
|
if self.class_weight == "balanced":
|
||
|
class_weight += 1.
|
||
|
|
||
|
# Simply assigning coef_ to the class_weight
|
||
|
self.coef_ = class_weight
|
||
|
return self
|
||
|
|
||
|
|
||
|
class BadTransformerWithoutMixin(BaseEstimator):
|
||
|
def fit(self, X, y=None):
|
||
|
X = self._validate_data(X)
|
||
|
return self
|
||
|
|
||
|
def transform(self, X):
|
||
|
X = check_array(X)
|
||
|
return X
|
||
|
|
||
|
|
||
|
class NotInvariantPredict(BaseEstimator):
|
||
|
def fit(self, X, y):
|
||
|
# Convert data
|
||
|
X, y = self._validate_data(
|
||
|
X, y,
|
||
|
accept_sparse=("csr", "csc"),
|
||
|
multi_output=True,
|
||
|
y_numeric=True)
|
||
|
return self
|
||
|
|
||
|
def predict(self, X):
|
||
|
# return 1 if X has more than one element else return 0
|
||
|
X = check_array(X)
|
||
|
if X.shape[0] > 1:
|
||
|
return np.ones(X.shape[0])
|
||
|
return np.zeros(X.shape[0])
|
||
|
|
||
|
|
||
|
class NotInvariantSampleOrder(BaseEstimator):
|
||
|
def fit(self, X, y):
|
||
|
X, y = self._validate_data(
|
||
|
X, y,
|
||
|
accept_sparse=("csr", "csc"),
|
||
|
multi_output=True,
|
||
|
y_numeric=True)
|
||
|
# store the original X to check for sample order later
|
||
|
self._X = X
|
||
|
return self
|
||
|
|
||
|
def predict(self, X):
|
||
|
X = check_array(X)
|
||
|
# if the input contains the same elements but different sample order,
|
||
|
# then just return zeros.
|
||
|
if (np.array_equiv(np.sort(X, axis=0), np.sort(self._X, axis=0)) and
|
||
|
(X != self._X).any()):
|
||
|
return np.zeros(X.shape[0])
|
||
|
return X[:, 0]
|
||
|
|
||
|
|
||
|
class LargeSparseNotSupportedClassifier(BaseEstimator):
|
||
|
def fit(self, X, y):
|
||
|
X, y = self._validate_data(
|
||
|
X, y,
|
||
|
accept_sparse=("csr", "csc", "coo"),
|
||
|
accept_large_sparse=True,
|
||
|
multi_output=True,
|
||
|
y_numeric=True)
|
||
|
if sp.issparse(X):
|
||
|
if X.getformat() == "coo":
|
||
|
if X.row.dtype == "int64" or X.col.dtype == "int64":
|
||
|
raise ValueError(
|
||
|
"Estimator doesn't support 64-bit indices")
|
||
|
elif X.getformat() in ["csc", "csr"]:
|
||
|
assert "int64" not in (X.indices.dtype, X.indptr.dtype),\
|
||
|
"Estimator doesn't support 64-bit indices"
|
||
|
|
||
|
return self
|
||
|
|
||
|
|
||
|
class SparseTransformer(BaseEstimator):
|
||
|
def fit(self, X, y=None):
|
||
|
self.X_shape_ = self._validate_data(X).shape
|
||
|
return self
|
||
|
|
||
|
def fit_transform(self, X, y=None):
|
||
|
return self.fit(X, y).transform(X)
|
||
|
|
||
|
def transform(self, X):
|
||
|
X = check_array(X)
|
||
|
if X.shape[1] != self.X_shape_[1]:
|
||
|
raise ValueError('Bad number of features')
|
||
|
return sp.csr_matrix(X)
|
||
|
|
||
|
|
||
|
class EstimatorInconsistentForPandas(BaseEstimator):
|
||
|
def fit(self, X, y):
|
||
|
try:
|
||
|
from pandas import DataFrame
|
||
|
if isinstance(X, DataFrame):
|
||
|
self.value_ = X.iloc[0, 0]
|
||
|
else:
|
||
|
X = check_array(X)
|
||
|
self.value_ = X[1, 0]
|
||
|
return self
|
||
|
|
||
|
except ImportError:
|
||
|
X = check_array(X)
|
||
|
self.value_ = X[1, 0]
|
||
|
return self
|
||
|
|
||
|
def predict(self, X):
|
||
|
X = check_array(X)
|
||
|
return np.array([self.value_] * X.shape[0])
|
||
|
|
||
|
|
||
|
class UntaggedBinaryClassifier(SGDClassifier):
|
||
|
# Toy classifier that only supports binary classification, will fail tests.
|
||
|
def fit(self, X, y, coef_init=None, intercept_init=None,
|
||
|
sample_weight=None):
|
||
|
super().fit(X, y, coef_init, intercept_init, sample_weight)
|
||
|
if len(self.classes_) > 2:
|
||
|
raise ValueError('Only 2 classes are supported')
|
||
|
return self
|
||
|
|
||
|
def partial_fit(self, X, y, classes=None, sample_weight=None):
|
||
|
super().partial_fit(X=X, y=y, classes=classes,
|
||
|
sample_weight=sample_weight)
|
||
|
if len(self.classes_) > 2:
|
||
|
raise ValueError('Only 2 classes are supported')
|
||
|
return self
|
||
|
|
||
|
|
||
|
class TaggedBinaryClassifier(UntaggedBinaryClassifier):
|
||
|
# Toy classifier that only supports binary classification.
|
||
|
def _more_tags(self):
|
||
|
return {'binary_only': True}
|
||
|
|
||
|
|
||
|
class EstimatorMissingDefaultTags(BaseEstimator):
|
||
|
def _get_tags(self):
|
||
|
tags = super()._get_tags().copy()
|
||
|
del tags["allow_nan"]
|
||
|
return tags
|
||
|
|
||
|
|
||
|
class RequiresPositiveYRegressor(LinearRegression):
|
||
|
|
||
|
def fit(self, X, y):
|
||
|
X, y = self._validate_data(X, y, multi_output=True)
|
||
|
if (y <= 0).any():
|
||
|
raise ValueError('negative y values not supported!')
|
||
|
return super().fit(X, y)
|
||
|
|
||
|
def _more_tags(self):
|
||
|
return {"requires_positive_y": True}
|
||
|
|
||
|
|
||
|
class PoorScoreLogisticRegression(LogisticRegression):
|
||
|
def decision_function(self, X):
|
||
|
return super().decision_function(X) + 1
|
||
|
|
||
|
def _more_tags(self):
|
||
|
return {"poor_score": True}
|
||
|
|
||
|
|
||
|
def test_not_an_array_array_function():
|
||
|
if np_version < parse_version('1.17'):
|
||
|
raise SkipTest("array_function protocol not supported in numpy <1.17")
|
||
|
not_array = _NotAnArray(np.ones(10))
|
||
|
msg = "Don't want to call array_function sum!"
|
||
|
assert_raises_regex(TypeError, msg, np.sum, not_array)
|
||
|
# always returns True
|
||
|
assert np.may_share_memory(not_array, None)
|
||
|
|
||
|
|
||
|
def test_check_fit_score_takes_y_works_on_deprecated_fit():
|
||
|
# Tests that check_fit_score_takes_y works on a class with
|
||
|
# a deprecated fit method
|
||
|
|
||
|
class TestEstimatorWithDeprecatedFitMethod(BaseEstimator):
|
||
|
@deprecated("Deprecated for the purpose of testing "
|
||
|
"check_fit_score_takes_y")
|
||
|
def fit(self, X, y):
|
||
|
return self
|
||
|
|
||
|
check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod())
|
||
|
|
||
|
|
||
|
def test_check_estimator():
|
||
|
# tests that the estimator actually fails on "bad" estimators.
|
||
|
# not a complete test of all checks, which are very extensive.
|
||
|
|
||
|
# check that we have a set_params and can clone
|
||
|
msg = "Passing a class was deprecated"
|
||
|
assert_raises_regex(TypeError, msg, check_estimator, object)
|
||
|
msg = (
|
||
|
"Parameter 'p' of estimator 'HasMutableParameters' is of type "
|
||
|
"object which is not allowed"
|
||
|
)
|
||
|
# check that the "default_constructible" test checks for mutable parameters
|
||
|
check_estimator(HasImmutableParameters()) # should pass
|
||
|
assert_raises_regex(
|
||
|
AssertionError, msg, check_estimator, HasMutableParameters()
|
||
|
)
|
||
|
# check that values returned by get_params match set_params
|
||
|
msg = "get_params result does not match what was passed to set_params"
|
||
|
assert_raises_regex(AssertionError, msg, check_estimator,
|
||
|
ModifiesValueInsteadOfRaisingError())
|
||
|
assert_warns(UserWarning, check_estimator, RaisesErrorInSetParams())
|
||
|
assert_raises_regex(AssertionError, msg, check_estimator,
|
||
|
ModifiesAnotherValue())
|
||
|
# check that we have a fit method
|
||
|
msg = "object has no attribute 'fit'"
|
||
|
assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator())
|
||
|
# check that fit does input validation
|
||
|
msg = "Did not raise"
|
||
|
assert_raises_regex(AssertionError, msg, check_estimator,
|
||
|
BaseBadClassifier())
|
||
|
# check that sample_weights in fit accepts pandas.Series type
|
||
|
try:
|
||
|
from pandas import Series # noqa
|
||
|
msg = ("Estimator NoSampleWeightPandasSeriesType raises error if "
|
||
|
"'sample_weight' parameter is of type pandas.Series")
|
||
|
assert_raises_regex(
|
||
|
ValueError, msg, check_estimator, NoSampleWeightPandasSeriesType())
|
||
|
except ImportError:
|
||
|
pass
|
||
|
# check that predict does input validation (doesn't accept dicts in input)
|
||
|
msg = "Estimator doesn't check for NaN and inf in predict"
|
||
|
assert_raises_regex(AssertionError, msg, check_estimator,
|
||
|
NoCheckinPredict())
|
||
|
# check that estimator state does not change
|
||
|
# at transform/predict/predict_proba time
|
||
|
msg = 'Estimator changes __dict__ during predict'
|
||
|
assert_raises_regex(AssertionError, msg, check_estimator, ChangesDict())
|
||
|
# check that `fit` only changes attribures that
|
||
|
# are private (start with an _ or end with a _).
|
||
|
msg = ('Estimator ChangesWrongAttribute should not change or mutate '
|
||
|
'the parameter wrong_attribute from 0 to 1 during fit.')
|
||
|
assert_raises_regex(AssertionError, msg,
|
||
|
check_estimator, ChangesWrongAttribute())
|
||
|
check_estimator(ChangesUnderscoreAttribute())
|
||
|
# check that `fit` doesn't add any public attribute
|
||
|
msg = (r'Estimator adds public attribute\(s\) during the fit method.'
|
||
|
' Estimators are only allowed to add private attributes'
|
||
|
' either started with _ or ended'
|
||
|
' with _ but wrong_attribute added')
|
||
|
assert_raises_regex(AssertionError, msg,
|
||
|
check_estimator, SetsWrongAttribute())
|
||
|
# check for sample order invariance
|
||
|
name = NotInvariantSampleOrder.__name__
|
||
|
method = 'predict'
|
||
|
msg = ("{method} of {name} is not invariant when applied to a dataset"
|
||
|
"with different sample order.").format(method=method, name=name)
|
||
|
assert_raises_regex(AssertionError, msg,
|
||
|
check_estimator, NotInvariantSampleOrder())
|
||
|
# check for invariant method
|
||
|
name = NotInvariantPredict.__name__
|
||
|
method = 'predict'
|
||
|
msg = ("{method} of {name} is not invariant when applied "
|
||
|
"to a subset.").format(method=method, name=name)
|
||
|
assert_raises_regex(AssertionError, msg,
|
||
|
check_estimator, NotInvariantPredict())
|
||
|
# check for sparse matrix input handling
|
||
|
name = NoSparseClassifier.__name__
|
||
|
msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name
|
||
|
assert_raises_regex(
|
||
|
AssertionError, msg, check_estimator, NoSparseClassifier()
|
||
|
)
|
||
|
|
||
|
# Large indices test on bad estimator
|
||
|
msg = ('Estimator LargeSparseNotSupportedClassifier doesn\'t seem to '
|
||
|
r'support \S{3}_64 matrix, and is not failing gracefully.*')
|
||
|
assert_raises_regex(AssertionError, msg, check_estimator,
|
||
|
LargeSparseNotSupportedClassifier())
|
||
|
|
||
|
# does error on binary_only untagged estimator
|
||
|
msg = 'Only 2 classes are supported'
|
||
|
assert_raises_regex(ValueError, msg, check_estimator,
|
||
|
UntaggedBinaryClassifier())
|
||
|
|
||
|
# non-regression test for estimators transforming to sparse data
|
||
|
check_estimator(SparseTransformer())
|
||
|
|
||
|
# doesn't error on actual estimator
|
||
|
check_estimator(LogisticRegression())
|
||
|
check_estimator(LogisticRegression(C=0.01))
|
||
|
check_estimator(MultiTaskElasticNet())
|
||
|
|
||
|
# doesn't error on binary_only tagged estimator
|
||
|
check_estimator(TaggedBinaryClassifier())
|
||
|
|
||
|
# Check regressor with requires_positive_y estimator tag
|
||
|
msg = 'negative y values not supported!'
|
||
|
assert_raises_regex(ValueError, msg, check_estimator,
|
||
|
RequiresPositiveYRegressor())
|
||
|
|
||
|
# Does not raise error on classifier with poor_score tag
|
||
|
check_estimator(PoorScoreLogisticRegression())
|
||
|
|
||
|
|
||
|
def test_check_outlier_corruption():
|
||
|
# should raise AssertionError
|
||
|
decision = np.array([0., 1., 1.5, 2.])
|
||
|
assert_raises(AssertionError, check_outlier_corruption, 1, 2, decision)
|
||
|
# should pass
|
||
|
decision = np.array([0., 1., 1., 2.])
|
||
|
check_outlier_corruption(1, 2, decision)
|
||
|
|
||
|
|
||
|
def test_check_estimator_transformer_no_mixin():
|
||
|
# check that TransformerMixin is not required for transformer tests to run
|
||
|
assert_raises_regex(AttributeError, '.*fit_transform.*',
|
||
|
check_estimator, BadTransformerWithoutMixin())
|
||
|
|
||
|
|
||
|
def test_check_estimator_clones():
|
||
|
# check that check_estimator doesn't modify the estimator it receives
|
||
|
from sklearn.datasets import load_iris
|
||
|
iris = load_iris()
|
||
|
|
||
|
for Estimator in [GaussianMixture, LinearRegression,
|
||
|
RandomForestClassifier, NMF, SGDClassifier,
|
||
|
MiniBatchKMeans]:
|
||
|
with ignore_warnings(category=FutureWarning):
|
||
|
# when 'est = SGDClassifier()'
|
||
|
est = Estimator()
|
||
|
_set_checking_parameters(est)
|
||
|
set_random_state(est)
|
||
|
# without fitting
|
||
|
old_hash = joblib.hash(est)
|
||
|
check_estimator(est)
|
||
|
assert old_hash == joblib.hash(est)
|
||
|
|
||
|
with ignore_warnings(category=FutureWarning):
|
||
|
# when 'est = SGDClassifier()'
|
||
|
est = Estimator()
|
||
|
_set_checking_parameters(est)
|
||
|
set_random_state(est)
|
||
|
# with fitting
|
||
|
est.fit(iris.data + 10, iris.target)
|
||
|
old_hash = joblib.hash(est)
|
||
|
check_estimator(est)
|
||
|
assert old_hash == joblib.hash(est)
|
||
|
|
||
|
|
||
|
def test_check_estimators_unfitted():
|
||
|
# check that a ValueError/AttributeError is raised when calling predict
|
||
|
# on an unfitted estimator
|
||
|
msg = "Did not raise"
|
||
|
assert_raises_regex(AssertionError, msg, check_estimators_unfitted,
|
||
|
"estimator", NoSparseClassifier())
|
||
|
|
||
|
# check that CorrectNotFittedError inherit from either ValueError
|
||
|
# or AttributeError
|
||
|
check_estimators_unfitted("estimator", CorrectNotFittedErrorClassifier())
|
||
|
|
||
|
|
||
|
def test_check_no_attributes_set_in_init():
|
||
|
class NonConformantEstimatorPrivateSet(BaseEstimator):
|
||
|
def __init__(self):
|
||
|
self.you_should_not_set_this_ = None
|
||
|
|
||
|
class NonConformantEstimatorNoParamSet(BaseEstimator):
|
||
|
def __init__(self, you_should_set_this_=None):
|
||
|
pass
|
||
|
|
||
|
assert_raises_regex(AssertionError,
|
||
|
"Estimator estimator_name should not set any"
|
||
|
" attribute apart from parameters during init."
|
||
|
r" Found attributes \['you_should_not_set_this_'\].",
|
||
|
check_no_attributes_set_in_init,
|
||
|
'estimator_name',
|
||
|
NonConformantEstimatorPrivateSet())
|
||
|
assert_raises_regex(AttributeError,
|
||
|
"Estimator estimator_name should store all "
|
||
|
"parameters as an attribute during init.",
|
||
|
check_no_attributes_set_in_init,
|
||
|
'estimator_name',
|
||
|
NonConformantEstimatorNoParamSet())
|
||
|
|
||
|
|
||
|
def test_check_estimator_pairwise():
|
||
|
# check that check_estimator() works on estimator with _pairwise
|
||
|
# kernel or metric
|
||
|
|
||
|
# test precomputed kernel
|
||
|
est = SVC(kernel='precomputed')
|
||
|
check_estimator(est)
|
||
|
|
||
|
# test precomputed metric
|
||
|
est = KNeighborsRegressor(metric='precomputed')
|
||
|
check_estimator(est)
|
||
|
|
||
|
|
||
|
def test_check_classifier_data_not_an_array():
|
||
|
assert_raises_regex(AssertionError,
|
||
|
'Not equal to tolerance',
|
||
|
check_classifier_data_not_an_array,
|
||
|
'estimator_name',
|
||
|
EstimatorInconsistentForPandas())
|
||
|
|
||
|
|
||
|
def test_check_regressor_data_not_an_array():
|
||
|
assert_raises_regex(AssertionError,
|
||
|
'Not equal to tolerance',
|
||
|
check_regressor_data_not_an_array,
|
||
|
'estimator_name',
|
||
|
EstimatorInconsistentForPandas())
|
||
|
|
||
|
|
||
|
def test_check_estimator_get_tags_default_keys():
|
||
|
estimator = EstimatorMissingDefaultTags()
|
||
|
err_msg = (r"EstimatorMissingDefaultTags._get_tags\(\) is missing entries"
|
||
|
r" for the following default tags: {'allow_nan'}")
|
||
|
assert_raises_regex(
|
||
|
AssertionError,
|
||
|
err_msg,
|
||
|
check_estimator_get_tags_default_keys,
|
||
|
estimator.__class__.__name__,
|
||
|
estimator,
|
||
|
)
|
||
|
|
||
|
# noop check when _get_tags is not available
|
||
|
estimator = MinimalTransformer()
|
||
|
check_estimator_get_tags_default_keys(
|
||
|
estimator.__class__.__name__, estimator
|
||
|
)
|
||
|
|
||
|
|
||
|
def run_tests_without_pytest():
|
||
|
"""Runs the tests in this file without using pytest.
|
||
|
"""
|
||
|
main_module = sys.modules['__main__']
|
||
|
test_functions = [getattr(main_module, name) for name in dir(main_module)
|
||
|
if name.startswith('test_')]
|
||
|
test_cases = [unittest.FunctionTestCase(fn) for fn in test_functions]
|
||
|
suite = unittest.TestSuite()
|
||
|
suite.addTests(test_cases)
|
||
|
runner = unittest.TextTestRunner()
|
||
|
runner.run(suite)
|
||
|
|
||
|
|
||
|
def test_check_class_weight_balanced_linear_classifier():
|
||
|
# check that ill-computed balanced weights raises an exception
|
||
|
assert_raises_regex(AssertionError,
|
||
|
"Classifier estimator_name is not computing"
|
||
|
" class_weight=balanced properly.",
|
||
|
check_class_weight_balanced_linear_classifier,
|
||
|
'estimator_name',
|
||
|
BadBalancedWeightsClassifier)
|
||
|
|
||
|
|
||
|
def test_all_estimators_all_public():
|
||
|
# all_estimator should not fail when pytest is not installed and return
|
||
|
# only public estimators
|
||
|
estimators = all_estimators()
|
||
|
for est in estimators:
|
||
|
assert not est.__class__.__name__.startswith("_")
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# This module is run as a script to check that we have no dependency on
|
||
|
# pytest for estimator checks.
|
||
|
run_tests_without_pytest()
|
||
|
|
||
|
|
||
|
def test_xfail_ignored_in_check_estimator():
|
||
|
# Make sure checks marked as xfail are just ignored and not run by
|
||
|
# check_estimator(), but still raise a warning.
|
||
|
assert_warns(SkipTestWarning, check_estimator, NuSVC())
|
||
|
|
||
|
|
||
|
# FIXME: this test should be uncommented when the checks will be granular
|
||
|
# enough. In 0.24, these tests fail due to low estimator performance.
|
||
|
def test_minimal_class_implementation_checks():
|
||
|
# Check that third-party library can run tests without inheriting from
|
||
|
# BaseEstimator.
|
||
|
# FIXME
|
||
|
raise SkipTest
|
||
|
minimal_estimators = [
|
||
|
MinimalTransformer(), MinimalRegressor(), MinimalClassifier()
|
||
|
]
|
||
|
for estimator in minimal_estimators:
|
||
|
check_estimator(estimator)
|