Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/nn/functions.py

526 lines
15 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared neural network activations and other functions."""
from functools import partial
import operator
import warnings
import numpy as np
from typing import Any, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from jax import custom_jvp
from jax import lax
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.core import AxisName
from jax._src.ops.special import logsumexp as _logsumexp
Array = Any
# activations
@custom_jvp
@jax.jit
def relu(x: Array) -> Array:
r"""Rectified linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{relu}(x) = \max(x, 0)
except under differentiation, we take:
.. math::
\nabla \mathrm{relu}(0) = 0
For more information see
`Numerical influence of ReLU(0) on backpropagation
<https://openreview.net/forum?id=urrcVI-_jRm>`_.
Args:
x : input array
"""
return jnp.maximum(x, 0)
# For behavior at 0, see https://openreview.net/forum?id=urrcVI-_jRm
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
@jax.jit
def softplus(x: Array) -> Array:
r"""Softplus activation function.
Computes the element-wise function
.. math::
\mathrm{softplus}(x) = \log(1 + e^x)
Args:
x : input array
"""
return jnp.logaddexp(x, 0)
@jax.jit
def soft_sign(x: Array) -> Array:
r"""Soft-sign activation function.
Computes the element-wise function
.. math::
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
Args:
x : input array
"""
return x / (jnp.abs(x) + 1)
@jax.jit
def sigmoid(x: Array) -> Array:
r"""Sigmoid activation function.
Computes the element-wise function:
.. math::
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
Args:
x : input array
"""
return lax.logistic(x)
@jax.jit
def silu(x: Array) -> Array:
r"""SiLU activation function.
Computes the element-wise function:
.. math::
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
Args:
x : input array
"""
return x * sigmoid(x)
swish = silu
@jax.jit
def log_sigmoid(x: Array) -> Array:
r"""Log-sigmoid activation function.
Computes the element-wise function:
.. math::
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
Args:
x : input array
"""
return -softplus(-x)
@jax.jit
def elu(x: Array, alpha: Array = 1.0) -> Array:
r"""Exponential linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{elu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(x) - 1\right), & x \le 0
\end{cases}
Args:
x : input array
alpha : scalar or array of alpha values (default: 1.0)
"""
safe_x = jnp.where(x > 0, 0., x)
return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))
@jax.jit
def leaky_relu(x: Array, negative_slope: Array = 1e-2) -> Array:
r"""Leaky rectified linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{leaky\_relu}(x) = \begin{cases}
x, & x \ge 0\\
\alpha x, & x < 0
\end{cases}
where :math:`\alpha` = :code:`negative_slope`.
Args:
x : input array
negative_slope : array or scalar specifying the negative slope (default: 0.01)
"""
return jnp.where(x >= 0, x, negative_slope * x)
@jax.jit
def hard_tanh(x: Array) -> Array:
r"""Hard :math:`\mathrm{tanh}` activation function.
Computes the element-wise function:
.. math::
\mathrm{hard\_tanh}(x) = \begin{cases}
-1, & x < -1\\
x, & -1 \le x \le 1\\
1, & 1 < x
\end{cases}
Args:
x : input array
"""
return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))
@jax.jit
def celu(x: Array, alpha: Array = 1.0) -> Array:
r"""Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
.. math::
\mathrm{celu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
\end{cases}
For more information, see
`Continuously Differentiable Exponential Linear Units
<https://arxiv.org/pdf/1704.07483.pdf>`_.
Args:
x : input array
alpha : array or scalar (default: 1.0)
"""
return jnp.maximum(x, 0.0) + alpha * jnp.expm1(jnp.minimum(x, 0.0) / alpha)
@jax.jit
def selu(x: Array) -> Array:
r"""Scaled exponential linear unit activation.
Computes the element-wise function:
.. math::
\mathrm{selu}(x) = \lambda \begin{cases}
x, & x > 0\\
\alpha e^x - \alpha, & x \le 0
\end{cases}
where :math:`\lambda = 1.0507009873554804934193349852946` and
:math:`\alpha = 1.6732632423543772848170429916717`.
For more information, see
`Self-Normalizing Neural Networks
<https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
Args:
x : input array
"""
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
return scale * elu(x, alpha)
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
# @partial(jax.jit, static_argnames=("approximate",))
def gelu(x: Array, approximate: bool = True) -> Array:
r"""Gaussian error linear unit activation function.
If ``approximate=False``, computes the element-wise function:
.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
\frac{x}{\sqrt{2}} \right) \right)
If ``approximate=True``, uses the approximate formulation of GELU:
.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
For more information, see `Gaussian Error Linear Units (GELUs)
<https://arxiv.org/abs/1606.08415>`_, section 2.
Args:
x : input array
approximate: whether to use the approximate or exact formulation.
"""
# Promote to nearest float-like dtype.
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
if approximate:
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3))))
return x * cdf
else:
sqrt_2 = np.sqrt(2).astype(x.dtype)
return jnp.array(x * (lax.erf(x / sqrt_2) + 1) / 2, dtype=x.dtype)
@partial(jax.jit, static_argnames=("axis",))
def glu(x: Array, axis: int = -1) -> Array:
"""Gated linear unit activation function.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
"""
size = x.shape[axis]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = jnp.split(x, 2, axis)
return x1 * sigmoid(x2)
# other functions
logsumexp = _logsumexp
@partial(jax.jit, static_argnames=("axis",))
def log_softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
r"""Log-Softmax function.
Computes the logarithm of the :code:`softmax` function, which rescales
elements to the range :math:`[-\infty, 0)`.
.. math ::
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
\right)
Args:
x : input array
axis: the axis or axes along which the :code:`log_softmax` should be
computed. Either an integer or a tuple of integers.
where: Elements to include in the :code:`log_softmax`.
initial: The minimum value used to shift the input array. Must be present
when :code:`where` is not None.
"""
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
shifted = x - lax.stop_gradient(x_max)
shifted_logsumexp = jnp.log(
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
result = shifted - shifted_logsumexp
if where is not None:
return jnp.where(where, result, -jnp.inf)
return result
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
#@partial(jax.jit, static_argnames=("axis",))
def softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]`
such that the elements along :code:`axis` sum to :math:`1`.
.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Args:
x : input array
axis: the axis or axes along which the softmax should be computed. The
softmax output summed across these dimensions should sum to :math:`1`.
Either an integer or a tuple of integers.
where: Elements to include in the :code:`softmax`.
initial: The minimum value used to shift the input array. Must be present
when :code:`where` is not None.
"""
if jax.config.jax_softmax_custom_jvp:
return _softmax(x, axis, where, initial)
else:
return _softmax_deprecated(x, axis, where, initial)
# TODO(mattjj): replace softmax with _softmax when deprecation flag is removed
@partial(jax.custom_jvp, nondiff_argnums=(1,))
def _softmax(
x,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - x_max)
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
if where is not None:
result = jnp.where(where, result, 0)
return result
@_softmax.defjvp
def _softmax_jvp(axis, primals, tangents):
(x, where, initial), (x_dot, _, _) = primals, tangents
y = _softmax(x, axis, where, initial)
return y, y * (x_dot - (y * x_dot).sum(axis, where=where, keepdims=True))
def _softmax_deprecated(x, axis, where, initial):
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
if where is not None:
result = jnp.where(where, result, 0)
return result
@partial(jax.jit, static_argnames=("axis",))
def standardize(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,
where: Optional[Array] = None) -> Array:
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
if mean is None:
mean = jnp.mean(x, axis, keepdims=True, where=where)
if variance is None:
# this definition is traditionally seen as less accurate than jnp.var's
# mean((x - mean(x))**2) but may be faster and even, given typical
# activation distributions and low-precision arithmetic, more accurate
# when used in neural network normalization layers
variance = jnp.mean(
jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean)
return (x - mean) * lax.rsqrt(variance + epsilon)
def normalize(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,
where: Optional[Array] = None) -> Array:
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
return standardize(x, axis, mean, variance, epsilon, where)
@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis"))
def _one_hot(x: Array, num_classes: int, *,
dtype: Any, axis: Union[int, AxisName]) -> Array:
if not core.is_special_dim_size(num_classes):
num_classes = core.concrete_or_error(
int, num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
dtype = dtypes.canonicalize_dtype(dtype)
x = jnp.asarray(x)
try:
output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
except TypeError:
axis_size = lax.psum(1, axis)
if num_classes != axis_size:
raise ValueError(f"Expected num_classes to match the size of axis {axis}, "
f"but {num_classes} != {axis_size}") from None
axis_idx = lax.axis_index(axis)
return jnp.asarray(x == axis_idx, dtype=dtype)
axis = operator.index(axis) # type: ignore[arg-type]
lhs = lax.expand_dims(x, (axis,))
rhs_shape = [1] * x.ndim
rhs_shape.insert(output_pos_axis, num_classes)
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis)
return jnp.asarray(lhs == rhs, dtype=dtype)
def one_hot(x: Array, num_classes: int, *,
dtype: Any = jnp.float_, axis: Union[int, AxisName] = -1) -> Array:
"""One-hot encodes the given indices.
Each index in the input ``x`` is encoded as a vector of zeros of length
``num_classes`` with the element at ``index`` set to one::
>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
Array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
Indices outside the range [0, num_classes) will be encoded as zeros::
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
Args:
x: A tensor of indices.
num_classes: Number of classes in the one-hot dimension.
dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
axis: the axis or axes along which the function should be
computed.
"""
if not core.is_special_dim_size(num_classes):
num_classes = core.concrete_or_error(
int, num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
return _one_hot(x, num_classes, dtype=dtype, axis=axis)
@jax.custom_jvp
@jax.jit
def relu6(x: Array) -> Array:
r"""Rectified Linear Unit 6 activation function.
Computes the element-wise function
.. math::
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
except under differentiation, we take:
.. math::
\nabla \mathrm{relu}(0) = 0
and
.. math::
\nabla \mathrm{relu}(6) = 0
Args:
x : input array
"""
return jnp.minimum(jnp.maximum(x, 0), 6.)
relu6.defjvps(lambda g, ans, x:
lax.select((x > 0) & (x < 6), g, lax.full_like(g, 0)))
@jax.jit
def hard_sigmoid(x: Array) -> Array:
r"""Hard Sigmoid activation function.
Computes the element-wise function
.. math::
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
Args:
x : input array
"""
return relu6(x + 3.) / 6.
@jax.jit
def hard_silu(x: Array) -> Array:
r"""Hard SiLU activation function
Computes the element-wise function
.. math::
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
Args:
x : input array
"""
return x * hard_sigmoid(x)
hard_swish = hard_silu