Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/stages.py
2023-06-19 00:49:18 +02:00

681 lines
24 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Interfaces to JAX's compilation steps, and utilities for conforming to them.
This module defines a set of public-facing types that reflect the output of
intermediate stages in the process of compilation. Currently there are two
stages modeled: lowering (which produces compiler input), and compilation
(which produces compiler output).
It also defines some internal-facing types to guide what JAX can present in
this common form: an internal ``Lowering`` suffices to back a public-facing
``Lowered`` and an internal ``Executable`` suffices to back a public-facing
``Compiled``.
Finally, this module defines a couple more classes to commonly adapt our
various internal XLA-backed lowerings and executables into the lowering and
executable protocols described above.
"""
from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import (
Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tuple, Union)
import jax
from jax._src import core
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib import xla_client as xc
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
xla_extension = xc._xla
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
CompilerOptions = Dict[str, Union[str, bool]]
# -- Internal protocols
class Executable(Protocol):
"""Protocol for executables, which a user-facing ``Compiled`` encapsulates."""
def call(self, *args_flat) -> Sequence[Any]:
"""Execute on the flat list of arguments, returning flat outputs."""
# TODO(frostig): improve annotation (sequences of arrays/buffers)
raise NotImplementedError
def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
"""Flat sequence of input shardings.
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
compiler, or runtime.
"""
raise NotImplementedError
def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
"""Flat sequence of output shardings.
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
compiler, or runtime.
"""
raise NotImplementedError
def as_text(self) -> str:
"""A human-readable text representation of this executable.
Intended for visualization and debugging purposes. This need not be a valid
nor reliable serialization. It is relayed directly to external callers.
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
compiler, or runtime.
"""
raise NotImplementedError
def cost_analysis(self) -> Any:
"""A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output by
this is some simple data structure that can easily be printed or serialized
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
structure can be arbitrary: it need not be consistent across versions of JAX
and jaxlib, or even across invocations. It is relayed directly to external
callers.
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
compiler, or runtime.
"""
# TODO(frostig): improve annotation (arbitrary pytree)
raise NotImplementedError
def memory_analysis(self) -> Any:
"""A summary of estimated memory requirements.
Intended for visualization and debugging purposes. The object output by
this is some simple data structure that can easily be printed or serialized
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
structure can be arbitrary: it need not be consistent across versions of JAX
and jaxlib, or even across invocations. It is relayed directly to external
callers.
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
compiler, or runtime.
"""
# TODO(frostig): improve annotation (arbitrary pytree)
raise NotImplementedError
def runtime_executable(self) -> Any:
"""An arbitrary object representation of this executable.
Intended for debugging purposes. This need not be a valid nor reliable
serialization. It is relayed directly to external callers, with no
guarantee on type, structure, or consistency across invocations.
May raise ``NotImplementedError`` if unavailable, e.g. based on backend or
compiler.
"""
raise NotImplementedError
def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any:
"""Optionally constructs a fast c++ dispatcher."""
return None
class Lowering(Protocol):
"""Protocol for lowerings, which a user-facing ``Lowered`` encapsulates."""
def compile(
self, compiler_options: Optional[CompilerOptions] = None) -> Executable:
"""Compile and return a corresponding ``Executable``."""
raise NotImplementedError
def as_text(self, dialect: Optional[str] = None) -> str:
"""A human-readable text representation of this lowering.
Intended for visualization and debugging purposes. This need not be a valid
nor reliable serialization. It is relayed directly to external callers.
"""
raise NotImplementedError
def compiler_ir(self, dialect: Optional[str] = None) -> Any:
"""An arbitrary object representation of this lowering.
Intended for debugging purposes. This need not be a valid nor reliable
serialization. It is relayed directly to external callers, with no
guarantee on type, structure, or consistency across invocations.
May raise ``NotImplementedError`` if unavailable, e.g. based on backend or
compiler.
Args:
dialect: Optional string specifying a representation dialect
(e.g. "stablehlo")
"""
raise NotImplementedError
def cost_analysis(self) -> Any:
"""A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output by
this is some simple data structure that can easily be printed or serialized
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
structure can be arbitrary: it need not be consistent across versions of JAX
and jaxlib, or even across invocations. It is relayed directly to external
callers.
This function estimates execution cost in the absence of compiler
optimizations, which may drastically affect the cost. For execution cost
estimates after optimizations, compile this lowering and see
``Compiled.cost_analysis``.
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
compiler, or runtime.
"""
# TODO(frostig): improve annotation (arbitrary pytree)
raise NotImplementedError
# -- Internal adapters from XLA-related objects to the above protocols
class XlaExecutable(Executable):
def xla_extension_executable(self) -> xc.LoadedExecutable:
raise NotImplementedError("must override")
def call(self, *args_flat) -> Sequence[Any]:
raise NotImplementedError("must override")
def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
raise NotImplementedError(
"compiled executable carries no input sharding information")
def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
raise NotImplementedError(
"compiled executable carries no output sharding information")
def as_text(self) -> str:
xla_ext_exe = self.xla_extension_executable()
err_msg = ("text view unsupported on current XLA backend: "
f"{type(xla_ext_exe)}")
if not hasattr(xla_ext_exe, "hlo_modules"):
raise NotImplementedError(err_msg)
try:
return "\n\n".join([m.to_string() for m in xla_ext_exe.hlo_modules()])
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
raise NotImplementedError(err_msg) from e
else:
raise
def cost_analysis(self) -> List[Dict[str, float]]:
xla_ext_exe = self.xla_extension_executable()
err_msg = ("cost analysis unsupported on current XLA backend: "
f"{type(xla_ext_exe)}")
# TODO(b/259255524): Unify/merge the two cost_analysis calls below.
if hasattr(xla_ext_exe, "client"):
try:
return [
xla_extension.hlo_module_cost_analysis(xla_ext_exe.client, m)
for m in xla_ext_exe.hlo_modules()
]
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
raise NotImplementedError(err_msg) from e
else:
raise
elif hasattr(xla_ext_exe, "cost_analysis"):
try:
return xla_ext_exe.cost_analysis()
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
raise NotImplementedError(err_msg) from e
else:
raise
else:
raise NotImplementedError(err_msg)
def memory_analysis(self) -> Any:
xla_ext_exe = self.xla_extension_executable()
err_msg = ("memory analysis unsupported on current XLA backend: "
f"{type(xla_ext_exe)}")
if not hasattr(xla_ext_exe, "get_compiled_memory_stats"):
raise NotImplementedError(err_msg)
try:
return xla_ext_exe.get_compiled_memory_stats()
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
raise NotImplementedError(err_msg) from e
else:
raise
def runtime_executable(self) -> Any:
return self.xla_extension_executable()
class XlaLowering(Lowering):
"""Adapts our various internal XLA-backed computations into a ``Lowering``."""
compile_args: Dict[str, Any]
def hlo(self) -> xc.XlaComputation:
"""Return an HLO representation of this computation."""
return xla_extension.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(self.stablehlo()),
use_tuple_args=self.compile_args["tuple_args"])
def mhlo(self) -> ir.Module:
"""Return an MHLO representation of this computation."""
module_str = xla_extension.mlir.stablehlo_to_mhlo(
mlir.module_to_bytecode(self.stablehlo()))
with self.stablehlo().context:
return ir.Module.parse(module_str)
def stablehlo(self) -> ir.Module:
"""Return a StableHLO representation of this computation."""
raise NotImplementedError("must override")
def compile(
self, compiler_options: Optional[CompilerOptions] = None) -> Executable:
raise NotImplementedError("must override")
def as_text(self, dialect: Optional[str] = None) -> str:
if dialect is None:
dialect = "stablehlo"
if dialect == "mhlo":
return str(self.mhlo())
elif dialect == "stablehlo":
return str(self.stablehlo())
elif dialect == "hlo":
return self.hlo().as_hlo_text()
else:
raise ValueError(f"unknown dialect: {dialect}")
def compiler_ir(self, dialect: Optional[str] = None) -> Any:
if dialect is None:
dialect = "stablehlo"
if dialect == "mhlo":
return self.mhlo()
elif dialect == "stablehlo":
return self.stablehlo()
elif dialect == "hlo":
return self.hlo()
else:
raise ValueError(f"unknown dialect: {dialect}")
def cost_analysis(self) -> Dict[str, float]:
raise NotImplementedError("must override")
# -- Public-facing API, plus helpers
@dataclass
class ArgInfo:
aval: core.AbstractValue
donated: bool
class Stage:
args_info: Any # PyTree of ArgInfo
@property
def in_tree(self) -> tree_util.PyTreeDef:
"""Tree structure of the pair (positional arguments, keyword arguments)."""
return tree_util.tree_structure(self.args_info)
@property
def in_avals(self):
"""Tree of input avals."""
return tree_util.tree_map(lambda x: x.aval, self.args_info)
@property
def donate_argnums(self):
"""Flat tuple of donated argument indices."""
return tuple(
i for i, x in enumerate(tree_util.tree_leaves(self.args_info))
if x.donated)
def make_args_info(in_tree, in_avals, donate_argnums):
donate_argnums = frozenset(donate_argnums)
flat_avals, _ = tree_util.tree_flatten(in_avals) # todo: remove
return in_tree.unflatten([
ArgInfo(aval, i in donate_argnums)
for i, aval in enumerate(flat_avals)])
class CompiledCallParams(NamedTuple):
executable: Executable
no_kwargs: bool
in_tree: tree_util.PyTreeDef
out_tree: tree_util.PyTreeDef
class Compiled(Stage):
"""Compiled representation of a function specialized to types/values.
A compiled computation is associated with an executable and the
remaining information needed to execute it. It also provides a
common API for querying properties of compiled computations across
JAX's various compilation paths and backends.
"""
__slots__ = ["args_info", "out_tree", "_executable", "_no_kwargs"]
args_info: Any # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef
_executable: Executable
_no_kwargs: bool
def __init__(self, executable, args_info, out_tree, no_kwargs=False):
self._executable = executable
self._no_kwargs = no_kwargs
self.args_info = args_info
self.out_tree = out_tree
self._params = CompiledCallParams(self._executable, self._no_kwargs,
self.in_tree, self.out_tree)
self._call = None
def compiler_ir(self):
"""Post-compilation IR.
Compilation typically involves code transformation and
optimization. This method exists to reflect the compiler's
representation of the program after such passes, whenever
possible.
"""
# TODO(frostig): remove (deprecated)
warnings.warn(
"compiler_ir() is deprecated, consider runtime_executable() instead",
DeprecationWarning)
exe = self.runtime_executable()
return exe.hlo_modules() if exe is not None else None
def as_text(self) -> Optional[str]:
"""A human-readable text representation of this executable.
Intended for visualization and debugging purposes. This is not a valid nor
reliable serialization.
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
runtime.
"""
try:
return self._executable.as_text()
except NotImplementedError:
return None
def cost_analysis(self) -> Optional[Any]:
"""A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output by
this is some simple data structure that can easily be printed or serialized
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
structure can be arbitrary: it may be inconsistent across versions of JAX
and jaxlib, or even across invocations.
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
runtime.
"""
# TODO(frostig): improve annotation (basic pytree of arbitrary structure)
try:
return self._executable.cost_analysis()
except NotImplementedError:
return None
def memory_analysis(self) -> Optional[Any]:
"""A summary of estimated memory requirements.
Intended for visualization and debugging purposes. The object output by
this is some simple data structure that can easily be printed or serialized
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
structure can be arbitrary: it may be inconsistent across versions of JAX
and jaxlib, or even across invocations.
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
runtime.
"""
# TODO(frostig): improve annotation (basic pytree of arbitrary structure)
try:
return self._executable.memory_analysis()
except NotImplementedError:
return None
def runtime_executable(self) -> Optional[Any]:
"""An arbitrary object representation of this executable.
Intended for debugging purposes. This is not valid nor reliable
serialization. The output has no guarantee of consistency across
invocations.
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
runtime.
"""
return self._executable.runtime_executable()
@property
def input_shardings(self): # PyTree[sharding.XLACompatibleSharding]
shardings_flat = self._executable.input_shardings()
return tree_util.tree_unflatten(self.in_tree, shardings_flat) # pytype: disable=attribute-error
@property
def output_shardings(self): # PyTree[sharding.XLACompatibleSharding]
shardings_flat = self._executable.output_shardings()
return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error
@staticmethod
def call(*args, **kwargs):
# This is because `__call__` passes in `self._params` as the first argument.
# Instead of making the call signature `call(params, *args, **kwargs)`
# extract it from args because `params` can be passed as a kwarg by users
# which might confict here.
params = args[0]
args = args[1:]
if jax.config.jax_dynamic_shapes:
raise NotImplementedError
if params.no_kwargs and kwargs:
kws = ', '.join(kwargs.keys())
raise NotImplementedError(
"function was compiled by a transformation that does not support "
f"keyword arguments, but called with keyword arguments: {kws}")
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
if in_tree != params.in_tree:
# TODO(frostig): provide more info about the source function
# and transformation
raise TypeError(
f"function compiled for {params.in_tree}, called with {in_tree}")
try:
out_flat = params.executable.call(*args_flat)
except TypeError as e:
# We can't transform ahead-of-time compiled calls, since we've
# lowered and compiled for a fixed function signature, and JAX
# transformations change signatures. We interpret a Tracer
# argument as an indication of a transformation attempt. We
# could check this before the executable call, but we'd rather
# avoid isinstance checks on the call path. Seeing a TypeError
# might mean that arguments have JAX-invalid types, which in
# turn might mean some are Tracers.
for arg in args_flat:
if isinstance(arg, core.Tracer):
raise TypeError(
"Cannot apply JAX transformations to a function lowered and "
"compiled for a particular signature. Detected argument of "
f"Tracer type {type(arg)}.") from e
else:
raise
outs = tree_util.tree_unflatten(params.out_tree, out_flat)
return outs, out_flat, args_flat
def __call__(self, *args, **kwargs):
if self._call is None:
self._call = self._executable.create_cpp_call(self._no_kwargs,
self.in_tree,
self.out_tree)
if self._call is None:
params = self._params
def cpp_call_fallback(*args, **kwargs):
outs, _, _ = Compiled.call(params, *args, **kwargs)
return outs
self._call = cpp_call_fallback
return self._call(*args, **kwargs)
class Lowered(Stage):
"""Lowering of a function specialized to argument types and values.
A lowering is a computation ready for compilation. This class
carries a lowering together with the remaining information needed to
later compile and execute it. It also provides a common API for
querying properties of lowered computations across JAX's various
lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.).
"""
__slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"]
args_info: Any # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef
_lowering: XlaLowering
_no_kwargs: bool
def __init__(
self,
lowering: XlaLowering,
args_info, # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False):
self._lowering = lowering
self._no_kwargs = no_kwargs
self.args_info = args_info
self.out_tree = out_tree
@classmethod
def from_flat_info(cls,
lowering: XlaLowering,
in_tree: tree_util.PyTreeDef,
in_avals,
donate_argnums: Tuple[int, ...],
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False):
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
Args:
in_tree: The ``PyTreeDef`` of (args, kwargs).
out_tree: The ``PyTreeDef`` of the outputs.
no_kwargs: If ``True`` the transformation, and the
``Compiled`` returned from this object will not support keyword
arguments (an error will be raised if some are provided).
"""
return cls(
lowering,
make_args_info(in_tree, in_avals, donate_argnums),
out_tree,
no_kwargs=no_kwargs)
def compile(
self, compiler_options: Optional[CompilerOptions] = None) -> Compiled:
"""Compile, returning a corresponding ``Compiled`` instance."""
kw: Dict[str, Any] = {"compiler_options": compiler_options}
return Compiled(
self._lowering.compile(**kw), # pytype: disable=wrong-keyword-args
self.args_info,
self.out_tree,
no_kwargs=self._no_kwargs,
)
def as_text(self, dialect: Optional[str] = None) -> str:
"""A human-readable text representation of this lowering.
Intended for visualization and debugging purposes. This need not be a valid
nor reliable serialization. It is relayed directly to external callers.
Args:
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
"""
return self._lowering.as_text(dialect)
def compiler_ir(self, dialect: Optional[str] = None) -> Optional[Any]:
"""An arbitrary object representation of this lowering.
Intended for debugging purposes. This is not a valid nor reliable
serialization. The output has no guarantee of consistency across
invocations.
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
runtime.
Args:
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
"""
try:
return self._lowering.compiler_ir(dialect)
except NotImplementedError:
return None
def cost_analysis(self) -> Optional[Any]:
"""A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output by
this is some simple data structure that can easily be printed or serialized
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
structure can be arbitrary: it may be inconsistent across versions of JAX
and jaxlib, or even across invocations.
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
runtime.
"""
# TODO(frostig): improve annotation (basic pytree of arbitrary structure)
try:
return self._lowering.cost_analysis()
except NotImplementedError:
return None
class Wrapped(Protocol):
"""A function ready to be specialized, lowered, and compiled.
This protocol reflects the output of functions such as
``jax.jit``. Calling it results in JIT (just-in-time) lowering,
compilation, and execution. It can also be explicitly lowered prior
to compilation, and the result compiled prior to execution.
"""
def __call__(self, *args, **kwargs):
"""Executes the wrapped function, lowering and compiling as needed."""
raise NotImplementedError
def lower(self, *args, **kwargs) -> Lowered:
"""Lower this function explicitly for the given arguments.
A lowered function is staged out of Python and translated to a
compiler's input language, possibly in a backend-dependent
manner. It is ready for compilation but not yet compiled.
Returns:
A ``Lowered`` instance representing the lowering.
"""
raise NotImplementedError