projektAI/venv/Lib/site-packages/mpl_toolkits/axisartist/grid_finder.py

282 lines
9.9 KiB
Python
Raw Normal View History

2021-06-06 22:13:05 +02:00
import numpy as np
from matplotlib import _api, ticker as mticker
from matplotlib.transforms import Bbox, Transform
from .clip_path import clip_line_to_rect
class ExtremeFinderSimple:
"""
A helper class to figure out the range of grid lines that need to be drawn.
"""
def __init__(self, nx, ny):
"""
Parameters
----------
nx, ny : int
The number of samples in each direction.
"""
self.nx = nx
self.ny = ny
def __call__(self, transform_xy, x1, y1, x2, y2):
"""
Compute an approximation of the bounding box obtained by applying
*transform_xy* to the box delimited by ``(x1, y1, x2, y2)``.
The intended use is to have ``(x1, y1, x2, y2)`` in axes coordinates,
and have *transform_xy* be the transform from axes coordinates to data
coordinates; this method then returns the range of data coordinates
that span the actual axes.
The computation is done by sampling ``nx * ny`` equispaced points in
the ``(x1, y1, x2, y2)`` box and finding the resulting points with
extremal coordinates; then adding some padding to take into account the
finite sampling.
As each sampling step covers a relative range of *1/nx* or *1/ny*,
the padding is computed by expanding the span covered by the extremal
coordinates by these fractions.
"""
x, y = np.meshgrid(
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
xt, yt = transform_xy(np.ravel(x), np.ravel(y))
return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max())
def _add_pad(self, x_min, x_max, y_min, y_max):
"""Perform the padding mentioned in `__call__`."""
dx = (x_max - x_min) / self.nx
dy = (y_max - y_min) / self.ny
return x_min - dx, x_max + dx, y_min - dy, y_max + dy
class GridFinder:
def __init__(self,
transform,
extreme_finder=None,
grid_locator1=None,
grid_locator2=None,
tick_formatter1=None,
tick_formatter2=None):
"""
transform : transform from the image coordinate (which will be
the transData of the axes to the world coordinate.
or transform = (transform_xy, inv_transform_xy)
locator1, locator2 : grid locator for 1st and 2nd axis.
"""
if extreme_finder is None:
extreme_finder = ExtremeFinderSimple(20, 20)
if grid_locator1 is None:
grid_locator1 = MaxNLocator()
if grid_locator2 is None:
grid_locator2 = MaxNLocator()
if tick_formatter1 is None:
tick_formatter1 = FormatterPrettyPrint()
if tick_formatter2 is None:
tick_formatter2 = FormatterPrettyPrint()
self.extreme_finder = extreme_finder
self.grid_locator1 = grid_locator1
self.grid_locator2 = grid_locator2
self.tick_formatter1 = tick_formatter1
self.tick_formatter2 = tick_formatter2
self.update_transform(transform)
def get_grid_info(self, x1, y1, x2, y2):
"""
lon_values, lat_values : list of grid values. if integer is given,
rough number of grids in each direction.
"""
extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)
# min & max rage of lat (or lon) for each grid line will be drawn.
# i.e., gridline of lon=0 will be drawn from lat_min to lat_max.
lon_min, lon_max, lat_min, lat_max = extremes
lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max)
lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max)
lon_values = lon_levs[:lon_n] / lon_factor
lat_values = lat_levs[:lat_n] / lat_factor
lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
lat_values,
lon_min, lon_max,
lat_min, lat_max)
ddx = (x2-x1)*1.e-10
ddy = (y2-y1)*1.e-10
bb = Bbox.from_extents(x1-ddx, y1-ddy, x2+ddx, y2+ddy)
grid_info = {
"extremes": extremes,
"lon_lines": lon_lines,
"lat_lines": lat_lines,
"lon": self._clip_grid_lines_and_find_ticks(
lon_lines, lon_values, lon_levs, bb),
"lat": self._clip_grid_lines_and_find_ticks(
lat_lines, lat_values, lat_levs, bb),
}
tck_labels = grid_info["lon"]["tick_labels"] = {}
for direction in ["left", "bottom", "right", "top"]:
levs = grid_info["lon"]["tick_levels"][direction]
tck_labels[direction] = self.tick_formatter1(
direction, lon_factor, levs)
tck_labels = grid_info["lat"]["tick_labels"] = {}
for direction in ["left", "bottom", "right", "top"]:
levs = grid_info["lat"]["tick_levels"][direction]
tck_labels[direction] = self.tick_formatter2(
direction, lat_factor, levs)
return grid_info
def _get_raw_grid_lines(self,
lon_values, lat_values,
lon_min, lon_max, lat_min, lat_max):
lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation
lats_i = np.linspace(lat_min, lat_max, 100)
lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i)
for lon in lon_values]
lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat))
for lat in lat_values]
return lon_lines, lat_lines
def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
gi = {
"values": [],
"levels": [],
"tick_levels": dict(left=[], bottom=[], right=[], top=[]),
"tick_locs": dict(left=[], bottom=[], right=[], top=[]),
"lines": [],
}
tck_levels = gi["tick_levels"]
tck_locs = gi["tick_locs"]
for (lx, ly), v, lev in zip(lines, values, levs):
xy, tcks = clip_line_to_rect(lx, ly, bb)
if not xy:
continue
gi["levels"].append(v)
gi["lines"].append(xy)
for tck, direction in zip(tcks,
["left", "bottom", "right", "top"]):
for t in tck:
tck_levels[direction].append(lev)
tck_locs[direction].append(t)
return gi
def update_transform(self, aux_trans):
if not isinstance(aux_trans, Transform) and len(aux_trans) != 2:
raise TypeError("'aux_trans' must be either a Transform instance "
"or a pair of callables")
self._aux_transform = aux_trans
def transform_xy(self, x, y):
aux_trf = self._aux_transform
if isinstance(aux_trf, Transform):
return aux_trf.transform(np.column_stack([x, y])).T
else:
transform_xy, inv_transform_xy = aux_trf
return transform_xy(x, y)
def inv_transform_xy(self, x, y):
aux_trf = self._aux_transform
if isinstance(aux_trf, Transform):
return aux_trf.inverted().transform(np.column_stack([x, y])).T
else:
transform_xy, inv_transform_xy = aux_trf
return inv_transform_xy(x, y)
def update(self, **kw):
for k in kw:
if k in ["extreme_finder",
"grid_locator1",
"grid_locator2",
"tick_formatter1",
"tick_formatter2"]:
setattr(self, k, kw[k])
else:
raise ValueError("Unknown update property '%s'" % k)
class MaxNLocator(mticker.MaxNLocator):
def __init__(self, nbins=10, steps=None,
trim=True,
integer=False,
symmetric=False,
prune=None):
# trim argument has no effect. It has been left for API compatibility
super().__init__(nbins, steps=steps, integer=integer,
symmetric=symmetric, prune=prune)
self.create_dummy_axis()
self._factor = 1
def __call__(self, v1, v2):
self.set_bounds(v1 * self._factor, v2 * self._factor)
locs = super().__call__()
return np.array(locs), len(locs), self._factor
@_api.deprecated("3.3")
def set_factor(self, f):
self._factor = f
class FixedLocator:
def __init__(self, locs):
self._locs = locs
self._factor = 1
def __call__(self, v1, v2):
v1, v2 = sorted([v1 * self._factor, v2 * self._factor])
locs = np.array([l for l in self._locs if v1 <= l <= v2])
return locs, len(locs), self._factor
@_api.deprecated("3.3")
def set_factor(self, f):
self._factor = f
# Tick Formatter
class FormatterPrettyPrint:
def __init__(self, useMathText=True):
self._fmt = mticker.ScalarFormatter(
useMathText=useMathText, useOffset=False)
self._fmt.create_dummy_axis()
def __call__(self, direction, factor, values):
return self._fmt.format_ticks(values)
class DictFormatter:
def __init__(self, format_dict, formatter=None):
"""
format_dict : dictionary for format strings to be used.
formatter : fall-back formatter
"""
super().__init__()
self._format_dict = format_dict
self._fallback_formatter = formatter
def __call__(self, direction, factor, values):
"""
factor is ignored if value is found in the dictionary
"""
if self._fallback_formatter:
fallback_strings = self._fallback_formatter(
direction, factor, values)
else:
fallback_strings = [""] * len(values)
return [self._format_dict.get(k, v)
for k, v in zip(values, fallback_strings)]