493 lines
15 KiB
Python
493 lines
15 KiB
Python
# being a bit too dynamic
|
|
from __future__ import annotations
|
|
|
|
from math import ceil
|
|
from typing import TYPE_CHECKING
|
|
import warnings
|
|
|
|
from matplotlib import ticker
|
|
import matplotlib.table
|
|
import numpy as np
|
|
|
|
from pandas.util._exceptions import find_stack_level
|
|
|
|
from pandas.core.dtypes.common import is_list_like
|
|
from pandas.core.dtypes.generic import (
|
|
ABCDataFrame,
|
|
ABCIndex,
|
|
ABCSeries,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import (
|
|
Iterable,
|
|
Sequence,
|
|
)
|
|
|
|
from matplotlib.axes import Axes
|
|
from matplotlib.axis import Axis
|
|
from matplotlib.figure import Figure
|
|
from matplotlib.lines import Line2D
|
|
from matplotlib.table import Table
|
|
|
|
from pandas import (
|
|
DataFrame,
|
|
Series,
|
|
)
|
|
|
|
|
|
def do_adjust_figure(fig: Figure) -> bool:
|
|
"""Whether fig has constrained_layout enabled."""
|
|
if not hasattr(fig, "get_constrained_layout"):
|
|
return False
|
|
return not fig.get_constrained_layout()
|
|
|
|
|
|
def maybe_adjust_figure(fig: Figure, *args, **kwargs) -> None:
|
|
"""Call fig.subplots_adjust unless fig has constrained_layout enabled."""
|
|
if do_adjust_figure(fig):
|
|
fig.subplots_adjust(*args, **kwargs)
|
|
|
|
|
|
def format_date_labels(ax: Axes, rot) -> None:
|
|
# mini version of autofmt_xdate
|
|
for label in ax.get_xticklabels():
|
|
label.set_horizontalalignment("right")
|
|
label.set_rotation(rot)
|
|
fig = ax.get_figure()
|
|
if fig is not None:
|
|
# should always be a Figure but can technically be None
|
|
maybe_adjust_figure(fig, bottom=0.2)
|
|
|
|
|
|
def table(
|
|
ax, data: DataFrame | Series, rowLabels=None, colLabels=None, **kwargs
|
|
) -> Table:
|
|
if isinstance(data, ABCSeries):
|
|
data = data.to_frame()
|
|
elif isinstance(data, ABCDataFrame):
|
|
pass
|
|
else:
|
|
raise ValueError("Input data must be DataFrame or Series")
|
|
|
|
if rowLabels is None:
|
|
rowLabels = data.index
|
|
|
|
if colLabels is None:
|
|
colLabels = data.columns
|
|
|
|
cellText = data.values
|
|
|
|
# error: Argument "cellText" to "table" has incompatible type "ndarray[Any,
|
|
# Any]"; expected "Sequence[Sequence[str]] | None"
|
|
return matplotlib.table.table(
|
|
ax,
|
|
cellText=cellText, # type: ignore[arg-type]
|
|
rowLabels=rowLabels,
|
|
colLabels=colLabels,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def _get_layout(
|
|
nplots: int,
|
|
layout: tuple[int, int] | None = None,
|
|
layout_type: str = "box",
|
|
) -> tuple[int, int]:
|
|
if layout is not None:
|
|
if not isinstance(layout, (tuple, list)) or len(layout) != 2:
|
|
raise ValueError("Layout must be a tuple of (rows, columns)")
|
|
|
|
nrows, ncols = layout
|
|
|
|
if nrows == -1 and ncols > 0:
|
|
layout = nrows, ncols = (ceil(nplots / ncols), ncols)
|
|
elif ncols == -1 and nrows > 0:
|
|
layout = nrows, ncols = (nrows, ceil(nplots / nrows))
|
|
elif ncols <= 0 and nrows <= 0:
|
|
msg = "At least one dimension of layout must be positive"
|
|
raise ValueError(msg)
|
|
|
|
if nrows * ncols < nplots:
|
|
raise ValueError(
|
|
f"Layout of {nrows}x{ncols} must be larger than required size {nplots}"
|
|
)
|
|
|
|
return layout
|
|
|
|
if layout_type == "single":
|
|
return (1, 1)
|
|
elif layout_type == "horizontal":
|
|
return (1, nplots)
|
|
elif layout_type == "vertical":
|
|
return (nplots, 1)
|
|
|
|
layouts = {1: (1, 1), 2: (1, 2), 3: (2, 2), 4: (2, 2)}
|
|
try:
|
|
return layouts[nplots]
|
|
except KeyError:
|
|
k = 1
|
|
while k**2 < nplots:
|
|
k += 1
|
|
|
|
if (k - 1) * k >= nplots:
|
|
return k, (k - 1)
|
|
else:
|
|
return k, k
|
|
|
|
|
|
# copied from matplotlib/pyplot.py and modified for pandas.plotting
|
|
|
|
|
|
def create_subplots(
|
|
naxes: int,
|
|
sharex: bool = False,
|
|
sharey: bool = False,
|
|
squeeze: bool = True,
|
|
subplot_kw=None,
|
|
ax=None,
|
|
layout=None,
|
|
layout_type: str = "box",
|
|
**fig_kw,
|
|
):
|
|
"""
|
|
Create a figure with a set of subplots already made.
|
|
|
|
This utility wrapper makes it convenient to create common layouts of
|
|
subplots, including the enclosing figure object, in a single call.
|
|
|
|
Parameters
|
|
----------
|
|
naxes : int
|
|
Number of required axes. Exceeded axes are set invisible. Default is
|
|
nrows * ncols.
|
|
|
|
sharex : bool
|
|
If True, the X axis will be shared amongst all subplots.
|
|
|
|
sharey : bool
|
|
If True, the Y axis will be shared amongst all subplots.
|
|
|
|
squeeze : bool
|
|
|
|
If True, extra dimensions are squeezed out from the returned axis object:
|
|
- if only one subplot is constructed (nrows=ncols=1), the resulting
|
|
single Axis object is returned as a scalar.
|
|
- for Nx1 or 1xN subplots, the returned object is a 1-d numpy object
|
|
array of Axis objects are returned as numpy 1-d arrays.
|
|
- for NxM subplots with N>1 and M>1 are returned as a 2d array.
|
|
|
|
If False, no squeezing is done: the returned axis object is always
|
|
a 2-d array containing Axis instances, even if it ends up being 1x1.
|
|
|
|
subplot_kw : dict
|
|
Dict with keywords passed to the add_subplot() call used to create each
|
|
subplots.
|
|
|
|
ax : Matplotlib axis object, optional
|
|
|
|
layout : tuple
|
|
Number of rows and columns of the subplot grid.
|
|
If not specified, calculated from naxes and layout_type
|
|
|
|
layout_type : {'box', 'horizontal', 'vertical'}, default 'box'
|
|
Specify how to layout the subplot grid.
|
|
|
|
fig_kw : Other keyword arguments to be passed to the figure() call.
|
|
Note that all keywords not recognized above will be
|
|
automatically included here.
|
|
|
|
Returns
|
|
-------
|
|
fig, ax : tuple
|
|
- fig is the Matplotlib Figure object
|
|
- ax can be either a single axis object or an array of axis objects if
|
|
more than one subplot was created. The dimensions of the resulting array
|
|
can be controlled with the squeeze keyword, see above.
|
|
|
|
Examples
|
|
--------
|
|
x = np.linspace(0, 2*np.pi, 400)
|
|
y = np.sin(x**2)
|
|
|
|
# Just a figure and one subplot
|
|
f, ax = plt.subplots()
|
|
ax.plot(x, y)
|
|
ax.set_title('Simple plot')
|
|
|
|
# Two subplots, unpack the output array immediately
|
|
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
|
|
ax1.plot(x, y)
|
|
ax1.set_title('Sharing Y axis')
|
|
ax2.scatter(x, y)
|
|
|
|
# Four polar axes
|
|
plt.subplots(2, 2, subplot_kw=dict(polar=True))
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
|
|
if subplot_kw is None:
|
|
subplot_kw = {}
|
|
|
|
if ax is None:
|
|
fig = plt.figure(**fig_kw)
|
|
else:
|
|
if is_list_like(ax):
|
|
if squeeze:
|
|
ax = flatten_axes(ax)
|
|
if layout is not None:
|
|
warnings.warn(
|
|
"When passing multiple axes, layout keyword is ignored.",
|
|
UserWarning,
|
|
stacklevel=find_stack_level(),
|
|
)
|
|
if sharex or sharey:
|
|
warnings.warn(
|
|
"When passing multiple axes, sharex and sharey "
|
|
"are ignored. These settings must be specified when creating axes.",
|
|
UserWarning,
|
|
stacklevel=find_stack_level(),
|
|
)
|
|
if ax.size == naxes:
|
|
fig = ax.flat[0].get_figure()
|
|
return fig, ax
|
|
else:
|
|
raise ValueError(
|
|
f"The number of passed axes must be {naxes}, the "
|
|
"same as the output plot"
|
|
)
|
|
|
|
fig = ax.get_figure()
|
|
# if ax is passed and a number of subplots is 1, return ax as it is
|
|
if naxes == 1:
|
|
if squeeze:
|
|
return fig, ax
|
|
else:
|
|
return fig, flatten_axes(ax)
|
|
else:
|
|
warnings.warn(
|
|
"To output multiple subplots, the figure containing "
|
|
"the passed axes is being cleared.",
|
|
UserWarning,
|
|
stacklevel=find_stack_level(),
|
|
)
|
|
fig.clear()
|
|
|
|
nrows, ncols = _get_layout(naxes, layout=layout, layout_type=layout_type)
|
|
nplots = nrows * ncols
|
|
|
|
# Create empty object array to hold all axes. It's easiest to make it 1-d
|
|
# so we can just append subplots upon creation, and then
|
|
axarr = np.empty(nplots, dtype=object)
|
|
|
|
# Create first subplot separately, so we can share it if requested
|
|
ax0 = fig.add_subplot(nrows, ncols, 1, **subplot_kw)
|
|
|
|
if sharex:
|
|
subplot_kw["sharex"] = ax0
|
|
if sharey:
|
|
subplot_kw["sharey"] = ax0
|
|
axarr[0] = ax0
|
|
|
|
# Note off-by-one counting because add_subplot uses the MATLAB 1-based
|
|
# convention.
|
|
for i in range(1, nplots):
|
|
kwds = subplot_kw.copy()
|
|
# Set sharex and sharey to None for blank/dummy axes, these can
|
|
# interfere with proper axis limits on the visible axes if
|
|
# they share axes e.g. issue #7528
|
|
if i >= naxes:
|
|
kwds["sharex"] = None
|
|
kwds["sharey"] = None
|
|
ax = fig.add_subplot(nrows, ncols, i + 1, **kwds)
|
|
axarr[i] = ax
|
|
|
|
if naxes != nplots:
|
|
for ax in axarr[naxes:]:
|
|
ax.set_visible(False)
|
|
|
|
handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey)
|
|
|
|
if squeeze:
|
|
# Reshape the array to have the final desired dimension (nrow,ncol),
|
|
# though discarding unneeded dimensions that equal 1. If we only have
|
|
# one subplot, just return it instead of a 1-element array.
|
|
if nplots == 1:
|
|
axes = axarr[0]
|
|
else:
|
|
axes = axarr.reshape(nrows, ncols).squeeze()
|
|
else:
|
|
# returned axis array will be always 2-d, even if nrows=ncols=1
|
|
axes = axarr.reshape(nrows, ncols)
|
|
|
|
return fig, axes
|
|
|
|
|
|
def _remove_labels_from_axis(axis: Axis) -> None:
|
|
for t in axis.get_majorticklabels():
|
|
t.set_visible(False)
|
|
|
|
# set_visible will not be effective if
|
|
# minor axis has NullLocator and NullFormatter (default)
|
|
if isinstance(axis.get_minor_locator(), ticker.NullLocator):
|
|
axis.set_minor_locator(ticker.AutoLocator())
|
|
if isinstance(axis.get_minor_formatter(), ticker.NullFormatter):
|
|
axis.set_minor_formatter(ticker.FormatStrFormatter(""))
|
|
for t in axis.get_minorticklabels():
|
|
t.set_visible(False)
|
|
|
|
axis.get_label().set_visible(False)
|
|
|
|
|
|
def _has_externally_shared_axis(ax1: Axes, compare_axis: str) -> bool:
|
|
"""
|
|
Return whether an axis is externally shared.
|
|
|
|
Parameters
|
|
----------
|
|
ax1 : matplotlib.axes.Axes
|
|
Axis to query.
|
|
compare_axis : str
|
|
`"x"` or `"y"` according to whether the X-axis or Y-axis is being
|
|
compared.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
`True` if the axis is externally shared. Otherwise `False`.
|
|
|
|
Notes
|
|
-----
|
|
If two axes with different positions are sharing an axis, they can be
|
|
referred to as *externally* sharing the common axis.
|
|
|
|
If two axes sharing an axis also have the same position, they can be
|
|
referred to as *internally* sharing the common axis (a.k.a twinning).
|
|
|
|
_handle_shared_axes() is only interested in axes externally sharing an
|
|
axis, regardless of whether either of the axes is also internally sharing
|
|
with a third axis.
|
|
"""
|
|
if compare_axis == "x":
|
|
axes = ax1.get_shared_x_axes()
|
|
elif compare_axis == "y":
|
|
axes = ax1.get_shared_y_axes()
|
|
else:
|
|
raise ValueError(
|
|
"_has_externally_shared_axis() needs 'x' or 'y' as a second parameter"
|
|
)
|
|
|
|
axes_siblings = axes.get_siblings(ax1)
|
|
|
|
# Retain ax1 and any of its siblings which aren't in the same position as it
|
|
ax1_points = ax1.get_position().get_points()
|
|
|
|
for ax2 in axes_siblings:
|
|
if not np.array_equal(ax1_points, ax2.get_position().get_points()):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def handle_shared_axes(
|
|
axarr: Iterable[Axes],
|
|
nplots: int,
|
|
naxes: int,
|
|
nrows: int,
|
|
ncols: int,
|
|
sharex: bool,
|
|
sharey: bool,
|
|
) -> None:
|
|
if nplots > 1:
|
|
row_num = lambda x: x.get_subplotspec().rowspan.start
|
|
col_num = lambda x: x.get_subplotspec().colspan.start
|
|
|
|
is_first_col = lambda x: x.get_subplotspec().is_first_col()
|
|
|
|
if nrows > 1:
|
|
try:
|
|
# first find out the ax layout,
|
|
# so that we can correctly handle 'gaps"
|
|
layout = np.zeros((nrows + 1, ncols + 1), dtype=np.bool_)
|
|
for ax in axarr:
|
|
layout[row_num(ax), col_num(ax)] = ax.get_visible()
|
|
|
|
for ax in axarr:
|
|
# only the last row of subplots should get x labels -> all
|
|
# other off layout handles the case that the subplot is
|
|
# the last in the column, because below is no subplot/gap.
|
|
if not layout[row_num(ax) + 1, col_num(ax)]:
|
|
continue
|
|
if sharex or _has_externally_shared_axis(ax, "x"):
|
|
_remove_labels_from_axis(ax.xaxis)
|
|
|
|
except IndexError:
|
|
# if gridspec is used, ax.rowNum and ax.colNum may different
|
|
# from layout shape. in this case, use last_row logic
|
|
is_last_row = lambda x: x.get_subplotspec().is_last_row()
|
|
for ax in axarr:
|
|
if is_last_row(ax):
|
|
continue
|
|
if sharex or _has_externally_shared_axis(ax, "x"):
|
|
_remove_labels_from_axis(ax.xaxis)
|
|
|
|
if ncols > 1:
|
|
for ax in axarr:
|
|
# only the first column should get y labels -> set all other to
|
|
# off as we only have labels in the first column and we always
|
|
# have a subplot there, we can skip the layout test
|
|
if is_first_col(ax):
|
|
continue
|
|
if sharey or _has_externally_shared_axis(ax, "y"):
|
|
_remove_labels_from_axis(ax.yaxis)
|
|
|
|
|
|
def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray:
|
|
if not is_list_like(axes):
|
|
return np.array([axes])
|
|
elif isinstance(axes, (np.ndarray, ABCIndex)):
|
|
return np.asarray(axes).ravel()
|
|
return np.array(axes)
|
|
|
|
|
|
def set_ticks_props(
|
|
axes: Axes | Sequence[Axes],
|
|
xlabelsize: int | None = None,
|
|
xrot=None,
|
|
ylabelsize: int | None = None,
|
|
yrot=None,
|
|
):
|
|
import matplotlib.pyplot as plt
|
|
|
|
for ax in flatten_axes(axes):
|
|
if xlabelsize is not None:
|
|
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
|
|
if xrot is not None:
|
|
plt.setp(ax.get_xticklabels(), rotation=xrot)
|
|
if ylabelsize is not None:
|
|
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
|
|
if yrot is not None:
|
|
plt.setp(ax.get_yticklabels(), rotation=yrot)
|
|
return axes
|
|
|
|
|
|
def get_all_lines(ax: Axes) -> list[Line2D]:
|
|
lines = ax.get_lines()
|
|
|
|
if hasattr(ax, "right_ax"):
|
|
lines += ax.right_ax.get_lines()
|
|
|
|
if hasattr(ax, "left_ax"):
|
|
lines += ax.left_ax.get_lines()
|
|
|
|
return lines
|
|
|
|
|
|
def get_xlim(lines: Iterable[Line2D]) -> tuple[float, float]:
|
|
left, right = np.inf, -np.inf
|
|
for line in lines:
|
|
x = line.get_xdata(orig=False)
|
|
left = min(np.nanmin(x), left)
|
|
right = max(np.nanmax(x), right)
|
|
return left, right
|