# Sebastian Raschka 2014-2020 # mlxtend Machine Learning Library Extensions # Author: Sebastian Raschka # # A function for plotting a confusion matrix. # License: BSD 3 clause import matplotlib.pyplot as plt import numpy as np def plot_confusion_matrix(conf_mat, hide_spines=False, hide_ticks=False, figsize=None, cmap=None, colorbar=False, show_absolute=True, show_normed=False, class_names=None, figure=None, axis=None): """Plot a confusion matrix via matplotlib. Parameters ----------- conf_mat : array-like, shape = [n_classes, n_classes] Confusion matrix from evaluate.confusion matrix. hide_spines : bool (default: False) Hides axis spines if True. hide_ticks : bool (default: False) Hides axis ticks if True figsize : tuple (default: (2.5, 2.5)) Height and width of the figure cmap : matplotlib colormap (default: `None`) Uses matplotlib.pyplot.cm.Blues if `None` colorbar : bool (default: False) Shows a colorbar if True show_absolute : bool (default: True) Shows absolute confusion matrix coefficients if True. At least one of `show_absolute` or `show_normed` must be True. show_normed : bool (default: False) Shows normed confusion matrix coefficients if True. The normed confusion matrix coefficients give the proportion of training examples per class that are assigned the correct label. At least one of `show_absolute` or `show_normed` must be True. class_names : array-like, shape = [n_classes] (default: None) List of class names. If not `None`, ticks will be set to these values. figure : None or Matplotlib figure (default: None) If None will create a new figure. axis : None or Matplotlib figure axis (default: None) If None will create a new axis. Returns ----------- fig, ax : matplotlib.pyplot subplot objects Figure and axis elements of the subplot. Examples ----------- For usage examples, please see http://rasbt.github.io/mlxtend/user_guide/plotting/plot_confusion_matrix/ """ if not (show_absolute or show_normed): raise AssertionError('Both show_absolute and show_normed are False') if class_names is not None and len(class_names) != len(conf_mat): raise AssertionError('len(class_names) should be equal to number of' 'classes in the dataset') total_samples = conf_mat.sum(axis=1)[:, np.newaxis] normed_conf_mat = conf_mat.astype('float') / total_samples if figure is None and axis is None: fig, ax = plt.subplots(figsize=figsize) elif axis is None: fig = figure ax = fig.add_subplot(1, 1, 1) else: fig, ax = figure, axis ax.grid(False) if cmap is None: cmap = plt.cm.Blues if figsize is None: figsize = (len(conf_mat)*1.25, len(conf_mat)*1.25) if show_normed: matshow = ax.matshow(normed_conf_mat, cmap=cmap) else: matshow = ax.matshow(conf_mat, cmap=cmap) if colorbar: fig.colorbar(matshow) for i in range(conf_mat.shape[0]): for j in range(conf_mat.shape[1]): cell_text = "" if show_absolute: cell_text += format(conf_mat[i, j], 'd') if show_normed: cell_text += "\n" + '(' cell_text += format(normed_conf_mat[i, j], '.2f') + ')' else: cell_text += format(normed_conf_mat[i, j], '.2f') if show_normed: ax.text(x=j, y=i, s=cell_text, va='center', ha='center', color="white" if normed_conf_mat[i, j] > 0.5 else "black") else: ax.text(x=j, y=i, s=cell_text, va='center', ha='center', color="white" if conf_mat[i, j] > np.max(conf_mat)/2 else "black") if class_names is not None: tick_marks = np.arange(len(class_names)) plt.xticks(tick_marks, class_names, rotation=45) plt.yticks(tick_marks, class_names) if hide_spines: ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.spines['left'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.yaxis.set_ticks_position('left') ax.xaxis.set_ticks_position('bottom') if hide_ticks: ax.axes.get_yaxis().set_ticks([]) ax.axes.get_xaxis().set_ticks([]) plt.xlabel('predicted label') plt.ylabel('true label') return fig, ax