146 lines
5.0 KiB
Python
146 lines
5.0 KiB
Python
|
# Sebastian Raschka 2014-2020
|
||
|
# mlxtend Machine Learning Library Extensions
|
||
|
# Author: Sebastian Raschka <sebastianraschka.com>
|
||
|
#
|
||
|
# A function for plotting a confusion matrix.
|
||
|
# License: BSD 3 clause
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
def plot_confusion_matrix(conf_mat,
|
||
|
hide_spines=False,
|
||
|
hide_ticks=False,
|
||
|
figsize=None,
|
||
|
cmap=None,
|
||
|
colorbar=False,
|
||
|
show_absolute=True,
|
||
|
show_normed=False,
|
||
|
class_names=None,
|
||
|
figure=None,
|
||
|
axis=None):
|
||
|
"""Plot a confusion matrix via matplotlib.
|
||
|
|
||
|
Parameters
|
||
|
-----------
|
||
|
conf_mat : array-like, shape = [n_classes, n_classes]
|
||
|
Confusion matrix from evaluate.confusion matrix.
|
||
|
hide_spines : bool (default: False)
|
||
|
Hides axis spines if True.
|
||
|
hide_ticks : bool (default: False)
|
||
|
Hides axis ticks if True
|
||
|
figsize : tuple (default: (2.5, 2.5))
|
||
|
Height and width of the figure
|
||
|
cmap : matplotlib colormap (default: `None`)
|
||
|
Uses matplotlib.pyplot.cm.Blues if `None`
|
||
|
colorbar : bool (default: False)
|
||
|
Shows a colorbar if True
|
||
|
show_absolute : bool (default: True)
|
||
|
Shows absolute confusion matrix coefficients if True.
|
||
|
At least one of `show_absolute` or `show_normed`
|
||
|
must be True.
|
||
|
show_normed : bool (default: False)
|
||
|
Shows normed confusion matrix coefficients if True.
|
||
|
The normed confusion matrix coefficients give the
|
||
|
proportion of training examples per class that are
|
||
|
assigned the correct label.
|
||
|
At least one of `show_absolute` or `show_normed`
|
||
|
must be True.
|
||
|
class_names : array-like, shape = [n_classes] (default: None)
|
||
|
List of class names.
|
||
|
If not `None`, ticks will be set to these values.
|
||
|
figure : None or Matplotlib figure (default: None)
|
||
|
If None will create a new figure.
|
||
|
axis : None or Matplotlib figure axis (default: None)
|
||
|
If None will create a new axis.
|
||
|
|
||
|
Returns
|
||
|
-----------
|
||
|
fig, ax : matplotlib.pyplot subplot objects
|
||
|
Figure and axis elements of the subplot.
|
||
|
|
||
|
Examples
|
||
|
-----------
|
||
|
For usage examples, please see
|
||
|
http://rasbt.github.io/mlxtend/user_guide/plotting/plot_confusion_matrix/
|
||
|
|
||
|
"""
|
||
|
if not (show_absolute or show_normed):
|
||
|
raise AssertionError('Both show_absolute and show_normed are False')
|
||
|
if class_names is not None and len(class_names) != len(conf_mat):
|
||
|
raise AssertionError('len(class_names) should be equal to number of'
|
||
|
'classes in the dataset')
|
||
|
|
||
|
total_samples = conf_mat.sum(axis=1)[:, np.newaxis]
|
||
|
normed_conf_mat = conf_mat.astype('float') / total_samples
|
||
|
|
||
|
if figure is None and axis is None:
|
||
|
fig, ax = plt.subplots(figsize=figsize)
|
||
|
elif axis is None:
|
||
|
fig = figure
|
||
|
ax = fig.add_subplot(1, 1, 1)
|
||
|
else:
|
||
|
fig, ax = figure, axis
|
||
|
|
||
|
ax.grid(False)
|
||
|
if cmap is None:
|
||
|
cmap = plt.cm.Blues
|
||
|
|
||
|
if figsize is None:
|
||
|
figsize = (len(conf_mat)*1.25, len(conf_mat)*1.25)
|
||
|
|
||
|
if show_normed:
|
||
|
matshow = ax.matshow(normed_conf_mat, cmap=cmap)
|
||
|
else:
|
||
|
matshow = ax.matshow(conf_mat, cmap=cmap)
|
||
|
|
||
|
if colorbar:
|
||
|
fig.colorbar(matshow)
|
||
|
|
||
|
for i in range(conf_mat.shape[0]):
|
||
|
for j in range(conf_mat.shape[1]):
|
||
|
cell_text = ""
|
||
|
if show_absolute:
|
||
|
cell_text += format(conf_mat[i, j], 'd')
|
||
|
if show_normed:
|
||
|
cell_text += "\n" + '('
|
||
|
cell_text += format(normed_conf_mat[i, j], '.2f') + ')'
|
||
|
else:
|
||
|
cell_text += format(normed_conf_mat[i, j], '.2f')
|
||
|
if show_normed:
|
||
|
ax.text(x=j,
|
||
|
y=i,
|
||
|
s=cell_text,
|
||
|
va='center',
|
||
|
ha='center',
|
||
|
color="white" if normed_conf_mat[i, j] > 0.5
|
||
|
else "black")
|
||
|
else:
|
||
|
ax.text(x=j,
|
||
|
y=i,
|
||
|
s=cell_text,
|
||
|
va='center',
|
||
|
ha='center',
|
||
|
color="white" if conf_mat[i, j] > np.max(conf_mat)/2
|
||
|
else "black")
|
||
|
if class_names is not None:
|
||
|
tick_marks = np.arange(len(class_names))
|
||
|
plt.xticks(tick_marks, class_names, rotation=45)
|
||
|
plt.yticks(tick_marks, class_names)
|
||
|
|
||
|
if hide_spines:
|
||
|
ax.spines['right'].set_visible(False)
|
||
|
ax.spines['top'].set_visible(False)
|
||
|
ax.spines['left'].set_visible(False)
|
||
|
ax.spines['bottom'].set_visible(False)
|
||
|
ax.yaxis.set_ticks_position('left')
|
||
|
ax.xaxis.set_ticks_position('bottom')
|
||
|
if hide_ticks:
|
||
|
ax.axes.get_yaxis().set_ticks([])
|
||
|
ax.axes.get_xaxis().set_ticks([])
|
||
|
|
||
|
plt.xlabel('predicted label')
|
||
|
plt.ylabel('true label')
|
||
|
return fig, ax
|