from itertools import product import numpy as np from .. import confusion_matrix from ...utils import check_matplotlib_support from ...utils.multiclass import unique_labels from ...base import is_classifier class ConfusionMatrixDisplay: """Confusion Matrix visualization. It is recommend to use :func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator` or :func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` to create a :class:`ConfusionMatrixDisplay`. All parameters are stored as attributes. Read more in the :ref:`User Guide `. Parameters ---------- confusion_matrix : ndarray of shape (n_classes, n_classes) Confusion matrix. display_labels : ndarray of shape (n_classes,), default=None Display labels for plot. If None, display labels are set from 0 to `n_classes - 1`. Attributes ---------- im_ : matplotlib AxesImage Image representing the confusion matrix. text_ : ndarray of shape (n_classes, n_classes), dtype=matplotlib Text, \ or None Array of matplotlib axes. `None` if `include_values` is false. ax_ : matplotlib Axes Axes with confusion matrix. figure_ : matplotlib Figure Figure containing the confusion matrix. See Also -------- confusion_matrix : Compute Confusion Matrix to evaluate the accuracy of a classification. ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix given an estimator, the data, and the label. ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix given the true and predicted labels. Examples -------- >>> import matplotlib.pyplot as plt >>> from sklearn.datasets import make_classification >>> from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay >>> from sklearn.model_selection import train_test_split >>> from sklearn.svm import SVC >>> X, y = make_classification(random_state=0) >>> X_train, X_test, y_train, y_test = train_test_split(X, y, ... random_state=0) >>> clf = SVC(random_state=0) >>> clf.fit(X_train, y_train) SVC(random_state=0) >>> predictions = clf.predict(X_test) >>> cm = confusion_matrix(y_test, predictions, labels=clf.classes_) >>> disp = ConfusionMatrixDisplay(confusion_matrix=cm, ... display_labels=clf.classes_) >>> disp.plot() <...> >>> plt.show() """ def __init__(self, confusion_matrix, *, display_labels=None): self.confusion_matrix = confusion_matrix self.display_labels = display_labels def plot( self, *, include_values=True, cmap="viridis", xticks_rotation="horizontal", values_format=None, ax=None, colorbar=True, im_kw=None, text_kw=None, ): """Plot visualization. Parameters ---------- include_values : bool, default=True Includes values in confusion matrix. cmap : str or matplotlib Colormap, default='viridis' Colormap recognized by matplotlib. xticks_rotation : {'vertical', 'horizontal'} or float, \ default='horizontal' Rotation of xtick labels. values_format : str, default=None Format specification for values in confusion matrix. If `None`, the format specification is 'd' or '.2g' whichever is shorter. ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. colorbar : bool, default=True Whether or not to add a colorbar to the plot. im_kw : dict, default=None Dict with keywords passed to `matplotlib.pyplot.imshow` call. text_kw : dict, default=None Dict with keywords passed to `matplotlib.pyplot.text` call. .. versionadded:: 1.2 Returns ------- display : :class:`~sklearn.metrics.ConfusionMatrixDisplay` Returns a :class:`~sklearn.metrics.ConfusionMatrixDisplay` instance that contains all the information to plot the confusion matrix. """ check_matplotlib_support("ConfusionMatrixDisplay.plot") import matplotlib.pyplot as plt if ax is None: fig, ax = plt.subplots() else: fig = ax.figure cm = self.confusion_matrix n_classes = cm.shape[0] default_im_kw = dict(interpolation="nearest", cmap=cmap) im_kw = im_kw or {} im_kw = {**default_im_kw, **im_kw} text_kw = text_kw or {} self.im_ = ax.imshow(cm, **im_kw) self.text_ = None cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(1.0) if include_values: self.text_ = np.empty_like(cm, dtype=object) # print text with appropriate color depending on background thresh = (cm.max() + cm.min()) / 2.0 for i, j in product(range(n_classes), range(n_classes)): color = cmap_max if cm[i, j] < thresh else cmap_min if values_format is None: text_cm = format(cm[i, j], ".2g") if cm.dtype.kind != "f": text_d = format(cm[i, j], "d") if len(text_d) < len(text_cm): text_cm = text_d else: text_cm = format(cm[i, j], values_format) default_text_kwargs = dict(ha="center", va="center", color=color) text_kwargs = {**default_text_kwargs, **text_kw} self.text_[i, j] = ax.text(j, i, text_cm, **text_kwargs) if self.display_labels is None: display_labels = np.arange(n_classes) else: display_labels = self.display_labels if colorbar: fig.colorbar(self.im_, ax=ax) ax.set( xticks=np.arange(n_classes), yticks=np.arange(n_classes), xticklabels=display_labels, yticklabels=display_labels, ylabel="True label", xlabel="Predicted label", ) ax.set_ylim((n_classes - 0.5, -0.5)) plt.setp(ax.get_xticklabels(), rotation=xticks_rotation) self.figure_ = fig self.ax_ = ax return self @classmethod def from_estimator( cls, estimator, X, y, *, labels=None, sample_weight=None, normalize=None, display_labels=None, include_values=True, xticks_rotation="horizontal", values_format=None, cmap="viridis", ax=None, colorbar=True, im_kw=None, text_kw=None, ): """Plot Confusion Matrix given an estimator and some data. Read more in the :ref:`User Guide `. .. versionadded:: 1.0 Parameters ---------- estimator : estimator instance Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` in which the last estimator is a classifier. X : {array-like, sparse matrix} of shape (n_samples, n_features) Input values. y : array-like of shape (n_samples,) Target values. labels : array-like of shape (n_classes,), default=None List of labels to index the confusion matrix. This may be used to reorder or select a subset of labels. If `None` is given, those that appear at least once in `y_true` or `y_pred` are used in sorted order. sample_weight : array-like of shape (n_samples,), default=None Sample weights. normalize : {'true', 'pred', 'all'}, default=None Either to normalize the counts display in the matrix: - if `'true'`, the confusion matrix is normalized over the true conditions (e.g. rows); - if `'pred'`, the confusion matrix is normalized over the predicted conditions (e.g. columns); - if `'all'`, the confusion matrix is normalized by the total number of samples; - if `None` (default), the confusion matrix will not be normalized. display_labels : array-like of shape (n_classes,), default=None Target names used for plotting. By default, `labels` will be used if it is defined, otherwise the unique labels of `y_true` and `y_pred` will be used. include_values : bool, default=True Includes values in confusion matrix. xticks_rotation : {'vertical', 'horizontal'} or float, \ default='horizontal' Rotation of xtick labels. values_format : str, default=None Format specification for values in confusion matrix. If `None`, the format specification is 'd' or '.2g' whichever is shorter. cmap : str or matplotlib Colormap, default='viridis' Colormap recognized by matplotlib. ax : matplotlib Axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. colorbar : bool, default=True Whether or not to add a colorbar to the plot. im_kw : dict, default=None Dict with keywords passed to `matplotlib.pyplot.imshow` call. text_kw : dict, default=None Dict with keywords passed to `matplotlib.pyplot.text` call. .. versionadded:: 1.2 Returns ------- display : :class:`~sklearn.metrics.ConfusionMatrixDisplay` See Also -------- ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix given the true and predicted labels. Examples -------- >>> import matplotlib.pyplot as plt >>> from sklearn.datasets import make_classification >>> from sklearn.metrics import ConfusionMatrixDisplay >>> from sklearn.model_selection import train_test_split >>> from sklearn.svm import SVC >>> X, y = make_classification(random_state=0) >>> X_train, X_test, y_train, y_test = train_test_split( ... X, y, random_state=0) >>> clf = SVC(random_state=0) >>> clf.fit(X_train, y_train) SVC(random_state=0) >>> ConfusionMatrixDisplay.from_estimator( ... clf, X_test, y_test) <...> >>> plt.show() """ method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) if not is_classifier(estimator): raise ValueError(f"{method_name} only supports classifiers") y_pred = estimator.predict(X) return cls.from_predictions( y, y_pred, sample_weight=sample_weight, labels=labels, normalize=normalize, display_labels=display_labels, include_values=include_values, cmap=cmap, ax=ax, xticks_rotation=xticks_rotation, values_format=values_format, colorbar=colorbar, im_kw=im_kw, text_kw=text_kw, ) @classmethod def from_predictions( cls, y_true, y_pred, *, labels=None, sample_weight=None, normalize=None, display_labels=None, include_values=True, xticks_rotation="horizontal", values_format=None, cmap="viridis", ax=None, colorbar=True, im_kw=None, text_kw=None, ): """Plot Confusion Matrix given true and predicted labels. Read more in the :ref:`User Guide `. .. versionadded:: 1.0 Parameters ---------- y_true : array-like of shape (n_samples,) True labels. y_pred : array-like of shape (n_samples,) The predicted labels given by the method `predict` of an classifier. labels : array-like of shape (n_classes,), default=None List of labels to index the confusion matrix. This may be used to reorder or select a subset of labels. If `None` is given, those that appear at least once in `y_true` or `y_pred` are used in sorted order. sample_weight : array-like of shape (n_samples,), default=None Sample weights. normalize : {'true', 'pred', 'all'}, default=None Either to normalize the counts display in the matrix: - if `'true'`, the confusion matrix is normalized over the true conditions (e.g. rows); - if `'pred'`, the confusion matrix is normalized over the predicted conditions (e.g. columns); - if `'all'`, the confusion matrix is normalized by the total number of samples; - if `None` (default), the confusion matrix will not be normalized. display_labels : array-like of shape (n_classes,), default=None Target names used for plotting. By default, `labels` will be used if it is defined, otherwise the unique labels of `y_true` and `y_pred` will be used. include_values : bool, default=True Includes values in confusion matrix. xticks_rotation : {'vertical', 'horizontal'} or float, \ default='horizontal' Rotation of xtick labels. values_format : str, default=None Format specification for values in confusion matrix. If `None`, the format specification is 'd' or '.2g' whichever is shorter. cmap : str or matplotlib Colormap, default='viridis' Colormap recognized by matplotlib. ax : matplotlib Axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. colorbar : bool, default=True Whether or not to add a colorbar to the plot. im_kw : dict, default=None Dict with keywords passed to `matplotlib.pyplot.imshow` call. text_kw : dict, default=None Dict with keywords passed to `matplotlib.pyplot.text` call. .. versionadded:: 1.2 Returns ------- display : :class:`~sklearn.metrics.ConfusionMatrixDisplay` See Also -------- ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix given an estimator, the data, and the label. Examples -------- >>> import matplotlib.pyplot as plt >>> from sklearn.datasets import make_classification >>> from sklearn.metrics import ConfusionMatrixDisplay >>> from sklearn.model_selection import train_test_split >>> from sklearn.svm import SVC >>> X, y = make_classification(random_state=0) >>> X_train, X_test, y_train, y_test = train_test_split( ... X, y, random_state=0) >>> clf = SVC(random_state=0) >>> clf.fit(X_train, y_train) SVC(random_state=0) >>> y_pred = clf.predict(X_test) >>> ConfusionMatrixDisplay.from_predictions( ... y_test, y_pred) <...> >>> plt.show() """ check_matplotlib_support(f"{cls.__name__}.from_predictions") if display_labels is None: if labels is None: display_labels = unique_labels(y_true, y_pred) else: display_labels = labels cm = confusion_matrix( y_true, y_pred, sample_weight=sample_weight, labels=labels, normalize=normalize, ) disp = cls(confusion_matrix=cm, display_labels=display_labels) return disp.plot( include_values=include_values, cmap=cmap, ax=ax, xticks_rotation=xticks_rotation, values_format=values_format, colorbar=colorbar, im_kw=im_kw, text_kw=text_kw, )