2023-06-19 00:49:18 +02:00
# Copyright 2022 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Any, Iterable, List, Tuple, Union
import jax
from jax._src import core
from jax._src.numpy.util import promote_dtypes
from jax._src.numpy.lax_numpy import (
arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose
from jax._src.typing import Array, ArrayLike
import numpy as np
__all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"]
def _make_1d_grid_from_slice(s: slice, op_name: str) -> Array:
start = core.concrete_or_error(None, s.start,
f"slice start of jnp.{op_name}") or 0
stop = core.concrete_or_error(None, s.stop,
f"slice stop of jnp.{op_name}")
step = core.concrete_or_error(None, s.step,
f"slice step of jnp.{op_name}") or 1
if np.iscomplex(step):
newobj = linspace(start, stop, int(abs(step)))
newobj = arange(start, stop, step)
return newobj
class _Mgrid:
"""Return dense multi-dimensional "meshgrid".
LAX-backend implementation of :obj:`numpy.mgrid`. This is a convenience wrapper for
functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=False``.
See Also:
jnp.ogrid: open/sparse version of jnp.mgrid
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
>>> jnp.mgrid[0:4:1]
Array([0, 1, 2, 3], dtype=int32)
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
>>> jnp.mgrid[0:1:4j]
Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Multiple slices can be used to create broadcasted grids of indices:
>>> jnp.mgrid[:2, :3]
Array([[[0, 0, 0],
[1, 1, 1]],
[[0, 1, 2],
[0, 1, 2]]], dtype=int32)
def __getitem__(self, key: Union[slice, Tuple[slice, ...]]) -> Array:
if isinstance(key, slice):
return _make_1d_grid_from_slice(key, op_name="mgrid")
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="mgrid") for k in key)
with jax.numpy_dtype_promotion('standard'):
output = promote_dtypes(*output)
output_arr = meshgrid(*output, indexing='ij', sparse=False)
if len(output_arr) == 0:
return arange(0)
return stack(output_arr, 0)
mgrid = _Mgrid()
class _Ogrid:
"""Return open multi-dimensional "meshgrid".
LAX-backend implementation of :obj:`numpy.ogrid`. This is a convenience wrapper for
functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=True``.
See Also:
jnp.mgrid: dense version of jnp.ogrid
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
>>> jnp.ogrid[0:4:1]
Array([0, 1, 2, 3], dtype=int32)
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
>>> jnp.ogrid[0:1:4j]
Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Multiple slices can be used to create sparse grids of indices:
>>> jnp.ogrid[:2, :3]
[1]], dtype=int32),
Array([[0, 1, 2]], dtype=int32)]
def __getitem__(
self, key: Union[slice, Tuple[slice, ...]]
) -> Union[Array, List[Array]]:
if isinstance(key, slice):
return _make_1d_grid_from_slice(key, op_name="ogrid")
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="ogrid") for k in key)
with jax.numpy_dtype_promotion('standard'):
output = promote_dtypes(*output)
return meshgrid(*output, indexing='ij', sparse=True)
ogrid = _Ogrid()
_IndexType = Union[ArrayLike, str, slice]
class _AxisConcat(abc.ABC):
"""Concatenates slices, scalars and array-like objects along a given axis."""
axis: int
ndmin: int
trans1d: int
op_name: str
def __getitem__(self, key: Union[_IndexType, Tuple[_IndexType, ...]]) -> Array:
key_tup: Tuple[_IndexType, ...] = key if isinstance(key, tuple) else (key,)
params = [self.axis, self.ndmin, self.trans1d, -1]
directive = key_tup[0]
if isinstance(directive, str):
key_tup = key_tup[1:]
# check two special cases: matrix directives
if directive == "r":
params[-1] = 0
elif directive == "c":
params[-1] = 1
vec: List[Any] = directive.split(",")
k = len(vec)
if k < 4:
vec += params[k:]
# ignore everything after the first three comma-separated ints
vec = vec[:3] + [params[-1]]
params = list(map(int, vec))
except ValueError as err:
raise ValueError(
f"could not understand directive {directive!r}"
) from err
axis, ndmin, trans1d, matrix = params
output = []
for item in key_tup:
if isinstance(item, slice):
newobj = _make_1d_grid_from_slice(item, op_name=self.op_name)
item_ndim = 0
elif isinstance(item, str):
raise ValueError("string directive must be placed at the beginning")
newobj = array(item, copy=False)
item_ndim = newobj.ndim
newobj = array(newobj, copy=False, ndmin=ndmin)
if trans1d != -1 and ndmin - item_ndim > 0:
shape_obj = tuple(range(ndmin))
# Calculate number of left shifts, with overflow protection by mod
num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin
shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts])
newobj = transpose(newobj, shape_obj)
res = concatenate(tuple(output), axis=axis)
if matrix != -1 and res.ndim == 1:
# insert 2nd dim at axis 0 or 1
res = expand_dims(res, matrix)
return res
def __len__(self) -> int:
return 0
class RClass(_AxisConcat):
"""Concatenate slices, scalars and array-like objects along the first axis.
LAX-backend implementation of :obj:`numpy.r_`.
See Also:
``jnp.c_``: Concatenates slices, scalars and array-like objects along the last axis.
Passing slices in the form ``[start:stop:step]`` generates ``jnp.arange`` objects:
>>> jnp.r_[-1:5:1, 0, 0, jnp.array([1,2,3])]
Array([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32)
An imaginary value for ``step`` will create a ``jnp.linspace`` object instead,
which includes the right endpoint:
>>> jnp.r_[-1:1:6j, 0, jnp.array([1,2,3])]
Array([-1. , -0.6 , -0.20000002, 0.20000005,
0.6 , 1. , 0. , 1. ,
2. , 3. ], dtype=float32)
Use a string directive of the form ``"axis,dims,trans1d"`` as the first argument to
specify concatenation axis, minimum number of dimensions, and the position of the
upgraded array's original dimensions in the resulting array's shape tuple:
>>> jnp.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output
Array([[1, 2, 3],
[4, 5, 6]], dtype=int32)
>>> jnp.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front
[6]], dtype=int32)
Negative values for ``trans1d`` offset the last axis towards the start
of the shape tuple:
>>> jnp.r_['0,2,-2', [1,2,3], [4,5,6]]
[6]], dtype=int32)
Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs
to create an array with an extra row or column axis, respectively:
>>> jnp.r_['r',[1,2,3], [4,5,6]]
Array([[1, 2, 3, 4, 5, 6]], dtype=int32)
>>> jnp.r_['c',[1,2,3], [4,5,6]]
[6]], dtype=int32)
For higher-dimensional inputs (``dim >= 2``), both directives ``"r"`` and ``"c"``
give the same result.
axis = 0
ndmin = 1
trans1d = -1
op_name = "r_"
r_ = RClass()
class CClass(_AxisConcat):
"""Concatenate slices, scalars and array-like objects along the last axis.
LAX-backend implementation of :obj:`numpy.c_`.
See Also:
``jnp.r_``: Concatenates slices, scalars and array-like objects along the first axis.
>>> a = jnp.arange(6).reshape((2,3))
>>> jnp.c_[a,a]
Array([[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5]], dtype=int32)
Use a string directive of the form ``"axis:dims:trans1d"`` as the first argument to specify
concatenation axis, minimum number of dimensions, and the position of the upgraded array's
original dimensions in the resulting array's shape tuple:
>>> jnp.c_['0,2', [1,2,3], [4,5,6]]
[6]], dtype=int32)
>>> jnp.c_['0,2,-1', [1,2,3], [4,5,6]]
Array([[1, 2, 3],
[4, 5, 6]], dtype=int32)
Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs
to create an array with inputs stacked along the last axis:
>>> jnp.c_['r',[1,2,3], [4,5,6]]
Array([[1, 4],
[2, 5],
[3, 6]], dtype=int32)
axis = -1
ndmin = 2
trans1d = 0
op_name = "c_"
c_ = CClass()
s_ = np.s_
index_exp = np.index_exp