356 lines
12 KiB
Python
356 lines
12 KiB
Python
![]() |
import warnings
|
||
|
|
||
|
import pytest
|
||
|
import numpy as np
|
||
|
from numpy.testing import assert_allclose
|
||
|
|
||
|
from sklearn.base import BaseEstimator
|
||
|
from sklearn.base import ClassifierMixin
|
||
|
from sklearn.datasets import make_classification
|
||
|
from sklearn.linear_model import LogisticRegression
|
||
|
from sklearn.datasets import load_iris
|
||
|
from sklearn.datasets import make_multilabel_classification
|
||
|
from sklearn.tree import DecisionTreeRegressor
|
||
|
from sklearn.tree import DecisionTreeClassifier
|
||
|
|
||
|
from sklearn.inspection import DecisionBoundaryDisplay
|
||
|
from sklearn.inspection._plot.decision_boundary import _check_boundary_response_method
|
||
|
|
||
|
|
||
|
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||
|
pytestmark = pytest.mark.filterwarnings(
|
||
|
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||
|
"matplotlib.*"
|
||
|
)
|
||
|
|
||
|
|
||
|
X, y = make_classification(
|
||
|
n_informative=1,
|
||
|
n_redundant=1,
|
||
|
n_clusters_per_class=1,
|
||
|
n_features=2,
|
||
|
random_state=42,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def fitted_clf():
|
||
|
return LogisticRegression().fit(X, y)
|
||
|
|
||
|
|
||
|
def test_input_data_dimension(pyplot):
|
||
|
"""Check that we raise an error when `X` does not have exactly 2 features."""
|
||
|
X, y = make_classification(n_samples=10, n_features=4, random_state=0)
|
||
|
|
||
|
clf = LogisticRegression().fit(X, y)
|
||
|
msg = "n_features must be equal to 2. Got 4 instead."
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
DecisionBoundaryDisplay.from_estimator(estimator=clf, X=X)
|
||
|
|
||
|
|
||
|
def test_check_boundary_response_method_auto():
|
||
|
"""Check _check_boundary_response_method behavior with 'auto'."""
|
||
|
|
||
|
class A:
|
||
|
def decision_function(self):
|
||
|
pass
|
||
|
|
||
|
a_inst = A()
|
||
|
method = _check_boundary_response_method(a_inst, "auto")
|
||
|
assert method == a_inst.decision_function
|
||
|
|
||
|
class B:
|
||
|
def predict_proba(self):
|
||
|
pass
|
||
|
|
||
|
b_inst = B()
|
||
|
method = _check_boundary_response_method(b_inst, "auto")
|
||
|
assert method == b_inst.predict_proba
|
||
|
|
||
|
class C:
|
||
|
def predict_proba(self):
|
||
|
pass
|
||
|
|
||
|
def decision_function(self):
|
||
|
pass
|
||
|
|
||
|
c_inst = C()
|
||
|
method = _check_boundary_response_method(c_inst, "auto")
|
||
|
assert method == c_inst.decision_function
|
||
|
|
||
|
class D:
|
||
|
def predict(self):
|
||
|
pass
|
||
|
|
||
|
d_inst = D()
|
||
|
method = _check_boundary_response_method(d_inst, "auto")
|
||
|
assert method == d_inst.predict
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||
|
def test_multiclass_error(pyplot, response_method):
|
||
|
"""Check multiclass errors."""
|
||
|
X, y = make_classification(n_classes=3, n_informative=3, random_state=0)
|
||
|
X = X[:, [0, 1]]
|
||
|
lr = LogisticRegression().fit(X, y)
|
||
|
|
||
|
msg = (
|
||
|
"Multiclass classifiers are only supported when response_method is 'predict' or"
|
||
|
" 'auto'"
|
||
|
)
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
DecisionBoundaryDisplay.from_estimator(lr, X, response_method=response_method)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("response_method", ["auto", "predict"])
|
||
|
def test_multiclass(pyplot, response_method):
|
||
|
"""Check multiclass gives expected results."""
|
||
|
grid_resolution = 10
|
||
|
eps = 1.0
|
||
|
X, y = make_classification(n_classes=3, n_informative=3, random_state=0)
|
||
|
X = X[:, [0, 1]]
|
||
|
lr = LogisticRegression(random_state=0).fit(X, y)
|
||
|
|
||
|
disp = DecisionBoundaryDisplay.from_estimator(
|
||
|
lr, X, response_method=response_method, grid_resolution=grid_resolution, eps=1.0
|
||
|
)
|
||
|
|
||
|
x0_min, x0_max = X[:, 0].min() - eps, X[:, 0].max() + eps
|
||
|
x1_min, x1_max = X[:, 1].min() - eps, X[:, 1].max() + eps
|
||
|
xx0, xx1 = np.meshgrid(
|
||
|
np.linspace(x0_min, x0_max, grid_resolution),
|
||
|
np.linspace(x1_min, x1_max, grid_resolution),
|
||
|
)
|
||
|
response = lr.predict(np.c_[xx0.ravel(), xx1.ravel()])
|
||
|
assert_allclose(disp.response, response.reshape(xx0.shape))
|
||
|
assert_allclose(disp.xx0, xx0)
|
||
|
assert_allclose(disp.xx1, xx1)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"kwargs, error_msg",
|
||
|
[
|
||
|
(
|
||
|
{"plot_method": "hello_world"},
|
||
|
r"plot_method must be one of contourf, contour, pcolormesh. Got hello_world"
|
||
|
r" instead.",
|
||
|
),
|
||
|
(
|
||
|
{"grid_resolution": 1},
|
||
|
r"grid_resolution must be greater than 1. Got 1 instead",
|
||
|
),
|
||
|
(
|
||
|
{"grid_resolution": -1},
|
||
|
r"grid_resolution must be greater than 1. Got -1 instead",
|
||
|
),
|
||
|
({"eps": -1.1}, r"eps must be greater than or equal to 0. Got -1.1 instead"),
|
||
|
],
|
||
|
)
|
||
|
def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf):
|
||
|
"""Check input validation from_estimator."""
|
||
|
with pytest.raises(ValueError, match=error_msg):
|
||
|
DecisionBoundaryDisplay.from_estimator(fitted_clf, X, **kwargs)
|
||
|
|
||
|
|
||
|
def test_display_plot_input_error(pyplot, fitted_clf):
|
||
|
"""Check input validation for `plot`."""
|
||
|
disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, X, grid_resolution=5)
|
||
|
|
||
|
with pytest.raises(ValueError, match="plot_method must be 'contourf'"):
|
||
|
disp.plot(plot_method="hello_world")
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"response_method", ["auto", "predict", "predict_proba", "decision_function"]
|
||
|
)
|
||
|
@pytest.mark.parametrize("plot_method", ["contourf", "contour"])
|
||
|
def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_method):
|
||
|
"""Check that decision boundary is correct."""
|
||
|
fig, ax = pyplot.subplots()
|
||
|
eps = 2.0
|
||
|
disp = DecisionBoundaryDisplay.from_estimator(
|
||
|
fitted_clf,
|
||
|
X,
|
||
|
grid_resolution=5,
|
||
|
response_method=response_method,
|
||
|
plot_method=plot_method,
|
||
|
eps=eps,
|
||
|
ax=ax,
|
||
|
)
|
||
|
assert isinstance(disp.surface_, pyplot.matplotlib.contour.QuadContourSet)
|
||
|
assert disp.ax_ == ax
|
||
|
assert disp.figure_ == fig
|
||
|
|
||
|
x0, x1 = X[:, 0], X[:, 1]
|
||
|
|
||
|
x0_min, x0_max = x0.min() - eps, x0.max() + eps
|
||
|
x1_min, x1_max = x1.min() - eps, x1.max() + eps
|
||
|
|
||
|
assert disp.xx0.min() == pytest.approx(x0_min)
|
||
|
assert disp.xx0.max() == pytest.approx(x0_max)
|
||
|
assert disp.xx1.min() == pytest.approx(x1_min)
|
||
|
assert disp.xx1.max() == pytest.approx(x1_max)
|
||
|
|
||
|
fig2, ax2 = pyplot.subplots()
|
||
|
# change plotting method for second plot
|
||
|
disp.plot(plot_method="pcolormesh", ax=ax2, shading="auto")
|
||
|
assert isinstance(disp.surface_, pyplot.matplotlib.collections.QuadMesh)
|
||
|
assert disp.ax_ == ax2
|
||
|
assert disp.figure_ == fig2
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"response_method, msg",
|
||
|
[
|
||
|
(
|
||
|
"predict_proba",
|
||
|
"MyClassifier has none of the following attributes: predict_proba",
|
||
|
),
|
||
|
(
|
||
|
"decision_function",
|
||
|
"MyClassifier has none of the following attributes: decision_function",
|
||
|
),
|
||
|
(
|
||
|
"auto",
|
||
|
"MyClassifier has none of the following attributes: decision_function, "
|
||
|
"predict_proba, predict",
|
||
|
),
|
||
|
(
|
||
|
"bad_method",
|
||
|
"MyClassifier has none of the following attributes: bad_method",
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
def test_error_bad_response(pyplot, response_method, msg):
|
||
|
"""Check errors for bad response."""
|
||
|
|
||
|
class MyClassifier(BaseEstimator, ClassifierMixin):
|
||
|
def fit(self, X, y):
|
||
|
self.fitted_ = True
|
||
|
self.classes_ = [0, 1]
|
||
|
return self
|
||
|
|
||
|
clf = MyClassifier().fit(X, y)
|
||
|
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
DecisionBoundaryDisplay.from_estimator(clf, X, response_method=response_method)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"])
|
||
|
def test_multilabel_classifier_error(pyplot, response_method):
|
||
|
"""Check that multilabel classifier raises correct error."""
|
||
|
X, y = make_multilabel_classification(random_state=0)
|
||
|
X = X[:, :2]
|
||
|
tree = DecisionTreeClassifier().fit(X, y)
|
||
|
|
||
|
msg = "Multi-label and multi-output multi-class classifiers are not supported"
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
DecisionBoundaryDisplay.from_estimator(
|
||
|
tree,
|
||
|
X,
|
||
|
response_method=response_method,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"])
|
||
|
def test_multi_output_multi_class_classifier_error(pyplot, response_method):
|
||
|
"""Check that multi-output multi-class classifier raises correct error."""
|
||
|
X = np.asarray([[0, 1], [1, 2]])
|
||
|
y = np.asarray([["tree", "cat"], ["cat", "tree"]])
|
||
|
tree = DecisionTreeClassifier().fit(X, y)
|
||
|
|
||
|
msg = "Multi-label and multi-output multi-class classifiers are not supported"
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
DecisionBoundaryDisplay.from_estimator(
|
||
|
tree,
|
||
|
X,
|
||
|
response_method=response_method,
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_multioutput_regressor_error(pyplot):
|
||
|
"""Check that multioutput regressor raises correct error."""
|
||
|
X = np.asarray([[0, 1], [1, 2]])
|
||
|
y = np.asarray([[0, 1], [4, 1]])
|
||
|
tree = DecisionTreeRegressor().fit(X, y)
|
||
|
with pytest.raises(ValueError, match="Multi-output regressors are not supported"):
|
||
|
DecisionBoundaryDisplay.from_estimator(tree, X)
|
||
|
|
||
|
|
||
|
@pytest.mark.filterwarnings(
|
||
|
# We expect to raise the following warning because the classifier is fit on a
|
||
|
# NumPy array
|
||
|
"ignore:X has feature names, but LogisticRegression was fitted without"
|
||
|
)
|
||
|
def test_dataframe_labels_used(pyplot, fitted_clf):
|
||
|
"""Check that column names are used for pandas."""
|
||
|
pd = pytest.importorskip("pandas")
|
||
|
df = pd.DataFrame(X, columns=["col_x", "col_y"])
|
||
|
|
||
|
# pandas column names are used by default
|
||
|
_, ax = pyplot.subplots()
|
||
|
disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, df, ax=ax)
|
||
|
assert ax.get_xlabel() == "col_x"
|
||
|
assert ax.get_ylabel() == "col_y"
|
||
|
|
||
|
# second call to plot will have the names
|
||
|
fig, ax = pyplot.subplots()
|
||
|
disp.plot(ax=ax)
|
||
|
assert ax.get_xlabel() == "col_x"
|
||
|
assert ax.get_ylabel() == "col_y"
|
||
|
|
||
|
# axes with a label will not get overridden
|
||
|
fig, ax = pyplot.subplots()
|
||
|
ax.set(xlabel="hello", ylabel="world")
|
||
|
disp.plot(ax=ax)
|
||
|
assert ax.get_xlabel() == "hello"
|
||
|
assert ax.get_ylabel() == "world"
|
||
|
|
||
|
# labels get overriden only if provided to the `plot` method
|
||
|
disp.plot(ax=ax, xlabel="overwritten_x", ylabel="overwritten_y")
|
||
|
assert ax.get_xlabel() == "overwritten_x"
|
||
|
assert ax.get_ylabel() == "overwritten_y"
|
||
|
|
||
|
# labels do not get inferred if provided to `from_estimator`
|
||
|
_, ax = pyplot.subplots()
|
||
|
disp = DecisionBoundaryDisplay.from_estimator(
|
||
|
fitted_clf, df, ax=ax, xlabel="overwritten_x", ylabel="overwritten_y"
|
||
|
)
|
||
|
assert ax.get_xlabel() == "overwritten_x"
|
||
|
assert ax.get_ylabel() == "overwritten_y"
|
||
|
|
||
|
|
||
|
def test_string_target(pyplot):
|
||
|
"""Check that decision boundary works with classifiers trained on string labels."""
|
||
|
iris = load_iris()
|
||
|
X = iris.data[:, [0, 1]]
|
||
|
|
||
|
# Use strings as target
|
||
|
y = iris.target_names[iris.target]
|
||
|
log_reg = LogisticRegression().fit(X, y)
|
||
|
|
||
|
# Does not raise
|
||
|
DecisionBoundaryDisplay.from_estimator(
|
||
|
log_reg,
|
||
|
X,
|
||
|
grid_resolution=5,
|
||
|
response_method="predict",
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_dataframe_support(pyplot):
|
||
|
"""Check that passing a dataframe at fit and to the Display does not
|
||
|
raise warnings.
|
||
|
|
||
|
Non-regression test for:
|
||
|
https://github.com/scikit-learn/scikit-learn/issues/23311
|
||
|
"""
|
||
|
pd = pytest.importorskip("pandas")
|
||
|
df = pd.DataFrame(X, columns=["col_x", "col_y"])
|
||
|
estimator = LogisticRegression().fit(df, y)
|
||
|
|
||
|
with warnings.catch_warnings():
|
||
|
# no warnings linked to feature names validation should be raised
|
||
|
warnings.simplefilter("error", UserWarning)
|
||
|
DecisionBoundaryDisplay.from_estimator(estimator, df, response_method="predict")
|