from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple, Union import torchgen.api.ufunc as ufunc from torchgen.api.translate import translate from torchgen.api.types import ( BaseCType, Binding, CType, Expr, NamedCType, opmath_t, scalar_t, StructuredImplSignature, VectorizedCType, ) from torchgen.api.ufunc import UfunctorBindings from torchgen.context import with_native_function from torchgen.model import ( Argument, BaseTy, BaseType, DispatchKey, NativeFunctionsGroup, ScalarType, UfuncKey, ) from torchgen.utils import OrderedSet # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # CUDA STUFF # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # NB: not bothering to generate dispatch stub forward declaration in header, # we can just paste it whereever necessary # TODO: use BackendIndex # dispatch_key: DispatchKey # only CPU/CUDA right now # Represents functors for implementing CUDA ufuncs. # Functors are templated by scalar_t because when USERS instantiate functors # they are templated. A functor looks something like this: # # template # struct CUDAFunctorOnSelf_add { # using opmath_t = at::opmath_type; # opmath_t other_; # opmath_t alpha_; # CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) # : other_(other), alpha_(alpha) {} # __device__ scalar_t operator()(scalar_t self) { # return ufunc::add(static_cast(self), other_, alpha_); # } # }; # @dataclass(frozen=True) class UfunctorSignature: g: NativeFunctionsGroup scalar_tensor_idx: Optional[int] name: str def arguments(self) -> UfunctorBindings: return ufunc.ufunctor_arguments( self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t ) def fields(self) -> List[Binding]: # fields are renamed to have a trailing underscore, as is conventional return [b.rename(f"{b.name}_") for b in self.arguments().ctor] def returns_type(self) -> CType: # TODO: don't hardcode; return type will be inferred based on tags on # the native function return BaseCType(scalar_t) def decl_fields(self) -> str: return "\n".join(f"{f.type} {f.name};" for f in self.fields()) def inline_defn_ctor(self) -> str: args_str = ", ".join(a.decl() for a in self.arguments().ctor) # NB: hypothetically could do this with translate but the # transition here is very regular init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor) return f"{self.name}({args_str}) : {init_str} {{}}" def decl_apply(self) -> str: args_str = ", ".join(a.decl() for a in self.arguments().apply) return f"{self.returns_type().cpp_type()} operator()({args_str}) const" @dataclass(frozen=True) class UfuncSignature: g: NativeFunctionsGroup name: str compute_t: CType def arguments(self) -> List[Binding]: return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t) def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str: return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})" # steps: # 1. take the functional signature # 2. use api.ufunc to convert it to template signature. this establishes # the type of the template function # 3. use api.ufunc (II) to generate a split struct / operator() signature. # this establish context in which we call the template signature # # StructuredImplSignature context # ~> functor constructor sig # # Functor constructor context # ~> functor fields sig # # Functor apply context (functor fields + functor apply sig) # ~> template sig # def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool: num_tensors = sum( 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like() ) return num_tensors == 2 def compute_ufunc_cuda_functors( g: NativeFunctionsGroup, ) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]: # First, build the functors. ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {} ufunctors: List[str] = [] loops = g.out.ufunc_inner_loop scalar_tensor_idx_lookup = { UfuncKey.CUDAFunctorOnSelf: 1, UfuncKey.CUDAFunctorOnOther: 0, UfuncKey.CUDAFunctor: None, } if eligible_for_binary_scalar_specialization(g): keys = [ UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther, UfuncKey.CUDAFunctor, ] else: keys = [UfuncKey.CUDAFunctor] for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]: assert k not in loops, f"cannot use {k} on non-binary function" for k in keys: # If the key was directly defined, skip functor codegen; we assume the # user already done it for us if k in loops: ufunctor_sig = UfunctorSignature( g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name ) for dtype in loops[k].supported_dtypes: ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig continue # Note [ScalarOnly and Generic must match names for CUDA] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Otherwise, look in ANY of the generic entries. For simplicity of # codegen, both ScalarOnly and Generic are defined, the ufunc name # must match (if they didn't match, we'd have to generate distinct # functors per dtype, which is awful, so we're not going to do it unless # someone really forces us to) ufunc_name = None supported_dtypes: OrderedSet[ScalarType] = OrderedSet() for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]: if lk not in loops: continue if ufunc_name is None: ufunc_name = loops[lk].name else: # See Note [ScalarOnly and Generic must match names for CUDA] assert ( ufunc_name == loops[lk].name ), "ScalarOnly and Generic must have same ufunc name" supported_dtypes |= loops[lk].supported_dtypes assert ufunc_name is not None name = f"{k}_{ufunc_name}" ufunctor_sig = UfunctorSignature( g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name ) for dtype in supported_dtypes: ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig ufunc_sig = UfuncSignature( g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t) ) apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply ufunctors.append( f""" template struct {ufunctor_sig.name} {{ using opmath_t = at::opmath_type; {ufunctor_sig.decl_fields()} {ufunctor_sig.inline_defn_ctor()} __device__ {ufunctor_sig.decl_apply()} {{ return {ufunc_sig.call(apply_ctx)}; }} }}; """ ) return ufunctor_sigs, "\n".join(ufunctors) @dataclass(frozen=True) class BinaryScalarSpecializationConfig: scalar_idx: int ctor_tensor: str ufunc_key: UfuncKey BinaryScalarSpecializationConfigs = [ BinaryScalarSpecializationConfig( scalar_idx=0, ctor_tensor="self", ufunc_key=UfuncKey.CUDAFunctorOnOther, ), BinaryScalarSpecializationConfig( scalar_idx=1, ctor_tensor="other", ufunc_key=UfuncKey.CUDAFunctorOnSelf, ), ] def compute_ufunc_cuda_dtype_body( g: NativeFunctionsGroup, dtype: ScalarType, inner_loops: Dict[UfuncKey, UfunctorSignature], parent_ctx: Sequence[Binding], ) -> str: body = "using opmath_t = at::opmath_type;" body += "if (false) {}\n" # for ease of codegen for config in BinaryScalarSpecializationConfigs: if config.ufunc_key not in inner_loops: continue ufunctor_sig = inner_loops[config.ufunc_key] scalar_idx = config.scalar_idx + 1 # Make a copy and at the same time widen the type (not permissible # without copy; we don't want to mutate the input argument anyway) ctx: List[Union[Expr, Binding]] = list(parent_ctx) ctx.append( Expr( expr=f"iter.scalar_value({scalar_idx})", type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)), ) ) ufunctor_ctor_exprs_str = ", ".join( a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor) ) # NB: ufunctor must be allocated before iter.remove_operand is called, # as it relies on iter body += f"""\ else if (iter.is_cpu_scalar({scalar_idx})) {{ {ufunctor_sig.name} ufunctor({ufunctor_ctor_exprs_str}); iter.remove_operand({scalar_idx}); gpu_kernel(iter, ufunctor); }}""" ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor] ufunctor_ctor_exprs_str = ", ".join( a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor) ) body += f""" else {{ gpu_kernel(iter, {ufunctor_sig.name}({ufunctor_ctor_exprs_str})); }} """ return body @with_native_function def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str: # First, build the functors, indexing them by dtype ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g) # Next, build the conditionals sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA)) dtype_cases = [] for dtype, inner_ufunc_sigs in ufunctor_sigs.items(): dtype_cases.append( f""" AT_DISPATCH_CASE(at::ScalarType::{dtype}, [&]() {{ {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())} }} ) """ ) dtype_cases_str = "\n".join(dtype_cases) stub_sig = StubSignature(g) return f""" {ufunctors} {stub_sig.type_defn()}; {stub_sig.dispatch_decl()}; {stub_sig.kernel_defn()} {{ AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}", {dtype_cases_str} ); }} REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); {sig.defn()} {{ {stub_sig.direct_call(sig.arguments())}; }} """ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # CPU STUFF # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # @dataclass(frozen=True) class StubSignature: g: NativeFunctionsGroup @property def name(self) -> str: return f"{str(self.g.functional.func.name.name)}_stub" @property def kernel_name(self) -> str: return f"{str(self.g.functional.func.name.name)}_kernel" @property def type_name(self) -> str: return f"{str(self.g.functional.func.name.name)}_fn" def arguments(self) -> List[Binding]: return ufunc.stub_arguments(self.g) def type(self) -> str: cpp_args = self.arguments() return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})" def dispatch_decl(self) -> str: return f"DECLARE_DISPATCH({self.type_name}, {self.name})" def dispatch_defn(self) -> str: return f"DEFINE_DISPATCH({self.name})" def kernel_defn(self) -> str: return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})" def type_defn(self) -> str: return f"using {self.type_name} = {self.type()}" # must be called from context where this is TensorIteratorBase* def call(self, ctx: Sequence[Binding]) -> str: return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" # used in CUDA to skip the unnecessary dynamic dispatch def direct_call(self, ctx: Sequence[Binding]) -> str: return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" @with_native_function def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str: stub_sig = StubSignature(g) sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU)) return f""" {stub_sig.type_defn()}; {stub_sig.dispatch_decl()}; {stub_sig.dispatch_defn()}; {sig.defn()} {{ {stub_sig.call(sig.arguments())}; }} """ def compute_ufunc_cpu_dtype_body( g: NativeFunctionsGroup, dtype: ScalarType, inner_loops: Dict[UfuncKey, UfuncSignature], parent_ctx: Sequence[Binding], ) -> str: assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector} scalar_loop = inner_loops[UfuncKey.CPUScalar] vec_loop = None if UfuncKey.CPUVector in inner_loops: vec_loop = inner_loops[UfuncKey.CPUVector] # NB: We DON'T use translate here, because translate is # incapable of CSE'ing the scalar accesses in case it is also # used by Vectorized; also, the unpacking here is very simple # and only affects Scalar; everything else is implicitly captured # by the lambda # Setup scalar in scope body = [] ctx = [] for b in parent_ctx: if isinstance(b.argument, Argument) and b.argument.type != BaseType( BaseTy.Scalar ): continue body.append(f"auto _s_{b.name} = {b.name}.to();") ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t)))) if vec_loop is not None: for b in parent_ctx: if isinstance(b.argument, Argument) and b.argument.type != BaseType( BaseTy.Scalar ): continue body.append( f"auto _v_{b.name} = at::vec::Vectorized(_s_{b.name});" ) ctx.append( Expr( f"_v_{b.name}", NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))), ) ) # Setup lambda signature # NB: simplified version of ufunctor_arguments scalar_bindings = [] vec_bindings = [] for a in g.functional.func.arguments.flat_non_out: if not a.type.is_tensor_like(): continue assert a.type == BaseType(BaseTy.Tensor) scalar_bindings.append( Binding( name=a.name, nctype=NamedCType(a.name, BaseCType(scalar_t)), argument=a, ) ) if vec_loop is not None: vec_bindings.append( Binding( name=a.name, nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))), argument=a, ) ) def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]: r: List[Union[Expr, Binding]] = [] r.extend(ctx) r.extend(b) return r body_str = "\n".join(body) if vec_loop is not None: return f""" {body_str} cpu_kernel_vec(iter, [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} ); """ else: return f""" {body_str} cpu_kernel(iter, [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }} ); """ @with_native_function def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: stub_sig = StubSignature(g) # Reindex the ufunc by dtypes; processing generic/scalaronly as well loops = g.out.ufunc_inner_loop ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {} for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]: lks = [] # ORDER MATTERS: this specifies overriding precedence if k in loops: # should happen rarely lks.append(k) if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar: lks.append(UfuncKey.ScalarOnly) if UfuncKey.Generic in loops: lks.append(UfuncKey.Generic) # TODO: don't hardcode ufunc:: namespace here, should be centralized smh for lk in lks: for dtype in loops[lk].supported_dtypes: compute_t: CType if k is UfuncKey.CPUScalar: compute_t = BaseCType(scalar_t) elif k is UfuncKey.CPUVector: compute_t = VectorizedCType(BaseCType(scalar_t)) else: raise AssertionError() inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {}) if k not in inner_ufunc_sigs: inner_ufunc_sigs[k] = UfuncSignature( g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t ) # Build the conditionals dtype_cases = [] for dtype, inner_ufunc_sigs in ufunc_sigs.items(): dtype_cases.append( f""" AT_DISPATCH_CASE(at::ScalarType::{dtype}, [&]() {{ {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())} }} ) """ ) dtype_cases_str = "\n".join(dtype_cases) return f""" namespace {{ {stub_sig.kernel_defn()} {{ AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}", {dtype_cases_str} ); }} }} // anonymous namespace {stub_sig.type_defn()}; {stub_sig.dispatch_decl()}; REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); """