from typing import Optional, Tuple import numpy as np import pandas as pd class TablePlotter: """ Layout some DataFrames in vertical/horizontal layout for explanation. Used in merging.rst """ def __init__( self, cell_width: float = 0.37, cell_height: float = 0.25, font_size: float = 7.5, ): self.cell_width = cell_width self.cell_height = cell_height self.font_size = font_size def _shape(self, df: pd.DataFrame) -> Tuple[int, int]: """ Calculate table shape considering index levels. """ row, col = df.shape return row + df.columns.nlevels, col + df.index.nlevels def _get_cells(self, left, right, vertical) -> Tuple[int, int]: """ Calculate appropriate figure size based on left and right data. """ if vertical: # calculate required number of cells vcells = max(sum(self._shape(df)[0] for df in left), self._shape(right)[0]) hcells = max(self._shape(df)[1] for df in left) + self._shape(right)[1] else: vcells = max([self._shape(df)[0] for df in left] + [self._shape(right)[0]]) hcells = sum([self._shape(df)[1] for df in left] + [self._shape(right)[1]]) return hcells, vcells def plot(self, left, right, labels=None, vertical: bool = True): """ Plot left / right DataFrames in specified layout. Parameters ---------- left : list of DataFrames before operation is applied right : DataFrame of operation result labels : list of str to be drawn as titles of left DataFrames vertical : bool, default True If True, use vertical layout. If False, use horizontal layout. """ import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt if not isinstance(left, list): left = [left] left = [self._conv(df) for df in left] right = self._conv(right) hcells, vcells = self._get_cells(left, right, vertical) if vertical: figsize = self.cell_width * hcells, self.cell_height * vcells else: # include margin for titles figsize = self.cell_width * hcells, self.cell_height * vcells fig = plt.figure(figsize=figsize) if vertical: gs = gridspec.GridSpec(len(left), hcells) # left max_left_cols = max(self._shape(df)[1] for df in left) max_left_rows = max(self._shape(df)[0] for df in left) for i, (l, label) in enumerate(zip(left, labels)): ax = fig.add_subplot(gs[i, 0:max_left_cols]) self._make_table(ax, l, title=label, height=1.0 / max_left_rows) # right ax = plt.subplot(gs[:, max_left_cols:]) self._make_table(ax, right, title="Result", height=1.05 / vcells) fig.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95) else: max_rows = max(self._shape(df)[0] for df in left + [right]) height = 1.0 / np.max(max_rows) gs = gridspec.GridSpec(1, hcells) # left i = 0 for df, label in zip(left, labels): sp = self._shape(df) ax = fig.add_subplot(gs[0, i : i + sp[1]]) self._make_table(ax, df, title=label, height=height) i += sp[1] # right ax = plt.subplot(gs[0, i:]) self._make_table(ax, right, title="Result", height=height) fig.subplots_adjust(top=0.85, bottom=0.05, left=0.05, right=0.95) return fig def _conv(self, data): """ Convert each input to appropriate for table outplot. """ if isinstance(data, pd.Series): if data.name is None: data = data.to_frame(name="") else: data = data.to_frame() data = data.fillna("NaN") return data def _insert_index(self, data): # insert is destructive data = data.copy() idx_nlevels = data.index.nlevels if idx_nlevels == 1: data.insert(0, "Index", data.index) else: for i in range(idx_nlevels): data.insert(i, f"Index{i}", data.index._get_level_values(i)) col_nlevels = data.columns.nlevels if col_nlevels > 1: col = data.columns._get_level_values(0) values = [ data.columns._get_level_values(i)._values for i in range(1, col_nlevels) ] col_df = pd.DataFrame(values) data.columns = col_df.columns data = pd.concat([col_df, data]) data.columns = col return data def _make_table(self, ax, df, title: str, height: Optional[float] = None): if df is None: ax.set_visible(False) return import pandas.plotting as plotting idx_nlevels = df.index.nlevels col_nlevels = df.columns.nlevels # must be convert here to get index levels for colorization df = self._insert_index(df) tb = plotting.table(ax, df, loc=9) tb.set_fontsize(self.font_size) if height is None: height = 1.0 / (len(df) + 1) props = tb.properties() for (r, c), cell in props["celld"].items(): if c == -1: cell.set_visible(False) elif r < col_nlevels and c < idx_nlevels: cell.set_visible(False) elif r < col_nlevels or c < idx_nlevels: cell.set_facecolor("#AAAAAA") cell.set_height(height) ax.set_title(title, size=self.font_size) ax.axis("off") if __name__ == "__main__": import matplotlib.pyplot as plt p = TablePlotter() df1 = pd.DataFrame({"A": [10, 11, 12], "B": [20, 21, 22], "C": [30, 31, 32]}) df2 = pd.DataFrame({"A": [10, 12], "C": [30, 32]}) p.plot([df1, df2], pd.concat([df1, df2]), labels=["df1", "df2"], vertical=True) plt.show() df3 = pd.DataFrame({"X": [10, 12], "Z": [30, 32]}) p.plot( [df1, df3], pd.concat([df1, df3], axis=1), labels=["df1", "df2"], vertical=False ) plt.show() idx = pd.MultiIndex.from_tuples( [(1, "A"), (1, "B"), (1, "C"), (2, "A"), (2, "B"), (2, "C")] ) col = pd.MultiIndex.from_tuples([(1, "A"), (1, "B")]) df3 = pd.DataFrame({"v1": [1, 2, 3, 4, 5, 6], "v2": [5, 6, 7, 8, 9, 10]}, index=idx) df3.columns = col p.plot(df3, df3, labels=["df3"]) plt.show()