from typing import List, Optional from torchgen.api import dispatcher from torchgen.api.types import ( BaseCppType, BaseCType, Binding, boolT, ConstRefCType, CType, longT, NamedCType, tensorT, ) from torchgen.model import ( Argument, BaseTy, BaseType, FunctionSchema, NativeFunction, NativeFunctionsViewGroup, ) # This file describes the translation of JIT schema to API's used # when creating view lambdas that are used by the functionalization pass. # There are two types of lambdas: forward lambdas and reverse lambdas. # These API's mostly follow the dispatcher API, with a few quirks: # - The lambda capture has to convert reference types to value types # - While the forward lambda just directly calls into the at::_ops API # (following the dispatcher convention), the logic here for the reverse lambda # is responsible for generating both the call-site, and the declarations # (which are implemented manually in the at::functionalization::impl namespace). # The lambdas generated for each view op in the functionalization pass are of the form # [capture_arguments](outer_arguments) -> returns_type { # return name(inner_arguments); # } # Define some specific lambda input arguments. base_binding = Binding( name="base", nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))), argument=Argument( name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None ), default=None, ) mutated_view_binding = Binding( name="mutated_view", nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), argument=Argument( name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None ), default=None, ) mutated_view_idx_binding = Binding( name="mutated_view_idx", nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), argument=Argument( name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None ), default=None, ) reapply_views_binding = Binding( name="reapply_views", nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)), argument=Argument( name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None ), default=None, ) InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode") inverse_return_mode_binding = Binding( name="inverse_return_mode", nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)), argument=Argument( name="inverse_return_mode", # NB: not actually a bool but it doesn't matter because this isn't used type=BaseType(BaseTy.bool), default=None, annotation=None, ), default=None, ) # The lambda capture itself doesn't have a name. # The name returned here corresponds to the name of the inner function called by the lambda. def name( g: NativeFunctionsViewGroup, *, is_reverse: bool, include_namespace: bool, reapply_views: Optional[bool] = None, ) -> str: if reapply_views is None: # reapply_views is only important for the fwd lambda, # since we always plumb the runtime "reapply_views" argument into the reverse function. assert is_reverse if is_reverse: return reverse_name(g.view, include_namespace) # in the forward case, we just directly call into the at::_ops API (so we always need the namespace) assert include_namespace assert g.view_copy is not None api_name = ( g.view.func.name.unambiguous_name() if reapply_views else g.view_copy.func.name.unambiguous_name() ) return f"at::_ops::{api_name}::call" def reverse_name(f: NativeFunction, include_namespace: bool) -> str: # for the reverse: we plumb the "reapply_views" flag into that function and support # both copy and non-copy variants. (We could avoid doing that, but that would require # writing out twice as many view inverse functions). api_name = f.func.name.unambiguous_name() # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't) if include_namespace: return f"at::functionalization::FunctionalInverses::{api_name}_inverse" else: return f"{api_name}_inverse" def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]: # capture arguments include all arguments except `self`. # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), # So any reference types (IntArrayRef) need to be converted to value types (vector) args = func.arguments.flat_all assert args[0].type == BaseType(BaseTy.Tensor) non_self_args = args[1:] non_self_value_bindings = [ dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args ] all_bindings = [ inverse_return_mode_binding if is_reverse else reapply_views_binding ] all_bindings.extend(non_self_value_bindings) return all_bindings def returns_type(func: FunctionSchema) -> CType: # Assertion: all view ops return tensor-like outputs assert len(func.returns) >= 1 for ret in func.returns: assert ret.type.is_tensor_like() # However, the return type of the lambda is always an individual tensor. # For multi-tensor outputs, each tensor needs to be tracked individually. return BaseCType(tensorT) def outer_arguments(*, is_reverse: bool) -> List[Binding]: if is_reverse: return [base_binding, mutated_view_binding, mutated_view_idx_binding] else: return [base_binding, mutated_view_idx_binding] def inner_call_index(func: FunctionSchema) -> Optional[Binding]: # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. # When we replay a view op that returns multiple tensors, we need to index into the output appropriately if len(func.returns) > 1 or ( len(func.returns) == 1 and func.returns[0].type.is_list_like() ): return mutated_view_idx_binding return None def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]: args = func.arguments.flat_all assert args[0].type == BaseType(BaseTy.Tensor) non_self_args = args[1:] # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API. # Both of these follow the dispatcher API. non_self_bindings = [dispatcher.argument(a) for a in non_self_args] if not is_reverse: # the forward lambda swaps out the original tensor argument with the lambd arg "base" return [base_binding] + non_self_bindings else: # the reverse lambda does the same, but with an additional "mutated_view" arg # additionally, we have a calling convention: for view ops that return multiple tensor outputs # their corresponding view_inverse function takes in an additional index argument. index_binding = inner_call_index(func) if index_binding is not None: return [ base_binding, mutated_view_binding, inverse_return_mode_binding, index_binding, ] + non_self_bindings else: return [ base_binding, mutated_view_binding, inverse_return_mode_binding, ] + non_self_bindings