Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/typing.py
2023-06-19 00:49:18 +02:00

71 lines
2.9 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The JAX typing module is where JAX-specific static type annotations live.
This submodule is a work in progress; to see the proposal behind the types exported
here, see https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html.
The currently-available types are:
- :class:`jax.Array`: annotation for any JAX array or tracer (i.e. representations of arrays
within JAX transforms).
- :class:`jax.typing.ArrayLike`: annotation for any value that is safe to implicitly cast to
a JAX array; this includes :class:`jax.Array`, :class:`numpy.ndarray`, as well as Python
builtin numeric values (e.g. :class:`int`, :class:`float`, etc.) and numpy scalar values
(e.g. :class:`numpy.int32`, :class:`numpy.flota64`, etc.)
We may add additional types here in future releases.
JAX Typing Best Practices
-------------------------
When annotating JAX arrays in public API functions, we recommend using :class:`~jax.typing.ArrayLike`
for array inputs, and :class:`~jax.Array` for array outputs.
For example, your function might look like this::
import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike
def my_function(x: ArrayLike) -> Array:
# Runtime type validation, Python 3.10 or newer:
if not isinstance(x, ArrayLike):
raise TypeError(f"Expected arraylike input; got {x}")
# Runtime type validation, any Python version:
if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
raise TypeError(f"Expected arraylike input; got {x}")
# Convert input to jax.Array:
x_arr = jnp.asarray(x)
# ... do some computation; JAX functions will return Array types:
result = x_arr.sum(0) / x_arr.shape[0]
# return an Array
return result
Most of JAX's public APIs follow this pattern. Note in particular that we recommend JAX functions
to not accept sequences such as :class:`list` or :class:`tuple` in place of arrays, as this can
cause extra overhead in JAX transforms like :func:`~jax.jit` and can behave in unexpected ways with
batch-wise transforms like :func:`~jax.vmap` or :func:`jax.pmap`. For more information on this,
see `Non-array inputs NumPy vs JAX`_
.. _Non-array inputs NumPy vs JAX: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax
"""
from jax._src.typing import (
ArrayLike as ArrayLike
)