Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/numpy/ufuncs.py

741 lines
25 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implements ufuncs for jax.numpy.
"""
from functools import partial
import operator
from textwrap import dedent
from typing import Any, Callable, Tuple, Union, overload
import numpy as np
from jax._src import core
from jax._src import dtypes
from jax._src.api import jit
from jax._src.custom_derivatives import custom_jvp
from jax._src.lax import lax
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import (
check_arraylike, promote_args, promote_args_inexact,
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
promote_shapes, _where, _wraps)
_lax_const = lax._const
_INT_DTYPES = {
16: np.int16,
32: np.int32,
64: np.int64,
}
UnOp = Callable[[ArrayLike], Array]
BinOp = Callable[[ArrayLike, ArrayLike], Array]
def _constant_like(x, const):
return np.array(const, dtype=dtypes.dtype(x))
def _replace_inf(x: ArrayLike) -> Array:
return lax.select(isposinf(real(x)), lax._zeros(x), x)
def _one_to_one_unop(
numpy_fn: Callable[..., Any], lax_fn: UnOp,
promote_to_inexact: bool = False, lax_doc: bool = False) -> UnOp:
if promote_to_inexact:
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
else:
fn = lambda x, /: lax_fn(*promote_args(numpy_fn.__name__, x))
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr]
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
else:
return _wraps(numpy_fn, module='numpy')(fn)
def _one_to_one_binop(
numpy_fn: Callable[..., Any], lax_fn: BinOp,
promote_to_inexact: bool = False, lax_doc: bool = False,
promote_to_numeric: bool = False) -> BinOp:
if promote_to_inexact:
fn = lambda x1, x2, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x1, x2))
elif promote_to_numeric:
fn = lambda x1, x2, /: lax_fn(*promote_args_numeric(numpy_fn.__name__, x1, x2))
else:
fn = lambda x1, x2, /: lax_fn(*promote_args(numpy_fn.__name__, x1, x2))
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr]
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
else:
return _wraps(numpy_fn, module='numpy')(fn)
def _maybe_bool_binop(
numpy_fn: Callable[..., Any], lax_fn: BinOp, bool_lax_fn: BinOp,
lax_doc: bool = False) -> BinOp:
def fn(x1, x2, /):
x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr]
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
else:
return _wraps(numpy_fn, module='numpy')(fn)
def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp:
def fn(x1, x2, /):
x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
# Comparison on complex types are defined as a lexicographic ordering on
# the (real, imag) pair.
if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
rx = lax.real(x1)
ry = lax.real(x2)
return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
lax_fn(rx, ry))
return lax_fn(x1, x2)
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
fn = jit(fn, inline=True)
return _wraps(numpy_fn, module='numpy')(fn)
@overload
def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp) -> UnOp: ...
@overload
def _logical_op(np_op: Callable[..., Any], bitwise_op: BinOp) -> BinOp: ...
@overload
def _logical_op(np_op: Callable[..., Any], bitwise_op: Union[UnOp, BinOp]) -> Union[UnOp, BinOp]: ...
def _logical_op(np_op: Callable[..., Any], bitwise_op: Union[UnOp, BinOp]) -> Union[UnOp, BinOp]:
@_wraps(np_op, update_doc=False, module='numpy')
@partial(jit, inline=True)
def op(*args):
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(x, zero(x))
for x in args)
return bitwise_op(*promote_args(np_op.__name__, *args))
return op
fabs = _one_to_one_unop(np.fabs, lax.abs, True)
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
invert = _one_to_one_unop(np.invert, lax.bitwise_not)
negative = _one_to_one_unop(np.negative, lax.neg)
positive = _one_to_one_unop(np.positive, lambda x: lax.asarray(x))
floor = _one_to_one_unop(np.floor, lax.floor, True)
ceil = _one_to_one_unop(np.ceil, lax.ceil, True)
exp = _one_to_one_unop(np.exp, lax.exp, True)
log = _one_to_one_unop(np.log, lax.log, True)
expm1 = _one_to_one_unop(np.expm1, lax.expm1, True)
log1p = _one_to_one_unop(np.log1p, lax.log1p, True)
sin = _one_to_one_unop(np.sin, lax.sin, True)
cos = _one_to_one_unop(np.cos, lax.cos, True)
tan = _one_to_one_unop(np.tan, lax.tan, True)
arcsin = _one_to_one_unop(np.arcsin, lax.asin, True)
arccos = _one_to_one_unop(np.arccos, lax.acos, True)
arctan = _one_to_one_unop(np.arctan, lax.atan, True)
sinh = _one_to_one_unop(np.sinh, lax.sinh, True)
cosh = _one_to_one_unop(np.cosh, lax.cosh, True)
arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
tanh = _one_to_one_unop(np.tanh, lax.tanh, True)
arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True)
add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and)
bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or)
bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left, promote_to_numeric=True)
equal = _one_to_one_binop(np.equal, lax.eq)
multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and)
not_equal = _one_to_one_binop(np.not_equal, lax.ne)
subtract = _one_to_one_binop(np.subtract, lax.sub)
arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True)
minimum = _one_to_one_binop(np.minimum, lax.min)
maximum = _one_to_one_binop(np.maximum, lax.max)
float_power = _one_to_one_binop(np.float_power, lax.pow, True)
nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True)
greater_equal = _comparison_op(np.greater_equal, lax.ge)
greater = _comparison_op(np.greater, lax.gt)
less_equal = _comparison_op(np.less_equal, lax.le)
less = _comparison_op(np.less, lax.lt)
logical_and: BinOp = _logical_op(np.logical_and, lax.bitwise_and)
logical_not: UnOp = _logical_op(np.logical_not, lax.bitwise_not)
logical_or: BinOp = _logical_op(np.logical_or, lax.bitwise_or)
logical_xor: BinOp = _logical_op(np.logical_xor, lax.bitwise_xor)
@_wraps(np.arccosh, module='numpy')
@jit
def arccosh(x: ArrayLike, /) -> Array:
# Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
# convention than np.arccosh.
out = lax.acosh(*promote_args_inexact("arccosh", x))
if dtypes.issubdtype(out.dtype, np.complexfloating):
out = _where(real(out) < 0, lax.neg(out), out)
return out
@_wraps(np.right_shift, module='numpy')
@partial(jit, inline=True)
def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_numeric(np.right_shift.__name__, x1, x2)
lax_fn = lax.shift_right_logical if \
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
return lax_fn(x1, x2)
@_wraps(np.absolute, module='numpy')
@partial(jit, inline=True)
def absolute(x: ArrayLike, /) -> Array:
check_arraylike('absolute', x)
dt = dtypes.dtype(x)
return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
abs = _wraps(np.abs, module='numpy')(absolute)
@_wraps(np.rint, module='numpy')
@jit
def rint(x: ArrayLike, /) -> Array:
check_arraylike('rint', x)
dtype = dtypes.dtype(x)
if dtype == bool or dtypes.issubdtype(dtype, np.integer):
return lax.convert_element_type(x, dtypes.float_)
if dtypes.issubdtype(dtype, np.complexfloating):
return lax.complex(rint(lax.real(x)), rint(lax.imag(x)))
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
@_wraps(np.sign, module='numpy')
@jit
def sign(x: ArrayLike, /) -> Array:
check_arraylike('sign', x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.complexfloating):
re = lax.real(x)
return lax.complex(
lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0))
return lax.sign(x)
@_wraps(np.copysign, module='numpy')
@jit
def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_inexact("copysign", x1, x2)
if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
raise TypeError("copysign does not support complex-valued inputs")
return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
@_wraps(np.true_divide, module='numpy')
@partial(jit, inline=True)
def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_inexact("true_divide", x1, x2)
return lax.div(x1, x2)
divide = true_divide
@_wraps(np.floor_divide, module='numpy')
@jit
def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_numeric("floor_divide", x1, x2)
dtype = dtypes.dtype(x1)
if dtypes.issubdtype(dtype, np.integer):
quotient = lax.div(x1, x2)
select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
# TODO(mattjj): investigate why subtracting a scalar was causing promotion
return _where(select, quotient - 1, quotient)
elif dtypes.issubdtype(dtype, np.complexfloating):
x1r = lax.real(x1)
x1i = lax.imag(x1)
x2r = lax.real(x2)
x2i = lax.imag(x2)
which = lax.ge(lax.abs(x2r), lax.abs(x2i))
rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1))
out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
return lax.convert_element_type(out, dtype)
else:
return _float_divmod(x1, x2)[0]
@_wraps(np.divmod, module='numpy')
@jit
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> Tuple[Array, Array]:
x1, x2 = promote_args_numeric("divmod", x1, x2)
if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
return floor_divide(x1, x2), remainder(x1, x2)
else:
return _float_divmod(x1, x2)
def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> Tuple[Array, Array]:
# see float_divmod in floatobject.c of CPython
mod = lax.rem(x1, x2)
div = lax.div(lax.sub(x1, mod), x2)
ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod))
mod = lax.select(ind, mod + x2, mod)
div = lax.select(ind, div - _constant_like(div, 1), div)
return lax.round(div), mod
@partial(jit, inline=True)
def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
x1, x2 = promote_args_numeric("power", x1, x2)
dtype = dtypes.dtype(x1)
if not dtypes.issubdtype(dtype, np.integer):
return lax.pow(x1, x2)
# Integer power => use binary exponentiation.
# TODO(phawkins): add integer pow support to XLA.
bits = 6 # Anything more would overflow for any x1 > 1
zero = _constant_like(x2, 0)
one = _constant_like(x2, 1)
# Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
for _ in range(bits):
acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
x1 = lax.mul(x1, x1)
x2 = lax.shift_right_logical(x2, one)
return acc
@_wraps(np.power, module='numpy')
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("power", x1, x2)
# Special case for concrete integer scalars: use binary exponentiation.
# Using lax.pow may be imprecise for floating-point values; the goal of this
# code path is to make sure we end up with a precise output for the common
# pattern ``x ** 2`` or similar.
if isinstance(core.get_aval(x2), core.ConcreteArray):
try:
x2 = operator.index(x2) # type: ignore[arg-type]
except TypeError:
pass
else:
x1, = promote_dtypes_numeric(x1)
return lax.integer_pow(x1, x2)
return _power(x1, x2)
@custom_jvp
@_wraps(np.logaddexp, module='numpy')
@jit
def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_inexact("logaddexp", x1, x2)
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(lax._isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
else:
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
out = lax.add(amax, lax.log1p(lax.exp(delta)))
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
def _wrap_between(x, _a):
"""Wraps `x` between `[-a, a]`."""
a = _constant_like(x, _a)
two_a = _constant_like(x, 2 * _a)
zero = _constant_like(x, 0)
rem = lax.rem(lax.add(x, a), two_a)
rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem)
return lax.sub(rem, a)
@logaddexp.defjvp
def _logaddexp_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
primal_out = logaddexp(x1, x2)
tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out
@custom_jvp
@_wraps(np.logaddexp2, module='numpy')
@jit
def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_inexact("logaddexp2", x1, x2)
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(lax._isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))),
_constant_like(x1, np.log(2)))))
else:
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2))))
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))
@logaddexp2.defjvp
def _logaddexp2_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
primal_out = logaddexp2(x1, x2)
tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out
@_wraps(np.log2, module='numpy')
@partial(jit, inline=True)
def log2(x: ArrayLike, /) -> Array:
x, = promote_args_inexact("log2", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
@_wraps(np.log10, module='numpy')
@partial(jit, inline=True)
def log10(x: ArrayLike, /) -> Array:
x, = promote_args_inexact("log10", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
@_wraps(np.exp2, module='numpy')
@partial(jit, inline=True)
def exp2(x: ArrayLike, /) -> Array:
x, = promote_args_inexact("exp2", x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
@_wraps(np.signbit, module='numpy')
@jit
def signbit(x: ArrayLike, /) -> Array:
x, = promote_args("signbit", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.integer):
return lax.lt(x, _constant_like(x, 0))
elif dtypes.issubdtype(dtype, np.bool_):
return lax.full_like(x, False, dtype=np.bool_)
elif not dtypes.issubdtype(dtype, np.floating):
raise ValueError(
"jax.numpy.signbit is not well defined for %s" % dtype)
# TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to
# F32.
if dtype == dtypes.bfloat16:
dtype = np.dtype('float32')
x = lax.convert_element_type(x, dtype)
info = dtypes.finfo(dtype)
if info.bits not in _INT_DTYPES:
raise NotImplementedError(
"jax.numpy.signbit only supports 16, 32, and 64-bit types.")
int_type = _INT_DTYPES[info.bits]
x = lax.bitcast_convert_type(x, int_type)
return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
def _normalize_float(x):
info = dtypes.finfo(dtypes.dtype(x))
int_type = _INT_DTYPES[info.bits]
cond = lax.abs(x) < info.tiny
x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x)
x2 = _where(cond, int_type(-info.nmant), int_type(0))
return lax.bitcast_convert_type(x1, int_type), x2
@_wraps(np.ldexp, module='numpy')
@jit
def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("ldexp", x1, x2)
x1_dtype = dtypes.dtype(x1)
x2_dtype = dtypes.dtype(x2)
if (dtypes.issubdtype(x1_dtype, np.complexfloating)
or dtypes.issubdtype(x2_dtype, np.inexact)):
raise ValueError(f"ldexp not supported for input types {(x1_dtype, x2_dtype)}")
x1, x2 = promote_shapes("ldexp", x1, x2)
dtype = dtypes.canonicalize_dtype(dtypes.to_inexact_dtype(x1_dtype))
info = dtypes.finfo(dtype)
int_type = _INT_DTYPES[info.bits]
x1 = lax.convert_element_type(x1, dtype)
x2 = lax.convert_element_type(x2, int_type)
mask = (1 << info.nexp) - 1
bias = ((1 << info.nexp) - 1) >> 1
x, e = _normalize_float(x1)
x2 += e + ((x >> info.nmant) & mask) - bias
# find underflow/overflow before denormalization
underflow_cond = less(x2, -(bias + info.nmant))
overflow_cond = greater(x2, bias)
m = lax.full_like(x, 1, dtype=dtype)
# denormals
cond = less(x2, -bias + 1)
x2 = _where(cond, x2 + info.nmant, x2)
m = _where(cond, m / (1 << info.nmant), m)
x2 = lax.convert_element_type(x2, np.int32)
x &= ~(mask << info.nmant)
x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)
x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)
# underflow
x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
# overflow
x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
# ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
@_wraps(np.frexp, module='numpy')
@jit
def frexp(x: ArrayLike, /) -> Tuple[Array, Array]:
check_arraylike("frexp", x)
x, = promote_dtypes_inexact(x)
if dtypes.issubdtype(x.dtype, np.complexfloating):
raise TypeError("frexp does not support complex-valued inputs")
dtype = dtypes.dtype(x)
info = dtypes.finfo(dtype)
mask = (1 << info.nexp) - 1
bias = ((1 << info.nexp) - 1) >> 1
x1, x2 = _normalize_float(x)
x2 += ((x1 >> info.nmant) & mask) - bias + 1
x1 &= ~(mask << info.nmant)
x1 |= (bias - 1) << info.nmant
x1 = lax.bitcast_convert_type(x1, dtype)
cond = isinf(x) | isnan(x) | (x == 0)
x2 = _where(cond, lax._zeros(x2), x2)
return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
@_wraps(np.remainder, module='numpy')
@jit
def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_numeric("remainder", x1, x2)
zero = _constant_like(x1, 0)
if dtypes.issubdtype(x2.dtype, np.integer):
x2 = _where(x2 == 0, lax._ones(x2), x2)
trunc_mod = lax.rem(x1, x2)
trunc_mod_not_zero = lax.ne(trunc_mod, zero)
do_plus = lax.bitwise_and(
lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
mod = _wraps(np.mod, module='numpy')(remainder)
@_wraps(np.fmod, module='numpy')
@jit
def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("fmod", x1, x2)
if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
x2 = _where(x2 == 0, lax._ones(x2), x2)
return lax.rem(*promote_args_numeric("fmod", x1, x2))
@_wraps(np.square, module='numpy')
@partial(jit, inline=True)
def square(x: ArrayLike, /) -> Array:
check_arraylike("square", x)
x, = promote_dtypes_numeric(x)
return lax.integer_pow(x, 2)
@_wraps(np.deg2rad, module='numpy')
@partial(jit, inline=True)
def deg2rad(x: ArrayLike, /) -> Array:
x, = promote_args_inexact("deg2rad", x)
return lax.mul(x, _lax_const(x, np.pi / 180))
@_wraps(np.rad2deg, module='numpy')
@partial(jit, inline=True)
def rad2deg(x: ArrayLike, /) -> Array:
x, = promote_args_inexact("rad2deg", x)
return lax.mul(x, _lax_const(x, 180 / np.pi))
degrees = rad2deg
radians = deg2rad
@_wraps(np.conjugate, module='numpy')
@partial(jit, inline=True)
def conjugate(x: ArrayLike, /) -> Array:
check_arraylike("conjugate", x)
return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x)
conj = conjugate
@_wraps(np.imag)
@partial(jit, inline=True)
def imag(val: ArrayLike, /) -> Array:
check_arraylike("imag", val)
return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
@_wraps(np.real)
@partial(jit, inline=True)
def real(val: ArrayLike, /) -> Array:
check_arraylike("real", val)
return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)
@_wraps(np.modf, module='numpy', skip_params=['out'])
@jit
def modf(x: ArrayLike, /, out=None) -> Tuple[Array, Array]:
check_arraylike("modf", x)
x, = promote_dtypes_inexact(x)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.modf is not supported.")
whole = _where(lax.ge(x, lax._zero(x)), floor(x), ceil(x))
return x - whole, whole
@_wraps(np.isfinite, module='numpy')
@jit
def isfinite(x: ArrayLike, /) -> Array:
check_arraylike("isfinite", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.floating):
return lax.is_finite(x)
elif dtypes.issubdtype(dtype, np.complexfloating):
return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x)))
else:
return lax.full_like(x, True, dtype=np.bool_)
@_wraps(np.isinf, module='numpy')
@jit
def isinf(x: ArrayLike, /) -> Array:
check_arraylike("isinf", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.floating):
return lax.eq(lax.abs(x), _constant_like(x, np.inf))
elif dtypes.issubdtype(dtype, np.complexfloating):
re = lax.real(x)
im = lax.imag(x)
return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)),
lax.eq(lax.abs(im), _constant_like(im, np.inf)))
else:
return lax.full_like(x, False, dtype=np.bool_)
def _isposneginf(infinity: float, x: ArrayLike, out) -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.")
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.floating):
return lax.eq(x, _constant_like(x, infinity))
elif dtypes.issubdtype(dtype, np.complexfloating):
raise ValueError("isposinf/isneginf are not well defined for complex types")
else:
return lax.full_like(x, False, dtype=np.bool_)
isposinf: UnOp = _wraps(np.isposinf, skip_params=['out'])(
lambda x, /, out=None: _isposneginf(np.inf, x, out)
)
isneginf: UnOp = _wraps(np.isneginf, skip_params=['out'])(
lambda x, /, out=None: _isposneginf(-np.inf, x, out)
)
@_wraps(np.isnan, module='numpy')
@jit
def isnan(x: ArrayLike, /) -> Array:
check_arraylike("isnan", x)
return lax.ne(x, x)
@_wraps(np.heaviside, module='numpy')
@jit
def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("heaviside", x1, x2)
x1, x2 = promote_dtypes_inexact(x1, x2)
zero = _lax_const(x1, 0)
return _where(lax.lt(x1, zero), zero,
_where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
@_wraps(np.hypot, module='numpy')
@jit
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("hypot", x1, x2)
x1, x2 = promote_dtypes_inexact(x1, x2)
x1 = lax.abs(x1)
x2 = lax.abs(x2)
x1, x2 = maximum(x1, x2), minimum(x1, x2)
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1)))))
@_wraps(np.reciprocal, module='numpy')
@partial(jit, inline=True)
def reciprocal(x: ArrayLike, /) -> Array:
check_arraylike("reciprocal", x)
x, = promote_dtypes_inexact(x)
return lax.integer_pow(x, -1)
@_wraps(np.sinc, update_doc=False)
@jit
def sinc(x: ArrayLike, /) -> Array:
check_arraylike("sinc", x)
x, = promote_dtypes_inexact(x)
eq_zero = lax.eq(x, _lax_const(x, 0))
pi_x = lax.mul(_lax_const(x, np.pi), x)
safe_pi_x = _where(eq_zero, _lax_const(x, 1), pi_x)
return _where(eq_zero, _sinc_maclaurin(0, pi_x),
lax.div(lax.sin(safe_pi_x), safe_pi_x))
@partial(custom_jvp, nondiff_argnums=(0,))
def _sinc_maclaurin(k, x):
# compute the kth derivative of x -> sin(x)/x evaluated at zero (since we
# compute the monomial term in the jvp rule)
# TODO(mattjj): see https://github.com/google/jax/issues/10750
if k % 2:
return x * 0
else:
return x * 0 + _lax_const(x, (-1) ** (k // 2) / (k + 1))
@_sinc_maclaurin.defjvp
def _sinc_maclaurin_jvp(k, primals, tangents):
(x,), (t,) = primals, tangents
return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t