210 lines
6.5 KiB
Python
210 lines
6.5 KiB
Python
|
from dataclasses import dataclass
|
||
|
from typing import List, Optional
|
||
|
|
||
|
import torchgen.api.types as api_types
|
||
|
|
||
|
from torchgen.api import cpp, structured
|
||
|
from torchgen.api.types import (
|
||
|
ArgName,
|
||
|
BaseCppType,
|
||
|
BaseCType,
|
||
|
Binding,
|
||
|
ConstRefCType,
|
||
|
CType,
|
||
|
NamedCType,
|
||
|
scalarT,
|
||
|
)
|
||
|
from torchgen.model import (
|
||
|
Argument,
|
||
|
BaseTy,
|
||
|
BaseType,
|
||
|
DispatchKey,
|
||
|
FunctionSchema,
|
||
|
NativeFunctionsGroup,
|
||
|
Type,
|
||
|
)
|
||
|
|
||
|
|
||
|
def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
|
||
|
assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
|
||
|
return f"ufunc_{func.name.name}_{dispatch_key}"
|
||
|
|
||
|
|
||
|
def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
|
||
|
return schema_kernel_name(g.out.func, dispatch_key)
|
||
|
|
||
|
|
||
|
# Tensors are omitted (as they are stored in TensorIterator), everything else is
|
||
|
# passed along (technically, we can pass tensors along too, it just wastes
|
||
|
# argument registers)
|
||
|
#
|
||
|
# NB: used for CPU only
|
||
|
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
|
||
|
# Dispatch stubs are always plain ints
|
||
|
r = cpp.valuetype_type(t, binds=binds, symint=False)
|
||
|
if r is not None:
|
||
|
return r
|
||
|
|
||
|
if t == BaseType(BaseTy.Scalar):
|
||
|
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
|
||
|
elif t == BaseType(BaseTy.Tensor):
|
||
|
return None
|
||
|
else:
|
||
|
raise AssertionError(f"unrecognized type {repr(t)}")
|
||
|
|
||
|
|
||
|
def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
|
||
|
if scalar_t == api_types.scalar_t:
|
||
|
return api_types.opmath_t
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
# NB: Tensors in constructor are stored in opmath_t, not scalar_t
|
||
|
# because Tensor in constructor = its a scalar tensor partially applied =
|
||
|
# it can be higher precision and we want to compute in that higher precision
|
||
|
#
|
||
|
# NB: CUDA only
|
||
|
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
|
||
|
r = cpp.valuetype_type(t, binds=binds, symint=False)
|
||
|
if r is not None:
|
||
|
return r
|
||
|
|
||
|
if t == BaseType(BaseTy.Scalar):
|
||
|
return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
|
||
|
elif t == BaseType(BaseTy.Tensor):
|
||
|
return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
|
||
|
else:
|
||
|
raise AssertionError(f"unrecognized type {repr(t)}")
|
||
|
|
||
|
|
||
|
# Only Tensors ever get passed directly to operator()
|
||
|
#
|
||
|
# NB: CUDA only
|
||
|
# (Actually, this works for CPU too)
|
||
|
def ufunctor_apply_type(
|
||
|
t: Type, *, binds: ArgName, scalar_t: BaseCppType
|
||
|
) -> NamedCType:
|
||
|
if t == BaseType(BaseTy.Tensor):
|
||
|
return NamedCType(binds, BaseCType(scalar_t))
|
||
|
else:
|
||
|
raise AssertionError(f"unrecognized type {repr(t)}")
|
||
|
|
||
|
|
||
|
# The actual ufunc template function the user writes. Everything here
|
||
|
# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t
|
||
|
# in CPU
|
||
|
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
|
||
|
r = cpp.valuetype_type(t, binds=binds, symint=False)
|
||
|
if r is not None:
|
||
|
return r
|
||
|
|
||
|
if t == BaseType(BaseTy.Scalar):
|
||
|
return NamedCType(binds, compute_t)
|
||
|
elif t == BaseType(BaseTy.Tensor):
|
||
|
return NamedCType(binds, compute_t)
|
||
|
else:
|
||
|
raise AssertionError(f"unrecognized type {repr(t)}")
|
||
|
|
||
|
|
||
|
def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
|
||
|
return Binding(
|
||
|
nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
|
||
|
name=a.name,
|
||
|
default=None,
|
||
|
argument=a,
|
||
|
)
|
||
|
|
||
|
|
||
|
def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
|
||
|
return Binding(
|
||
|
nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
|
||
|
name=a.name,
|
||
|
default=None,
|
||
|
argument=a,
|
||
|
)
|
||
|
|
||
|
|
||
|
def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
|
||
|
return Binding(
|
||
|
nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
|
||
|
name=a.name,
|
||
|
default=None,
|
||
|
argument=a,
|
||
|
)
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class UfunctorBindings:
|
||
|
ctor: List[Binding]
|
||
|
apply: List[Binding]
|
||
|
|
||
|
|
||
|
# ufunctors are a CUDA-only concept representing functors that take some of
|
||
|
# their arguments on a host-side constructor, and the rest in the device-side
|
||
|
# apply. E.g.,
|
||
|
#
|
||
|
# template <typename scalar_t>
|
||
|
# struct CUDAFunctorOnSelf_add {
|
||
|
# using opmath_t = at::opmath_type<scalar_t>;
|
||
|
# opmath_t other_;
|
||
|
# opmath_t alpha_;
|
||
|
# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
|
||
|
# __device__ scalar_t operator()(scalar_t self) {
|
||
|
# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
|
||
|
# }
|
||
|
# };
|
||
|
#
|
||
|
# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
|
||
|
# to the operator() definition
|
||
|
def ufunctor_arguments(
|
||
|
g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType
|
||
|
) -> UfunctorBindings:
|
||
|
ctor = []
|
||
|
apply = []
|
||
|
for a in g.functional.func.arguments.flat_non_out:
|
||
|
if a.type.is_tensor_like():
|
||
|
if scalar_tensor_idx == 0:
|
||
|
# put it in the ctor anyway
|
||
|
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
|
||
|
scalar_tensor_idx = None
|
||
|
else:
|
||
|
if scalar_tensor_idx is not None:
|
||
|
scalar_tensor_idx -= 1
|
||
|
apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
|
||
|
else:
|
||
|
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
|
||
|
assert scalar_tensor_idx is None
|
||
|
return UfunctorBindings(ctor=ctor, apply=apply)
|
||
|
|
||
|
|
||
|
# ufuncs are the inner loop template functions that you wrote in ufunc/add.h
|
||
|
# which do the actual computation in question. E.g.,
|
||
|
#
|
||
|
# template <typename T>
|
||
|
# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
|
||
|
# return self + alpha * other;
|
||
|
# }
|
||
|
#
|
||
|
# In this file, we refer to T as compute_t which is bound by caller
|
||
|
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]:
|
||
|
return [
|
||
|
ufunc_argument(a, compute_t=compute_t)
|
||
|
for a in g.functional.func.arguments.flat_non_out
|
||
|
]
|
||
|
|
||
|
|
||
|
# Stubs are the DispatchStub trampolines that CPU kernels use to get to their
|
||
|
# vectorized versions. E.g.,
|
||
|
#
|
||
|
# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
|
||
|
# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
|
||
|
def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
||
|
# stubs drop all tensor arguments (they are implicit in the TensorIterator
|
||
|
# argument and keep everything else)
|
||
|
return [
|
||
|
r
|
||
|
for a in g.out.func.arguments.flat_non_out
|
||
|
if not a.type.is_tensor_like()
|
||
|
for r in structured.argument(a)
|
||
|
]
|