188 lines
6.4 KiB
Python
188 lines
6.4 KiB
Python
|
import cmath
|
||
|
import math
|
||
|
import warnings
|
||
|
|
||
|
from collections import OrderedDict
|
||
|
from typing import Dict, Optional
|
||
|
|
||
|
import torch
|
||
|
import torch.backends.cudnn as cudnn
|
||
|
|
||
|
from ..nn.modules.utils import _list_with_default, _pair, _quadruple, _single, _triple
|
||
|
|
||
|
_builtin_table: Optional[Dict[int, str]] = None
|
||
|
|
||
|
_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950
|
||
|
|
||
|
_builtin_ops = [
|
||
|
# Pairs of (function, op_name)
|
||
|
(_pair, "aten::_pair"),
|
||
|
(_quadruple, "aten::_quadruple"),
|
||
|
(_single, "aten::_single"),
|
||
|
(_triple, "aten::_triple"),
|
||
|
(_list_with_default, "aten::list_with_default"),
|
||
|
(OrderedDict, "aten::dict"),
|
||
|
(dict, "aten::dict"),
|
||
|
(cudnn.is_acceptable, "aten::cudnn_is_acceptable"),
|
||
|
(math.ceil, "aten::ceil"),
|
||
|
(math.copysign, "aten::copysign"),
|
||
|
(math.erf, "aten::erf"),
|
||
|
(math.erfc, "aten::erfc"),
|
||
|
(math.exp, "aten::exp"),
|
||
|
(math.expm1, "aten::expm1"),
|
||
|
(math.fabs, "aten::fabs"),
|
||
|
(math.floor, "aten::floor"),
|
||
|
(math.gamma, "aten::gamma"),
|
||
|
(math.lgamma, "aten::lgamma"),
|
||
|
(math.log, "aten::log"),
|
||
|
(math.log10, "aten::log10"),
|
||
|
(math.log1p, "aten::log1p"),
|
||
|
(math.pow, "aten::pow"),
|
||
|
(math.sqrt, "aten::sqrt"),
|
||
|
(math.isnan, "aten::isnan"),
|
||
|
(math.asinh, "aten::asinh"),
|
||
|
(math.atanh, "aten::atanh"),
|
||
|
(math.cosh, "aten::cosh"),
|
||
|
(math.sinh, "aten::sinh"),
|
||
|
(math.tanh, "aten::tanh"),
|
||
|
(math.acos, "aten::acos"),
|
||
|
(math.asin, "aten::asin"),
|
||
|
(math.atan, "aten::atan"),
|
||
|
(math.atan2, "aten::atan2"),
|
||
|
(math.cos, "aten::cos"),
|
||
|
(math.sin, "aten::sin"),
|
||
|
(math.tan, "aten::tan"),
|
||
|
(math.asinh, "aten::asinh"),
|
||
|
(math.atanh, "aten::atanh"),
|
||
|
(math.acosh, "aten::acosh"),
|
||
|
(math.fmod, "aten::fmod"),
|
||
|
(math.modf, "aten::modf"),
|
||
|
(math.factorial, "aten::factorial"),
|
||
|
(math.frexp, "aten::frexp"),
|
||
|
(math.isinf, "aten::isinf"),
|
||
|
(math.degrees, "aten::degrees"),
|
||
|
(math.radians, "aten::radians"),
|
||
|
(cmath.isnan, "aten::isnan"),
|
||
|
(cmath.isfinite, "aten::isfinite"),
|
||
|
(cmath.isinf, "aten::isinf"),
|
||
|
(cmath.phase, "aten::angle"),
|
||
|
(cmath.rect, "aten::polar"),
|
||
|
(cmath.log, "aten::log"),
|
||
|
(cmath.log10, "aten::log10"),
|
||
|
(cmath.sqrt, "aten::sqrt"),
|
||
|
(cmath.exp, "aten::exp"),
|
||
|
(cmath.sin, "aten::sin"),
|
||
|
(cmath.tan, "aten::tan"),
|
||
|
(cmath.cos, "aten::cos"),
|
||
|
(cmath.asin, "aten::asin"),
|
||
|
(cmath.acos, "aten::acos"),
|
||
|
(cmath.atan, "aten::atan"),
|
||
|
(cmath.sinh, "aten::sinh"),
|
||
|
(cmath.cosh, "aten::cosh"),
|
||
|
(cmath.tanh, "aten::tanh"),
|
||
|
(cmath.asinh, "aten::asinh"),
|
||
|
(cmath.acosh, "aten::acosh"),
|
||
|
(cmath.atanh, "aten::atanh"),
|
||
|
(math.ldexp, "aten::ldexp"),
|
||
|
(torch._assert, "aten::_assert"),
|
||
|
(torch.autograd.grad, "aten::grad"),
|
||
|
(torch.autograd.backward, "aten::backward"),
|
||
|
(torch._C._infer_size, "aten::_infer_size"),
|
||
|
(torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"), # type: ignore[attr-defined]
|
||
|
(torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"),
|
||
|
(torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"),
|
||
|
(torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"),
|
||
|
(torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"),
|
||
|
(torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
|
||
|
(torch._C._get_tracing_state, "aten::_get_tracing_state"),
|
||
|
(torch._C._get_cpu_capability, "aten::_get_cpu_capability"),
|
||
|
(warnings.warn, "aten::warn"),
|
||
|
(torch._VF.stft, "aten::stft"), # type: ignore[attr-defined]
|
||
|
(torch._VF.istft, "aten::istft"), # type: ignore[attr-defined]
|
||
|
(torch._VF.cdist, "aten::cdist"), # type: ignore[attr-defined]
|
||
|
(torch._VF.norm, "aten::norm"), # type: ignore[attr-defined]
|
||
|
(torch._VF.unique_dim, "aten::unique_dim"),
|
||
|
(torch._VF.unique_consecutive, "aten::unique_consecutive"), # type: ignore[attr-defined]
|
||
|
(torch._VF.nuclear_norm, "aten::nuclear_norm"),
|
||
|
(torch._VF.frobenius_norm, "aten::frobenius_norm"),
|
||
|
(torch._VF.tensordot, "aten::tensordot"), # type: ignore[attr-defined]
|
||
|
]
|
||
|
|
||
|
# ops in torch.functional are bound to torch
|
||
|
# in these cases, we want to resolve the function to their python implementation
|
||
|
# instead looking up a builtin "aten::" schema
|
||
|
|
||
|
|
||
|
def _gen_torch_functional_registered_ops():
|
||
|
# eventually ops should encompass all of torch/functional.py, (torch.functional.__all__)
|
||
|
# but we are currently only able to compile some of the functions. additionally,
|
||
|
# some functions directly map to their aten:: implementations.
|
||
|
# TODO: add support for more ops
|
||
|
ops = [
|
||
|
"stft",
|
||
|
"istft",
|
||
|
"lu",
|
||
|
"cdist",
|
||
|
"norm",
|
||
|
"unique",
|
||
|
"unique_consecutive",
|
||
|
"tensordot",
|
||
|
]
|
||
|
return {getattr(torch.functional, name) for name in ops}
|
||
|
|
||
|
|
||
|
_functional_registered_ops = _gen_torch_functional_registered_ops()
|
||
|
|
||
|
|
||
|
def _is_special_functional_bound_op(fn):
|
||
|
return fn in _functional_registered_ops
|
||
|
|
||
|
|
||
|
# lazily built to ensure the correct initialization order
|
||
|
def _get_builtin_table():
|
||
|
global _builtin_table
|
||
|
if _builtin_table is not None:
|
||
|
return _builtin_table
|
||
|
_builtin_table = {}
|
||
|
|
||
|
def register_all(mod):
|
||
|
for name in dir(mod):
|
||
|
v = getattr(mod, name)
|
||
|
if (
|
||
|
callable(v)
|
||
|
and not _is_special_functional_bound_op(v)
|
||
|
and v is not torch.no_grad
|
||
|
and v is not torch.autocast
|
||
|
):
|
||
|
# Fixup inconsistency in segment_reduce
|
||
|
if name == "_segment_reduce":
|
||
|
name = name[1:]
|
||
|
_builtin_ops.append((v, "aten::" + name))
|
||
|
|
||
|
for mod in _modules_containing_builtins:
|
||
|
register_all(mod)
|
||
|
|
||
|
_builtin_ops.append((math.gcd, "aten::gcd"))
|
||
|
_builtin_ops.append((math.isfinite, "aten::isfinite"))
|
||
|
_builtin_ops.append((math.remainder, "aten::mathremainder")) # type: ignore[attr-defined]
|
||
|
|
||
|
import torch.distributed.autograd as dist_autograd
|
||
|
|
||
|
if dist_autograd.is_available():
|
||
|
_builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients"))
|
||
|
_builtin_ops.append((dist_autograd.backward, "aten::dist_backward"))
|
||
|
|
||
|
# populate the _builtin_table from _builtin_ops
|
||
|
for builtin, aten_op in _builtin_ops:
|
||
|
_builtin_table[id(builtin)] = aten_op
|
||
|
|
||
|
return _builtin_table
|
||
|
|
||
|
|
||
|
def _register_builtin(fn, op):
|
||
|
_get_builtin_table()[id(fn)] = op
|
||
|
|
||
|
|
||
|
def _find_builtin(fn):
|
||
|
return _get_builtin_table().get(id(fn))
|