Intelegentny_Pszczelarz/.venv/Lib/site-packages/sklearn/inspection/_plot/tests/test_boundary_decision_display.py
2023-06-19 00:49:18 +02:00

356 lines
12 KiB
Python

import warnings
import pytest
import numpy as np
from numpy.testing import assert_allclose
from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.datasets import make_multilabel_classification
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree import DecisionTreeClassifier
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.inspection._plot.decision_boundary import _check_boundary_response_method
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
pytestmark = pytest.mark.filterwarnings(
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
"matplotlib.*"
)
X, y = make_classification(
n_informative=1,
n_redundant=1,
n_clusters_per_class=1,
n_features=2,
random_state=42,
)
@pytest.fixture(scope="module")
def fitted_clf():
return LogisticRegression().fit(X, y)
def test_input_data_dimension(pyplot):
"""Check that we raise an error when `X` does not have exactly 2 features."""
X, y = make_classification(n_samples=10, n_features=4, random_state=0)
clf = LogisticRegression().fit(X, y)
msg = "n_features must be equal to 2. Got 4 instead."
with pytest.raises(ValueError, match=msg):
DecisionBoundaryDisplay.from_estimator(estimator=clf, X=X)
def test_check_boundary_response_method_auto():
"""Check _check_boundary_response_method behavior with 'auto'."""
class A:
def decision_function(self):
pass
a_inst = A()
method = _check_boundary_response_method(a_inst, "auto")
assert method == a_inst.decision_function
class B:
def predict_proba(self):
pass
b_inst = B()
method = _check_boundary_response_method(b_inst, "auto")
assert method == b_inst.predict_proba
class C:
def predict_proba(self):
pass
def decision_function(self):
pass
c_inst = C()
method = _check_boundary_response_method(c_inst, "auto")
assert method == c_inst.decision_function
class D:
def predict(self):
pass
d_inst = D()
method = _check_boundary_response_method(d_inst, "auto")
assert method == d_inst.predict
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
def test_multiclass_error(pyplot, response_method):
"""Check multiclass errors."""
X, y = make_classification(n_classes=3, n_informative=3, random_state=0)
X = X[:, [0, 1]]
lr = LogisticRegression().fit(X, y)
msg = (
"Multiclass classifiers are only supported when response_method is 'predict' or"
" 'auto'"
)
with pytest.raises(ValueError, match=msg):
DecisionBoundaryDisplay.from_estimator(lr, X, response_method=response_method)
@pytest.mark.parametrize("response_method", ["auto", "predict"])
def test_multiclass(pyplot, response_method):
"""Check multiclass gives expected results."""
grid_resolution = 10
eps = 1.0
X, y = make_classification(n_classes=3, n_informative=3, random_state=0)
X = X[:, [0, 1]]
lr = LogisticRegression(random_state=0).fit(X, y)
disp = DecisionBoundaryDisplay.from_estimator(
lr, X, response_method=response_method, grid_resolution=grid_resolution, eps=1.0
)
x0_min, x0_max = X[:, 0].min() - eps, X[:, 0].max() + eps
x1_min, x1_max = X[:, 1].min() - eps, X[:, 1].max() + eps
xx0, xx1 = np.meshgrid(
np.linspace(x0_min, x0_max, grid_resolution),
np.linspace(x1_min, x1_max, grid_resolution),
)
response = lr.predict(np.c_[xx0.ravel(), xx1.ravel()])
assert_allclose(disp.response, response.reshape(xx0.shape))
assert_allclose(disp.xx0, xx0)
assert_allclose(disp.xx1, xx1)
@pytest.mark.parametrize(
"kwargs, error_msg",
[
(
{"plot_method": "hello_world"},
r"plot_method must be one of contourf, contour, pcolormesh. Got hello_world"
r" instead.",
),
(
{"grid_resolution": 1},
r"grid_resolution must be greater than 1. Got 1 instead",
),
(
{"grid_resolution": -1},
r"grid_resolution must be greater than 1. Got -1 instead",
),
({"eps": -1.1}, r"eps must be greater than or equal to 0. Got -1.1 instead"),
],
)
def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf):
"""Check input validation from_estimator."""
with pytest.raises(ValueError, match=error_msg):
DecisionBoundaryDisplay.from_estimator(fitted_clf, X, **kwargs)
def test_display_plot_input_error(pyplot, fitted_clf):
"""Check input validation for `plot`."""
disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, X, grid_resolution=5)
with pytest.raises(ValueError, match="plot_method must be 'contourf'"):
disp.plot(plot_method="hello_world")
@pytest.mark.parametrize(
"response_method", ["auto", "predict", "predict_proba", "decision_function"]
)
@pytest.mark.parametrize("plot_method", ["contourf", "contour"])
def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_method):
"""Check that decision boundary is correct."""
fig, ax = pyplot.subplots()
eps = 2.0
disp = DecisionBoundaryDisplay.from_estimator(
fitted_clf,
X,
grid_resolution=5,
response_method=response_method,
plot_method=plot_method,
eps=eps,
ax=ax,
)
assert isinstance(disp.surface_, pyplot.matplotlib.contour.QuadContourSet)
assert disp.ax_ == ax
assert disp.figure_ == fig
x0, x1 = X[:, 0], X[:, 1]
x0_min, x0_max = x0.min() - eps, x0.max() + eps
x1_min, x1_max = x1.min() - eps, x1.max() + eps
assert disp.xx0.min() == pytest.approx(x0_min)
assert disp.xx0.max() == pytest.approx(x0_max)
assert disp.xx1.min() == pytest.approx(x1_min)
assert disp.xx1.max() == pytest.approx(x1_max)
fig2, ax2 = pyplot.subplots()
# change plotting method for second plot
disp.plot(plot_method="pcolormesh", ax=ax2, shading="auto")
assert isinstance(disp.surface_, pyplot.matplotlib.collections.QuadMesh)
assert disp.ax_ == ax2
assert disp.figure_ == fig2
@pytest.mark.parametrize(
"response_method, msg",
[
(
"predict_proba",
"MyClassifier has none of the following attributes: predict_proba",
),
(
"decision_function",
"MyClassifier has none of the following attributes: decision_function",
),
(
"auto",
"MyClassifier has none of the following attributes: decision_function, "
"predict_proba, predict",
),
(
"bad_method",
"MyClassifier has none of the following attributes: bad_method",
),
],
)
def test_error_bad_response(pyplot, response_method, msg):
"""Check errors for bad response."""
class MyClassifier(BaseEstimator, ClassifierMixin):
def fit(self, X, y):
self.fitted_ = True
self.classes_ = [0, 1]
return self
clf = MyClassifier().fit(X, y)
with pytest.raises(ValueError, match=msg):
DecisionBoundaryDisplay.from_estimator(clf, X, response_method=response_method)
@pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"])
def test_multilabel_classifier_error(pyplot, response_method):
"""Check that multilabel classifier raises correct error."""
X, y = make_multilabel_classification(random_state=0)
X = X[:, :2]
tree = DecisionTreeClassifier().fit(X, y)
msg = "Multi-label and multi-output multi-class classifiers are not supported"
with pytest.raises(ValueError, match=msg):
DecisionBoundaryDisplay.from_estimator(
tree,
X,
response_method=response_method,
)
@pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"])
def test_multi_output_multi_class_classifier_error(pyplot, response_method):
"""Check that multi-output multi-class classifier raises correct error."""
X = np.asarray([[0, 1], [1, 2]])
y = np.asarray([["tree", "cat"], ["cat", "tree"]])
tree = DecisionTreeClassifier().fit(X, y)
msg = "Multi-label and multi-output multi-class classifiers are not supported"
with pytest.raises(ValueError, match=msg):
DecisionBoundaryDisplay.from_estimator(
tree,
X,
response_method=response_method,
)
def test_multioutput_regressor_error(pyplot):
"""Check that multioutput regressor raises correct error."""
X = np.asarray([[0, 1], [1, 2]])
y = np.asarray([[0, 1], [4, 1]])
tree = DecisionTreeRegressor().fit(X, y)
with pytest.raises(ValueError, match="Multi-output regressors are not supported"):
DecisionBoundaryDisplay.from_estimator(tree, X)
@pytest.mark.filterwarnings(
# We expect to raise the following warning because the classifier is fit on a
# NumPy array
"ignore:X has feature names, but LogisticRegression was fitted without"
)
def test_dataframe_labels_used(pyplot, fitted_clf):
"""Check that column names are used for pandas."""
pd = pytest.importorskip("pandas")
df = pd.DataFrame(X, columns=["col_x", "col_y"])
# pandas column names are used by default
_, ax = pyplot.subplots()
disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, df, ax=ax)
assert ax.get_xlabel() == "col_x"
assert ax.get_ylabel() == "col_y"
# second call to plot will have the names
fig, ax = pyplot.subplots()
disp.plot(ax=ax)
assert ax.get_xlabel() == "col_x"
assert ax.get_ylabel() == "col_y"
# axes with a label will not get overridden
fig, ax = pyplot.subplots()
ax.set(xlabel="hello", ylabel="world")
disp.plot(ax=ax)
assert ax.get_xlabel() == "hello"
assert ax.get_ylabel() == "world"
# labels get overriden only if provided to the `plot` method
disp.plot(ax=ax, xlabel="overwritten_x", ylabel="overwritten_y")
assert ax.get_xlabel() == "overwritten_x"
assert ax.get_ylabel() == "overwritten_y"
# labels do not get inferred if provided to `from_estimator`
_, ax = pyplot.subplots()
disp = DecisionBoundaryDisplay.from_estimator(
fitted_clf, df, ax=ax, xlabel="overwritten_x", ylabel="overwritten_y"
)
assert ax.get_xlabel() == "overwritten_x"
assert ax.get_ylabel() == "overwritten_y"
def test_string_target(pyplot):
"""Check that decision boundary works with classifiers trained on string labels."""
iris = load_iris()
X = iris.data[:, [0, 1]]
# Use strings as target
y = iris.target_names[iris.target]
log_reg = LogisticRegression().fit(X, y)
# Does not raise
DecisionBoundaryDisplay.from_estimator(
log_reg,
X,
grid_resolution=5,
response_method="predict",
)
def test_dataframe_support(pyplot):
"""Check that passing a dataframe at fit and to the Display does not
raise warnings.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/23311
"""
pd = pytest.importorskip("pandas")
df = pd.DataFrame(X, columns=["col_x", "col_y"])
estimator = LogisticRegression().fit(df, y)
with warnings.catch_warnings():
# no warnings linked to feature names validation should be raised
warnings.simplefilter("error", UserWarning)
DecisionBoundaryDisplay.from_estimator(estimator, df, response_method="predict")