Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/scipy/ndimage.py
2023-06-19 00:49:18 +02:00

142 lines
5.2 KiB
Python

# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import itertools
import operator
import textwrap
from typing import Callable, Dict, List, Sequence, Tuple
import scipy.ndimage
from jax._src import api
from jax._src import util
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.typing import ArrayLike, Array
from jax._src.util import safe_zip as zip
def _nonempty_prod(arrs: Sequence[Array]) -> Array:
return functools.reduce(operator.mul, arrs)
def _nonempty_sum(arrs: Sequence[Array]) -> Array:
return functools.reduce(operator.add, arrs)
def _mirror_index_fixer(index: Array, size: int) -> Array:
s = size - 1 # Half-wavelength of triangular wave
# Scaled, integer-valued version of the triangular wave |x - round(x)|
return jnp.abs((index + s) % (2 * s) - s)
def _reflect_index_fixer(index: Array, size: int) -> Array:
return jnp.floor_divide(_mirror_index_fixer(2*index+1, 2*size+1) - 1, 2)
_INDEX_FIXERS: Dict[str, Callable[[Array, int], Array]] = {
'constant': lambda index, size: index,
'nearest': lambda index, size: jnp.clip(index, 0, size - 1),
'wrap': lambda index, size: index % size,
'mirror': _mirror_index_fixer,
'reflect': _reflect_index_fixer,
}
def _round_half_away_from_zero(a: Array) -> Array:
return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)
def _nearest_indices_and_weights(coordinate: Array) -> List[Tuple[Array, ArrayLike]]:
index = _round_half_away_from_zero(coordinate).astype(jnp.int32)
weight = coordinate.dtype.type(1)
return [(index, weight)]
def _linear_indices_and_weights(coordinate: Array) -> List[Tuple[Array, ArrayLike]]:
lower = jnp.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight
index = lower.astype(jnp.int32)
return [(index, lower_weight), (index + 1, upper_weight)]
@functools.partial(api.jit, static_argnums=(2, 3, 4))
def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike],
order: int, mode: str, cval: ArrayLike) -> Array:
input_arr = jnp.asarray(input)
coordinate_arrs = [jnp.asarray(c) for c in coordinates]
cval = jnp.asarray(cval, input_arr.dtype)
if len(coordinates) != input_arr.ndim:
raise ValueError('coordinates must be a sequence of length input.ndim, but '
'{} != {}'.format(len(coordinates), input_arr.ndim))
index_fixer = _INDEX_FIXERS.get(mode)
if index_fixer is None:
raise NotImplementedError(
'jax.scipy.ndimage.map_coordinates does not yet support mode {}. '
'Currently supported modes are {}.'.format(mode, set(_INDEX_FIXERS)))
if mode == 'constant':
is_valid = lambda index, size: (0 <= index) & (index < size)
else:
is_valid = lambda index, size: True
if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
raise NotImplementedError(
'jax.scipy.ndimage.map_coordinates currently requires order<=1')
valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
fixed_index = index_fixer(index, size)
valid = is_valid(index, size)
valid_interp.append((fixed_index, valid, weight))
valid_1d_interpolations.append(valid_interp)
outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = util.unzip3(items)
if all(valid is True for valid in validities):
# fast path
contribution = input_arr[indices]
else:
all_valid = functools.reduce(operator.and_, validities)
contribution = jnp.where(all_valid, input_arr[indices], cval)
outputs.append(_nonempty_prod(weights) * contribution)
result = _nonempty_sum(outputs)
if jnp.issubdtype(input_arr.dtype, jnp.integer):
result = _round_half_away_from_zero(result)
return result.astype(input_arr.dtype)
@_wraps(scipy.ndimage.map_coordinates, lax_description=textwrap.dedent("""\
Only nearest neighbor (``order=0``), linear interpolation (``order=1``) and
modes ``'constant'``, ``'nearest'``, ``'wrap'`` ``'mirror'`` and ``'reflect'`` are currently supported.
Note that interpolation near boundaries differs from the scipy function,
because we fixed an outstanding bug (https://github.com/scipy/scipy/issues/2640);
this function interprets the ``mode`` argument as documented by SciPy, but
not as implemented by SciPy.
"""))
def map_coordinates(
input: ArrayLike, coordinates: Sequence[ArrayLike], order: int, mode: str = 'constant', cval: ArrayLike = 0.0,
):
return _map_coordinates(input, coordinates, order, mode, cval)