from collections import defaultdict, namedtuple from typing import Any, Dict, List, Optional, Set, Tuple import yaml from torchgen.executorch.model import ETKernelIndex, ETKernelKey from torchgen.gen import LineLoader, parse_native_yaml from torchgen.model import ( BackendMetadata, DispatchKey, FunctionSchema, NativeFunction, OperatorName, ) from torchgen.utils import NamespaceHelper # Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices. ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"]) # Fields in native_functions.yaml used to determine which kernels should be used ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"] def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]: """Given a loaded yaml representing kernel assignment information, extract the mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance) Args: ei: Dict keys {kernels, type_alias, dim_order_alias} See ETKernelKey for description of arguments """ e = ei.copy() if (kernels := e.pop("kernels", None)) is None: return {} type_alias: Dict[str, List[str]] = e.pop("type_alias", {}) # type: ignore[assignment] dim_order_alias: Dict[str, List[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment] dim_order_alias.pop("__line__", None) kernel_mapping: Dict[ETKernelKey, BackendMetadata] = {} for entry in kernels: # type: ignore[attr-defined] arg_meta = entry.get("arg_meta") if arg_meta is not None: arg_meta.pop("__line__") kernel_name = entry.get("kernel_name") namespace_helper = NamespaceHelper.from_namespaced_entity( kernel_name, max_level=3 ) kernel_namespace = namespace_helper.get_cpp_namespace(default="at") backend_metadata = BackendMetadata( kernel=namespace_helper.entity_name, structured=False, cpp_namespace=(kernel_namespace + "::native"), ) kernel_keys = ( [ETKernelKey((), default=True)] if arg_meta is None else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type] ) for kernel_key in kernel_keys: assert kernel_key not in kernel_mapping, ( "Duplicate kernel key: " + str(kernel_key) + " " + str(e) ) kernel_mapping[kernel_key] = backend_metadata return kernel_mapping def parse_et_yaml_struct(es: object) -> ETKernelIndex: """Given a loaded yaml representing a list of operators, for each op extract the mapping of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance that should be used by the kernel key). """ indices: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] = {} for ei in es: # type: ignore[attr-defined] e = ei.copy() funcs = e.pop("func") assert isinstance(funcs, str), f"not a str: {funcs}" namespace_helper = NamespaceHelper.from_namespaced_entity( namespaced_entity=funcs, max_level=1 ) opname = FunctionSchema.parse(namespace_helper.entity_name).name assert opname not in indices, f"Duplicate func found in yaml: {opname} already" if len(index := parse_from_yaml(e)) != 0: indices[opname] = index return ETKernelIndex(indices) def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]: """Given a loaded yaml representing a list of operators, extract the kernel key related fields indexed by the operator name. """ fields: Dict[OperatorName, Dict[str, Any]] = defaultdict(dict) for ei in es: # type: ignore[attr-defined] funcs = ei.get("func") assert isinstance(funcs, str), f"not a str: {funcs}" namespace_helper = NamespaceHelper.from_namespaced_entity( namespaced_entity=funcs, max_level=1 ) opname = FunctionSchema.parse(namespace_helper.entity_name).name for field in ET_FIELDS: if (value := ei.get(field)) is not None: fields[opname][field] = value return fields def parse_et_yaml( path: str, tags_yaml_path: str, ignore_keys: Optional[Set[DispatchKey]] = None, skip_native_fns_gen: bool = False, ) -> Tuple[List[NativeFunction], Dict[OperatorName, Dict[str, Any]]]: """Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict of fields to persist from native_functions.yaml to functions.yaml """ with open(path) as f: es = yaml.load(f, Loader=LineLoader) et_kernel = extract_kernel_fields(es) # Remove ET specific fields from entries for BC compatibility strip_et_fields(es) native_yaml = parse_native_yaml( path, tags_yaml_path, ignore_keys, skip_native_fns_gen=skip_native_fns_gen, loaded_yaml=es, ) return native_yaml.native_functions, et_kernel def strip_et_fields(es: object) -> None: """Given a loaded yaml representing a list of operators, remove ET specific fields from every entries for BC compatibility """ for entry in es: # type: ignore[attr-defined] for field in ET_FIELDS: entry.pop(field, None)