# Sebastian Raschka 2014-2020 # mlxtend Machine Learning Library Extensions # # A function for plotting decision regions of classifiers. # Author: Sebastian Raschka # # License: BSD 3 clause from itertools import cycle import matplotlib.pyplot as plt import numpy as np from mlxtend.utils import check_Xy, format_kwarg_dictionaries import warnings from math import floor from math import ceil def get_feature_range_mask(X, filler_feature_values=None, filler_feature_ranges=None): """ Function that constucts a boolean array to get rid of samples in X that are outside the feature range specified by filler_feature_values and filler_feature_ranges """ if not isinstance(X, np.ndarray) or not len(X.shape) == 2: raise ValueError('X must be a 2D array') elif filler_feature_values is None: raise ValueError('filler_feature_values must not be None') elif filler_feature_ranges is None: raise ValueError('filler_feature_ranges must not be None') mask = np.ones(X.shape[0], dtype=bool) for feature_idx in filler_feature_ranges: feature_value = filler_feature_values[feature_idx] feature_width = filler_feature_ranges[feature_idx] upp_limit = feature_value + feature_width low_limit = feature_value - feature_width feature_mask = (X[:, feature_idx] > low_limit) & \ (X[:, feature_idx] < upp_limit) mask = mask & feature_mask return mask def plot_decision_regions(X, y, clf, feature_index=None, filler_feature_values=None, filler_feature_ranges=None, ax=None, X_highlight=None, res=None, zoom_factor=1., legend=1, hide_spines=True, markers='s^oxv<>', colors=('#1f77b4,#ff7f0e,#3ca02c,#d62728,' '#9467bd,#8c564b,#e377c2,' '#7f7f7f,#bcbd22,#17becf'), scatter_kwargs=None, contourf_kwargs=None, scatter_highlight_kwargs=None): """Plot decision regions of a classifier. Please note that this functions assumes that class labels are labeled consecutively, e.g,. 0, 1, 2, 3, 4, and 5. If you have class labels with integer labels > 4, you may want to provide additional colors and/or markers as `colors` and `markers` arguments. See http://matplotlib.org/examples/color/named_colors.html for more information. Parameters ---------- X : array-like, shape = [n_samples, n_features] Feature Matrix. y : array-like, shape = [n_samples] True class labels. clf : Classifier object. Must have a .predict method. feature_index : array-like (default: (0,) for 1D, (0, 1) otherwise) Feature indices to use for plotting. The first index in `feature_index` will be on the x-axis, the second index will be on the y-axis. filler_feature_values : dict (default: None) Only needed for number features > 2. Dictionary of feature index-value pairs for the features not being plotted. filler_feature_ranges : dict (default: None) Only needed for number features > 2. Dictionary of feature index-value pairs for the features not being plotted. Will use the ranges provided to select training samples for plotting. ax : matplotlib.axes.Axes (default: None) An existing matplotlib Axes. Creates one if ax=None. X_highlight : array-like, shape = [n_samples, n_features] (default: None) An array with data points that are used to highlight samples in `X`. res : float or array-like, shape = (2,) (default: None) This parameter was used to define the grid width, but it has been deprecated in favor of determining the number of points given the figure DPI and size automatically for optimal results and computational efficiency. To increase the resolution, it's is recommended to use to provide a `dpi argument via matplotlib, e.g., `plt.figure(dpi=600)`. zoom_factor : float (default: 1.0) Controls the scale of the x- and y-axis of the decision plot. hide_spines : bool (default: True) Hide axis spines if True. legend : int (default: 1) Integer to specify the legend location. No legend if legend is 0. markers : str (default: 's^oxv<>') Scatterplot markers. colors : str (default: 'red,blue,limegreen,gray,cyan') Comma separated list of colors. scatter_kwargs : dict (default: None) Keyword arguments for underlying matplotlib scatter function. contourf_kwargs : dict (default: None) Keyword arguments for underlying matplotlib contourf function. scatter_highlight_kwargs : dict (default: None) Keyword arguments for underlying matplotlib scatter function. Returns --------- ax : matplotlib.axes.Axes object Examples ----------- For usage examples, please see http://rasbt.github.io/mlxtend/user_guide/plotting/plot_decision_regions/ """ check_Xy(X, y, y_int=True) # Validate X and y arrays dim = X.shape[1] if ax is None: ax = plt.gca() if res is not None: warnings.warn("The 'res' parameter has been deprecated." "To increase the resolution, it's is recommended" "to use to provide a `dpi argument via matplotlib," "e.g., `plt.figure(dpi=600)`.", DeprecationWarning) plot_testdata = True if not isinstance(X_highlight, np.ndarray): if X_highlight is not None: raise ValueError('X_highlight must be a NumPy array or None') else: plot_testdata = False elif len(X_highlight.shape) < 2: raise ValueError('X_highlight must be a 2D array') if feature_index is not None: # Unpack and validate the feature_index values if dim == 1: raise ValueError( 'feature_index requires more than one training feature') try: x_index, y_index = feature_index except ValueError: raise ValueError( 'Unable to unpack feature_index. Make sure feature_index ' 'only has two dimensions.') try: X[:, x_index], X[:, y_index] except IndexError: raise IndexError( 'feature_index values out of range. X.shape is {}, but ' 'feature_index is {}'.format(X.shape, feature_index)) else: feature_index = (0, 1) x_index, y_index = feature_index # Extra input validation for higher number of training features if dim > 2: if filler_feature_values is None: raise ValueError('Filler values must be provided when ' 'X has more than 2 training features.') if filler_feature_ranges is not None: if not set(filler_feature_values) == set(filler_feature_ranges): raise ValueError( 'filler_feature_values and filler_feature_ranges must ' 'have the same keys') # Check that all columns in X are accounted for column_check = np.zeros(dim, dtype=bool) for idx in filler_feature_values: column_check[idx] = True for idx in feature_index: column_check[idx] = True if not all(column_check): missing_cols = np.argwhere(~column_check).flatten() raise ValueError( 'Column(s) {} need to be accounted for in either ' 'feature_index or filler_feature_values'.format(missing_cols)) marker_gen = cycle(list(markers)) n_classes = np.unique(y).shape[0] colors = colors.split(',') colors_gen = cycle(colors) colors = [next(colors_gen) for c in range(n_classes)] # Get minimum and maximum x_min, x_max = (X[:, x_index].min() - 1./zoom_factor, X[:, x_index].max() + 1./zoom_factor) if dim == 1: y_min, y_max = -1, 1 else: y_min, y_max = (X[:, y_index].min() - 1./zoom_factor, X[:, y_index].max() + 1./zoom_factor) xnum, ynum = plt.gcf().dpi * plt.gcf().get_size_inches() xnum, ynum = floor(xnum), ceil(ynum) xx, yy = np.meshgrid(np.linspace(x_min, x_max, num=xnum), np.linspace(y_min, y_max, num=ynum)) if dim == 1: X_predict = np.array([xx.ravel()]).T else: X_grid = np.array([xx.ravel(), yy.ravel()]).T X_predict = np.zeros((X_grid.shape[0], dim)) X_predict[:, x_index] = X_grid[:, 0] X_predict[:, y_index] = X_grid[:, 1] if dim > 2: for feature_idx in filler_feature_values: X_predict[:, feature_idx] = filler_feature_values[feature_idx] Z = clf.predict(X_predict.astype(X.dtype)) Z = Z.reshape(xx.shape) # Plot decisoin region # Make sure contourf_kwargs has backwards compatible defaults contourf_kwargs_default = {'alpha': 0.45, 'antialiased': True} contourf_kwargs = format_kwarg_dictionaries( default_kwargs=contourf_kwargs_default, user_kwargs=contourf_kwargs, protected_keys=['colors', 'levels']) cset = ax.contourf(xx, yy, Z, colors=colors, levels=np.arange(Z.max() + 2) - 0.5, **contourf_kwargs) ax.contour(xx, yy, Z, cset.levels, colors='k', linewidths=0.5, antialiased=True) ax.axis([xx.min(), xx.max(), yy.min(), yy.max()]) # Scatter training data samples # Make sure scatter_kwargs has backwards compatible defaults scatter_kwargs_default = {'alpha': 0.8, 'edgecolor': 'black'} scatter_kwargs = format_kwarg_dictionaries( default_kwargs=scatter_kwargs_default, user_kwargs=scatter_kwargs, protected_keys=['c', 'marker', 'label']) for idx, c in enumerate(np.unique(y)): if dim == 1: y_data = [0 for i in X[y == c]] x_data = X[y == c] elif dim == 2: y_data = X[y == c, y_index] x_data = X[y == c, x_index] elif dim > 2 and filler_feature_ranges is not None: class_mask = y == c feature_range_mask = get_feature_range_mask( X, filler_feature_values=filler_feature_values, filler_feature_ranges=filler_feature_ranges) y_data = X[class_mask & feature_range_mask, y_index] x_data = X[class_mask & feature_range_mask, x_index] else: continue ax.scatter(x=x_data, y=y_data, c=colors[idx], marker=next(marker_gen), label=c, **scatter_kwargs) if hide_spines: ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.spines['left'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.yaxis.set_ticks_position('left') ax.xaxis.set_ticks_position('bottom') if dim == 1: ax.axes.get_yaxis().set_ticks([]) if plot_testdata: if dim == 1: x_data = X_highlight y_data = [0 for i in X_highlight] elif dim == 2: x_data = X_highlight[:, x_index] y_data = X_highlight[:, y_index] else: feature_range_mask = get_feature_range_mask( X_highlight, filler_feature_values=filler_feature_values, filler_feature_ranges=filler_feature_ranges) y_data = X_highlight[feature_range_mask, y_index] x_data = X_highlight[feature_range_mask, x_index] # Make sure scatter_highlight_kwargs backwards compatible defaults scatter_highlight_defaults = {'c': '', 'edgecolor': 'black', 'alpha': 1.0, 'linewidths': 1, 'marker': 'o', 's': 80} scatter_highlight_kwargs = format_kwarg_dictionaries( default_kwargs=scatter_highlight_defaults, user_kwargs=scatter_highlight_kwargs) ax.scatter(x_data, y_data, **scatter_highlight_kwargs) if legend: if dim > 2 and filler_feature_ranges is None: pass else: handles, labels = ax.get_legend_handles_labels() ax.legend(handles, labels, framealpha=0.3, scatterpoints=1, loc=legend) return ax