329 lines
13 KiB
Python
329 lines
13 KiB
Python
|
# Sebastian Raschka 2014-2020
|
||
|
# mlxtend Machine Learning Library Extensions
|
||
|
#
|
||
|
# A function for plotting decision regions of classifiers.
|
||
|
# Author: Sebastian Raschka <sebastianraschka.com>
|
||
|
#
|
||
|
# License: BSD 3 clause
|
||
|
|
||
|
from itertools import cycle
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
from mlxtend.utils import check_Xy, format_kwarg_dictionaries
|
||
|
import warnings
|
||
|
from math import floor
|
||
|
from math import ceil
|
||
|
|
||
|
|
||
|
def get_feature_range_mask(X, filler_feature_values=None,
|
||
|
filler_feature_ranges=None):
|
||
|
"""
|
||
|
Function that constucts a boolean array to get rid of samples
|
||
|
in X that are outside the feature range specified by filler_feature_values
|
||
|
and filler_feature_ranges
|
||
|
"""
|
||
|
|
||
|
if not isinstance(X, np.ndarray) or not len(X.shape) == 2:
|
||
|
raise ValueError('X must be a 2D array')
|
||
|
elif filler_feature_values is None:
|
||
|
raise ValueError('filler_feature_values must not be None')
|
||
|
elif filler_feature_ranges is None:
|
||
|
raise ValueError('filler_feature_ranges must not be None')
|
||
|
|
||
|
mask = np.ones(X.shape[0], dtype=bool)
|
||
|
for feature_idx in filler_feature_ranges:
|
||
|
feature_value = filler_feature_values[feature_idx]
|
||
|
feature_width = filler_feature_ranges[feature_idx]
|
||
|
upp_limit = feature_value + feature_width
|
||
|
low_limit = feature_value - feature_width
|
||
|
feature_mask = (X[:, feature_idx] > low_limit) & \
|
||
|
(X[:, feature_idx] < upp_limit)
|
||
|
mask = mask & feature_mask
|
||
|
|
||
|
return mask
|
||
|
|
||
|
|
||
|
def plot_decision_regions(X, y, clf,
|
||
|
feature_index=None,
|
||
|
filler_feature_values=None,
|
||
|
filler_feature_ranges=None,
|
||
|
ax=None,
|
||
|
X_highlight=None,
|
||
|
res=None,
|
||
|
zoom_factor=1.,
|
||
|
legend=1,
|
||
|
hide_spines=True,
|
||
|
markers='s^oxv<>',
|
||
|
colors=('#1f77b4,#ff7f0e,#3ca02c,#d62728,'
|
||
|
'#9467bd,#8c564b,#e377c2,'
|
||
|
'#7f7f7f,#bcbd22,#17becf'),
|
||
|
scatter_kwargs=None,
|
||
|
contourf_kwargs=None,
|
||
|
scatter_highlight_kwargs=None):
|
||
|
"""Plot decision regions of a classifier.
|
||
|
|
||
|
Please note that this functions assumes that class labels are
|
||
|
labeled consecutively, e.g,. 0, 1, 2, 3, 4, and 5. If you have class
|
||
|
labels with integer labels > 4, you may want to provide additional colors
|
||
|
and/or markers as `colors` and `markers` arguments.
|
||
|
See http://matplotlib.org/examples/color/named_colors.html for more
|
||
|
information.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : array-like, shape = [n_samples, n_features]
|
||
|
Feature Matrix.
|
||
|
y : array-like, shape = [n_samples]
|
||
|
True class labels.
|
||
|
clf : Classifier object.
|
||
|
Must have a .predict method.
|
||
|
feature_index : array-like (default: (0,) for 1D, (0, 1) otherwise)
|
||
|
Feature indices to use for plotting. The first index in
|
||
|
`feature_index` will be on the x-axis, the second index will be
|
||
|
on the y-axis.
|
||
|
filler_feature_values : dict (default: None)
|
||
|
Only needed for number features > 2. Dictionary of feature
|
||
|
index-value pairs for the features not being plotted.
|
||
|
filler_feature_ranges : dict (default: None)
|
||
|
Only needed for number features > 2. Dictionary of feature
|
||
|
index-value pairs for the features not being plotted. Will use the
|
||
|
ranges provided to select training samples for plotting.
|
||
|
ax : matplotlib.axes.Axes (default: None)
|
||
|
An existing matplotlib Axes. Creates
|
||
|
one if ax=None.
|
||
|
X_highlight : array-like, shape = [n_samples, n_features] (default: None)
|
||
|
An array with data points that are used to highlight samples in `X`.
|
||
|
res : float or array-like, shape = (2,) (default: None)
|
||
|
This parameter was used to define the grid width,
|
||
|
but it has been deprecated in favor of
|
||
|
determining the number of points given the figure DPI and size
|
||
|
automatically for optimal results and computational efficiency.
|
||
|
To increase the resolution, it's is recommended to use to provide
|
||
|
a `dpi argument via matplotlib, e.g., `plt.figure(dpi=600)`.
|
||
|
zoom_factor : float (default: 1.0)
|
||
|
Controls the scale of the x- and y-axis of the decision plot.
|
||
|
hide_spines : bool (default: True)
|
||
|
Hide axis spines if True.
|
||
|
legend : int (default: 1)
|
||
|
Integer to specify the legend location.
|
||
|
No legend if legend is 0.
|
||
|
markers : str (default: 's^oxv<>')
|
||
|
Scatterplot markers.
|
||
|
colors : str (default: 'red,blue,limegreen,gray,cyan')
|
||
|
Comma separated list of colors.
|
||
|
scatter_kwargs : dict (default: None)
|
||
|
Keyword arguments for underlying matplotlib scatter function.
|
||
|
contourf_kwargs : dict (default: None)
|
||
|
Keyword arguments for underlying matplotlib contourf function.
|
||
|
scatter_highlight_kwargs : dict (default: None)
|
||
|
Keyword arguments for underlying matplotlib scatter function.
|
||
|
|
||
|
Returns
|
||
|
---------
|
||
|
ax : matplotlib.axes.Axes object
|
||
|
|
||
|
Examples
|
||
|
-----------
|
||
|
For usage examples, please see
|
||
|
http://rasbt.github.io/mlxtend/user_guide/plotting/plot_decision_regions/
|
||
|
|
||
|
"""
|
||
|
|
||
|
check_Xy(X, y, y_int=True) # Validate X and y arrays
|
||
|
dim = X.shape[1]
|
||
|
|
||
|
if ax is None:
|
||
|
ax = plt.gca()
|
||
|
|
||
|
if res is not None:
|
||
|
warnings.warn("The 'res' parameter has been deprecated."
|
||
|
"To increase the resolution, it's is recommended"
|
||
|
"to use to provide a `dpi argument via matplotlib,"
|
||
|
"e.g., `plt.figure(dpi=600)`.",
|
||
|
DeprecationWarning)
|
||
|
|
||
|
plot_testdata = True
|
||
|
if not isinstance(X_highlight, np.ndarray):
|
||
|
if X_highlight is not None:
|
||
|
raise ValueError('X_highlight must be a NumPy array or None')
|
||
|
else:
|
||
|
plot_testdata = False
|
||
|
elif len(X_highlight.shape) < 2:
|
||
|
raise ValueError('X_highlight must be a 2D array')
|
||
|
|
||
|
if feature_index is not None:
|
||
|
# Unpack and validate the feature_index values
|
||
|
if dim == 1:
|
||
|
raise ValueError(
|
||
|
'feature_index requires more than one training feature')
|
||
|
try:
|
||
|
x_index, y_index = feature_index
|
||
|
except ValueError:
|
||
|
raise ValueError(
|
||
|
'Unable to unpack feature_index. Make sure feature_index '
|
||
|
'only has two dimensions.')
|
||
|
try:
|
||
|
X[:, x_index], X[:, y_index]
|
||
|
except IndexError:
|
||
|
raise IndexError(
|
||
|
'feature_index values out of range. X.shape is {}, but '
|
||
|
'feature_index is {}'.format(X.shape, feature_index))
|
||
|
else:
|
||
|
feature_index = (0, 1)
|
||
|
x_index, y_index = feature_index
|
||
|
|
||
|
# Extra input validation for higher number of training features
|
||
|
if dim > 2:
|
||
|
if filler_feature_values is None:
|
||
|
raise ValueError('Filler values must be provided when '
|
||
|
'X has more than 2 training features.')
|
||
|
|
||
|
if filler_feature_ranges is not None:
|
||
|
if not set(filler_feature_values) == set(filler_feature_ranges):
|
||
|
raise ValueError(
|
||
|
'filler_feature_values and filler_feature_ranges must '
|
||
|
'have the same keys')
|
||
|
|
||
|
# Check that all columns in X are accounted for
|
||
|
column_check = np.zeros(dim, dtype=bool)
|
||
|
for idx in filler_feature_values:
|
||
|
column_check[idx] = True
|
||
|
for idx in feature_index:
|
||
|
column_check[idx] = True
|
||
|
if not all(column_check):
|
||
|
missing_cols = np.argwhere(~column_check).flatten()
|
||
|
raise ValueError(
|
||
|
'Column(s) {} need to be accounted for in either '
|
||
|
'feature_index or filler_feature_values'.format(missing_cols))
|
||
|
|
||
|
marker_gen = cycle(list(markers))
|
||
|
|
||
|
n_classes = np.unique(y).shape[0]
|
||
|
colors = colors.split(',')
|
||
|
colors_gen = cycle(colors)
|
||
|
colors = [next(colors_gen) for c in range(n_classes)]
|
||
|
|
||
|
# Get minimum and maximum
|
||
|
x_min, x_max = (X[:, x_index].min() - 1./zoom_factor,
|
||
|
X[:, x_index].max() + 1./zoom_factor)
|
||
|
if dim == 1:
|
||
|
y_min, y_max = -1, 1
|
||
|
else:
|
||
|
y_min, y_max = (X[:, y_index].min() - 1./zoom_factor,
|
||
|
X[:, y_index].max() + 1./zoom_factor)
|
||
|
|
||
|
xnum, ynum = plt.gcf().dpi * plt.gcf().get_size_inches()
|
||
|
xnum, ynum = floor(xnum), ceil(ynum)
|
||
|
xx, yy = np.meshgrid(np.linspace(x_min, x_max, num=xnum),
|
||
|
np.linspace(y_min, y_max, num=ynum))
|
||
|
|
||
|
if dim == 1:
|
||
|
X_predict = np.array([xx.ravel()]).T
|
||
|
else:
|
||
|
X_grid = np.array([xx.ravel(), yy.ravel()]).T
|
||
|
X_predict = np.zeros((X_grid.shape[0], dim))
|
||
|
X_predict[:, x_index] = X_grid[:, 0]
|
||
|
X_predict[:, y_index] = X_grid[:, 1]
|
||
|
if dim > 2:
|
||
|
for feature_idx in filler_feature_values:
|
||
|
X_predict[:, feature_idx] = filler_feature_values[feature_idx]
|
||
|
Z = clf.predict(X_predict.astype(X.dtype))
|
||
|
Z = Z.reshape(xx.shape)
|
||
|
# Plot decisoin region
|
||
|
# Make sure contourf_kwargs has backwards compatible defaults
|
||
|
contourf_kwargs_default = {'alpha': 0.45, 'antialiased': True}
|
||
|
contourf_kwargs = format_kwarg_dictionaries(
|
||
|
default_kwargs=contourf_kwargs_default,
|
||
|
user_kwargs=contourf_kwargs,
|
||
|
protected_keys=['colors', 'levels'])
|
||
|
cset = ax.contourf(xx, yy, Z,
|
||
|
colors=colors,
|
||
|
levels=np.arange(Z.max() + 2) - 0.5,
|
||
|
**contourf_kwargs)
|
||
|
|
||
|
ax.contour(xx, yy, Z, cset.levels,
|
||
|
colors='k',
|
||
|
linewidths=0.5,
|
||
|
antialiased=True)
|
||
|
|
||
|
ax.axis([xx.min(), xx.max(), yy.min(), yy.max()])
|
||
|
|
||
|
# Scatter training data samples
|
||
|
# Make sure scatter_kwargs has backwards compatible defaults
|
||
|
scatter_kwargs_default = {'alpha': 0.8, 'edgecolor': 'black'}
|
||
|
scatter_kwargs = format_kwarg_dictionaries(
|
||
|
default_kwargs=scatter_kwargs_default,
|
||
|
user_kwargs=scatter_kwargs,
|
||
|
protected_keys=['c', 'marker', 'label'])
|
||
|
for idx, c in enumerate(np.unique(y)):
|
||
|
if dim == 1:
|
||
|
y_data = [0 for i in X[y == c]]
|
||
|
x_data = X[y == c]
|
||
|
elif dim == 2:
|
||
|
y_data = X[y == c, y_index]
|
||
|
x_data = X[y == c, x_index]
|
||
|
elif dim > 2 and filler_feature_ranges is not None:
|
||
|
class_mask = y == c
|
||
|
feature_range_mask = get_feature_range_mask(
|
||
|
X, filler_feature_values=filler_feature_values,
|
||
|
filler_feature_ranges=filler_feature_ranges)
|
||
|
y_data = X[class_mask & feature_range_mask, y_index]
|
||
|
x_data = X[class_mask & feature_range_mask, x_index]
|
||
|
else:
|
||
|
continue
|
||
|
|
||
|
ax.scatter(x=x_data,
|
||
|
y=y_data,
|
||
|
c=colors[idx],
|
||
|
marker=next(marker_gen),
|
||
|
label=c,
|
||
|
**scatter_kwargs)
|
||
|
|
||
|
if hide_spines:
|
||
|
ax.spines['right'].set_visible(False)
|
||
|
ax.spines['top'].set_visible(False)
|
||
|
ax.spines['left'].set_visible(False)
|
||
|
ax.spines['bottom'].set_visible(False)
|
||
|
ax.yaxis.set_ticks_position('left')
|
||
|
ax.xaxis.set_ticks_position('bottom')
|
||
|
if dim == 1:
|
||
|
ax.axes.get_yaxis().set_ticks([])
|
||
|
|
||
|
if plot_testdata:
|
||
|
if dim == 1:
|
||
|
x_data = X_highlight
|
||
|
y_data = [0 for i in X_highlight]
|
||
|
elif dim == 2:
|
||
|
x_data = X_highlight[:, x_index]
|
||
|
y_data = X_highlight[:, y_index]
|
||
|
else:
|
||
|
feature_range_mask = get_feature_range_mask(
|
||
|
X_highlight, filler_feature_values=filler_feature_values,
|
||
|
filler_feature_ranges=filler_feature_ranges)
|
||
|
y_data = X_highlight[feature_range_mask, y_index]
|
||
|
x_data = X_highlight[feature_range_mask, x_index]
|
||
|
|
||
|
# Make sure scatter_highlight_kwargs backwards compatible defaults
|
||
|
scatter_highlight_defaults = {'c': '',
|
||
|
'edgecolor': 'black',
|
||
|
'alpha': 1.0,
|
||
|
'linewidths': 1,
|
||
|
'marker': 'o',
|
||
|
's': 80}
|
||
|
scatter_highlight_kwargs = format_kwarg_dictionaries(
|
||
|
default_kwargs=scatter_highlight_defaults,
|
||
|
user_kwargs=scatter_highlight_kwargs)
|
||
|
ax.scatter(x_data,
|
||
|
y_data,
|
||
|
**scatter_highlight_kwargs)
|
||
|
|
||
|
if legend:
|
||
|
if dim > 2 and filler_feature_ranges is None:
|
||
|
pass
|
||
|
else:
|
||
|
handles, labels = ax.get_legend_handles_labels()
|
||
|
ax.legend(handles, labels,
|
||
|
framealpha=0.3, scatterpoints=1, loc=legend)
|
||
|
|
||
|
return ax
|