projektAI/venv/Lib/site-packages/mlxtend/plotting/enrichment_plot.py
2021-06-06 22:13:05 +02:00

129 lines
3.6 KiB
Python

# Sebastian Raschka 2014-2020
# mlxtend Machine Learning Library Extensions
#
# A function for plotting enrichment plots.
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from itertools import cycle
def enrichment_plot(df, colors='bgrkcy', markers=' ', linestyles='-',
alpha=0.5, lw=2, where='post', grid=True,
count_label='Count',
xlim='auto', ylim='auto', invert_axes=False,
legend_loc='best', ax=None):
"""Plot stacked barplots
Parameters
----------
df : pandas.DataFrame
A pandas DataFrame where columns represent the different categories.
colors: str (default: 'bgrcky')
The colors of the bars.
markers : str (default: ' ')
Matplotlib markerstyles, e.g,
'sov' for square,circle, and triangle markers.
linestyles : str (default: '-')
Matplotlib linestyles, e.g.,
'-,--' to cycle normal and dashed lines. Note
that the different linestyles need to be separated by commas.
alpha : float (default: 0.5)
Transparency level from 0.0 to 1.0.
lw : int or float (default: 2)
Linewidth parameter.
where : {'post', 'pre', 'mid'} (default: 'post')
Starting location of the steps.
grid : bool (default: `True`)
Plots a grid if True.
count_label : str (default: 'Count')
Label for the "Count"-axis.
xlim : 'auto' or array-like [min, max] (default: 'auto')
Min and maximum position of the x-axis range.
ylim : 'auto' or array-like [min, max] (default: 'auto')
Min and maximum position of the y-axis range.
invert_axes : bool (default: False)
Plots count on the x-axis if True.
legend_loc : str (default: 'best')
Location of the plot legend
{best, upper left, upper right, lower left, lower right}
No legend if legend_loc=False
ax : matplotlib axis, optional (default: None)
Use this axis for plotting or make a new one otherwise
Returns
----------
ax : matplotlib axis
Examples
-----------
For usage examples, please see
http://rasbt.github.io/mlxtend/user_guide/plotting/enrichment_plot/
"""
if isinstance(df, pd.Series):
df_temp = pd.DataFrame(df)
else:
df_temp = df
if ax is None:
ax = plt.gca()
color_gen = cycle(colors)
marker_gen = cycle(markers)
linestyle_gen = cycle(linestyles.split(','))
r = range(1, len(df_temp.index) + 1)
labels = df_temp.columns
x_data = df_temp
y_data = r
for lab in labels:
x, y = sorted(x_data[lab]), y_data
if invert_axes:
x, y = y, x
ax.step(x,
y,
where=where,
label=lab,
color=next(color_gen),
alpha=alpha,
lw=lw,
marker=next(marker_gen),
linestyle=next(linestyle_gen))
if invert_axes:
ax.set_ylim, ax.set_xlim = ax.set_xlim, ax.set_ylim
if ylim == 'auto':
ax.set_ylim([np.min(y_data) - 1, np.max(y_data) + 1])
else:
ax.set_ylim(ylim)
if xlim == 'auto':
df_min, df_max = np.min(x_data.min()), np.max(x_data.max())
ax.set_xlim([df_min - 1, df_max + 1])
else:
ax.set_xlim(xlim)
if legend_loc:
plt.legend(loc=legend_loc, scatterpoints=1)
if grid:
plt.grid()
if count_label:
if invert_axes:
plt.xlabel(count_label)
else:
plt.ylabel(count_label)
return ax