Traktor/myenv/Lib/site-packages/torchgen/executorch/parse.py

152 lines
5.3 KiB
Python
Raw Permalink Normal View History

2024-05-26 05:12:46 +02:00
from collections import defaultdict, namedtuple
from typing import Any, Dict, List, Optional, Set, Tuple
import yaml
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
from torchgen.gen import LineLoader, parse_native_yaml
from torchgen.model import (
BackendMetadata,
DispatchKey,
FunctionSchema,
NativeFunction,
OperatorName,
)
from torchgen.utils import NamespaceHelper
# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices.
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"])
# Fields in native_functions.yaml used to determine which kernels should be used
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]:
"""Given a loaded yaml representing kernel assignment information, extract the
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
Args:
ei: Dict keys {kernels, type_alias, dim_order_alias}
See ETKernelKey for description of arguments
"""
e = ei.copy()
if (kernels := e.pop("kernels", None)) is None:
return {}
type_alias: Dict[str, List[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
dim_order_alias: Dict[str, List[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
dim_order_alias.pop("__line__", None)
kernel_mapping: Dict[ETKernelKey, BackendMetadata] = {}
for entry in kernels: # type: ignore[attr-defined]
arg_meta = entry.get("arg_meta")
if arg_meta is not None:
arg_meta.pop("__line__")
kernel_name = entry.get("kernel_name")
namespace_helper = NamespaceHelper.from_namespaced_entity(
kernel_name, max_level=3
)
kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
backend_metadata = BackendMetadata(
kernel=namespace_helper.entity_name,
structured=False,
cpp_namespace=(kernel_namespace + "::native"),
)
kernel_keys = (
[ETKernelKey((), default=True)]
if arg_meta is None
else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type]
)
for kernel_key in kernel_keys:
assert kernel_key not in kernel_mapping, (
"Duplicate kernel key: " + str(kernel_key) + " " + str(e)
)
kernel_mapping[kernel_key] = backend_metadata
return kernel_mapping
def parse_et_yaml_struct(es: object) -> ETKernelIndex:
"""Given a loaded yaml representing a list of operators, for each op extract the mapping
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
that should be used by the kernel key).
"""
indices: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] = {}
for ei in es: # type: ignore[attr-defined]
e = ei.copy()
funcs = e.pop("func")
assert isinstance(funcs, str), f"not a str: {funcs}"
namespace_helper = NamespaceHelper.from_namespaced_entity(
namespaced_entity=funcs, max_level=1
)
opname = FunctionSchema.parse(namespace_helper.entity_name).name
assert opname not in indices, f"Duplicate func found in yaml: {opname} already"
if len(index := parse_from_yaml(e)) != 0:
indices[opname] = index
return ETKernelIndex(indices)
def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]:
"""Given a loaded yaml representing a list of operators, extract the
kernel key related fields indexed by the operator name.
"""
fields: Dict[OperatorName, Dict[str, Any]] = defaultdict(dict)
for ei in es: # type: ignore[attr-defined]
funcs = ei.get("func")
assert isinstance(funcs, str), f"not a str: {funcs}"
namespace_helper = NamespaceHelper.from_namespaced_entity(
namespaced_entity=funcs, max_level=1
)
opname = FunctionSchema.parse(namespace_helper.entity_name).name
for field in ET_FIELDS:
if (value := ei.get(field)) is not None:
fields[opname][field] = value
return fields
def parse_et_yaml(
path: str,
tags_yaml_path: str,
ignore_keys: Optional[Set[DispatchKey]] = None,
skip_native_fns_gen: bool = False,
) -> Tuple[List[NativeFunction], Dict[OperatorName, Dict[str, Any]]]:
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
of fields to persist from native_functions.yaml to functions.yaml
"""
with open(path) as f:
es = yaml.load(f, Loader=LineLoader)
et_kernel = extract_kernel_fields(es)
# Remove ET specific fields from entries for BC compatibility
strip_et_fields(es)
native_yaml = parse_native_yaml(
path,
tags_yaml_path,
ignore_keys,
skip_native_fns_gen=skip_native_fns_gen,
loaded_yaml=es,
)
return native_yaml.native_functions, et_kernel
def strip_et_fields(es: object) -> None:
"""Given a loaded yaml representing a list of operators,
remove ET specific fields from every entries for BC compatibility
"""
for entry in es: # type: ignore[attr-defined]
for field in ET_FIELDS:
entry.pop(field, None)