154 lines
5.0 KiB
Python
154 lines
5.0 KiB
Python
|
from typing import List, Optional, Sequence, Union
|
||
|
|
||
|
from torchgen import local
|
||
|
from torchgen.api import cpp
|
||
|
|
||
|
from torchgen.api.types import (
|
||
|
ArgName,
|
||
|
BaseCType,
|
||
|
Binding,
|
||
|
boolT,
|
||
|
ConstRefCType,
|
||
|
CType,
|
||
|
deviceT,
|
||
|
layoutT,
|
||
|
ListCType,
|
||
|
MutRefCType,
|
||
|
NamedCType,
|
||
|
OptionalCType,
|
||
|
scalarT,
|
||
|
scalarTypeT,
|
||
|
tensorT,
|
||
|
)
|
||
|
from torchgen.model import (
|
||
|
Argument,
|
||
|
FunctionSchema,
|
||
|
Return,
|
||
|
SelfArgument,
|
||
|
TensorOptionsArguments,
|
||
|
Type,
|
||
|
)
|
||
|
from torchgen.utils import assert_never
|
||
|
|
||
|
# This file describes the translation of JIT schema to the native functions API.
|
||
|
# This looks a lot like the C++ API (which makes historical sense, because the
|
||
|
# idea was you wrote native functions to implement functions in the C++ API),
|
||
|
# but over time we have evolved the C++ API without actually changing our
|
||
|
# native:: kernels. The intention is to make native API and dispatcher API
|
||
|
# line up as closely as possible, since this results in the least overhead
|
||
|
# (no translation is needed from dispatcher API to native API).
|
||
|
#
|
||
|
# NB: this is symint aware, you will get the non-SymInt variant for some
|
||
|
# dispatch entries and SymInt for others.
|
||
|
|
||
|
|
||
|
def name(func: FunctionSchema) -> str:
|
||
|
name = str(func.name.name)
|
||
|
# TODO: delete this!
|
||
|
if func.is_out_fn():
|
||
|
name += "_out"
|
||
|
if func.name.overload_name:
|
||
|
name += f"_{func.name.overload_name}"
|
||
|
return name
|
||
|
|
||
|
|
||
|
def argumenttype_type(
|
||
|
t: Type, *, mutable: bool, binds: ArgName, symint: bool
|
||
|
) -> NamedCType:
|
||
|
if str(t) == "Tensor?":
|
||
|
tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
|
||
|
if mutable and not local.use_const_ref_for_mutable_tensors():
|
||
|
return NamedCType(binds, MutRefCType(tensor_type))
|
||
|
else:
|
||
|
return NamedCType(binds, ConstRefCType(tensor_type))
|
||
|
elif str(t) == "Tensor?[]":
|
||
|
return NamedCType(
|
||
|
binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
|
||
|
)
|
||
|
elif str(t) == "Scalar":
|
||
|
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
|
||
|
elif str(t) == "Scalar?":
|
||
|
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
|
||
|
return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
|
||
|
|
||
|
|
||
|
def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
|
||
|
return cpp.returns_type(rs, symint=symint)
|
||
|
|
||
|
|
||
|
def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
|
||
|
return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
|
||
|
|
||
|
|
||
|
def argument(
|
||
|
a: Union[Argument, SelfArgument, TensorOptionsArguments],
|
||
|
*,
|
||
|
is_out: bool,
|
||
|
symint: bool,
|
||
|
) -> List[Binding]:
|
||
|
# Ideally, we NEVER default native functions. However, there are a number
|
||
|
# of functions that call native:: directly and rely on the defaulting
|
||
|
# existing. So for BC, we generate defaults for non-out variants (but not
|
||
|
# for out variants, where it is impossible to generate an appropriate
|
||
|
# default)
|
||
|
should_default = not is_out
|
||
|
if isinstance(a, Argument):
|
||
|
default: Optional[str] = None
|
||
|
if should_default and a.default is not None:
|
||
|
default = cpp.default_expr(a.default, a.type, symint=symint)
|
||
|
return [
|
||
|
Binding(
|
||
|
nctype=argument_type(a, binds=a.name, symint=symint),
|
||
|
name=a.name,
|
||
|
default=default,
|
||
|
argument=a,
|
||
|
)
|
||
|
]
|
||
|
elif isinstance(a, SelfArgument):
|
||
|
# Erase SelfArgument from the distinction
|
||
|
return argument(a.argument, is_out=is_out, symint=symint)
|
||
|
elif isinstance(a, TensorOptionsArguments):
|
||
|
default = None
|
||
|
if should_default:
|
||
|
default = "{}"
|
||
|
# TODO: Not sure why the arguments assigned here are for
|
||
|
# TensorOptionsArguments and not the constituent pieces. It seems
|
||
|
# to matter
|
||
|
return [
|
||
|
Binding(
|
||
|
nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
|
||
|
name="dtype",
|
||
|
default=default,
|
||
|
argument=a,
|
||
|
),
|
||
|
Binding(
|
||
|
nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
|
||
|
name="layout",
|
||
|
default=default,
|
||
|
argument=a,
|
||
|
),
|
||
|
Binding(
|
||
|
nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
|
||
|
name="device",
|
||
|
default=default,
|
||
|
argument=a,
|
||
|
),
|
||
|
Binding(
|
||
|
nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
|
||
|
name="pin_memory",
|
||
|
default=default,
|
||
|
argument=a,
|
||
|
),
|
||
|
]
|
||
|
else:
|
||
|
assert_never(a)
|
||
|
|
||
|
|
||
|
def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]:
|
||
|
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
||
|
args.extend(func.arguments.non_out)
|
||
|
args.extend(func.arguments.out)
|
||
|
return [
|
||
|
r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
|
||
|
]
|