from itertools import product import scipy.interpolate as osp_interpolate from jax.numpy import (asarray, broadcast_arrays, can_cast, empty, nan, searchsorted, where, zeros) from jax._src.tree_util import register_pytree_node from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, _wraps def _ndim_coords_from_arrays(points, ndim=None): """Convert a tuple of coordinate arrays to a (..., ndim)-shaped array.""" if isinstance(points, tuple) and len(points) == 1: # handle argument tuple points = points[0] if isinstance(points, tuple): p = broadcast_arrays(*points) for p_other in p[1:]: if p_other.shape != p[0].shape: raise ValueError("coordinate arrays do not have the same shape") points = empty(p[0].shape + (len(points),), dtype=float) for j, item in enumerate(p): points = points.at[..., j].set(item) else: check_arraylike("_ndim_coords_from_arrays", points) points = asarray(points) # SciPy: asanyarray(points) if points.ndim == 1: if ndim is None: points = points.reshape(-1, 1) else: points = points.reshape(-1, ndim) return points @_wraps( osp_interpolate.RegularGridInterpolator, lax_description=""" In the JAX version, `bounds_error` defaults to and must always be `False` since no bound error may be raised under JIT. Furthermore, in contrast to SciPy no input validation is performed. """) class RegularGridInterpolator: # Based on SciPy's implementation which in turn is originally based on an # implementation by Johannes Buchner def __init__(self, points, values, method="linear", bounds_error=False, fill_value=nan): if method not in ("linear", "nearest"): raise ValueError(f"method {method!r} is not defined") self.method = method self.bounds_error = bounds_error if self.bounds_error: raise NotImplementedError("`bounds_error` takes no effect under JIT") check_arraylike("RegularGridInterpolator", values) if len(points) > values.ndim: ve = f"there are {len(points)} point arrays, but values has {values.ndim} dimensions" raise ValueError(ve) values, = promote_dtypes_inexact(values) if fill_value is not None: check_arraylike("RegularGridInterpolator", fill_value) fill_value = asarray(fill_value) if not can_cast(fill_value.dtype, values.dtype, casting='same_kind'): ve = "fill_value must be either 'None' or of a type compatible with values" raise ValueError(ve) self.fill_value = fill_value # TODO: assert sanity of `points` similar to SciPy but in a JIT-able way check_arraylike("RegularGridInterpolator", *points) self.grid = tuple(asarray(p) for p in points) self.values = values @_wraps(osp_interpolate.RegularGridInterpolator.__call__, update_doc=False) def __call__(self, xi, method=None): method = self.method if method is None else method if method not in ("linear", "nearest"): raise ValueError(f"method {method!r} is not defined") ndim = len(self.grid) xi = _ndim_coords_from_arrays(xi, ndim=ndim) if xi.shape[-1] != len(self.grid): raise ValueError("the requested sample points xi have dimension" f" {xi.shape[1]}, but this RegularGridInterpolator has" f" dimension {ndim}") xi_shape = xi.shape xi = xi.reshape(-1, xi_shape[-1]) indices, norm_distances, out_of_bounds = self._find_indices(xi.T) if method == "linear": result = self._evaluate_linear(indices, norm_distances) elif method == "nearest": result = self._evaluate_nearest(indices, norm_distances) else: raise AssertionError("method must be bound") if not self.bounds_error and self.fill_value is not None: bc_shp = result.shape[:1] + (1,) * (result.ndim - 1) result = where(out_of_bounds.reshape(bc_shp), self.fill_value, result) return result.reshape(xi_shape[:-1] + self.values.shape[ndim:]) def _evaluate_linear(self, indices, norm_distances): # slice for broadcasting over trailing dimensions in self.values vslice = (slice(None),) + (None,) * (self.values.ndim - len(indices)) # find relevant values # each i and i+1 represents a edge edges = product(*[[i, i + 1] for i in indices]) values = asarray(0.) for edge_indices in edges: weight = asarray(1.) for ei, i, yi in zip(edge_indices, indices, norm_distances): weight *= where(ei == i, 1 - yi, yi) values += self.values[edge_indices] * weight[vslice] return values def _evaluate_nearest(self, indices, norm_distances): idx_res = [ where(yi <= .5, i, i + 1) for i, yi in zip(indices, norm_distances) ] return self.values[tuple(idx_res)] def _find_indices(self, xi): # find relevant edges between which xi are situated indices = [] # compute distance to lower edge in unity units norm_distances = [] # check for out of bounds xi out_of_bounds = zeros((xi.shape[1],), dtype=bool) # iterate through dimensions for x, g in zip(xi, self.grid): i = searchsorted(g, x) - 1 i = where(i < 0, 0, i) i = where(i > g.size - 2, g.size - 2, i) indices.append(i) norm_distances.append((x - g[i]) / (g[i + 1] - g[i])) if not self.bounds_error: out_of_bounds += x < g[0] out_of_bounds += x > g[-1] return indices, norm_distances, out_of_bounds register_pytree_node( RegularGridInterpolator, lambda obj: ((obj.grid, obj.values, obj.fill_value), (obj.method, obj.bounds_error)), lambda aux, children: RegularGridInterpolator( *children[:2], # type: ignore[index] *aux, *children[2:]), # type: ignore[index] )