135 lines
5.0 KiB
Python
135 lines
5.0 KiB
Python
|
# Sebastian Raschka 2014-2020
|
||
|
# mlxtend Machine Learning Library Extensions
|
||
|
#
|
||
|
# A function for plotting learning curves of classifiers.
|
||
|
# Author: Sebastian Raschka <sebastianraschka.com>
|
||
|
#
|
||
|
# License: BSD 3 clause
|
||
|
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
|
||
|
def plot_learning_curves(X_train, y_train,
|
||
|
X_test, y_test,
|
||
|
clf,
|
||
|
train_marker='o',
|
||
|
test_marker='^',
|
||
|
scoring='misclassification error',
|
||
|
suppress_plot=False, print_model=True,
|
||
|
style='fivethirtyeight',
|
||
|
legend_loc='best'):
|
||
|
"""Plots learning curves of a classifier.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X_train : array-like, shape = [n_samples, n_features]
|
||
|
Feature matrix of the training dataset.
|
||
|
y_train : array-like, shape = [n_samples]
|
||
|
True class labels of the training dataset.
|
||
|
X_test : array-like, shape = [n_samples, n_features]
|
||
|
Feature matrix of the test dataset.
|
||
|
y_test : array-like, shape = [n_samples]
|
||
|
True class labels of the test dataset.
|
||
|
clf : Classifier object. Must have a .predict .fit method.
|
||
|
train_marker : str (default: 'o')
|
||
|
Marker for the training set line plot.
|
||
|
test_marker : str (default: '^')
|
||
|
Marker for the test set line plot.
|
||
|
scoring : str (default: 'misclassification error')
|
||
|
If not 'misclassification error', accepts the following metrics
|
||
|
(from scikit-learn):
|
||
|
{'accuracy', 'average_precision', 'f1_micro', 'f1_macro',
|
||
|
'f1_weighted', 'f1_samples', 'log_loss',
|
||
|
'precision', 'recall', 'roc_auc',
|
||
|
'adjusted_rand_score', 'mean_absolute_error', 'mean_squared_error',
|
||
|
'median_absolute_error', 'r2'}
|
||
|
suppress_plot=False : bool (default: False)
|
||
|
Suppress matplotlib plots if True. Recommended
|
||
|
for testing purposes.
|
||
|
print_model : bool (default: True)
|
||
|
Print model parameters in plot title if True.
|
||
|
style : str (default: 'fivethirtyeight')
|
||
|
Matplotlib style
|
||
|
legend_loc : str (default: 'best')
|
||
|
Where to place the plot legend:
|
||
|
{'best', 'upper left', 'upper right', 'lower left', 'lower right'}
|
||
|
|
||
|
Returns
|
||
|
---------
|
||
|
errors : (training_error, test_error): tuple of lists
|
||
|
|
||
|
Examples
|
||
|
-----------
|
||
|
For usage examples, please see
|
||
|
http://rasbt.github.io/mlxtend/user_guide/plotting/plot_learning_curves/
|
||
|
|
||
|
"""
|
||
|
if scoring != 'misclassification error':
|
||
|
from sklearn import metrics
|
||
|
|
||
|
scoring_func = {
|
||
|
'accuracy': metrics.accuracy_score,
|
||
|
'average_precision': metrics.average_precision_score,
|
||
|
'f1': metrics.f1_score,
|
||
|
'f1_micro': metrics.f1_score,
|
||
|
'f1_macro': metrics.f1_score,
|
||
|
'f1_weighted': metrics.f1_score,
|
||
|
'f1_samples': metrics.f1_score,
|
||
|
'log_loss': metrics.log_loss,
|
||
|
'precision': metrics.precision_score,
|
||
|
'recall': metrics.recall_score,
|
||
|
'roc_auc': metrics.roc_auc_score,
|
||
|
'adjusted_rand_score': metrics.adjusted_rand_score,
|
||
|
'mean_absolute_error': metrics.mean_absolute_error,
|
||
|
'mean_squared_error': metrics.mean_squared_error,
|
||
|
'median_absolute_error': metrics.median_absolute_error,
|
||
|
'r2': metrics.r2_score}
|
||
|
|
||
|
if scoring not in scoring_func.keys():
|
||
|
raise AttributeError('scoring must be in', scoring_func.keys())
|
||
|
|
||
|
else:
|
||
|
def misclf_err(y_predict, y):
|
||
|
return (y_predict != y).sum() / float(len(y))
|
||
|
|
||
|
scoring_func = {
|
||
|
'misclassification error': misclf_err}
|
||
|
|
||
|
training_errors = []
|
||
|
test_errors = []
|
||
|
|
||
|
rng = [int(i) for i in np.linspace(0, X_train.shape[0], 11)][1:]
|
||
|
for r in rng:
|
||
|
model = clf.fit(X_train[:r], y_train[:r])
|
||
|
|
||
|
y_train_predict = clf.predict(X_train[:r])
|
||
|
y_test_predict = clf.predict(X_test)
|
||
|
|
||
|
train_misclf = scoring_func[scoring](y_train[:r], y_train_predict)
|
||
|
training_errors.append(train_misclf)
|
||
|
|
||
|
test_misclf = scoring_func[scoring](y_test, y_test_predict)
|
||
|
test_errors.append(test_misclf)
|
||
|
|
||
|
if not suppress_plot:
|
||
|
with plt.style.context(style):
|
||
|
plt.plot(np.arange(10, 101, 10), training_errors,
|
||
|
label='training set', marker=train_marker)
|
||
|
plt.plot(np.arange(10, 101, 10), test_errors,
|
||
|
label='test set', marker=test_marker)
|
||
|
plt.xlabel('Training set size in percent')
|
||
|
|
||
|
if not suppress_plot:
|
||
|
with plt.style.context(style):
|
||
|
plt.ylabel('Performance ({})'.format(scoring))
|
||
|
if print_model:
|
||
|
plt.title('Learning Curves\n\n{}\n'.format(model))
|
||
|
plt.legend(loc=legend_loc, numpoints=1)
|
||
|
plt.xlim([0, 110])
|
||
|
max_y = max(max(test_errors), max(training_errors))
|
||
|
min_y = min(min(test_errors), min(training_errors))
|
||
|
plt.ylim([min_y - min_y * 0.15, max_y + max_y * 0.15])
|
||
|
errors = (training_errors, test_errors)
|
||
|
return errors
|