Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/ad_util.py

131 lines
3.9 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import types
from typing import Any, Callable, Dict, TypeVar, Union, cast
from jax._src import core
from jax._src import traceback_util
from jax._src.core import (lattice_join, Primitive, valid_jaxtype,
raise_to_shaped, get_aval)
from jax._src.tree_util import register_pytree_node
from jax._src.typing import Array, ArrayLike
from jax._src.util import safe_map
traceback_util.register_exclusion(__file__)
T = TypeVar('T')
map = safe_map
jaxval_adders: Dict[type, Callable[[ArrayLike, ArrayLike], Array]] = {}
def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array:
return add_jaxvals_p.bind(x, y)
add_jaxvals_p: Primitive = Primitive('add_any')
add_any_p = add_jaxvals_p
@add_jaxvals_p.def_impl
def add_impl(xs, ys):
return jaxval_adders[type(xs)](xs, ys)
@add_jaxvals_p.def_abstract_eval
def add_abstract(xs, ys):
return lattice_join(xs, ys)
jaxval_zeros_likers: Dict[type, Callable[[Any], Array]] = {}
def instantiate(z: Union[Zero, Array]) -> Array:
if type(z) is Zero:
return zeros_like_aval(z.aval)
return cast(Array, z)
def zeros_like_aval(aval: core.AbstractValue) -> Array:
return aval_zeros_likers[type(aval)](aval)
aval_zeros_likers: Dict[type, Callable[[Any], Array]] = {}
def zeros_like_jaxval(val: ArrayLike) -> Array:
return zeros_like_p.bind(val)
zeros_like_p: Primitive = Primitive('zeros_like')
@zeros_like_p.def_impl
def zeros_like_impl(example):
return jaxval_zeros_likers[type(example)](example)
zeros_like_p.def_abstract_eval(lambda x: x)
class Zero:
__slots__ = ['aval']
def __init__(self, aval: core.AbstractValue):
self.aval = aval
def __repr__(self) -> str:
return f'Zero({self.aval})'
@staticmethod
def from_value(val: Any) -> Zero:
return Zero(raise_to_shaped(get_aval(val)))
register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval))
def _stop_gradient_impl(x: T) -> T:
if not valid_jaxtype(x):
raise TypeError("stop_gradient only works on valid JAX arrays, but "
f"input argument is: {x}")
return x
stop_gradient_p : Primitive = Primitive('stop_gradient')
stop_gradient_p.def_impl(_stop_gradient_impl)
stop_gradient_p.def_abstract_eval(lambda x: x)
class SymbolicZero:
def __init__(self, aval: core.AbstractValue) -> None:
self.aval = aval
def __repr__(self) -> str:
return self.__class__.__name__
# TODO(mattjj,frostig): this forwards attr lookup to self.aval delegate;
# should dedup with core.Tracer.__getattr__ which does the same thing
def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
try:
attr = getattr(self.aval, name)
except KeyError as err:
raise AttributeError(
f"{self.__class__.__name__} has no attribute {name}"
) from err
else:
t = type(attr)
if t is core.aval_property:
return attr.fget(self)
elif t is core.aval_method:
return types.MethodType(attr.fun, self)
else:
return attr
JaxTypeOrTracer = Any
def replace_internal_symbolic_zeros(
x: Union[JaxTypeOrTracer, Zero]) -> Union[JaxTypeOrTracer, SymbolicZero]:
return SymbolicZero(x.aval) if type(x) is Zero else x
def replace_rule_output_symbolic_zeros(
x: Union[JaxTypeOrTracer, SymbolicZero]) -> Union[JaxTypeOrTracer, Zero]:
return Zero(x.aval) if type(x) is SymbolicZero else x