71 lines
2.9 KiB
Python
71 lines
2.9 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
The JAX typing module is where JAX-specific static type annotations live.
|
|
This submodule is a work in progress; to see the proposal behind the types exported
|
|
here, see https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html.
|
|
|
|
The currently-available types are:
|
|
|
|
- :class:`jax.Array`: annotation for any JAX array or tracer (i.e. representations of arrays
|
|
within JAX transforms).
|
|
- :class:`jax.typing.ArrayLike`: annotation for any value that is safe to implicitly cast to
|
|
a JAX array; this includes :class:`jax.Array`, :class:`numpy.ndarray`, as well as Python
|
|
builtin numeric values (e.g. :class:`int`, :class:`float`, etc.) and numpy scalar values
|
|
(e.g. :class:`numpy.int32`, :class:`numpy.flota64`, etc.)
|
|
|
|
We may add additional types here in future releases.
|
|
|
|
JAX Typing Best Practices
|
|
-------------------------
|
|
When annotating JAX arrays in public API functions, we recommend using :class:`~jax.typing.ArrayLike`
|
|
for array inputs, and :class:`~jax.Array` for array outputs.
|
|
|
|
For example, your function might look like this::
|
|
|
|
import numpy as np
|
|
import jax.numpy as jnp
|
|
from jax import Array
|
|
from jax.typing import ArrayLike
|
|
|
|
def my_function(x: ArrayLike) -> Array:
|
|
# Runtime type validation, Python 3.10 or newer:
|
|
if not isinstance(x, ArrayLike):
|
|
raise TypeError(f"Expected arraylike input; got {x}")
|
|
# Runtime type validation, any Python version:
|
|
if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
|
|
raise TypeError(f"Expected arraylike input; got {x}")
|
|
|
|
# Convert input to jax.Array:
|
|
x_arr = jnp.asarray(x)
|
|
|
|
# ... do some computation; JAX functions will return Array types:
|
|
result = x_arr.sum(0) / x_arr.shape[0]
|
|
|
|
# return an Array
|
|
return result
|
|
|
|
Most of JAX's public APIs follow this pattern. Note in particular that we recommend JAX functions
|
|
to not accept sequences such as :class:`list` or :class:`tuple` in place of arrays, as this can
|
|
cause extra overhead in JAX transforms like :func:`~jax.jit` and can behave in unexpected ways with
|
|
batch-wise transforms like :func:`~jax.vmap` or :func:`jax.pmap`. For more information on this,
|
|
see `Non-array inputs NumPy vs JAX`_
|
|
|
|
.. _Non-array inputs NumPy vs JAX: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax
|
|
"""
|
|
from jax._src.typing import (
|
|
ArrayLike as ArrayLike
|
|
)
|