Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/numpy/util.py
2023-06-19 00:49:18 +02:00

427 lines
17 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import re
import textwrap
from typing import (
Any, Callable, Dict, List, NamedTuple, Optional, Sequence, TypeVar
)
import warnings
from jax._src import dtypes
from jax._src import api
from jax._src import core
from jax._src.config import config
from jax._src.lax import lax
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
import numpy as np
zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map
_T = TypeVar("_T")
_parameter_break = re.compile("\n(?=[A-Za-z_])")
_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE)
_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\([\w\W]*?\)$', re.MULTILINE)
_versionadded = re.compile(r'^\s+\.\.\s+versionadded::', re.MULTILINE)
_docreference = re.compile(r':doc:`(.*?)\s*<.*?>`')
class ParsedDoc(NamedTuple):
"""
docstr: full docstring
signature: signature from docstring.
summary: summary from docstring.
front_matter: front matter before sections.
sections: dictionary of section titles to section content.
"""
docstr: Optional[str]
signature: str = ""
summary: str = ""
front_matter: str = ""
sections: Dict[str, str] = {}
def _parse_numpydoc(docstr: Optional[str]) -> ParsedDoc:
"""Parse a standard numpy-style docstring.
Args:
docstr: the raw docstring from a function
Returns:
ParsedDoc: parsed version of the docstring
"""
if docstr is None or not docstr.strip():
return ParsedDoc(docstr)
# Remove any :doc: directives in the docstring to avoid sphinx errors
docstr = _docreference.sub(
lambda match: f"{match.groups()[0]}", docstr)
signature, body = "", docstr
match = _numpy_signature_re.match(body)
if match:
signature = match.group()
body = docstr[match.end():]
firstline, _, body = body.partition('\n')
body = textwrap.dedent(body.lstrip('\n'))
match = _numpy_signature_re.match(body)
if match:
signature = match.group()
body = body[match.end():]
summary = firstline
if not summary:
summary, _, body = body.lstrip('\n').partition('\n')
body = textwrap.dedent(body.lstrip('\n'))
front_matter = ""
body = "\n" + body
section_list = _section_break.split(body)
if not _section_break.match(section_list[0]):
front_matter, *section_list = section_list
sections = {section.split('\n', 1)[0]: section for section in section_list}
return ParsedDoc(docstr=docstr, signature=signature, summary=summary,
front_matter=front_matter, sections=sections)
def _parse_parameters(body: str) -> Dict[str, str]:
"""Parse the Parameters section of a docstring."""
title, underline, content = body.split('\n', 2)
assert title == 'Parameters'
assert underline and not underline.strip('-')
parameters = _parameter_break.split(content)
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
def _parse_extra_params(extra_params: str) -> Dict[str, str]:
"""Parse the extra parameters passed to _wraps()"""
parameters = _parameter_break.split(extra_params.strip('\n'))
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
def _wraps(
fun: Optional[Callable[..., Any]],
update_doc: bool = True,
lax_description: str = "",
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
skip_params: Sequence[str] = (),
extra_params: Optional[str] = None,
module: Optional[str] = None,
) -> Callable[[_T], _T]:
"""Specialized version of functools.wraps for wrapping numpy functions.
This produces a wrapped function with a modified docstring. In particular, if
`update_doc` is True, parameters listed in the wrapped function that are not
supported by the decorated function will be removed from the docstring. For
this reason, it is important that parameter names match those in the original
numpy function.
Args:
fun: The function being wrapped
update_doc: whether to transform the numpy docstring to remove references of
parameters that are supported by the numpy version but not the JAX version.
If False, include the numpy docstring verbatim.
lax_description: a string description that will be added to the beginning of
the docstring.
sections: a list of sections to include in the docstring. The default is
["Parameters", "returns", "References"]
skip_params: a list of strings containing names of parameters accepted by the
function that should be skipped in the parameter list.
extra_params: an optional string containing additional parameter descriptions.
When ``update_doc=True``, these will be added to the list of parameter
descriptions in the updated doc.
module: an optional string specifying the module from which the wrapped function
is imported. This is useful for objects such as ufuncs, where the module cannot
be determined from the wrapped function itself.
"""
def wrap(op):
op.__np_wrapped__ = fun
# Allows this pattern: @wraps(getattr(np, 'new_function', None))
if fun is None:
return op
docstr = getattr(fun, "__doc__", None)
name = getattr(fun, "__name__", getattr(op, "__name__", str(op)))
try:
mod = module or fun.__module__
except AttributeError:
if config.jax_enable_checks:
raise ValueError(f"function {fun} defines no __module__; pass module keyword to _wraps.")
else:
name = f"{mod}.{name}"
if docstr:
try:
parsed = _parse_numpydoc(docstr)
if update_doc and 'Parameters' in parsed.sections:
code = getattr(getattr(op, "__wrapped__", op), "__code__", None)
# Remove unrecognized parameter descriptions.
parameters = _parse_parameters(parsed.sections['Parameters'])
if extra_params:
parameters.update(_parse_extra_params(extra_params))
parameters = {p: desc for p, desc in parameters.items()
if (code is None or p in code.co_varnames)
and p not in skip_params}
if parameters:
parsed.sections['Parameters'] = (
"Parameters\n"
"----------\n" +
"\n".join(_versionadded.split(desc)[0].rstrip()
for p, desc in parameters.items())
)
else:
del parsed.sections['Parameters']
docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
docstr += f"\nLAX-backend implementation of :func:`{name}`.\n"
if lax_description:
docstr += "\n" + lax_description.strip() + "\n"
docstr += "\n*Original docstring below.*\n"
# We remove signatures from the docstrings, because they redundant at best and
# misleading at worst: e.g. JAX wrappers don't implement all ufunc keyword arguments.
# if parsed.signature:
# docstr += "\n" + parsed.signature.strip() + "\n"
if parsed.front_matter:
docstr += "\n" + parsed.front_matter.strip() + "\n"
kept_sections = (content.strip() for section, content in parsed.sections.items()
if section in sections)
if kept_sections:
docstr += "\n" + "\n\n".join(kept_sections) + "\n"
except:
if config.jax_enable_checks:
raise
docstr = fun.__doc__
op.__doc__ = docstr
for attr in ['__name__', '__qualname__']:
try:
value = getattr(fun, attr)
except AttributeError:
pass
else:
setattr(op, attr, value)
return op
return wrap
_dtype = partial(dtypes.dtype, canonicalize=True)
def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return [lax.asarray(arg) for arg in args]
else:
shapes = [np.shape(arg) for arg in args]
if config.jax_dynamic_shapes:
# With dynamic shapes we don't support singleton-dimension broadcasting;
# we instead broadcast out to the full shape as a temporary workaround.
# TODO(mattjj): revise this workaround
res_shape = lax.broadcast_shapes(*shapes) # Can raise an error!
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
else:
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
nonscalar_ranks = {len(shp) for shp in shapes if shp}
if len(nonscalar_ranks) < 2:
return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
result_rank = len(lax.broadcast_shapes(*shapes))
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
for arg, shp in zip(args, shapes)]
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
if config.jax_numpy_rank_promotion == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to "
"disable this warning; for more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
elif config.jax_numpy_rank_promotion == "raise":
msg = ("Operands could not be broadcast together for {} on shapes {} "
"and with the config option jax_numpy_rank_promotion='raise'. "
"For more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
def promote_dtypes(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return [lax.asarray(arg) for arg in args]
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_opaque_dtype=True)
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
def promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_opaque_dtype=True)
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype)
return [lax._convert_element_type(x, to_dtype_inexact, weak_type)
for x in args]
def promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to a numeric (non-bool) type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_numeric = dtypes.to_numeric_dtype(to_dtype)
return [lax._convert_element_type(x, to_dtype_numeric, weak_type)
for x in args]
def promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to a complex type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_complex = dtypes.to_complex_dtype(to_dtype)
return [lax._convert_element_type(x, to_dtype_complex, weak_type)
for x in args]
def _complex_elem_type(dtype: DTypeLike) -> DType:
"""Returns the float type of the real/imaginary parts of a complex dtype."""
return np.abs(np.zeros((), dtype)).dtype
def _arraylike(x: ArrayLike) -> bool:
return (isinstance(x, np.ndarray) or isinstance(x, Array) or
hasattr(x, '__jax_array__') or np.isscalar(x))
def check_arraylike(fun_name: str, *args: Any):
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if any(not _arraylike(arg) for arg in args):
pos, arg = next((i, arg) for i, arg in enumerate(args)
if not _arraylike(arg))
msg = "{} requires ndarray or scalar arguments, got {} at position {}."
raise TypeError(msg.format(fun_name, type(arg), pos))
def check_arraylike_or_none(fun_name: str, *args: Any):
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if any(not (_arraylike(arg) or arg is None) for arg in args):
pos, arg = next((i, arg) for i, arg in enumerate(args)
if not (_arraylike(arg) or arg is None))
msg = "{} requires ndarray, scalar, or None arguments, got {} at position {}."
raise TypeError(msg.format(fun_name, type(arg), pos))
def _check_no_float0s(fun_name: str, *args: Any):
"""Check if none of the args have dtype float0."""
if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
raise TypeError(
f"Called {fun_name} with a float0 array. "
"float0s do not support any operations by design because they "
"are not compatible with non-trivial vector spaces. No implicit dtype "
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
"to cast a float0 array to a regular zeros array. \n"
"If you didn't expect to get a float0 you might have accidentally "
"taken a gradient with respect to an integer argument.")
def promote_args(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion."""
check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return promote_shapes(fun_name, *promote_dtypes(*args))
def promote_args_numeric(fun_name: str, *args: ArrayLike) -> List[Array]:
check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return promote_shapes(fun_name, *promote_dtypes_numeric(*args))
def promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion.
Promotes non-inexact types to an inexact type."""
check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return promote_shapes(fun_name, *promote_dtypes_inexact(*args))
@partial(api.jit, inline=True)
def _broadcast_arrays(*args: ArrayLike) -> List[Array]:
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [np.shape(arg) for arg in args]
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
return [lax.asarray(arg) for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [_broadcast_to(arg, result_shape) for arg in args]
def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, Array) else lax.asarray(arr)
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
shape = (shape,)
shape = core.canonicalize_shape(shape) # check that shape is concrete
arr_shape = np.shape(arr)
if core.symbolic_equal_shape(arr_shape, shape):
return arr
else:
nlead = len(shape) - len(arr_shape)
shape_tail = shape[nlead:]
compatible = all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
if nlead < 0 or not compatible:
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
raise ValueError(msg.format(arr_shape, shape))
diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d)
for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims)
# The `jit` on `where` exists to avoid materializing constants in cases like
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
# materialize the broadcast forms of scalar arguments.
@api.jit
def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
if x is None or y is None:
raise ValueError("Either both or neither of the x and y arguments should "
"be provided to jax.numpy.where, got {} and {}."
.format(x, y))
if not np.issubdtype(_dtype(condition), np.bool_):
condition = lax.ne(condition, lax._zero(condition))
x, y = promote_dtypes(x, y)
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
try:
is_always_empty = core.is_empty_shape(x_arr.shape)
except:
is_always_empty = False # can fail with dynamic shapes
return lax.select(condition_arr, x_arr, y_arr) if not is_always_empty else x_arr