Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/__init__.py
2023-06-19 00:49:18 +02:00

253 lines
9.8 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Set default logging level before any logging happens.
import os as _os
_os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
del _os
# Set Cloud TPU env vars if necessary before transitively loading C++ backend
from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
try:
_cloud_tpu_init()
except Exception as exc:
# Defensively swallow any exceptions to avoid making jax unimportable
from warnings import warn as _warn
_warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
f"an issue at https://github.com/google/jax/issues")
del _warn
del _cloud_tpu_init
# Confusingly there are two things named "config": the module and the class.
# We want the exported object to be the class, so we first import the module
# to make sure a later import doesn't overwrite the class.
from jax import config as _config_module
del _config_module
# Force early import, allowing use of `jax.core` after importing `jax`.
import jax.core as _core
del _core
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.basearray import Array as Array
from jax import typing as typing
from jax._src.config import (
config as config,
enable_checks as enable_checks,
check_tracer_leaks as check_tracer_leaks,
checking_leaks as checking_leaks,
enable_custom_prng as enable_custom_prng,
softmax_custom_jvp as softmax_custom_jvp,
enable_custom_vjp_by_custom_transpose as enable_custom_vjp_by_custom_transpose,
debug_nans as debug_nans,
debug_infs as debug_infs,
log_compiles as log_compiles,
default_device as default_device,
default_matmul_precision as default_matmul_precision,
default_prng_impl as default_prng_impl,
numpy_dtype_promotion as numpy_dtype_promotion,
numpy_rank_promotion as numpy_rank_promotion,
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions,
transfer_guard as transfer_guard,
transfer_guard_host_to_device as transfer_guard_host_to_device,
transfer_guard_device_to_device as transfer_guard_device_to_device,
transfer_guard_device_to_host as transfer_guard_device_to_host,
spmd_mode as spmd_mode,
)
from jax._src.core import ensure_compile_time_eval as ensure_compile_time_eval
from jax._src.environment_info import print_environment_info as print_environment_info
from jax._src.lib import xla_client as _xc
Device = _xc.Device
del _xc
from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as clear_backends
from jax._src.api import clear_caches as clear_caches
from jax._src.custom_derivatives import closure_convert as closure_convert
from jax._src.util import curry as _deprecated_curry
from jax._src.custom_derivatives import custom_gradient as custom_gradient
from jax._src.custom_derivatives import custom_jvp as custom_jvp
from jax._src.custom_derivatives import custom_vjp as custom_vjp
from jax._src.xla_bridge import default_backend as default_backend
from jax._src.xla_bridge import device_count as device_count
from jax._src.api import device_get as device_get
from jax._src.api import device_put as device_put
from jax._src.api import device_put_sharded as device_put_sharded
from jax._src.api import device_put_replicated as device_put_replicated
from jax._src.xla_bridge import devices as devices
from jax._src.api import disable_jit as disable_jit
from jax._src.api import eval_shape as eval_shape
from jax._src.api_util import flatten_fun_nokwargs as _deprecated_flatten_fun_nokwargs
from jax._src.dtypes import float0 as float0
from jax._src.api import grad as grad
from jax._src.api import hessian as hessian
from jax._src.xla_bridge import host_count as host_count
from jax._src.xla_bridge import host_id as host_id
from jax._src.xla_bridge import host_ids as host_ids
from jax._src.api import jacobian as jacobian
from jax._src.api import jacfwd as jacfwd
from jax._src.api import jacrev as jacrev
from jax._src.api import jit as jit
from jax._src.api import jvp as jvp
from jax._src.xla_bridge import local_device_count as local_device_count
from jax._src.xla_bridge import local_devices as local_devices
from jax._src.api import linearize as linearize
from jax._src.api import linear_transpose as linear_transpose
from jax._src.api import live_arrays as live_arrays
from jax._src.api import make_jaxpr as make_jaxpr
from jax._src.api import named_call as named_call
from jax._src.api import named_scope as named_scope
from jax._src.api import pmap as pmap
from jax._src.xla_bridge import process_count as process_count
from jax._src.xla_bridge import process_index as process_index
from jax._src.callback import pure_callback_api as pure_callback
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
from jax._src.core import ShapedArray as _deprecated_ShapedArray
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap
from jax._src.api import xla_computation as xla_computation
from jax.interpreters import ad as _deprecated_ad
import jax.interpreters.batching
import jax.interpreters.mlir
from jax.interpreters import partial_eval as _deprecated_partial_eval
from jax.interpreters import pxla as _deprecated_pxla
from jax.interpreters import xla as _deprecated_xla
from jax._src.array import (
make_array_from_single_device_arrays as make_array_from_single_device_arrays,
make_array_from_callback as make_array_from_callback,
)
from jax.version import __version__ as __version__
from jax.version import __version_info__ as __version_info__
from jax._src.tree_util import (
tree_map as tree_map,
# TODO(jakevdp): remove these deprecated routines after October 2022
_deprecated_treedef_is_leaf as treedef_is_leaf,
_deprecated_tree_flatten as tree_flatten,
_deprecated_tree_leaves as tree_leaves,
_deprecated_tree_structure as tree_structure,
_deprecated_tree_transpose as tree_transpose,
_deprecated_tree_unflatten as tree_unflatten,
)
# These submodules are separate because they are in an import cycle with
# jax and rely on the names imported above.
from jax import abstract_arrays as _deprecated_abstract_arrays
from jax import custom_derivatives as custom_derivatives
from jax import custom_batching as custom_batching
from jax import custom_transpose as custom_transpose
from jax import api_util as api_util
from jax import distributed as distributed
from jax import debug as debug
from jax import dtypes as dtypes
from jax import errors as errors
from jax import image as image
from jax import lax as lax
from jax import linear_util as linear_util
from jax import monitoring as monitoring
from jax import nn as nn
from jax import numpy as numpy
from jax import ops as ops
from jax import profiler as profiler
from jax import random as random
from jax import scipy as scipy
from jax import sharding as sharding
from jax import stages as stages
from jax import tree_util as tree_util
from jax import util as util
# Also circular dependency.
from jax._src.array import Shard as Shard
import jax.experimental.compilation_cache.compilation_cache as _ccache
del _ccache
_deprecations = {
# Added 06 June 2023
"abstract_arrays": (
"jax.abstract_arrays is deprecated. Refer to jax.core.",
_deprecated_abstract_arrays
),
# Added 28 March 2023
"ShapedArray": (
"jax.ShapedArray is deprecated. Use jax.core.ShapedArray",
_deprecated_ShapedArray,
),
"ad": (
"jax.ad is deprecated. Use jax.interpreters.ad",
_deprecated_ad,
),
"partial_eval": (
"jax.partial_eval is deprecated. Use jax.interpreters.partial_eval",
_deprecated_partial_eval,
),
"pxla": (
"jax.pxla is deprecated. Use jax.interpreters.pxla",
_deprecated_pxla,
),
"xla": (
"jax.xla is deprecated. Use jax.interpreters.xla",
_deprecated_xla,
),
"curry": (
"jax.curry is deprecated. Use curry = lambda f: partial(partial, f)",
_deprecated_curry,
),
"flatten_fun_nokwargs": (
"jax.flatten_fun_nokwargs is deprecated. Use jax.api_util.flatten_fun_nokwargs.",
_deprecated_flatten_fun_nokwargs,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src import abstract_arrays as abstract_arrays
from jax._src.core import ShapedArray as ShapedArray
from jax.interpreters import ad as ad
from jax.interpreters import partial_eval as partial_eval
from jax.interpreters import pxla as pxla
from jax.interpreters import xla as xla
from jax._src.util import curry as curry
from jax._src.api_util import flatten_fun_nokwargs as flatten_fun_nokwargs
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
# TODO(yashkatariya): Remove after 2 jax releases from 0.4.6
if not config.jax_jit_pjit_api_merge:
raise ValueError(
'jax.config.jax_jit_pjit_api_merge cannot be disabled after jax 0.4.7'
' release. Please downgrade to jax and jaxlib 0.4.6 if you want to'
' disable jax.config.jax_jit_pjit_api_merge.'
)
import jax.lib # TODO(phawkins): remove this export.
# trailer