681 lines
24 KiB
Python
681 lines
24 KiB
Python
![]() |
# Copyright 2022 The JAX Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""
|
||
|
Interfaces to JAX's compilation steps, and utilities for conforming to them.
|
||
|
|
||
|
This module defines a set of public-facing types that reflect the output of
|
||
|
intermediate stages in the process of compilation. Currently there are two
|
||
|
stages modeled: lowering (which produces compiler input), and compilation
|
||
|
(which produces compiler output).
|
||
|
|
||
|
It also defines some internal-facing types to guide what JAX can present in
|
||
|
this common form: an internal ``Lowering`` suffices to back a public-facing
|
||
|
``Lowered`` and an internal ``Executable`` suffices to back a public-facing
|
||
|
``Compiled``.
|
||
|
|
||
|
Finally, this module defines a couple more classes to commonly adapt our
|
||
|
various internal XLA-backed lowerings and executables into the lowering and
|
||
|
executable protocols described above.
|
||
|
"""
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import warnings
|
||
|
|
||
|
from dataclasses import dataclass
|
||
|
from typing import (
|
||
|
Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tuple, Union)
|
||
|
|
||
|
import jax
|
||
|
|
||
|
from jax._src import core
|
||
|
from jax._src import source_info_util
|
||
|
from jax._src import traceback_util
|
||
|
from jax._src import tree_util
|
||
|
from jax._src import util
|
||
|
from jax._src.interpreters import mlir
|
||
|
from jax._src.lib.mlir import ir
|
||
|
from jax._src.lib import xla_client as xc
|
||
|
|
||
|
|
||
|
source_info_util.register_exclusion(__file__)
|
||
|
traceback_util.register_exclusion(__file__)
|
||
|
|
||
|
|
||
|
xla_extension = xc._xla
|
||
|
map, unsafe_map = util.safe_map, map
|
||
|
zip, unsafe_zip = util.safe_zip, zip
|
||
|
|
||
|
CompilerOptions = Dict[str, Union[str, bool]]
|
||
|
|
||
|
|
||
|
# -- Internal protocols
|
||
|
|
||
|
class Executable(Protocol):
|
||
|
"""Protocol for executables, which a user-facing ``Compiled`` encapsulates."""
|
||
|
|
||
|
def call(self, *args_flat) -> Sequence[Any]:
|
||
|
"""Execute on the flat list of arguments, returning flat outputs."""
|
||
|
# TODO(frostig): improve annotation (sequences of arrays/buffers)
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
|
||
|
"""Flat sequence of input shardings.
|
||
|
|
||
|
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
|
||
|
compiler, or runtime.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
|
||
|
"""Flat sequence of output shardings.
|
||
|
|
||
|
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
|
||
|
compiler, or runtime.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def as_text(self) -> str:
|
||
|
"""A human-readable text representation of this executable.
|
||
|
|
||
|
Intended for visualization and debugging purposes. This need not be a valid
|
||
|
nor reliable serialization. It is relayed directly to external callers.
|
||
|
|
||
|
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
|
||
|
compiler, or runtime.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def cost_analysis(self) -> Any:
|
||
|
"""A summary of execution cost estimates.
|
||
|
|
||
|
Intended for visualization and debugging purposes. The object output by
|
||
|
this is some simple data structure that can easily be printed or serialized
|
||
|
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
|
||
|
structure can be arbitrary: it need not be consistent across versions of JAX
|
||
|
and jaxlib, or even across invocations. It is relayed directly to external
|
||
|
callers.
|
||
|
|
||
|
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
|
||
|
compiler, or runtime.
|
||
|
"""
|
||
|
# TODO(frostig): improve annotation (arbitrary pytree)
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def memory_analysis(self) -> Any:
|
||
|
"""A summary of estimated memory requirements.
|
||
|
|
||
|
Intended for visualization and debugging purposes. The object output by
|
||
|
this is some simple data structure that can easily be printed or serialized
|
||
|
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
|
||
|
structure can be arbitrary: it need not be consistent across versions of JAX
|
||
|
and jaxlib, or even across invocations. It is relayed directly to external
|
||
|
callers.
|
||
|
|
||
|
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
|
||
|
compiler, or runtime.
|
||
|
"""
|
||
|
# TODO(frostig): improve annotation (arbitrary pytree)
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def runtime_executable(self) -> Any:
|
||
|
"""An arbitrary object representation of this executable.
|
||
|
|
||
|
Intended for debugging purposes. This need not be a valid nor reliable
|
||
|
serialization. It is relayed directly to external callers, with no
|
||
|
guarantee on type, structure, or consistency across invocations.
|
||
|
|
||
|
May raise ``NotImplementedError`` if unavailable, e.g. based on backend or
|
||
|
compiler.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any:
|
||
|
"""Optionally constructs a fast c++ dispatcher."""
|
||
|
return None
|
||
|
|
||
|
|
||
|
class Lowering(Protocol):
|
||
|
"""Protocol for lowerings, which a user-facing ``Lowered`` encapsulates."""
|
||
|
|
||
|
def compile(
|
||
|
self, compiler_options: Optional[CompilerOptions] = None) -> Executable:
|
||
|
"""Compile and return a corresponding ``Executable``."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def as_text(self, dialect: Optional[str] = None) -> str:
|
||
|
"""A human-readable text representation of this lowering.
|
||
|
|
||
|
Intended for visualization and debugging purposes. This need not be a valid
|
||
|
nor reliable serialization. It is relayed directly to external callers.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def compiler_ir(self, dialect: Optional[str] = None) -> Any:
|
||
|
"""An arbitrary object representation of this lowering.
|
||
|
|
||
|
Intended for debugging purposes. This need not be a valid nor reliable
|
||
|
serialization. It is relayed directly to external callers, with no
|
||
|
guarantee on type, structure, or consistency across invocations.
|
||
|
|
||
|
May raise ``NotImplementedError`` if unavailable, e.g. based on backend or
|
||
|
compiler.
|
||
|
|
||
|
Args:
|
||
|
dialect: Optional string specifying a representation dialect
|
||
|
(e.g. "stablehlo")
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def cost_analysis(self) -> Any:
|
||
|
"""A summary of execution cost estimates.
|
||
|
|
||
|
Intended for visualization and debugging purposes. The object output by
|
||
|
this is some simple data structure that can easily be printed or serialized
|
||
|
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
|
||
|
structure can be arbitrary: it need not be consistent across versions of JAX
|
||
|
and jaxlib, or even across invocations. It is relayed directly to external
|
||
|
callers.
|
||
|
|
||
|
This function estimates execution cost in the absence of compiler
|
||
|
optimizations, which may drastically affect the cost. For execution cost
|
||
|
estimates after optimizations, compile this lowering and see
|
||
|
``Compiled.cost_analysis``.
|
||
|
|
||
|
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
|
||
|
compiler, or runtime.
|
||
|
"""
|
||
|
# TODO(frostig): improve annotation (arbitrary pytree)
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
# -- Internal adapters from XLA-related objects to the above protocols
|
||
|
|
||
|
class XlaExecutable(Executable):
|
||
|
|
||
|
def xla_extension_executable(self) -> xc.LoadedExecutable:
|
||
|
raise NotImplementedError("must override")
|
||
|
|
||
|
def call(self, *args_flat) -> Sequence[Any]:
|
||
|
raise NotImplementedError("must override")
|
||
|
|
||
|
def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
|
||
|
raise NotImplementedError(
|
||
|
"compiled executable carries no input sharding information")
|
||
|
|
||
|
def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]:
|
||
|
raise NotImplementedError(
|
||
|
"compiled executable carries no output sharding information")
|
||
|
|
||
|
def as_text(self) -> str:
|
||
|
xla_ext_exe = self.xla_extension_executable()
|
||
|
err_msg = ("text view unsupported on current XLA backend: "
|
||
|
f"{type(xla_ext_exe)}")
|
||
|
if not hasattr(xla_ext_exe, "hlo_modules"):
|
||
|
raise NotImplementedError(err_msg)
|
||
|
try:
|
||
|
return "\n\n".join([m.to_string() for m in xla_ext_exe.hlo_modules()])
|
||
|
except xla_extension.XlaRuntimeError as e:
|
||
|
msg, *_ = e.args
|
||
|
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||
|
raise NotImplementedError(err_msg) from e
|
||
|
else:
|
||
|
raise
|
||
|
|
||
|
def cost_analysis(self) -> List[Dict[str, float]]:
|
||
|
xla_ext_exe = self.xla_extension_executable()
|
||
|
err_msg = ("cost analysis unsupported on current XLA backend: "
|
||
|
f"{type(xla_ext_exe)}")
|
||
|
# TODO(b/259255524): Unify/merge the two cost_analysis calls below.
|
||
|
if hasattr(xla_ext_exe, "client"):
|
||
|
try:
|
||
|
return [
|
||
|
xla_extension.hlo_module_cost_analysis(xla_ext_exe.client, m)
|
||
|
for m in xla_ext_exe.hlo_modules()
|
||
|
]
|
||
|
except xla_extension.XlaRuntimeError as e:
|
||
|
msg, *_ = e.args
|
||
|
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||
|
raise NotImplementedError(err_msg) from e
|
||
|
else:
|
||
|
raise
|
||
|
elif hasattr(xla_ext_exe, "cost_analysis"):
|
||
|
try:
|
||
|
return xla_ext_exe.cost_analysis()
|
||
|
except xla_extension.XlaRuntimeError as e:
|
||
|
msg, *_ = e.args
|
||
|
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||
|
raise NotImplementedError(err_msg) from e
|
||
|
else:
|
||
|
raise
|
||
|
else:
|
||
|
raise NotImplementedError(err_msg)
|
||
|
|
||
|
def memory_analysis(self) -> Any:
|
||
|
xla_ext_exe = self.xla_extension_executable()
|
||
|
err_msg = ("memory analysis unsupported on current XLA backend: "
|
||
|
f"{type(xla_ext_exe)}")
|
||
|
if not hasattr(xla_ext_exe, "get_compiled_memory_stats"):
|
||
|
raise NotImplementedError(err_msg)
|
||
|
try:
|
||
|
return xla_ext_exe.get_compiled_memory_stats()
|
||
|
except xla_extension.XlaRuntimeError as e:
|
||
|
msg, *_ = e.args
|
||
|
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||
|
raise NotImplementedError(err_msg) from e
|
||
|
else:
|
||
|
raise
|
||
|
|
||
|
def runtime_executable(self) -> Any:
|
||
|
return self.xla_extension_executable()
|
||
|
|
||
|
|
||
|
class XlaLowering(Lowering):
|
||
|
"""Adapts our various internal XLA-backed computations into a ``Lowering``."""
|
||
|
|
||
|
compile_args: Dict[str, Any]
|
||
|
|
||
|
def hlo(self) -> xc.XlaComputation:
|
||
|
"""Return an HLO representation of this computation."""
|
||
|
return xla_extension.mlir.mlir_module_to_xla_computation(
|
||
|
mlir.module_to_string(self.stablehlo()),
|
||
|
use_tuple_args=self.compile_args["tuple_args"])
|
||
|
|
||
|
def mhlo(self) -> ir.Module:
|
||
|
"""Return an MHLO representation of this computation."""
|
||
|
module_str = xla_extension.mlir.stablehlo_to_mhlo(
|
||
|
mlir.module_to_bytecode(self.stablehlo()))
|
||
|
with self.stablehlo().context:
|
||
|
return ir.Module.parse(module_str)
|
||
|
|
||
|
def stablehlo(self) -> ir.Module:
|
||
|
"""Return a StableHLO representation of this computation."""
|
||
|
raise NotImplementedError("must override")
|
||
|
|
||
|
def compile(
|
||
|
self, compiler_options: Optional[CompilerOptions] = None) -> Executable:
|
||
|
raise NotImplementedError("must override")
|
||
|
|
||
|
def as_text(self, dialect: Optional[str] = None) -> str:
|
||
|
if dialect is None:
|
||
|
dialect = "stablehlo"
|
||
|
if dialect == "mhlo":
|
||
|
return str(self.mhlo())
|
||
|
elif dialect == "stablehlo":
|
||
|
return str(self.stablehlo())
|
||
|
elif dialect == "hlo":
|
||
|
return self.hlo().as_hlo_text()
|
||
|
else:
|
||
|
raise ValueError(f"unknown dialect: {dialect}")
|
||
|
|
||
|
def compiler_ir(self, dialect: Optional[str] = None) -> Any:
|
||
|
if dialect is None:
|
||
|
dialect = "stablehlo"
|
||
|
if dialect == "mhlo":
|
||
|
return self.mhlo()
|
||
|
elif dialect == "stablehlo":
|
||
|
return self.stablehlo()
|
||
|
elif dialect == "hlo":
|
||
|
return self.hlo()
|
||
|
else:
|
||
|
raise ValueError(f"unknown dialect: {dialect}")
|
||
|
|
||
|
def cost_analysis(self) -> Dict[str, float]:
|
||
|
raise NotImplementedError("must override")
|
||
|
|
||
|
|
||
|
# -- Public-facing API, plus helpers
|
||
|
|
||
|
@dataclass
|
||
|
class ArgInfo:
|
||
|
aval: core.AbstractValue
|
||
|
donated: bool
|
||
|
|
||
|
|
||
|
class Stage:
|
||
|
args_info: Any # PyTree of ArgInfo
|
||
|
|
||
|
@property
|
||
|
def in_tree(self) -> tree_util.PyTreeDef:
|
||
|
"""Tree structure of the pair (positional arguments, keyword arguments)."""
|
||
|
return tree_util.tree_structure(self.args_info)
|
||
|
|
||
|
@property
|
||
|
def in_avals(self):
|
||
|
"""Tree of input avals."""
|
||
|
return tree_util.tree_map(lambda x: x.aval, self.args_info)
|
||
|
|
||
|
@property
|
||
|
def donate_argnums(self):
|
||
|
"""Flat tuple of donated argument indices."""
|
||
|
return tuple(
|
||
|
i for i, x in enumerate(tree_util.tree_leaves(self.args_info))
|
||
|
if x.donated)
|
||
|
|
||
|
|
||
|
def make_args_info(in_tree, in_avals, donate_argnums):
|
||
|
donate_argnums = frozenset(donate_argnums)
|
||
|
flat_avals, _ = tree_util.tree_flatten(in_avals) # todo: remove
|
||
|
return in_tree.unflatten([
|
||
|
ArgInfo(aval, i in donate_argnums)
|
||
|
for i, aval in enumerate(flat_avals)])
|
||
|
|
||
|
class CompiledCallParams(NamedTuple):
|
||
|
executable: Executable
|
||
|
no_kwargs: bool
|
||
|
in_tree: tree_util.PyTreeDef
|
||
|
out_tree: tree_util.PyTreeDef
|
||
|
|
||
|
|
||
|
class Compiled(Stage):
|
||
|
"""Compiled representation of a function specialized to types/values.
|
||
|
|
||
|
A compiled computation is associated with an executable and the
|
||
|
remaining information needed to execute it. It also provides a
|
||
|
common API for querying properties of compiled computations across
|
||
|
JAX's various compilation paths and backends.
|
||
|
"""
|
||
|
__slots__ = ["args_info", "out_tree", "_executable", "_no_kwargs"]
|
||
|
|
||
|
args_info: Any # PyTree of ArgInfo
|
||
|
out_tree: tree_util.PyTreeDef
|
||
|
_executable: Executable
|
||
|
_no_kwargs: bool
|
||
|
|
||
|
def __init__(self, executable, args_info, out_tree, no_kwargs=False):
|
||
|
self._executable = executable
|
||
|
self._no_kwargs = no_kwargs
|
||
|
self.args_info = args_info
|
||
|
self.out_tree = out_tree
|
||
|
self._params = CompiledCallParams(self._executable, self._no_kwargs,
|
||
|
self.in_tree, self.out_tree)
|
||
|
self._call = None
|
||
|
|
||
|
def compiler_ir(self):
|
||
|
"""Post-compilation IR.
|
||
|
|
||
|
Compilation typically involves code transformation and
|
||
|
optimization. This method exists to reflect the compiler's
|
||
|
representation of the program after such passes, whenever
|
||
|
possible.
|
||
|
"""
|
||
|
# TODO(frostig): remove (deprecated)
|
||
|
warnings.warn(
|
||
|
"compiler_ir() is deprecated, consider runtime_executable() instead",
|
||
|
DeprecationWarning)
|
||
|
exe = self.runtime_executable()
|
||
|
return exe.hlo_modules() if exe is not None else None
|
||
|
|
||
|
def as_text(self) -> Optional[str]:
|
||
|
"""A human-readable text representation of this executable.
|
||
|
|
||
|
Intended for visualization and debugging purposes. This is not a valid nor
|
||
|
reliable serialization.
|
||
|
|
||
|
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
|
||
|
runtime.
|
||
|
"""
|
||
|
try:
|
||
|
return self._executable.as_text()
|
||
|
except NotImplementedError:
|
||
|
return None
|
||
|
|
||
|
def cost_analysis(self) -> Optional[Any]:
|
||
|
"""A summary of execution cost estimates.
|
||
|
|
||
|
Intended for visualization and debugging purposes. The object output by
|
||
|
this is some simple data structure that can easily be printed or serialized
|
||
|
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
|
||
|
structure can be arbitrary: it may be inconsistent across versions of JAX
|
||
|
and jaxlib, or even across invocations.
|
||
|
|
||
|
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
|
||
|
runtime.
|
||
|
"""
|
||
|
# TODO(frostig): improve annotation (basic pytree of arbitrary structure)
|
||
|
try:
|
||
|
return self._executable.cost_analysis()
|
||
|
except NotImplementedError:
|
||
|
return None
|
||
|
|
||
|
def memory_analysis(self) -> Optional[Any]:
|
||
|
"""A summary of estimated memory requirements.
|
||
|
|
||
|
Intended for visualization and debugging purposes. The object output by
|
||
|
this is some simple data structure that can easily be printed or serialized
|
||
|
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
|
||
|
structure can be arbitrary: it may be inconsistent across versions of JAX
|
||
|
and jaxlib, or even across invocations.
|
||
|
|
||
|
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
|
||
|
runtime.
|
||
|
"""
|
||
|
# TODO(frostig): improve annotation (basic pytree of arbitrary structure)
|
||
|
try:
|
||
|
return self._executable.memory_analysis()
|
||
|
except NotImplementedError:
|
||
|
return None
|
||
|
|
||
|
def runtime_executable(self) -> Optional[Any]:
|
||
|
"""An arbitrary object representation of this executable.
|
||
|
|
||
|
Intended for debugging purposes. This is not valid nor reliable
|
||
|
serialization. The output has no guarantee of consistency across
|
||
|
invocations.
|
||
|
|
||
|
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
|
||
|
runtime.
|
||
|
"""
|
||
|
return self._executable.runtime_executable()
|
||
|
|
||
|
@property
|
||
|
def input_shardings(self): # PyTree[sharding.XLACompatibleSharding]
|
||
|
shardings_flat = self._executable.input_shardings()
|
||
|
return tree_util.tree_unflatten(self.in_tree, shardings_flat) # pytype: disable=attribute-error
|
||
|
|
||
|
@property
|
||
|
def output_shardings(self): # PyTree[sharding.XLACompatibleSharding]
|
||
|
shardings_flat = self._executable.output_shardings()
|
||
|
return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error
|
||
|
|
||
|
@staticmethod
|
||
|
def call(*args, **kwargs):
|
||
|
# This is because `__call__` passes in `self._params` as the first argument.
|
||
|
# Instead of making the call signature `call(params, *args, **kwargs)`
|
||
|
# extract it from args because `params` can be passed as a kwarg by users
|
||
|
# which might confict here.
|
||
|
params = args[0]
|
||
|
args = args[1:]
|
||
|
if jax.config.jax_dynamic_shapes:
|
||
|
raise NotImplementedError
|
||
|
if params.no_kwargs and kwargs:
|
||
|
kws = ', '.join(kwargs.keys())
|
||
|
raise NotImplementedError(
|
||
|
"function was compiled by a transformation that does not support "
|
||
|
f"keyword arguments, but called with keyword arguments: {kws}")
|
||
|
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
|
||
|
if in_tree != params.in_tree:
|
||
|
# TODO(frostig): provide more info about the source function
|
||
|
# and transformation
|
||
|
raise TypeError(
|
||
|
f"function compiled for {params.in_tree}, called with {in_tree}")
|
||
|
try:
|
||
|
out_flat = params.executable.call(*args_flat)
|
||
|
except TypeError as e:
|
||
|
# We can't transform ahead-of-time compiled calls, since we've
|
||
|
# lowered and compiled for a fixed function signature, and JAX
|
||
|
# transformations change signatures. We interpret a Tracer
|
||
|
# argument as an indication of a transformation attempt. We
|
||
|
# could check this before the executable call, but we'd rather
|
||
|
# avoid isinstance checks on the call path. Seeing a TypeError
|
||
|
# might mean that arguments have JAX-invalid types, which in
|
||
|
# turn might mean some are Tracers.
|
||
|
for arg in args_flat:
|
||
|
if isinstance(arg, core.Tracer):
|
||
|
raise TypeError(
|
||
|
"Cannot apply JAX transformations to a function lowered and "
|
||
|
"compiled for a particular signature. Detected argument of "
|
||
|
f"Tracer type {type(arg)}.") from e
|
||
|
else:
|
||
|
raise
|
||
|
outs = tree_util.tree_unflatten(params.out_tree, out_flat)
|
||
|
return outs, out_flat, args_flat
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
if self._call is None:
|
||
|
self._call = self._executable.create_cpp_call(self._no_kwargs,
|
||
|
self.in_tree,
|
||
|
self.out_tree)
|
||
|
if self._call is None:
|
||
|
params = self._params
|
||
|
def cpp_call_fallback(*args, **kwargs):
|
||
|
outs, _, _ = Compiled.call(params, *args, **kwargs)
|
||
|
return outs
|
||
|
self._call = cpp_call_fallback
|
||
|
return self._call(*args, **kwargs)
|
||
|
|
||
|
|
||
|
class Lowered(Stage):
|
||
|
"""Lowering of a function specialized to argument types and values.
|
||
|
|
||
|
A lowering is a computation ready for compilation. This class
|
||
|
carries a lowering together with the remaining information needed to
|
||
|
later compile and execute it. It also provides a common API for
|
||
|
querying properties of lowered computations across JAX's various
|
||
|
lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.).
|
||
|
"""
|
||
|
__slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"]
|
||
|
|
||
|
args_info: Any # PyTree of ArgInfo
|
||
|
out_tree: tree_util.PyTreeDef
|
||
|
_lowering: XlaLowering
|
||
|
_no_kwargs: bool
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
lowering: XlaLowering,
|
||
|
args_info, # PyTree of ArgInfo
|
||
|
out_tree: tree_util.PyTreeDef,
|
||
|
no_kwargs: bool = False):
|
||
|
self._lowering = lowering
|
||
|
self._no_kwargs = no_kwargs
|
||
|
self.args_info = args_info
|
||
|
self.out_tree = out_tree
|
||
|
|
||
|
@classmethod
|
||
|
def from_flat_info(cls,
|
||
|
lowering: XlaLowering,
|
||
|
in_tree: tree_util.PyTreeDef,
|
||
|
in_avals,
|
||
|
donate_argnums: Tuple[int, ...],
|
||
|
out_tree: tree_util.PyTreeDef,
|
||
|
no_kwargs: bool = False):
|
||
|
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
|
||
|
|
||
|
Args:
|
||
|
in_tree: The ``PyTreeDef`` of (args, kwargs).
|
||
|
out_tree: The ``PyTreeDef`` of the outputs.
|
||
|
no_kwargs: If ``True`` the transformation, and the
|
||
|
``Compiled`` returned from this object will not support keyword
|
||
|
arguments (an error will be raised if some are provided).
|
||
|
"""
|
||
|
return cls(
|
||
|
lowering,
|
||
|
make_args_info(in_tree, in_avals, donate_argnums),
|
||
|
out_tree,
|
||
|
no_kwargs=no_kwargs)
|
||
|
|
||
|
def compile(
|
||
|
self, compiler_options: Optional[CompilerOptions] = None) -> Compiled:
|
||
|
"""Compile, returning a corresponding ``Compiled`` instance."""
|
||
|
kw: Dict[str, Any] = {"compiler_options": compiler_options}
|
||
|
return Compiled(
|
||
|
self._lowering.compile(**kw), # pytype: disable=wrong-keyword-args
|
||
|
self.args_info,
|
||
|
self.out_tree,
|
||
|
no_kwargs=self._no_kwargs,
|
||
|
)
|
||
|
|
||
|
def as_text(self, dialect: Optional[str] = None) -> str:
|
||
|
"""A human-readable text representation of this lowering.
|
||
|
|
||
|
Intended for visualization and debugging purposes. This need not be a valid
|
||
|
nor reliable serialization. It is relayed directly to external callers.
|
||
|
|
||
|
Args:
|
||
|
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
|
||
|
"""
|
||
|
return self._lowering.as_text(dialect)
|
||
|
|
||
|
def compiler_ir(self, dialect: Optional[str] = None) -> Optional[Any]:
|
||
|
"""An arbitrary object representation of this lowering.
|
||
|
|
||
|
Intended for debugging purposes. This is not a valid nor reliable
|
||
|
serialization. The output has no guarantee of consistency across
|
||
|
invocations.
|
||
|
|
||
|
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
|
||
|
runtime.
|
||
|
|
||
|
Args:
|
||
|
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
|
||
|
"""
|
||
|
try:
|
||
|
return self._lowering.compiler_ir(dialect)
|
||
|
except NotImplementedError:
|
||
|
return None
|
||
|
|
||
|
def cost_analysis(self) -> Optional[Any]:
|
||
|
"""A summary of execution cost estimates.
|
||
|
|
||
|
Intended for visualization and debugging purposes. The object output by
|
||
|
this is some simple data structure that can easily be printed or serialized
|
||
|
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
|
||
|
structure can be arbitrary: it may be inconsistent across versions of JAX
|
||
|
and jaxlib, or even across invocations.
|
||
|
|
||
|
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
|
||
|
runtime.
|
||
|
"""
|
||
|
# TODO(frostig): improve annotation (basic pytree of arbitrary structure)
|
||
|
try:
|
||
|
return self._lowering.cost_analysis()
|
||
|
except NotImplementedError:
|
||
|
return None
|
||
|
|
||
|
|
||
|
class Wrapped(Protocol):
|
||
|
"""A function ready to be specialized, lowered, and compiled.
|
||
|
|
||
|
This protocol reflects the output of functions such as
|
||
|
``jax.jit``. Calling it results in JIT (just-in-time) lowering,
|
||
|
compilation, and execution. It can also be explicitly lowered prior
|
||
|
to compilation, and the result compiled prior to execution.
|
||
|
"""
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
"""Executes the wrapped function, lowering and compiling as needed."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def lower(self, *args, **kwargs) -> Lowered:
|
||
|
"""Lower this function explicitly for the given arguments.
|
||
|
|
||
|
A lowered function is staged out of Python and translated to a
|
||
|
compiler's input language, possibly in a backend-dependent
|
||
|
manner. It is ready for compilation but not yet compiled.
|
||
|
|
||
|
Returns:
|
||
|
A ``Lowered`` instance representing the lowering.
|
||
|
"""
|
||
|
raise NotImplementedError
|