projektAI/venv/Lib/site-packages/mlxtend/plotting/plot_confusion_matrix.py

146 lines
5.0 KiB
Python
Raw Normal View History

2021-06-06 22:13:05 +02:00
# Sebastian Raschka 2014-2020
# mlxtend Machine Learning Library Extensions
# Author: Sebastian Raschka <sebastianraschka.com>
#
# A function for plotting a confusion matrix.
# License: BSD 3 clause
import matplotlib.pyplot as plt
import numpy as np
def plot_confusion_matrix(conf_mat,
hide_spines=False,
hide_ticks=False,
figsize=None,
cmap=None,
colorbar=False,
show_absolute=True,
show_normed=False,
class_names=None,
figure=None,
axis=None):
"""Plot a confusion matrix via matplotlib.
Parameters
-----------
conf_mat : array-like, shape = [n_classes, n_classes]
Confusion matrix from evaluate.confusion matrix.
hide_spines : bool (default: False)
Hides axis spines if True.
hide_ticks : bool (default: False)
Hides axis ticks if True
figsize : tuple (default: (2.5, 2.5))
Height and width of the figure
cmap : matplotlib colormap (default: `None`)
Uses matplotlib.pyplot.cm.Blues if `None`
colorbar : bool (default: False)
Shows a colorbar if True
show_absolute : bool (default: True)
Shows absolute confusion matrix coefficients if True.
At least one of `show_absolute` or `show_normed`
must be True.
show_normed : bool (default: False)
Shows normed confusion matrix coefficients if True.
The normed confusion matrix coefficients give the
proportion of training examples per class that are
assigned the correct label.
At least one of `show_absolute` or `show_normed`
must be True.
class_names : array-like, shape = [n_classes] (default: None)
List of class names.
If not `None`, ticks will be set to these values.
figure : None or Matplotlib figure (default: None)
If None will create a new figure.
axis : None or Matplotlib figure axis (default: None)
If None will create a new axis.
Returns
-----------
fig, ax : matplotlib.pyplot subplot objects
Figure and axis elements of the subplot.
Examples
-----------
For usage examples, please see
http://rasbt.github.io/mlxtend/user_guide/plotting/plot_confusion_matrix/
"""
if not (show_absolute or show_normed):
raise AssertionError('Both show_absolute and show_normed are False')
if class_names is not None and len(class_names) != len(conf_mat):
raise AssertionError('len(class_names) should be equal to number of'
'classes in the dataset')
total_samples = conf_mat.sum(axis=1)[:, np.newaxis]
normed_conf_mat = conf_mat.astype('float') / total_samples
if figure is None and axis is None:
fig, ax = plt.subplots(figsize=figsize)
elif axis is None:
fig = figure
ax = fig.add_subplot(1, 1, 1)
else:
fig, ax = figure, axis
ax.grid(False)
if cmap is None:
cmap = plt.cm.Blues
if figsize is None:
figsize = (len(conf_mat)*1.25, len(conf_mat)*1.25)
if show_normed:
matshow = ax.matshow(normed_conf_mat, cmap=cmap)
else:
matshow = ax.matshow(conf_mat, cmap=cmap)
if colorbar:
fig.colorbar(matshow)
for i in range(conf_mat.shape[0]):
for j in range(conf_mat.shape[1]):
cell_text = ""
if show_absolute:
cell_text += format(conf_mat[i, j], 'd')
if show_normed:
cell_text += "\n" + '('
cell_text += format(normed_conf_mat[i, j], '.2f') + ')'
else:
cell_text += format(normed_conf_mat[i, j], '.2f')
if show_normed:
ax.text(x=j,
y=i,
s=cell_text,
va='center',
ha='center',
color="white" if normed_conf_mat[i, j] > 0.5
else "black")
else:
ax.text(x=j,
y=i,
s=cell_text,
va='center',
ha='center',
color="white" if conf_mat[i, j] > np.max(conf_mat)/2
else "black")
if class_names is not None:
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
if hide_spines:
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
if hide_ticks:
ax.axes.get_yaxis().set_ticks([])
ax.axes.get_xaxis().set_ticks([])
plt.xlabel('predicted label')
plt.ylabel('true label')
return fig, ax