152 lines
5.3 KiB
Python
152 lines
5.3 KiB
Python
|
from collections import defaultdict, namedtuple
|
||
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||
|
|
||
|
import yaml
|
||
|
|
||
|
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
|
||
|
|
||
|
from torchgen.gen import LineLoader, parse_native_yaml
|
||
|
from torchgen.model import (
|
||
|
BackendMetadata,
|
||
|
DispatchKey,
|
||
|
FunctionSchema,
|
||
|
NativeFunction,
|
||
|
OperatorName,
|
||
|
)
|
||
|
from torchgen.utils import NamespaceHelper
|
||
|
|
||
|
# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices.
|
||
|
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"])
|
||
|
|
||
|
# Fields in native_functions.yaml used to determine which kernels should be used
|
||
|
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
|
||
|
|
||
|
|
||
|
def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]:
|
||
|
"""Given a loaded yaml representing kernel assignment information, extract the
|
||
|
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
|
||
|
|
||
|
Args:
|
||
|
ei: Dict keys {kernels, type_alias, dim_order_alias}
|
||
|
See ETKernelKey for description of arguments
|
||
|
"""
|
||
|
e = ei.copy()
|
||
|
if (kernels := e.pop("kernels", None)) is None:
|
||
|
return {}
|
||
|
|
||
|
type_alias: Dict[str, List[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
|
||
|
dim_order_alias: Dict[str, List[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
|
||
|
dim_order_alias.pop("__line__", None)
|
||
|
|
||
|
kernel_mapping: Dict[ETKernelKey, BackendMetadata] = {}
|
||
|
|
||
|
for entry in kernels: # type: ignore[attr-defined]
|
||
|
arg_meta = entry.get("arg_meta")
|
||
|
if arg_meta is not None:
|
||
|
arg_meta.pop("__line__")
|
||
|
|
||
|
kernel_name = entry.get("kernel_name")
|
||
|
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
||
|
kernel_name, max_level=3
|
||
|
)
|
||
|
kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
|
||
|
backend_metadata = BackendMetadata(
|
||
|
kernel=namespace_helper.entity_name,
|
||
|
structured=False,
|
||
|
cpp_namespace=(kernel_namespace + "::native"),
|
||
|
)
|
||
|
|
||
|
kernel_keys = (
|
||
|
[ETKernelKey((), default=True)]
|
||
|
if arg_meta is None
|
||
|
else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type]
|
||
|
)
|
||
|
|
||
|
for kernel_key in kernel_keys:
|
||
|
assert kernel_key not in kernel_mapping, (
|
||
|
"Duplicate kernel key: " + str(kernel_key) + " " + str(e)
|
||
|
)
|
||
|
kernel_mapping[kernel_key] = backend_metadata
|
||
|
|
||
|
return kernel_mapping
|
||
|
|
||
|
|
||
|
def parse_et_yaml_struct(es: object) -> ETKernelIndex:
|
||
|
"""Given a loaded yaml representing a list of operators, for each op extract the mapping
|
||
|
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
|
||
|
that should be used by the kernel key).
|
||
|
"""
|
||
|
indices: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] = {}
|
||
|
for ei in es: # type: ignore[attr-defined]
|
||
|
e = ei.copy()
|
||
|
|
||
|
funcs = e.pop("func")
|
||
|
assert isinstance(funcs, str), f"not a str: {funcs}"
|
||
|
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
||
|
namespaced_entity=funcs, max_level=1
|
||
|
)
|
||
|
opname = FunctionSchema.parse(namespace_helper.entity_name).name
|
||
|
|
||
|
assert opname not in indices, f"Duplicate func found in yaml: {opname} already"
|
||
|
|
||
|
if len(index := parse_from_yaml(e)) != 0:
|
||
|
indices[opname] = index
|
||
|
|
||
|
return ETKernelIndex(indices)
|
||
|
|
||
|
|
||
|
def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]:
|
||
|
"""Given a loaded yaml representing a list of operators, extract the
|
||
|
kernel key related fields indexed by the operator name.
|
||
|
"""
|
||
|
fields: Dict[OperatorName, Dict[str, Any]] = defaultdict(dict)
|
||
|
for ei in es: # type: ignore[attr-defined]
|
||
|
funcs = ei.get("func")
|
||
|
assert isinstance(funcs, str), f"not a str: {funcs}"
|
||
|
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
||
|
namespaced_entity=funcs, max_level=1
|
||
|
)
|
||
|
opname = FunctionSchema.parse(namespace_helper.entity_name).name
|
||
|
|
||
|
for field in ET_FIELDS:
|
||
|
if (value := ei.get(field)) is not None:
|
||
|
fields[opname][field] = value
|
||
|
|
||
|
return fields
|
||
|
|
||
|
|
||
|
def parse_et_yaml(
|
||
|
path: str,
|
||
|
tags_yaml_path: str,
|
||
|
ignore_keys: Optional[Set[DispatchKey]] = None,
|
||
|
skip_native_fns_gen: bool = False,
|
||
|
) -> Tuple[List[NativeFunction], Dict[OperatorName, Dict[str, Any]]]:
|
||
|
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
|
||
|
of fields to persist from native_functions.yaml to functions.yaml
|
||
|
"""
|
||
|
with open(path) as f:
|
||
|
es = yaml.load(f, Loader=LineLoader)
|
||
|
|
||
|
et_kernel = extract_kernel_fields(es)
|
||
|
|
||
|
# Remove ET specific fields from entries for BC compatibility
|
||
|
strip_et_fields(es)
|
||
|
|
||
|
native_yaml = parse_native_yaml(
|
||
|
path,
|
||
|
tags_yaml_path,
|
||
|
ignore_keys,
|
||
|
skip_native_fns_gen=skip_native_fns_gen,
|
||
|
loaded_yaml=es,
|
||
|
)
|
||
|
return native_yaml.native_functions, et_kernel
|
||
|
|
||
|
|
||
|
def strip_et_fields(es: object) -> None:
|
||
|
"""Given a loaded yaml representing a list of operators,
|
||
|
remove ET specific fields from every entries for BC compatibility
|
||
|
"""
|
||
|
for entry in es: # type: ignore[attr-defined]
|
||
|
for field in ET_FIELDS:
|
||
|
entry.pop(field, None)
|