Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/state/types.py
2023-06-19 00:49:18 +02:00

163 lines
5.2 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for state types."""
from __future__ import annotations
import math
from typing import Any, Generic, List, Sequence, Set, Tuple, TypeVar, Union
from jax._src import core
from jax._src import effects
from jax._src import pretty_printer as pp
from jax._src.util import safe_map, safe_zip
## JAX utilities
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Array = Any
_ref_effect_color = pp.Color.GREEN
class RefEffect(effects.JaxprInputEffect):
name: str
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.input_index == other.input_index
def __hash__(self):
return hash((self.__class__, self.input_index))
def _pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
if isinstance(self.input_index, core.Var):
index_text = pp.text(core.pp_var(self.input_index, context))
else:
index_text = pp.text(self.input_index)
return pp.concat([
pp.color(pp.text(self.name), foreground=_ref_effect_color),
pp.text("<"),
index_text,
pp.text(">")])
def __str__(self):
return f"{self.name}<{self.input_index}>"
class ReadEffect(RefEffect):
name: str = "Read"
class WriteEffect(RefEffect):
name: str = "Write"
class AccumEffect(RefEffect):
name: str = "Accum"
effects.control_flow_allowed_effects.add_type(RefEffect)
StateEffect = Union[ReadEffect, WriteEffect, AccumEffect]
# ## `Ref`s
Aval = TypeVar("Aval", bound=core.AbstractValue)
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
class AbstractRef(core.AbstractValue, Generic[Aval]):
__slots__ = ["inner_aval"]
def __init__(self, inner_aval: core.AbstractValue):
self.inner_aval = inner_aval
def join(self, other):
assert isinstance(other, AbstractRef)
return AbstractRef(self.inner_aval.join(other.inner_aval))
ndim = property(lambda self: len(self.shape))
size = property(lambda self: math.prod(self.shape))
@property
def shape(self):
if not isinstance(self.inner_aval, core.ShapedArray):
raise ValueError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.")
return self.inner_aval.shape
@property
def dtype(self):
if not isinstance(self.inner_aval, core.UnshapedArray):
raise ValueError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.")
return self.inner_aval.dtype
@core.aval_method
@staticmethod
def get(tracer, idx=()):
from jax._src.state.primitives import ref_get # pytype: disable=import-error
return ref_get(tracer, idx)
@core.aval_method
@staticmethod
def set(tracer, value, idx=()):
from jax._src.state.primitives import ref_set # pytype: disable=import-error
return ref_set(tracer, idx, value)
def _getitem(self, tracer, idx) -> Array:
if not isinstance(idx, tuple):
idx = idx,
from jax._src.state.primitives import ref_get # pytype: disable=import-error
return ref_get(tracer, idx)
def _setitem(self, tracer, idx, value) -> None:
if not isinstance(idx, tuple):
idx = idx,
from jax._src.state.primitives import ref_set # pytype: disable=import-error
return ref_set(tracer, idx, value)
def __repr__(self) -> str:
return f'Ref{{{self.inner_aval.str_short()}}}'
def at_least_vspace(self):
return AbstractRef(self.inner_aval.at_least_vspace())
def __eq__(self, other):
return (type(self) is type(other) and self.inner_aval == other.inner_aval)
def __hash__(self):
return hash((self.__class__, self.inner_aval))
def _ref_raise_to_shaped(ref_aval: AbstractRef, weak_type):
return AbstractRef(core.raise_to_shaped(ref_aval.inner_aval, weak_type))
core.raise_to_shaped_mappings[AbstractRef] = _ref_raise_to_shaped
def _map_ref(size, axis, ref_aval):
return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval))
def _unmap_ref(size, axis_name, axis, ref_aval):
return AbstractRef(core.unmapped_aval(size, axis_name, axis,
ref_aval.inner_aval))
core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref)
def get_ref_state_effects(
avals: Sequence[core.AbstractValue],
effects: core.Effects) -> List[Set[StateEffect]]:
return [{eff for eff in effects
if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect))
and eff.input_index == i} for i, _ in enumerate(avals)]
def shaped_array_ref(shape: Tuple[int, ...], dtype,
weak_type: bool = False,
named_shape = None) -> AbstractRef[core.AbstractValue]:
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type,
named_shape=named_shape))