# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Set default logging level before any logging happens. import os as _os _os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') del _os # Set Cloud TPU env vars if necessary before transitively loading C++ backend from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init try: _cloud_tpu_init() except Exception as exc: # Defensively swallow any exceptions to avoid making jax unimportable from warnings import warn as _warn _warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report " f"an issue at https://github.com/google/jax/issues") del _warn del _cloud_tpu_init # Confusingly there are two things named "config": the module and the class. # We want the exported object to be the class, so we first import the module # to make sure a later import doesn't overwrite the class. from jax import config as _config_module del _config_module # Force early import, allowing use of `jax.core` after importing `jax`. import jax.core as _core del _core # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 from jax._src.basearray import Array as Array from jax import typing as typing from jax._src.config import ( config as config, enable_checks as enable_checks, check_tracer_leaks as check_tracer_leaks, checking_leaks as checking_leaks, enable_custom_prng as enable_custom_prng, softmax_custom_jvp as softmax_custom_jvp, enable_custom_vjp_by_custom_transpose as enable_custom_vjp_by_custom_transpose, debug_nans as debug_nans, debug_infs as debug_infs, log_compiles as log_compiles, default_device as default_device, default_matmul_precision as default_matmul_precision, default_prng_impl as default_prng_impl, numpy_dtype_promotion as numpy_dtype_promotion, numpy_rank_promotion as numpy_rank_promotion, jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions, transfer_guard as transfer_guard, transfer_guard_host_to_device as transfer_guard_host_to_device, transfer_guard_device_to_device as transfer_guard_device_to_device, transfer_guard_device_to_host as transfer_guard_device_to_host, spmd_mode as spmd_mode, ) from jax._src.core import ensure_compile_time_eval as ensure_compile_time_eval from jax._src.environment_info import print_environment_info as print_environment_info from jax._src.lib import xla_client as _xc Device = _xc.Device del _xc from jax._src.api import effects_barrier as effects_barrier from jax._src.api import block_until_ready as block_until_ready from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies from jax._src.api import clear_backends as clear_backends from jax._src.api import clear_caches as clear_caches from jax._src.custom_derivatives import closure_convert as closure_convert from jax._src.util import curry as _deprecated_curry from jax._src.custom_derivatives import custom_gradient as custom_gradient from jax._src.custom_derivatives import custom_jvp as custom_jvp from jax._src.custom_derivatives import custom_vjp as custom_vjp from jax._src.xla_bridge import default_backend as default_backend from jax._src.xla_bridge import device_count as device_count from jax._src.api import device_get as device_get from jax._src.api import device_put as device_put from jax._src.api import device_put_sharded as device_put_sharded from jax._src.api import device_put_replicated as device_put_replicated from jax._src.xla_bridge import devices as devices from jax._src.api import disable_jit as disable_jit from jax._src.api import eval_shape as eval_shape from jax._src.api_util import flatten_fun_nokwargs as _deprecated_flatten_fun_nokwargs from jax._src.dtypes import float0 as float0 from jax._src.api import grad as grad from jax._src.api import hessian as hessian from jax._src.xla_bridge import host_count as host_count from jax._src.xla_bridge import host_id as host_id from jax._src.xla_bridge import host_ids as host_ids from jax._src.api import jacobian as jacobian from jax._src.api import jacfwd as jacfwd from jax._src.api import jacrev as jacrev from jax._src.api import jit as jit from jax._src.api import jvp as jvp from jax._src.xla_bridge import local_device_count as local_device_count from jax._src.xla_bridge import local_devices as local_devices from jax._src.api import linearize as linearize from jax._src.api import linear_transpose as linear_transpose from jax._src.api import live_arrays as live_arrays from jax._src.api import make_jaxpr as make_jaxpr from jax._src.api import named_call as named_call from jax._src.api import named_scope as named_scope from jax._src.api import pmap as pmap from jax._src.xla_bridge import process_count as process_count from jax._src.xla_bridge import process_index as process_index from jax._src.callback import pure_callback_api as pure_callback from jax._src.ad_checkpoint import checkpoint_wrapper as remat from jax._src.core import ShapedArray as _deprecated_ShapedArray from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.api import value_and_grad as value_and_grad from jax._src.api import vjp as vjp from jax._src.api import vmap as vmap from jax._src.api import xla_computation as xla_computation from jax.interpreters import ad as _deprecated_ad import jax.interpreters.batching import jax.interpreters.mlir from jax.interpreters import partial_eval as _deprecated_partial_eval from jax.interpreters import pxla as _deprecated_pxla from jax.interpreters import xla as _deprecated_xla from jax._src.array import ( make_array_from_single_device_arrays as make_array_from_single_device_arrays, make_array_from_callback as make_array_from_callback, ) from jax.version import __version__ as __version__ from jax.version import __version_info__ as __version_info__ from jax._src.tree_util import ( tree_map as tree_map, # TODO(jakevdp): remove these deprecated routines after October 2022 _deprecated_treedef_is_leaf as treedef_is_leaf, _deprecated_tree_flatten as tree_flatten, _deprecated_tree_leaves as tree_leaves, _deprecated_tree_structure as tree_structure, _deprecated_tree_transpose as tree_transpose, _deprecated_tree_unflatten as tree_unflatten, ) # These submodules are separate because they are in an import cycle with # jax and rely on the names imported above. from jax import abstract_arrays as _deprecated_abstract_arrays from jax import custom_derivatives as custom_derivatives from jax import custom_batching as custom_batching from jax import custom_transpose as custom_transpose from jax import api_util as api_util from jax import distributed as distributed from jax import debug as debug from jax import dtypes as dtypes from jax import errors as errors from jax import image as image from jax import lax as lax from jax import linear_util as linear_util from jax import monitoring as monitoring from jax import nn as nn from jax import numpy as numpy from jax import ops as ops from jax import profiler as profiler from jax import random as random from jax import scipy as scipy from jax import sharding as sharding from jax import stages as stages from jax import tree_util as tree_util from jax import util as util # Also circular dependency. from jax._src.array import Shard as Shard import jax.experimental.compilation_cache.compilation_cache as _ccache del _ccache _deprecations = { # Added 06 June 2023 "abstract_arrays": ( "jax.abstract_arrays is deprecated. Refer to jax.core.", _deprecated_abstract_arrays ), # Added 28 March 2023 "ShapedArray": ( "jax.ShapedArray is deprecated. Use jax.core.ShapedArray", _deprecated_ShapedArray, ), "ad": ( "jax.ad is deprecated. Use jax.interpreters.ad", _deprecated_ad, ), "partial_eval": ( "jax.partial_eval is deprecated. Use jax.interpreters.partial_eval", _deprecated_partial_eval, ), "pxla": ( "jax.pxla is deprecated. Use jax.interpreters.pxla", _deprecated_pxla, ), "xla": ( "jax.xla is deprecated. Use jax.interpreters.xla", _deprecated_xla, ), "curry": ( "jax.curry is deprecated. Use curry = lambda f: partial(partial, f)", _deprecated_curry, ), "flatten_fun_nokwargs": ( "jax.flatten_fun_nokwargs is deprecated. Use jax.api_util.flatten_fun_nokwargs.", _deprecated_flatten_fun_nokwargs, ), } import typing as _typing if _typing.TYPE_CHECKING: from jax._src import abstract_arrays as abstract_arrays from jax._src.core import ShapedArray as ShapedArray from jax.interpreters import ad as ad from jax.interpreters import partial_eval as partial_eval from jax.interpreters import pxla as pxla from jax.interpreters import xla as xla from jax._src.util import curry as curry from jax._src.api_util import flatten_fun_nokwargs as flatten_fun_nokwargs else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing # TODO(yashkatariya): Remove after 2 jax releases from 0.4.6 if not config.jax_jit_pjit_api_merge: raise ValueError( 'jax.config.jax_jit_pjit_api_merge cannot be disabled after jax 0.4.7' ' release. Please downgrade to jax and jaxlib 0.4.6 if you want to' ' disable jax.config.jax_jit_pjit_api_merge.' ) import jax.lib # TODO(phawkins): remove this export. # trailer