Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/interpreters/batching.py
2023-06-19 00:49:18 +02:00

79 lines
2.6 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.interpreters.batching import (
Array as Array,
AxisSize as AxisSize,
BatchTrace as BatchTrace,
BatchTracer as BatchTracer,
BatchingRule as BatchingRule,
RaggedAxis as RaggedAxis,
Elt as Elt,
FromEltHandler as FromEltHandler,
GetIdx as GetIdx,
IndexedAxisSize as IndexedAxisSize,
MakeIotaHandler as MakeIotaHandler,
MapSpec as MapSpec,
NotMapped as NotMapped,
Pile as Pile,
PileAxis as PileAxis,
PileTy as PileTy,
ToEltHandler as ToEltHandler,
Vmappable as Vmappable,
Zero as Zero,
ZeroIfMapped as ZeroIfMapped,
add_batched as add_batched,
axis_primitive_batchers as axis_primitive_batchers,
batch as batch,
batch_custom_jvp_subtrace as batch_custom_jvp_subtrace,
batch_custom_vjp_bwd as batch_custom_vjp_bwd,
batch_jaxpr as batch_jaxpr,
batch_jaxpr2 as batch_jaxpr2,
batch_jaxpr_axes as batch_jaxpr_axes,
batch_subtrace as batch_subtrace,
bdim_at_front as bdim_at_front,
broadcast as broadcast,
broadcast_batcher as broadcast_batcher,
defbroadcasting as defbroadcasting,
defreducer as defreducer,
defvectorized as defvectorized,
flatten_fun_for_vmap as flatten_fun_for_vmap,
from_elt as from_elt,
from_elt_handlers as from_elt_handlers,
is_vmappable as is_vmappable,
make_iota as make_iota,
make_iota_handlers as make_iota_handlers,
matchaxis as matchaxis,
moveaxis as moveaxis,
not_mapped as not_mapped,
pile_axis as pile_axis,
primitive_batchers as primitive_batchers,
reducer_batcher as reducer_batcher,
register_vmappable as register_vmappable,
spec_types as spec_types,
spmd_axis_primitive_batchers as spmd_axis_primitive_batchers,
to_elt as to_elt,
to_elt_handlers as to_elt_handlers,
unregister_vmappable as unregister_vmappable,
vectorized_batcher as vectorized_batcher,
vmappables as vmappables,
vtile as vtile,
zero_if_mapped as zero_if_mapped,
zeros_like_batched as zeros_like_batched,
zeros_like_p as zeros_like_p,
)