Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/abstract_arrays.py

89 lines
3.2 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import partial
from typing import Set
import numpy as np
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
UnshapedArray = core.UnshapedArray
ShapedArray = core.ShapedArray
ConcreteArray = core.ConcreteArray
AbstractToken = core.AbstractToken
abstract_token = core.abstract_token
canonicalize_shape = core.canonicalize_shape
raise_to_shaped = core.raise_to_shaped
def zeros_like_array(x):
dtype, weak_type = dtypes._lattice_result_type(x)
dtype = dtypes.canonicalize_dtype(dtype)
aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type)
return ad_util.zeros_like_aval(aval)
numpy_scalar_types: Set[type] = { # pylint: disable=g-bare-generic
np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.complex64, np.complex128,
np.bool_, np.longlong, np.intc,
} | set(np.dtype(dt).type for dt in dtypes._float_types)
if dtypes.int4 is not None:
numpy_scalar_types.add(dtypes.int4)
if dtypes.uint4 is not None:
numpy_scalar_types.add(dtypes.uint4)
array_types: Set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic
def canonical_concrete_aval(val, weak_type=None):
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,
weak_type=weak_type)
def masked_array_error(*args, **kwargs):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
ad_util.jaxval_zeros_likers[np.ma.MaskedArray] = masked_array_error
for t in array_types:
core.pytype_aval_mappings[t] = canonical_concrete_aval
ad_util.jaxval_zeros_likers[t] = zeros_like_array
core.literalable_types.update(array_types)
def _zeros_like_python_scalar(t, x):
dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[t])
aval = core.ShapedArray((), dtype, weak_type=True)
return ad_util.zeros_like_aval(aval)
def _make_concrete_python_scalar(t, x):
dtype = dtypes._scalar_type_to_dtype(t, x)
weak_type = dtypes.is_weakly_typed(x)
return canonical_concrete_aval(np.array(x, dtype=dtype), weak_type=weak_type)
for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) # type: ignore