projektAI/venv/Lib/site-packages/mlxtend/plotting/stacked_barplot.py
2021-06-06 22:13:05 +02:00

101 lines
2.7 KiB
Python

# Sebastian Raschka 2014-2020
# mlxtend Machine Learning Library Extensions
#
# Matplotlib wrapper for generating stacked barplots.
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause
import matplotlib.pyplot as plt
import numpy as np
from itertools import cycle
def stacked_barplot(df, bar_width='auto', colors='bgrcky',
labels='index', rotation=90, legend_loc='best'):
"""
Function to plot stacked barplots
Parameters
----------
df : pandas.DataFrame
A pandas DataFrame where the index denotes the
x-axis labels, and the columns contain the different
measurements for each row.
bar_width: 'auto' or float (default: 'auto')
Parameter to set the widths of the bars. if
'auto', the width is automatically determined by
the number of columns in the dataset.
colors: str (default: 'bgrcky')
The colors of the bars.
labels: 'index' or iterable (default: 'index')
If 'index', the DataFrame index will be used as
x-tick labels.
rotation: int (default: 90)
Parameter to rotate the x-axis labels.
legend_loc : str (default: 'best')
Location of the plot legend
{best, upper left, upper right, lower left, lower right}
No legend if legend_loc=False
Returns
---------
fig : matplotlib.pyplot figure object
Examples
-----------
For usage examples, please see
http://rasbt.github.io/mlxtend/user_guide/plotting/stacked_barplot/
"""
# Setting the positions and width for the bars
pos = np.array(range(len(df.index)))
if bar_width == 'auto':
width = 1 / (len(df.columns[1:]) * 2)
else:
width = bar_width
if labels == 'index':
labels = df.index
color_gen = cycle(colors)
label_pos = [pos]
# Plotting the bars
fig, ax = plt.subplots(figsize=(12, 6))
plt.bar(pos,
df.iloc[:, 0],
width,
alpha=0.8,
color=next(color_gen),
label=df.columns[0])
for i, c in enumerate(df.columns[1:]):
bar_pos = [p + width*(i+1) for p in pos]
label_pos.append(bar_pos)
plt.bar(bar_pos,
df.iloc[:, i+1],
width,
alpha=0.5,
color=next(color_gen),
label=c)
# Setting axis labels and ticks
label_pos = np.asarray(label_pos).mean(axis=0) + width*0.5
ax.set_xticks(label_pos)
ax.set_xticklabels(labels, rotation=rotation, horizontalalignment='center')
# Setting the x-axis and y-axis limits
plt.xlim(min(pos)-width, max(pos) + width*7)
if legend_loc:
plt.legend(loc=legend_loc, scatterpoints=1)
return fig