337 lines
12 KiB
Python
337 lines
12 KiB
Python
![]() |
import pytest
|
||
|
import numpy as np
|
||
|
from numpy.testing import assert_allclose
|
||
|
from numpy.testing import assert_array_equal
|
||
|
|
||
|
from sklearn.compose import make_column_transformer
|
||
|
from sklearn.datasets import make_classification
|
||
|
from sklearn.exceptions import NotFittedError
|
||
|
from sklearn.linear_model import LogisticRegression
|
||
|
from sklearn.pipeline import make_pipeline
|
||
|
from sklearn.preprocessing import StandardScaler
|
||
|
from sklearn.svm import SVC, SVR
|
||
|
|
||
|
from sklearn.metrics import confusion_matrix
|
||
|
from sklearn.metrics import plot_confusion_matrix
|
||
|
from sklearn.metrics import ConfusionMatrixDisplay
|
||
|
|
||
|
|
||
|
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||
|
pytestmark = pytest.mark.filterwarnings(
|
||
|
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||
|
"matplotlib.*")
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def n_classes():
|
||
|
return 5
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def data(n_classes):
|
||
|
X, y = make_classification(n_samples=100, n_informative=5,
|
||
|
n_classes=n_classes, random_state=0)
|
||
|
return X, y
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def fitted_clf(data):
|
||
|
return SVC(kernel='linear', C=0.01).fit(*data)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def y_pred(data, fitted_clf):
|
||
|
X, _ = data
|
||
|
return fitted_clf.predict(X)
|
||
|
|
||
|
|
||
|
def test_error_on_regressor(pyplot, data):
|
||
|
X, y = data
|
||
|
est = SVR().fit(X, y)
|
||
|
|
||
|
msg = "plot_confusion_matrix only supports classifiers"
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
plot_confusion_matrix(est, X, y)
|
||
|
|
||
|
|
||
|
def test_error_on_invalid_option(pyplot, fitted_clf, data):
|
||
|
X, y = data
|
||
|
msg = (r"normalize must be one of \{'true', 'pred', 'all', "
|
||
|
r"None\}")
|
||
|
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
plot_confusion_matrix(fitted_clf, X, y, normalize='invalid')
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("with_labels", [True, False])
|
||
|
@pytest.mark.parametrize("with_display_labels", [True, False])
|
||
|
def test_plot_confusion_matrix_custom_labels(pyplot, data, y_pred, fitted_clf,
|
||
|
n_classes, with_labels,
|
||
|
with_display_labels):
|
||
|
X, y = data
|
||
|
ax = pyplot.gca()
|
||
|
labels = [2, 1, 0, 3, 4] if with_labels else None
|
||
|
display_labels = ['b', 'd', 'a', 'e', 'f'] if with_display_labels else None
|
||
|
|
||
|
cm = confusion_matrix(y, y_pred, labels=labels)
|
||
|
disp = plot_confusion_matrix(fitted_clf, X, y,
|
||
|
ax=ax, display_labels=display_labels,
|
||
|
labels=labels)
|
||
|
|
||
|
assert_allclose(disp.confusion_matrix, cm)
|
||
|
|
||
|
if with_display_labels:
|
||
|
expected_display_labels = display_labels
|
||
|
elif with_labels:
|
||
|
expected_display_labels = labels
|
||
|
else:
|
||
|
expected_display_labels = list(range(n_classes))
|
||
|
|
||
|
expected_display_labels_str = [str(name)
|
||
|
for name in expected_display_labels]
|
||
|
|
||
|
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||
|
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||
|
|
||
|
assert_array_equal(disp.display_labels, expected_display_labels)
|
||
|
assert_array_equal(x_ticks, expected_display_labels_str)
|
||
|
assert_array_equal(y_ticks, expected_display_labels_str)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None])
|
||
|
@pytest.mark.parametrize("include_values", [True, False])
|
||
|
def test_plot_confusion_matrix(pyplot, data, y_pred, n_classes, fitted_clf,
|
||
|
normalize, include_values):
|
||
|
X, y = data
|
||
|
ax = pyplot.gca()
|
||
|
cmap = 'plasma'
|
||
|
cm = confusion_matrix(y, y_pred)
|
||
|
disp = plot_confusion_matrix(fitted_clf, X, y,
|
||
|
normalize=normalize,
|
||
|
cmap=cmap, ax=ax,
|
||
|
include_values=include_values)
|
||
|
|
||
|
assert disp.ax_ == ax
|
||
|
|
||
|
if normalize == 'true':
|
||
|
cm = cm / cm.sum(axis=1, keepdims=True)
|
||
|
elif normalize == 'pred':
|
||
|
cm = cm / cm.sum(axis=0, keepdims=True)
|
||
|
elif normalize == 'all':
|
||
|
cm = cm / cm.sum()
|
||
|
|
||
|
assert_allclose(disp.confusion_matrix, cm)
|
||
|
import matplotlib as mpl
|
||
|
assert isinstance(disp.im_, mpl.image.AxesImage)
|
||
|
assert disp.im_.get_cmap().name == cmap
|
||
|
assert isinstance(disp.ax_, pyplot.Axes)
|
||
|
assert isinstance(disp.figure_, pyplot.Figure)
|
||
|
|
||
|
assert disp.ax_.get_ylabel() == "True label"
|
||
|
assert disp.ax_.get_xlabel() == "Predicted label"
|
||
|
|
||
|
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||
|
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||
|
|
||
|
expected_display_labels = list(range(n_classes))
|
||
|
|
||
|
expected_display_labels_str = [str(name)
|
||
|
for name in expected_display_labels]
|
||
|
|
||
|
assert_array_equal(disp.display_labels, expected_display_labels)
|
||
|
assert_array_equal(x_ticks, expected_display_labels_str)
|
||
|
assert_array_equal(y_ticks, expected_display_labels_str)
|
||
|
|
||
|
image_data = disp.im_.get_array().data
|
||
|
assert_allclose(image_data, cm)
|
||
|
|
||
|
if include_values:
|
||
|
assert disp.text_.shape == (n_classes, n_classes)
|
||
|
fmt = '.2g'
|
||
|
expected_text = np.array([format(v, fmt) for v in cm.ravel(order="C")])
|
||
|
text_text = np.array([
|
||
|
t.get_text() for t in disp.text_.ravel(order="C")])
|
||
|
assert_array_equal(expected_text, text_text)
|
||
|
else:
|
||
|
assert disp.text_ is None
|
||
|
|
||
|
|
||
|
def test_confusion_matrix_display(pyplot, data, fitted_clf, y_pred, n_classes):
|
||
|
X, y = data
|
||
|
|
||
|
cm = confusion_matrix(y, y_pred)
|
||
|
disp = plot_confusion_matrix(fitted_clf, X, y, normalize=None,
|
||
|
include_values=True, cmap='viridis',
|
||
|
xticks_rotation=45.0)
|
||
|
|
||
|
assert_allclose(disp.confusion_matrix, cm)
|
||
|
assert disp.text_.shape == (n_classes, n_classes)
|
||
|
|
||
|
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
|
||
|
assert_allclose(rotations, 45.0)
|
||
|
|
||
|
image_data = disp.im_.get_array().data
|
||
|
assert_allclose(image_data, cm)
|
||
|
|
||
|
disp.plot(cmap='plasma')
|
||
|
assert disp.im_.get_cmap().name == 'plasma'
|
||
|
|
||
|
disp.plot(include_values=False)
|
||
|
assert disp.text_ is None
|
||
|
|
||
|
disp.plot(xticks_rotation=90.0)
|
||
|
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
|
||
|
assert_allclose(rotations, 90.0)
|
||
|
|
||
|
disp.plot(values_format='e')
|
||
|
expected_text = np.array([format(v, 'e') for v in cm.ravel(order="C")])
|
||
|
text_text = np.array([
|
||
|
t.get_text() for t in disp.text_.ravel(order="C")])
|
||
|
assert_array_equal(expected_text, text_text)
|
||
|
|
||
|
|
||
|
def test_confusion_matrix_contrast(pyplot):
|
||
|
# make sure text color is appropriate depending on background
|
||
|
|
||
|
cm = np.eye(2) / 2
|
||
|
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
|
||
|
|
||
|
disp.plot(cmap=pyplot.cm.gray)
|
||
|
# diagonal text is black
|
||
|
assert_allclose(disp.text_[0, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||
|
assert_allclose(disp.text_[1, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||
|
|
||
|
# off-diagonal text is white
|
||
|
assert_allclose(disp.text_[0, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||
|
assert_allclose(disp.text_[1, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||
|
|
||
|
disp.plot(cmap=pyplot.cm.gray_r)
|
||
|
# diagonal text is white
|
||
|
assert_allclose(disp.text_[0, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||
|
assert_allclose(disp.text_[1, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||
|
|
||
|
# off-diagonal text is black
|
||
|
assert_allclose(disp.text_[0, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||
|
assert_allclose(disp.text_[1, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||
|
|
||
|
# Regression test for #15920
|
||
|
cm = np.array([[19, 34], [32, 58]])
|
||
|
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
|
||
|
|
||
|
disp.plot(cmap=pyplot.cm.Blues)
|
||
|
min_color = pyplot.cm.Blues(0)
|
||
|
max_color = pyplot.cm.Blues(255)
|
||
|
assert_allclose(disp.text_[0, 0].get_color(), max_color)
|
||
|
assert_allclose(disp.text_[0, 1].get_color(), max_color)
|
||
|
assert_allclose(disp.text_[1, 0].get_color(), max_color)
|
||
|
assert_allclose(disp.text_[1, 1].get_color(), min_color)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"clf", [LogisticRegression(),
|
||
|
make_pipeline(StandardScaler(), LogisticRegression()),
|
||
|
make_pipeline(make_column_transformer((StandardScaler(), [0, 1])),
|
||
|
LogisticRegression())])
|
||
|
def test_confusion_matrix_pipeline(pyplot, clf, data, n_classes):
|
||
|
X, y = data
|
||
|
with pytest.raises(NotFittedError):
|
||
|
plot_confusion_matrix(clf, X, y)
|
||
|
clf.fit(X, y)
|
||
|
y_pred = clf.predict(X)
|
||
|
|
||
|
disp = plot_confusion_matrix(clf, X, y)
|
||
|
cm = confusion_matrix(y, y_pred)
|
||
|
|
||
|
assert_allclose(disp.confusion_matrix, cm)
|
||
|
assert disp.text_.shape == (n_classes, n_classes)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("colorbar", [True, False])
|
||
|
def test_plot_confusion_matrix_colorbar(pyplot, data, fitted_clf, colorbar):
|
||
|
X, y = data
|
||
|
|
||
|
def _check_colorbar(disp, has_colorbar):
|
||
|
if has_colorbar:
|
||
|
assert disp.im_.colorbar is not None
|
||
|
assert disp.im_.colorbar.__class__.__name__ == "Colorbar"
|
||
|
else:
|
||
|
assert disp.im_.colorbar is None
|
||
|
disp = plot_confusion_matrix(fitted_clf, X, y, colorbar=colorbar)
|
||
|
_check_colorbar(disp, colorbar)
|
||
|
# attempt a plot with the opposite effect of colorbar
|
||
|
disp.plot(colorbar=not colorbar)
|
||
|
_check_colorbar(disp, not colorbar)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("values_format", ['e', 'n'])
|
||
|
def test_confusion_matrix_text_format(pyplot, data, y_pred, n_classes,
|
||
|
fitted_clf, values_format):
|
||
|
# Make sure plot text is formatted with 'values_format'.
|
||
|
X, y = data
|
||
|
cm = confusion_matrix(y, y_pred)
|
||
|
disp = plot_confusion_matrix(fitted_clf, X, y,
|
||
|
include_values=True,
|
||
|
values_format=values_format)
|
||
|
|
||
|
assert disp.text_.shape == (n_classes, n_classes)
|
||
|
|
||
|
expected_text = np.array([format(v, values_format)
|
||
|
for v in cm.ravel()])
|
||
|
text_text = np.array([
|
||
|
t.get_text() for t in disp.text_.ravel()])
|
||
|
assert_array_equal(expected_text, text_text)
|
||
|
|
||
|
|
||
|
def test_confusion_matrix_standard_format(pyplot):
|
||
|
cm = np.array([[10000000, 0], [123456, 12345678]])
|
||
|
plotted_text = ConfusionMatrixDisplay(
|
||
|
cm, display_labels=[False, True]).plot().text_
|
||
|
# Values should be shown as whole numbers 'd',
|
||
|
# except the first number which should be shown as 1e+07 (longer length)
|
||
|
# and the last number will be shown as 1.2e+07 (longer length)
|
||
|
test = [t.get_text() for t in plotted_text.ravel()]
|
||
|
assert test == ['1e+07', '0', '123456', '1.2e+07']
|
||
|
|
||
|
cm = np.array([[0.1, 10], [100, 0.525]])
|
||
|
plotted_text = ConfusionMatrixDisplay(
|
||
|
cm, display_labels=[False, True]).plot().text_
|
||
|
# Values should now formatted as '.2g', since there's a float in
|
||
|
# Values are have two dec places max, (e.g 100 becomes 1e+02)
|
||
|
test = [t.get_text() for t in plotted_text.ravel()]
|
||
|
assert test == ['0.1', '10', '1e+02', '0.53']
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("display_labels, expected_labels", [
|
||
|
(None, ["0", "1"]),
|
||
|
(["cat", "dog"], ["cat", "dog"]),
|
||
|
])
|
||
|
def test_default_labels(pyplot, display_labels, expected_labels):
|
||
|
cm = np.array([[10, 0], [12, 120]])
|
||
|
disp = ConfusionMatrixDisplay(cm, display_labels=display_labels).plot()
|
||
|
|
||
|
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||
|
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||
|
|
||
|
assert_array_equal(x_ticks, expected_labels)
|
||
|
assert_array_equal(y_ticks, expected_labels)
|
||
|
|
||
|
|
||
|
def test_error_on_a_dataset_with_unseen_labels(
|
||
|
pyplot, fitted_clf, data, n_classes
|
||
|
):
|
||
|
"""Check that when labels=None, the unique values in `y_pred` and `y_true`
|
||
|
will be used.
|
||
|
Non-regression test for:
|
||
|
https://github.com/scikit-learn/scikit-learn/pull/18405
|
||
|
"""
|
||
|
X, y = data
|
||
|
|
||
|
# create unseen labels in `y_true` not seen during fitting and not present
|
||
|
# in 'fitted_clf.classes_'
|
||
|
y = y + 1
|
||
|
disp = plot_confusion_matrix(fitted_clf, X, y)
|
||
|
|
||
|
display_labels = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||
|
expected_labels = [str(i) for i in range(n_classes + 1)]
|
||
|
assert_array_equal(expected_labels, display_labels)
|