2796 lines
109 KiB
Python
2796 lines
109 KiB
Python
|
import dataclasses
|
||
|
import itertools
|
||
|
import re
|
||
|
|
||
|
from dataclasses import dataclass
|
||
|
from enum import auto, Enum
|
||
|
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
||
|
|
||
|
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
|
||
|
|
||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||
|
#
|
||
|
# DATA MODEL
|
||
|
#
|
||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||
|
#
|
||
|
# Some general principles for our data model.
|
||
|
#
|
||
|
# - Stop using C++ data types as the internal data representation
|
||
|
# format. Instead, the internal data structures are centered
|
||
|
# around JIT schema representation. This avoid a big problem
|
||
|
# with the old codegen where we read in all the types from
|
||
|
# native_functions.yaml and then immediately had to retranslate
|
||
|
# them into C++ types.
|
||
|
#
|
||
|
# - More semantic data representation. Instead of representing
|
||
|
# everything as dicts and strings, we define dataclasses for
|
||
|
# every interesting entity the code generation has to deal with.
|
||
|
# These dataclasses have strong semantic invariants: for example,
|
||
|
# we generally require them to roundtrip losslessly into the
|
||
|
# form they were parsed from. These structures are immutable
|
||
|
# and you're expected to populate information once during
|
||
|
# construction.
|
||
|
|
||
|
|
||
|
# Represent a source location; used for better error reporting
|
||
|
@dataclass(frozen=True)
|
||
|
class Location:
|
||
|
file: str
|
||
|
line: int
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return f"{self.file}:{self.line}"
|
||
|
|
||
|
|
||
|
# Valid values of the 'variants' field in native_functions.yaml
|
||
|
class Variant(Enum):
|
||
|
function = auto()
|
||
|
method = auto()
|
||
|
|
||
|
|
||
|
# Default kernel namespace
|
||
|
DEFAULT_KERNEL_NAMESPACE = "at::native"
|
||
|
|
||
|
# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h
|
||
|
BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
|
||
|
FUNCTIONALITY_KEYS = [
|
||
|
"",
|
||
|
"Quantized",
|
||
|
"Sparse",
|
||
|
"SparseCsr",
|
||
|
"NestedTensor",
|
||
|
"Autograd",
|
||
|
]
|
||
|
|
||
|
# This list guards dispatches that can be used in derivatives.yaml
|
||
|
# For now we omit AutogradFunctionality and AutogradOther
|
||
|
AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [
|
||
|
"Autograd" + component for component in BACKEND_COMPONENTS
|
||
|
]
|
||
|
|
||
|
FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"}
|
||
|
|
||
|
|
||
|
# This doesn't have to be in sync with the header, it only needs to contain
|
||
|
# entries that we actually use in the codegen or want pyi entries for
|
||
|
class DispatchKey(Enum):
|
||
|
Undefined = 0
|
||
|
CatchAll = Undefined
|
||
|
|
||
|
FPGA = auto()
|
||
|
ORT = auto()
|
||
|
Vulkan = auto()
|
||
|
Metal = auto()
|
||
|
MKLDNN = auto()
|
||
|
OpenGL = auto()
|
||
|
OpenCL = auto()
|
||
|
IDEEP = auto()
|
||
|
CustomRNGKeyId = auto()
|
||
|
MkldnnCPU = auto()
|
||
|
Sparse = auto()
|
||
|
SparseCsr = auto()
|
||
|
NestedTensor = auto()
|
||
|
Dense = auto()
|
||
|
|
||
|
PreDispatch = auto()
|
||
|
Python = auto()
|
||
|
FuncTorchDynamicLayerBackMode = auto()
|
||
|
ZeroTensor = auto()
|
||
|
Conjugate = auto()
|
||
|
Negative = auto()
|
||
|
BackendSelect = auto()
|
||
|
Named = auto()
|
||
|
AutogradOther = auto()
|
||
|
AutogradFunctionality = auto()
|
||
|
AutogradNestedTensor = auto()
|
||
|
Tracer = auto()
|
||
|
Autocast = auto()
|
||
|
Batched = auto()
|
||
|
VmapMode = auto()
|
||
|
FuncTorchGradWrapper = auto()
|
||
|
FuncTorchBatched = auto()
|
||
|
BatchedNestedTensor = auto()
|
||
|
FuncTorchVmapMode = auto()
|
||
|
FuncTorchDynamicLayerFrontMode = auto()
|
||
|
Functionalize = auto()
|
||
|
TESTING_ONLY_GenericWrapper = auto()
|
||
|
TESTING_ONLY_GenericMode = auto()
|
||
|
|
||
|
ADInplaceOrView = auto()
|
||
|
Autograd = auto()
|
||
|
CompositeImplicitAutograd = auto()
|
||
|
CompositeImplicitAutogradNestedTensor = auto()
|
||
|
CompositeExplicitAutograd = auto()
|
||
|
CompositeExplicitAutogradNonFunctional = auto()
|
||
|
FuncTorchBatchedDecomposition = auto()
|
||
|
|
||
|
# BEGIN autogenerated
|
||
|
CPU = auto()
|
||
|
CUDA = auto()
|
||
|
HIP = auto()
|
||
|
XLA = auto()
|
||
|
MTIA = auto()
|
||
|
MPS = auto()
|
||
|
IPU = auto()
|
||
|
XPU = auto()
|
||
|
HPU = auto()
|
||
|
VE = auto()
|
||
|
Lazy = auto()
|
||
|
Meta = auto()
|
||
|
PrivateUse1 = auto()
|
||
|
PrivateUse2 = auto()
|
||
|
PrivateUse3 = auto()
|
||
|
QuantizedCPU = auto()
|
||
|
QuantizedCUDA = auto()
|
||
|
QuantizedHIP = auto()
|
||
|
QuantizedXLA = auto()
|
||
|
QuantizedMTIA = auto()
|
||
|
QuantizedMPS = auto()
|
||
|
QuantizedIPU = auto()
|
||
|
QuantizedXPU = auto()
|
||
|
QuantizedHPU = auto()
|
||
|
QuantizedVE = auto()
|
||
|
QuantizedLazy = auto()
|
||
|
QuantizedMeta = auto()
|
||
|
QuantizedPrivateUse1 = auto()
|
||
|
QuantizedPrivateUse2 = auto()
|
||
|
QuantizedPrivateUse3 = auto()
|
||
|
SparseCPU = auto()
|
||
|
SparseCUDA = auto()
|
||
|
SparseHIP = auto()
|
||
|
SparseXLA = auto()
|
||
|
SparseMTIA = auto()
|
||
|
SparseMPS = auto()
|
||
|
SparseIPU = auto()
|
||
|
SparseXPU = auto()
|
||
|
SparseHPU = auto()
|
||
|
SparseVE = auto()
|
||
|
SparseLazy = auto()
|
||
|
SparseMeta = auto()
|
||
|
SparsePrivateUse1 = auto()
|
||
|
SparsePrivateUse2 = auto()
|
||
|
SparsePrivateUse3 = auto()
|
||
|
SparseCsrCPU = auto()
|
||
|
SparseCsrCUDA = auto()
|
||
|
SparseCsrHIP = auto()
|
||
|
SparseCsrXLA = auto()
|
||
|
SparseCsrMTIA = auto()
|
||
|
SparseCsrMPS = auto()
|
||
|
SparseCsrIPU = auto()
|
||
|
SparseCsrXPU = auto()
|
||
|
SparseCsrHPU = auto()
|
||
|
SparseCsrVE = auto()
|
||
|
SparseCsrLazy = auto()
|
||
|
SparseCsrMeta = auto()
|
||
|
SparseCsrPrivateUse1 = auto()
|
||
|
SparseCsrPrivateUse2 = auto()
|
||
|
SparseCsrPrivateUse3 = auto()
|
||
|
NestedTensorCPU = auto()
|
||
|
NestedTensorCUDA = auto()
|
||
|
NestedTensorHIP = auto()
|
||
|
NestedTensorXLA = auto()
|
||
|
NestedTensorMTIA = auto()
|
||
|
NestedTensorMPS = auto()
|
||
|
NestedTensorIPU = auto()
|
||
|
NestedTensorXPU = auto()
|
||
|
NestedTensorHPU = auto()
|
||
|
NestedTensorVE = auto()
|
||
|
NestedTensorLazy = auto()
|
||
|
NestedTensorMeta = auto()
|
||
|
NestedTensorPrivateUse1 = auto()
|
||
|
NestedTensorPrivateUse2 = auto()
|
||
|
NestedTensorPrivateUse3 = auto()
|
||
|
AutogradCPU = auto()
|
||
|
AutogradCUDA = auto()
|
||
|
AutogradHIP = auto()
|
||
|
AutogradXLA = auto()
|
||
|
AutogradMTIA = auto()
|
||
|
AutogradMPS = auto()
|
||
|
AutogradIPU = auto()
|
||
|
AutogradXPU = auto()
|
||
|
AutogradHPU = auto()
|
||
|
AutogradVE = auto()
|
||
|
AutogradLazy = auto()
|
||
|
AutogradMeta = auto()
|
||
|
AutogradPrivateUse1 = auto()
|
||
|
AutogradPrivateUse2 = auto()
|
||
|
AutogradPrivateUse3 = auto()
|
||
|
# END autogenerated
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return self.name
|
||
|
|
||
|
def lower(self) -> str:
|
||
|
return str(self).lower()
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(value: str) -> "DispatchKey":
|
||
|
for k, v in DispatchKey.__members__.items():
|
||
|
if k == value:
|
||
|
return v
|
||
|
raise AssertionError(f"unknown dispatch key {value}")
|
||
|
|
||
|
|
||
|
class _TorchDispatchModeKey(Enum):
|
||
|
FAKE = auto()
|
||
|
PROXY = auto()
|
||
|
FUNCTIONAL = auto()
|
||
|
|
||
|
|
||
|
def codegen_per_backend_entries() -> str:
|
||
|
r = []
|
||
|
for fk in FUNCTIONALITY_KEYS:
|
||
|
for bc in BACKEND_COMPONENTS:
|
||
|
r.append(f" {fk}{bc} = auto()")
|
||
|
return "\n".join(r)
|
||
|
|
||
|
|
||
|
for fk in FUNCTIONALITY_KEYS:
|
||
|
for bc in BACKEND_COMPONENTS:
|
||
|
if not hasattr(DispatchKey, fk + bc):
|
||
|
r = codegen_per_backend_entries()
|
||
|
print(r)
|
||
|
raise RuntimeError(
|
||
|
f"Missing {fk}{bc} from DispatchKey enum. Here is the autogenerated list we expect to have:\n\n{r}"
|
||
|
)
|
||
|
|
||
|
|
||
|
STRUCTURED_DISPATCH_KEYS = {DispatchKey.MPS, DispatchKey.CUDA, DispatchKey.CPU}
|
||
|
UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}
|
||
|
|
||
|
# Set of supported dispatch keys
|
||
|
dispatch_keys = [
|
||
|
DispatchKey.CPU,
|
||
|
DispatchKey.SparseCPU,
|
||
|
DispatchKey.SparseCsrCPU,
|
||
|
DispatchKey.MkldnnCPU,
|
||
|
DispatchKey.CUDA,
|
||
|
DispatchKey.MPS,
|
||
|
DispatchKey.SparseCUDA,
|
||
|
DispatchKey.SparseCsrCUDA,
|
||
|
DispatchKey.QuantizedCPU,
|
||
|
DispatchKey.QuantizedCUDA,
|
||
|
DispatchKey.CompositeImplicitAutograd,
|
||
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||
|
DispatchKey.CompositeExplicitAutograd,
|
||
|
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||
|
DispatchKey.NestedTensorCPU,
|
||
|
DispatchKey.NestedTensorCUDA,
|
||
|
# Meta is a magic key: it is automatically generated for structured
|
||
|
# kernels
|
||
|
DispatchKey.Meta,
|
||
|
DispatchKey.SparseMeta,
|
||
|
DispatchKey.SparseCsrMeta,
|
||
|
DispatchKey.QuantizedMeta,
|
||
|
DispatchKey.NestedTensorMeta,
|
||
|
DispatchKey.ZeroTensor,
|
||
|
]
|
||
|
|
||
|
|
||
|
# Dispatch keys that "support all backends". These codegen slightly differently
|
||
|
# then backend specific keys.
|
||
|
def is_generic_dispatch_key(dk: DispatchKey) -> bool:
|
||
|
return dk in {
|
||
|
DispatchKey.CompositeExplicitAutograd,
|
||
|
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||
|
DispatchKey.CompositeImplicitAutograd,
|
||
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||
|
}
|
||
|
|
||
|
|
||
|
# CUDA specific dispatch keys
|
||
|
def is_cuda_dispatch_key(dk: DispatchKey) -> bool:
|
||
|
return dk in {
|
||
|
DispatchKey.CUDA,
|
||
|
DispatchKey.QuantizedCUDA,
|
||
|
DispatchKey.SparseCUDA,
|
||
|
DispatchKey.SparseCsrCUDA,
|
||
|
DispatchKey.NestedTensorCUDA,
|
||
|
DispatchKey.AutogradCUDA,
|
||
|
}
|
||
|
|
||
|
|
||
|
# Structured kernel generation is only supported for certain key types;
|
||
|
# otherwise use old-style
|
||
|
def is_structured_dispatch_key(dk: DispatchKey) -> bool:
|
||
|
return dk in STRUCTURED_DISPATCH_KEYS
|
||
|
|
||
|
|
||
|
def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
|
||
|
# For now, ufunc dispatch keys coincide with structured keys
|
||
|
return dk in UFUNC_DISPATCH_KEYS
|
||
|
|
||
|
|
||
|
# This is oddly named ScalarType and not DType for symmetry with C++
|
||
|
class ScalarType(Enum):
|
||
|
Byte = auto()
|
||
|
Char = auto()
|
||
|
Short = auto()
|
||
|
Int = auto()
|
||
|
Long = auto()
|
||
|
Half = auto()
|
||
|
Float = auto()
|
||
|
Double = auto()
|
||
|
ComplexHalf = auto()
|
||
|
ComplexFloat = auto()
|
||
|
ComplexDouble = auto()
|
||
|
Bool = auto()
|
||
|
BFloat16 = auto()
|
||
|
Float8_e5m2 = auto()
|
||
|
Float8_e5m2fnuz = auto()
|
||
|
Float8_e4m3fn = auto()
|
||
|
Float8_e4m3fnuz = auto()
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return self.name
|
||
|
|
||
|
@staticmethod
|
||
|
def maybe_parse(value: str) -> Optional["ScalarType"]:
|
||
|
for k, v in ScalarType.__members__.items():
|
||
|
if k == value:
|
||
|
return v
|
||
|
return None
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(value: str) -> "ScalarType":
|
||
|
mb_r = ScalarType.maybe_parse(value)
|
||
|
assert mb_r is not None, f"unknown dtype {value}"
|
||
|
return mb_r
|
||
|
|
||
|
@staticmethod
|
||
|
def parse_set(values: str) -> OrderedSet["ScalarType"]:
|
||
|
dtypes: OrderedSet[ScalarType] = OrderedSet()
|
||
|
for value in values.split(", "):
|
||
|
if value in DTYPE_CLASSES:
|
||
|
dtypes.update(DTYPE_CLASSES[value])
|
||
|
else:
|
||
|
dtypes.add(ScalarType.parse(value))
|
||
|
return dtypes
|
||
|
|
||
|
|
||
|
DTYPE_CLASSES: Dict[str, OrderedSet[ScalarType]] = {}
|
||
|
# NB: Integral doesn't include boolean
|
||
|
DTYPE_CLASSES["Integral"] = OrderedSet(
|
||
|
[
|
||
|
ScalarType.Byte,
|
||
|
ScalarType.Char,
|
||
|
ScalarType.Int,
|
||
|
ScalarType.Long,
|
||
|
ScalarType.Short,
|
||
|
]
|
||
|
)
|
||
|
# NB: Floating doesn't include low precision types
|
||
|
DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double])
|
||
|
DTYPE_CLASSES["Complex"] = OrderedSet(
|
||
|
[ScalarType.ComplexFloat, ScalarType.ComplexDouble]
|
||
|
)
|
||
|
DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"]
|
||
|
DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"]
|
||
|
DTYPE_CLASSES["FloatingAndComplex"] = (
|
||
|
DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"]
|
||
|
)
|
||
|
|
||
|
|
||
|
# Represents the valid entries for ufunc_inner_loop in native_functions.yaml.
|
||
|
# NB: if you add a new UfuncKey, you will teach torchgen.dest.ufunc how
|
||
|
# to process it. Most logic will ignore keys they don't understand, so your
|
||
|
# new key will get silently ignored until you hook in logic to deal with it.
|
||
|
class UfuncKey(Enum):
|
||
|
# These are low level keys that represent exactly one particular
|
||
|
# instantiation of the kernel produced by codegen
|
||
|
CUDAFunctor = auto()
|
||
|
CUDAFunctorOnOther = auto()
|
||
|
CUDAFunctorOnSelf = auto()
|
||
|
|
||
|
CPUScalar = auto()
|
||
|
CPUVector = auto()
|
||
|
|
||
|
# These are the ones users will usually specify, and
|
||
|
# implicitly "fill in" the low level keys
|
||
|
ScalarOnly = auto() # CUDA*, CPUScalar
|
||
|
Generic = auto() # CUDA*, CPU*
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return self.name
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(value: str) -> "UfuncKey":
|
||
|
for k, v in UfuncKey.__members__.items():
|
||
|
if k == value:
|
||
|
return v
|
||
|
raise AssertionError(f"unknown ufunc key {value}")
|
||
|
|
||
|
|
||
|
class DeviceCheckType(Enum):
|
||
|
NoCheck = 0
|
||
|
ExactSame = 1
|
||
|
|
||
|
|
||
|
class ViewSchemaKind(Enum):
|
||
|
aliasing = auto()
|
||
|
aliasing_inplace = auto()
|
||
|
non_aliasing = auto()
|
||
|
|
||
|
|
||
|
# The basic input to the code generation is native_functions.yaml.
|
||
|
# The name "native", BTW, comes from the distinction between native
|
||
|
# functions and legacy TH functions. The legacy TH functions are gone,
|
||
|
# but the "native" descriptor has stuck.
|
||
|
#
|
||
|
# NativeFunction models a single entry in native_functions.yaml. Its
|
||
|
# fields roughly correspond to what you would see in the YAML itself,
|
||
|
# but after canonicalization and parsing has occurred.
|
||
|
#
|
||
|
# You can see some of the overall design patterns for how we setup
|
||
|
# dataclasses in this class, but we will defer a complete discussion
|
||
|
# of this at FunctionSchema.
|
||
|
@dataclass(frozen=True)
|
||
|
class NativeFunction:
|
||
|
# The namespace for this operator. For example, if we have "at::add"
|
||
|
# then the namespace would be "at". This enables ops to be registered
|
||
|
# through the same DSL with a custom namespace. If not specified, the
|
||
|
# default namespace would be "at".
|
||
|
namespace: str
|
||
|
|
||
|
# The function schema of the operator in question. This schema
|
||
|
# has been parsed; see FunctionSchema for more about its structure.
|
||
|
# (This type is quoted as we are forward referencing a type
|
||
|
# defined later in the file. I opted for this ordering of the
|
||
|
# classes for expository clarity.)
|
||
|
func: "FunctionSchema"
|
||
|
|
||
|
# Whether or not to generate mutable tensor arguments like regular
|
||
|
# ones
|
||
|
use_const_ref_for_mutable_tensors: bool
|
||
|
|
||
|
# Whether or not to omit automatic generation of a DeviceGuard
|
||
|
device_guard: bool
|
||
|
|
||
|
# How to emit automatic generation of device check
|
||
|
device_check: DeviceCheckType
|
||
|
|
||
|
# What python module to put the function in
|
||
|
python_module: Optional[str]
|
||
|
|
||
|
# TODO: figure out what this does
|
||
|
category_override: Optional[str]
|
||
|
|
||
|
# If no variants are specified in native_functions.yaml, this is
|
||
|
# assumed to be {'function'}.
|
||
|
variants: Set[Variant]
|
||
|
|
||
|
# Whether or not we should skip generating registrations for
|
||
|
# this kernel. This is a bit of a double-edged sword, as manual
|
||
|
# registrations don't participate in codegen-based selective build!
|
||
|
manual_kernel_registration: bool
|
||
|
|
||
|
# Whether or not to skip generating TensorMethod/Functions bindings
|
||
|
# for this kernel. Technically, this doesn't actually skip generating
|
||
|
# the binding; instead, the binding gets generated to __dispatch_{funcname}
|
||
|
# so you can make use of the normal binding if you need it.
|
||
|
manual_cpp_binding: bool
|
||
|
|
||
|
# The location in the YAML file were this native function entry was
|
||
|
# defined. This is for conveniently reporting error messages!
|
||
|
loc: "Location"
|
||
|
|
||
|
# A list of operators that are expected to be auto-generated for this NativeFunction.
|
||
|
# Note: This list isn't actually directly used by the codegen to generate anything.
|
||
|
# Instead, the codegen figures out what operators to generate purely based off of
|
||
|
# function schema, and uses the autogen declarations to error check.
|
||
|
# We expect every NativeFunction that gets auto-generated be explicitly called out
|
||
|
# in native_functions.yaml
|
||
|
autogen: List["OperatorName"]
|
||
|
|
||
|
# If non-empty, this kernel is subject to ufunc codegen.
|
||
|
# Sorted by ufunc_key
|
||
|
ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"]
|
||
|
|
||
|
# Whether or not this out functions is a "structured kernel". Structured
|
||
|
# kernels are defined a little differently from normal kernels; in
|
||
|
# particular, their shape checking logic is defined separately from
|
||
|
# the kernel. Only out functions can be structured; other functions
|
||
|
# delegate to the out function using the structured_delegate keyword.
|
||
|
# Every structured kernel must have at least an out and a functional
|
||
|
# variant.
|
||
|
structured: bool
|
||
|
|
||
|
# Whether or not this non-out function is a structured kernel, defined
|
||
|
# in terms of the out kernel referenced by the string here.
|
||
|
structured_delegate: Optional["OperatorName"]
|
||
|
|
||
|
# Only valid for structured kernels. Specifies alternative of what
|
||
|
# to inherit from when defining the meta class for the structured
|
||
|
# operator. This will usually be TensorIteratorBase. This also
|
||
|
# changes the semantics of set_output to call the parent class.
|
||
|
structured_inherits: Optional[str]
|
||
|
|
||
|
# Structured kernels can declare elements as "precomputed". These elements
|
||
|
# are returned by the meta function in one struct and passed to the impl
|
||
|
# function in lieu of certain kernel arguments that these precomputed
|
||
|
# elements supersede. Information about the names and types of these
|
||
|
# precomputed elements and how they correspond to kernel arguments is stored
|
||
|
# in this member, if applicable.
|
||
|
precomputed: Optional["Precompute"]
|
||
|
|
||
|
# Argument names whose default should be excluded from the C++ interface.
|
||
|
# Intended for resolving overload ambiguities between signatures.
|
||
|
cpp_no_default_args: Set[str]
|
||
|
|
||
|
# Note [Abstract ATen methods]
|
||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
# An abstract ATen method is one whose dispatch differs between
|
||
|
# types. These are implemented in derived types (with a
|
||
|
# standard (throwing) definition in Type). A concrete ATen
|
||
|
# method is one which has the same dispatch for all types;
|
||
|
# we just implement it in the base Type. This is exposed
|
||
|
# in Declarations.yaml via a field named 'abstract'.
|
||
|
is_abstract: bool
|
||
|
|
||
|
# Whether or not the NativeFunction contains a backend-agnostic kernel
|
||
|
has_composite_implicit_autograd_kernel: bool
|
||
|
has_composite_implicit_autograd_nested_tensor_kernel: bool
|
||
|
has_composite_explicit_autograd_kernel: bool
|
||
|
has_composite_explicit_autograd_non_functional_kernel: bool
|
||
|
|
||
|
# Tags are used to describe semantic information about (groups of) operators,
|
||
|
# That aren't easily inferrable directly from the operator's schema.
|
||
|
tags: Set[str]
|
||
|
|
||
|
# NB: The benefit of defining a dataclass is that we automatically get
|
||
|
# a constructor defined for all the fields we specify. No need
|
||
|
# to explicitly write it out.
|
||
|
|
||
|
# We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex.
|
||
|
@staticmethod
|
||
|
def from_yaml(
|
||
|
ei: Dict[str, object],
|
||
|
loc: "Location",
|
||
|
valid_tags: Set[str],
|
||
|
ignore_keys: Optional[Set[DispatchKey]] = None,
|
||
|
) -> Tuple[
|
||
|
"NativeFunction", Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]
|
||
|
]:
|
||
|
"""
|
||
|
Parse a NativeFunction from a dictionary as directly parsed
|
||
|
from native_functions.yaml
|
||
|
"""
|
||
|
e = ei.copy()
|
||
|
|
||
|
funcs = e.pop("func")
|
||
|
assert isinstance(funcs, str), f"not a str: {funcs}"
|
||
|
# only support one level of namespace. E.g., aten::add
|
||
|
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
||
|
namespaced_entity=funcs, max_level=1
|
||
|
)
|
||
|
namespace = namespace_helper.get_cpp_namespace(default="aten")
|
||
|
func = FunctionSchema.parse(namespace_helper.entity_name)
|
||
|
|
||
|
cpp_no_default_args_list = e.pop("cpp_no_default_args", [])
|
||
|
assert isinstance(cpp_no_default_args_list, list)
|
||
|
cpp_no_default_args = set(cpp_no_default_args_list)
|
||
|
|
||
|
use_const_ref_for_mutable_tensors = e.pop(
|
||
|
"use_const_ref_for_mutable_tensors", False
|
||
|
)
|
||
|
assert isinstance(use_const_ref_for_mutable_tensors, bool)
|
||
|
|
||
|
variants_s = e.pop("variants", "function")
|
||
|
assert isinstance(variants_s, str)
|
||
|
variants: Set[Variant] = set()
|
||
|
for v in variants_s.split(", "):
|
||
|
if v == "function":
|
||
|
variants.add(Variant.function)
|
||
|
elif v == "method":
|
||
|
variants.add(Variant.method)
|
||
|
else:
|
||
|
raise AssertionError(f"illegal variant {v}")
|
||
|
|
||
|
manual_kernel_registration = e.pop("manual_kernel_registration", False)
|
||
|
assert isinstance(
|
||
|
manual_kernel_registration, bool
|
||
|
), f"not a bool: {manual_kernel_registration}"
|
||
|
|
||
|
manual_cpp_binding = e.pop("manual_cpp_binding", False)
|
||
|
assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}"
|
||
|
|
||
|
device_guard = e.pop("device_guard", True)
|
||
|
assert isinstance(device_guard, bool), f"not a bool: {device_guard}"
|
||
|
|
||
|
device_check_s = e.pop("device_check", None)
|
||
|
assert device_check_s is None or isinstance(
|
||
|
device_check_s, str
|
||
|
), f"not a str: {device_check_s}"
|
||
|
device_check: DeviceCheckType
|
||
|
if device_check_s is None:
|
||
|
device_check = DeviceCheckType.ExactSame
|
||
|
else:
|
||
|
device_check = DeviceCheckType[device_check_s]
|
||
|
|
||
|
structured = e.pop("structured", False)
|
||
|
assert isinstance(structured, bool), f"not a bool: {structured}"
|
||
|
|
||
|
structured_delegate_s = e.pop("structured_delegate", None)
|
||
|
assert structured_delegate_s is None or isinstance(
|
||
|
structured_delegate_s, str
|
||
|
), f"not a str: {structured_delegate_s}"
|
||
|
assert structured_delegate_s is None or "::" not in structured_delegate_s, (
|
||
|
"namespace is not supported in structured delegate,"
|
||
|
" using the same namespace as the native function"
|
||
|
)
|
||
|
structured_delegate: Optional[OperatorName] = None
|
||
|
if structured_delegate_s is not None:
|
||
|
structured_delegate = OperatorName.parse(structured_delegate_s)
|
||
|
|
||
|
structured_inherits = e.pop("structured_inherits", None)
|
||
|
assert structured_inherits is None or isinstance(
|
||
|
structured_inherits, str
|
||
|
), f"not a str: {structured_inherits}"
|
||
|
assert structured_inherits is None or "::" not in structured_inherits, (
|
||
|
"namespace is not supported in structured inherits,"
|
||
|
" using the same namespace as the native function"
|
||
|
)
|
||
|
|
||
|
python_module = e.pop("python_module", None)
|
||
|
assert python_module is None or isinstance(
|
||
|
python_module, str
|
||
|
), f"not a str: {python_module}"
|
||
|
assert (
|
||
|
python_module is None or Variant.method not in variants
|
||
|
), "functions in modules cannot be methods"
|
||
|
|
||
|
category_override = e.pop("category_override", None)
|
||
|
assert category_override is None or isinstance(
|
||
|
category_override, str
|
||
|
), f"not a str: {category_override}"
|
||
|
|
||
|
precomputed_dict = e.pop("precomputed", None)
|
||
|
assert precomputed_dict is None or structured is True
|
||
|
precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
|
||
|
|
||
|
tags_inp = e.pop("tags", [])
|
||
|
if isinstance(tags_inp, str):
|
||
|
tags_inp = [tags_inp]
|
||
|
assert isinstance(tags_inp, list)
|
||
|
|
||
|
# All aten ops generated by torchgen receive the pt2_compliant tag.
|
||
|
if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
|
||
|
tags_inp.append("pt2_compliant_tag")
|
||
|
|
||
|
tags: Set[str] = set()
|
||
|
for t in tags_inp:
|
||
|
assert len(valid_tags) > 0
|
||
|
# TODO: verify that the tag is valid and has an entry in tags.yaml
|
||
|
if t in valid_tags:
|
||
|
tags.add(t)
|
||
|
else:
|
||
|
raise AssertionError(f"illegal tag {t}")
|
||
|
|
||
|
from torchgen.api import cpp
|
||
|
|
||
|
raw_dispatch = e.pop("dispatch", None)
|
||
|
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
|
||
|
dispatch: Dict[DispatchKey, BackendMetadata] = {}
|
||
|
num_dispatch_keys: int = 0
|
||
|
if raw_dispatch is not None:
|
||
|
assert not manual_kernel_registration, (
|
||
|
"cannot specify both manual_kernel_registration and dispatch; with "
|
||
|
"manual registration, dispatch has no effect!"
|
||
|
)
|
||
|
redundant_composite_implicit_autograd = False
|
||
|
for ks, v in raw_dispatch.items():
|
||
|
if ks == "__line__":
|
||
|
continue # not worth tracking line numbers for dispatch entries
|
||
|
assert isinstance(ks, str), e
|
||
|
for k in ks.split(","):
|
||
|
dispatch_key = DispatchKey.parse(k.strip())
|
||
|
num_dispatch_keys += 1
|
||
|
|
||
|
if ignore_keys and dispatch_key in ignore_keys:
|
||
|
continue
|
||
|
assert dispatch_key in dispatch_keys, (
|
||
|
f"Dispatch key {dispatch_key} of kernel {v} "
|
||
|
"is not a supported dispatch key."
|
||
|
)
|
||
|
# We only allow at most 3 levels of namespace for kernels.
|
||
|
# We will append "native" to a custom kernel namespace.
|
||
|
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
||
|
v, max_level=3
|
||
|
)
|
||
|
kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
|
||
|
# Why is 'structured' included? External backends (e.g.
|
||
|
# XLA) opt into which ops are structured independently
|
||
|
# of which in-tree ops are structured
|
||
|
dispatch[dispatch_key] = BackendMetadata(
|
||
|
kernel=namespace_helper.entity_name,
|
||
|
structured=structured
|
||
|
and is_structured_dispatch_key(dispatch_key),
|
||
|
cpp_namespace=(kernel_namespace + "::native"),
|
||
|
)
|
||
|
if (
|
||
|
dispatch_key is DispatchKey.CompositeImplicitAutograd
|
||
|
and v == cpp.name(func)
|
||
|
):
|
||
|
redundant_composite_implicit_autograd = True
|
||
|
|
||
|
# We count the number of dispatch keys which have not been ignored to prevent a dispatch table
|
||
|
# in which all backend keys are ignored but necessarily kept, remaining compositeimplicit,
|
||
|
# from being treated as redundant.
|
||
|
assert not (
|
||
|
num_dispatch_keys == 1 and redundant_composite_implicit_autograd
|
||
|
), (
|
||
|
"unnecessary dispatch table for this function; just delete the dispatch "
|
||
|
"key entirely"
|
||
|
)
|
||
|
# if a function is a structured delegate, deleting the dispatch
|
||
|
# table is NOT semantics preserving
|
||
|
assert (
|
||
|
structured_delegate
|
||
|
or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
|
||
|
or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
|
||
|
or num_dispatch_keys != 1
|
||
|
), (
|
||
|
f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
|
||
|
f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected "
|
||
|
"name, then delete the dispatch table"
|
||
|
)
|
||
|
elif not structured and structured_delegate is None:
|
||
|
name = str(func.name.name)
|
||
|
assert not (
|
||
|
name.startswith("new_")
|
||
|
or name.endswith("_like")
|
||
|
# TODO: maybe it's better to test the return
|
||
|
or (
|
||
|
func.arguments.tensor_options
|
||
|
and not func.arguments.has_tensor_arg()
|
||
|
)
|
||
|
), (
|
||
|
f"expected {name} to have a CompositeExplicitAutograd "
|
||
|
"dispatch entry, but there was no dispatch table. Factory functions "
|
||
|
"should not have implicit dispatch as they should not be decomposed "
|
||
|
"for __torch_dispatch__"
|
||
|
)
|
||
|
dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata(
|
||
|
cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE
|
||
|
)
|
||
|
|
||
|
composites_in_dispatch = [
|
||
|
d
|
||
|
for d in dispatch
|
||
|
if d == DispatchKey.CompositeExplicitAutograd
|
||
|
or d == DispatchKey.CompositeExplicitAutogradNonFunctional
|
||
|
or d == DispatchKey.CompositeImplicitAutograd
|
||
|
or d == DispatchKey.CompositeImplicitAutogradNestedTensor
|
||
|
]
|
||
|
|
||
|
assert len(composites_in_dispatch) <= 1 or (
|
||
|
len(composites_in_dispatch) == 2
|
||
|
and (
|
||
|
DispatchKey.CompositeExplicitAutogradNonFunctional
|
||
|
not in composites_in_dispatch
|
||
|
)
|
||
|
and (
|
||
|
DispatchKey.CompositeImplicitAutogradNestedTensor
|
||
|
in composites_in_dispatch
|
||
|
)
|
||
|
), (
|
||
|
"cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
|
||
|
"or CompositeImplicitAutograd on a single kernel; each "
|
||
|
"strictly subsumes the other. If you wanted to provide an explicit autograd "
|
||
|
"implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only"
|
||
|
)
|
||
|
|
||
|
autogen_str = e.pop("autogen", "")
|
||
|
assert isinstance(autogen_str, str)
|
||
|
autogen = (
|
||
|
[]
|
||
|
if autogen_str == ""
|
||
|
else [OperatorName.parse(x) for x in autogen_str.split(", ")]
|
||
|
)
|
||
|
|
||
|
raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {})
|
||
|
ufunc_inner_loop = {}
|
||
|
if isinstance(raw_ufunc_inner_loop, str):
|
||
|
ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse(
|
||
|
raw_ufunc_inner_loop, UfuncKey.Generic
|
||
|
)
|
||
|
elif isinstance(raw_ufunc_inner_loop, dict):
|
||
|
for k, vo in raw_ufunc_inner_loop.items():
|
||
|
if k == "__line__":
|
||
|
continue
|
||
|
assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}"
|
||
|
assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}"
|
||
|
ufunc_key = UfuncKey.parse(k)
|
||
|
ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key)
|
||
|
else:
|
||
|
raise AssertionError(
|
||
|
f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}"
|
||
|
)
|
||
|
# Program the BackendIndex for the implicit dispatch entry from ufunc
|
||
|
if ufunc_inner_loop:
|
||
|
assert structured, "ufunc must be structured"
|
||
|
|
||
|
# Delay import ufunc here to avoid circular import issue
|
||
|
# See: https://github.com/pytorch/pytorch/issues/81294
|
||
|
import torchgen.api.ufunc as ufunc
|
||
|
|
||
|
for dispatch_key in UFUNC_DISPATCH_KEYS:
|
||
|
assert (
|
||
|
dispatch_key not in dispatch
|
||
|
), f"ufunc should not have explicit dispatch entry for {dispatch_key}"
|
||
|
dispatch[dispatch_key] = BackendMetadata(
|
||
|
kernel=ufunc.schema_kernel_name(func, dispatch_key),
|
||
|
structured=True,
|
||
|
cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
|
||
|
)
|
||
|
|
||
|
if structured_delegate:
|
||
|
# Structured functions MUST have a dispatch table
|
||
|
is_abstract = True
|
||
|
else:
|
||
|
is_abstract = (
|
||
|
dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
|
||
|
and dispatch.keys()
|
||
|
!= {DispatchKey.CompositeImplicitAutogradNestedTensor}
|
||
|
and dispatch.keys()
|
||
|
!= {
|
||
|
DispatchKey.CompositeImplicitAutograd,
|
||
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
has_composite_implicit_autograd_kernel = (
|
||
|
DispatchKey.CompositeImplicitAutograd in dispatch.keys()
|
||
|
)
|
||
|
has_composite_implicit_autograd_nested_tensor_kernel = (
|
||
|
DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch.keys()
|
||
|
)
|
||
|
has_composite_explicit_autograd_kernel = (
|
||
|
DispatchKey.CompositeExplicitAutograd in dispatch.keys()
|
||
|
)
|
||
|
has_composite_explicit_autograd_non_functional_kernel = (
|
||
|
DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch.keys()
|
||
|
)
|
||
|
|
||
|
# We aren't going to store dispatch metadata inline in NativeFunctions;
|
||
|
# instead it is separately indexed by backend (so other backends can
|
||
|
# add more dispatch entries after the fact). Reindex the individual
|
||
|
# metadata by OperatorName!
|
||
|
backend_metadata = {k: {func.name: v} for k, v in dispatch.items()}
|
||
|
|
||
|
# don't care if it exists or not; make it easier to use this function
|
||
|
# with other yaml parsers that aren't setting __line__ in the dict
|
||
|
e.pop("__line__", None)
|
||
|
assert not e, f"leftover entries: {e}"
|
||
|
|
||
|
# Asserts that we can't do in post_init, because they rely on backend-specific info
|
||
|
if structured_delegate is not None:
|
||
|
for key in STRUCTURED_DISPATCH_KEYS:
|
||
|
assert key not in dispatch, (
|
||
|
f"if structured_delegate, then must not have {key} in dispatch dictionary "
|
||
|
"(it is delegated!)"
|
||
|
)
|
||
|
|
||
|
return (
|
||
|
NativeFunction(
|
||
|
func=func,
|
||
|
use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
|
||
|
variants=variants,
|
||
|
structured=structured,
|
||
|
structured_delegate=structured_delegate,
|
||
|
structured_inherits=structured_inherits,
|
||
|
precomputed=precomputed,
|
||
|
autogen=autogen,
|
||
|
ufunc_inner_loop=ufunc_inner_loop,
|
||
|
manual_kernel_registration=manual_kernel_registration,
|
||
|
manual_cpp_binding=manual_cpp_binding,
|
||
|
python_module=python_module,
|
||
|
category_override=category_override,
|
||
|
device_guard=device_guard,
|
||
|
device_check=device_check,
|
||
|
loc=loc,
|
||
|
cpp_no_default_args=cpp_no_default_args,
|
||
|
is_abstract=is_abstract,
|
||
|
has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
|
||
|
has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel,
|
||
|
has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
|
||
|
has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
|
||
|
tags=tags,
|
||
|
namespace=namespace,
|
||
|
),
|
||
|
backend_metadata,
|
||
|
)
|
||
|
|
||
|
def validate_unstructured(self) -> None:
|
||
|
# TODO: probably better to accumulate these errors and report them all
|
||
|
# at once
|
||
|
assert not self.structured, (
|
||
|
"This function is structured, but there was "
|
||
|
"no valid functional variant of it."
|
||
|
)
|
||
|
assert self.structured_delegate, (
|
||
|
"This function delegates to another structured out function, "
|
||
|
"but no valid function was found (the delegate may not exist, or it has the wrong type)"
|
||
|
)
|
||
|
|
||
|
# __post_init__ functions in dataclasses can be used to do extra
|
||
|
# validation after construction.
|
||
|
#
|
||
|
# Notice that we don't do any type validation here. In fact, we
|
||
|
# rely exclusively on mypy to check if you've done types correctly!
|
||
|
# Validation is for nontrivial invariants that cannot be (conveniently)
|
||
|
# encoded in the type system.
|
||
|
def __post_init__(self) -> None:
|
||
|
if self.func.arguments.out:
|
||
|
assert self.variants == {Variant.function}, (
|
||
|
"Native functions with out arguments MUST "
|
||
|
"be declared with only function variant; e.g., variants: function; "
|
||
|
"otherwise you will tickle a Python argument binding bug "
|
||
|
"(which usually manifests itself as the result variable being undefined.)"
|
||
|
)
|
||
|
if self.structured:
|
||
|
assert self.func.kind() == SchemaKind.out, (
|
||
|
"Put structured field on the out= "
|
||
|
"variant of a function; did you mean structured_delegate?"
|
||
|
)
|
||
|
assert (
|
||
|
self.device_guard
|
||
|
), "device_guard: False is not respected by structured kernels"
|
||
|
if self.structured_delegate:
|
||
|
assert self.func.kind() != SchemaKind.out, (
|
||
|
"structured_delegate field not allowed "
|
||
|
"on out= functions; did you mean structured?"
|
||
|
)
|
||
|
assert (
|
||
|
self.device_guard
|
||
|
), "device_guard: False is not respected by structured kernels"
|
||
|
# Technically, with the asserts above, this assert is impossible to
|
||
|
# happen
|
||
|
assert not (
|
||
|
self.structured and self.structured_delegate
|
||
|
), "Cannot have both structured and structured_delegate on function"
|
||
|
defaulted_arguments = {
|
||
|
a.name for a in self.func.schema_order_arguments() if a.default is not None
|
||
|
}
|
||
|
invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
|
||
|
assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}"
|
||
|
if self.structured_inherits is not None:
|
||
|
assert (
|
||
|
self.structured
|
||
|
), "structured_inherits must also imply structured: True"
|
||
|
if str(self.func.name).startswith("_foreach"):
|
||
|
assert self.device_check == DeviceCheckType.NoCheck, (
|
||
|
"foreach kernels fall back to slow path when tensor are on different devices, "
|
||
|
"device_check not allowed to be enabled"
|
||
|
)
|
||
|
|
||
|
# NB: if your function accidentally has rand/dropout/... in its name
|
||
|
# but is not actually random, feel free to amend this to special case
|
||
|
if (
|
||
|
"rand" in str(self.func.name)
|
||
|
or (
|
||
|
(
|
||
|
"dropout" in str(self.func.name)
|
||
|
or any(
|
||
|
"dropout" in arg.name for arg in self.func.arguments.flat_all
|
||
|
)
|
||
|
)
|
||
|
# Backwards of dropout is typically deterministic
|
||
|
and "backward" not in str(self.func.name)
|
||
|
and str(self.func.name.name) not in ["_cudnn_init_dropout_state"]
|
||
|
)
|
||
|
or self.func.arguments.has_generator_arg()
|
||
|
):
|
||
|
assert "nondeterministic_seeded" in self.tags, str(self.func.name)
|
||
|
|
||
|
@property
|
||
|
def has_composite_kernel(self) -> bool:
|
||
|
return (
|
||
|
self.has_composite_implicit_autograd_kernel
|
||
|
or self.has_composite_explicit_autograd_kernel
|
||
|
or self.has_composite_explicit_autograd_non_functional_kernel
|
||
|
) or (
|
||
|
self.has_composite_implicit_autograd_kernel
|
||
|
and self.has_composite_implicit_autograd_nested_tensor_kernel
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def is_view_op(self) -> bool:
|
||
|
rets = self.func.returns
|
||
|
is_non_mutating_view = len(rets) > 0 and any(
|
||
|
r.annotation is not None and not r.annotation.is_write for r in rets
|
||
|
)
|
||
|
# See Note [resize_ in Functionalization] for more dtails
|
||
|
is_inplace_view = (
|
||
|
"inplace_view" in self.tags
|
||
|
and str(self.func.name) != "resize_"
|
||
|
and str(self.func.name) != "resize_as_"
|
||
|
)
|
||
|
is_wildcard_view = any(
|
||
|
inp.annotation is not None and "*" in inp.annotation.alias_set_after
|
||
|
for inp in self.func.schema_order_arguments()
|
||
|
)
|
||
|
return is_non_mutating_view or is_inplace_view or is_wildcard_view
|
||
|
|
||
|
@property
|
||
|
def view_schema_kind(self) -> ViewSchemaKind:
|
||
|
if self.is_view_op and self.func.name.name.inplace:
|
||
|
assert "inplace_view" in self.tags
|
||
|
return ViewSchemaKind.aliasing_inplace
|
||
|
if self.is_view_op:
|
||
|
return ViewSchemaKind.aliasing
|
||
|
else:
|
||
|
return ViewSchemaKind.non_aliasing
|
||
|
|
||
|
@property
|
||
|
def root_name(self) -> str:
|
||
|
return self.func.name.name.base
|
||
|
|
||
|
@property
|
||
|
def part_of_structured_group(self) -> bool:
|
||
|
return self.structured or self.structured_delegate is not None
|
||
|
|
||
|
|
||
|
class SchemaKind(Enum):
|
||
|
functional = auto()
|
||
|
inplace = auto()
|
||
|
out = auto()
|
||
|
mutable = auto()
|
||
|
scratch = auto()
|
||
|
|
||
|
|
||
|
# A structured kernel is guaranteed to have a functional and out variant, and
|
||
|
# optionally an inplace variant.
|
||
|
#
|
||
|
# NB: we create NativeFunctionsGroup *even if* the function is not
|
||
|
# actually annotated structured. Test the structured boolean to see if it
|
||
|
# actually is structured or not.
|
||
|
@dataclass(frozen=True)
|
||
|
class NativeFunctionsGroup:
|
||
|
functional: NativeFunction
|
||
|
inplace: Optional[NativeFunction]
|
||
|
mutable: Optional[NativeFunction]
|
||
|
out: NativeFunction
|
||
|
|
||
|
@property
|
||
|
def structured(self) -> bool:
|
||
|
# Whether or not the operator has a meta() function. This information is backend-agnostic.
|
||
|
return self.out.structured
|
||
|
|
||
|
def __post_init__(self) -> None:
|
||
|
test_sig: FunctionSchema = self.functional.func.signature()
|
||
|
for f in self.functions():
|
||
|
if test_sig != f.func.signature():
|
||
|
raise AssertionError(
|
||
|
"NativeFunctionsGroup constructed from two NativeFunctions "
|
||
|
f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
|
||
|
)
|
||
|
|
||
|
if self.structured != f.part_of_structured_group:
|
||
|
raise AssertionError(
|
||
|
"NativeFunctionsGroup constructed from structured and unstructured "
|
||
|
f"functions: {self.out.func.name} and {f.func.name}"
|
||
|
)
|
||
|
assert self.functional.func.kind() == SchemaKind.functional
|
||
|
assert self.out.func.kind() == SchemaKind.out
|
||
|
assert self.functional.namespace == self.out.namespace
|
||
|
if self.inplace is not None:
|
||
|
assert self.inplace.func.kind() == SchemaKind.inplace
|
||
|
assert self.inplace.namespace == self.functional.namespace
|
||
|
|
||
|
if self.mutable is not None:
|
||
|
assert self.mutable.func.kind() == SchemaKind.mutable
|
||
|
assert self.mutable.namespace == self.functional.namespace
|
||
|
# See Note [Overload Ambiguity With Functional Variants]
|
||
|
assert self.functional.func.name.name.functional_overload
|
||
|
|
||
|
if self.structured:
|
||
|
# For now, structured composite kernels are not supported (need some
|
||
|
# design work to figure out how to make the composite case work)
|
||
|
assert (
|
||
|
not self.out.has_composite_implicit_autograd_kernel
|
||
|
and not self.out.has_composite_implicit_autograd_nested_tensor_kernel
|
||
|
)
|
||
|
|
||
|
assert self.functional.structured_delegate == self.out.func.name, (
|
||
|
f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
|
||
|
f"but its actual delegate is {self.out.func.name}"
|
||
|
)
|
||
|
if self.inplace is not None:
|
||
|
assert self.inplace.structured_delegate == self.out.func.name
|
||
|
|
||
|
generated_fns = sorted(
|
||
|
[str(f.func.name) for f in self.functions() if "generated" in f.tags]
|
||
|
)
|
||
|
generated_fns_str = ", ".join(str(x) for x in generated_fns)
|
||
|
expected_generated_fns: Set[str] = set()
|
||
|
for f in self.functions():
|
||
|
expected_generated_fns.update(str(op) for op in f.autogen)
|
||
|
expected_generated_fns_str = ", ".join(
|
||
|
str(x) for x in sorted(expected_generated_fns)
|
||
|
)
|
||
|
if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
|
||
|
raise RuntimeError(
|
||
|
f"The codegen expects to be able to generate '{generated_fns_str}'."
|
||
|
" In order to generate them however, we expect them to be called out explicitly in the yaml."
|
||
|
f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}"
|
||
|
)
|
||
|
if expected_generated_fns_str != generated_fns_str:
|
||
|
raise RuntimeError(
|
||
|
f"The codegen expects to be able to generate '{generated_fns_str}'."
|
||
|
f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
|
||
|
f" Instead, it found 'autogen: {expected_generated_fns_str}'"
|
||
|
)
|
||
|
|
||
|
def signature(self) -> "FunctionSchema":
|
||
|
return self.out.func.signature()
|
||
|
|
||
|
def functions(self) -> Iterator[NativeFunction]:
|
||
|
yield self.functional
|
||
|
yield self.out
|
||
|
if self.inplace is not None:
|
||
|
yield self.inplace
|
||
|
if self.mutable is not None:
|
||
|
yield self.mutable
|
||
|
|
||
|
@property
|
||
|
def root_name(self) -> str:
|
||
|
return self.functional.root_name
|
||
|
|
||
|
@staticmethod
|
||
|
def from_dict(
|
||
|
d: Dict[SchemaKind, NativeFunction]
|
||
|
) -> Optional["NativeFunctionsGroup"]:
|
||
|
assert d
|
||
|
if len(d) == 1:
|
||
|
return None
|
||
|
d = dict(d) # non-destructive updates please
|
||
|
functional = d.pop(SchemaKind.functional, None)
|
||
|
inplace = d.pop(SchemaKind.inplace, None)
|
||
|
mutable = d.pop(SchemaKind.mutable, None)
|
||
|
out = d.pop(SchemaKind.out, None)
|
||
|
assert not d
|
||
|
assert functional is not None
|
||
|
# There are a few operators which only have functional/inplace variants;
|
||
|
# these don't count as structured for our purposes here
|
||
|
if out is None:
|
||
|
return None
|
||
|
# assuming all variants have the same namespace
|
||
|
return NativeFunctionsGroup(
|
||
|
functional=functional,
|
||
|
inplace=inplace,
|
||
|
mutable=mutable,
|
||
|
out=out,
|
||
|
)
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class BackendMetadata:
|
||
|
# The name of the backend kernel, for a given operator
|
||
|
# for in-tree backends. These names come directly from the 'dispatch" field
|
||
|
# in native_functions.yaml. The dispatch entry is optional; in that
|
||
|
# case, that is equivalent to having written:
|
||
|
#
|
||
|
# dispatch:
|
||
|
# CompositeImplicitAutograd: $operator_name
|
||
|
kernel: str
|
||
|
# Whether or not the operator has a structured kernel implemented, for this particular backend.
|
||
|
# For in-tree backends, they all have the same value for structured- this is listed
|
||
|
# in native_functions.yaml.
|
||
|
# However, external backends like XLA can indendently toggle which ops are structured.
|
||
|
structured: bool
|
||
|
|
||
|
# The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE
|
||
|
cpp_namespace: str
|
||
|
|
||
|
def supports_symint(self) -> bool:
|
||
|
return "_symint" in self.kernel
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class UfuncInnerLoop:
|
||
|
name: str
|
||
|
supported_dtypes: OrderedSet[ScalarType]
|
||
|
# key is stored here because it affects the semantics of name,
|
||
|
# so its helpful to have them together for further processing
|
||
|
ufunc_key: UfuncKey
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(value: str, ufunc_key: UfuncKey) -> "UfuncInnerLoop":
|
||
|
name, supported_dtypes_str = value.split(" ", 1)
|
||
|
assert supported_dtypes_str[0] == "("
|
||
|
assert supported_dtypes_str[-1] == ")"
|
||
|
supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
|
||
|
for k in supported_dtypes_str[1:-1].split(", "):
|
||
|
supported_dtypes |= ScalarType.parse_set(k)
|
||
|
return UfuncInnerLoop(
|
||
|
name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key
|
||
|
)
|
||
|
|
||
|
|
||
|
# BackendIndex represents a backend.
|
||
|
# The BackendIndex encodes per-operator information that is potentially different
|
||
|
# for each backend. The most obvious example is the name of the kernel
|
||
|
# (the 'dispatch' entry in native_functions.yaml).
|
||
|
# However, there can be other examples of different backends having different information.
|
||
|
# External backends can choose to opt their kernels to be structured independently from in-tree backends,
|
||
|
# which means that this information isn't inherently tied to a NativeFunction- it's different per backend.
|
||
|
@dataclass(frozen=True)
|
||
|
class BackendIndex:
|
||
|
dispatch_key: DispatchKey
|
||
|
# Mainly important for structured kernels, this determines which variant in the operator group is used to implement the others.
|
||
|
# All in-tree ops use out kernels, while XLA uses functional kernels.
|
||
|
use_out_as_primary: bool
|
||
|
# Whether the backend requires a device guard, and device checks.
|
||
|
# For in-tree backends, this is currently just CUDA/HIP
|
||
|
# For out-of-tree backends, this is currently just Intel XPU
|
||
|
device_guard: bool
|
||
|
# Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA)
|
||
|
external: bool
|
||
|
# Other backend-specific information that is on a per-operator basis
|
||
|
index: Dict["OperatorName", BackendMetadata]
|
||
|
|
||
|
@staticmethod
|
||
|
def grow_index(
|
||
|
parent_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
|
||
|
child_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
|
||
|
) -> None:
|
||
|
for k, v in child_index.items():
|
||
|
for op_name, metadata in v.items():
|
||
|
assert (
|
||
|
op_name not in parent_index[k]
|
||
|
), f"duplicate operator {op_name} for dispatch key {k}"
|
||
|
parent_index[k][op_name] = metadata
|
||
|
|
||
|
def primary(self, g: NativeFunctionsGroup) -> NativeFunction:
|
||
|
if self.use_out_as_primary:
|
||
|
return g.out
|
||
|
else:
|
||
|
return g.functional
|
||
|
|
||
|
def has_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
|
||
|
m = self.get_kernel(g)
|
||
|
return m is not None
|
||
|
|
||
|
def get_kernel(
|
||
|
self, g: Union[NativeFunction, NativeFunctionsGroup]
|
||
|
) -> Optional[BackendMetadata]:
|
||
|
if isinstance(g, NativeFunction):
|
||
|
f = g
|
||
|
elif isinstance(g, NativeFunctionsGroup):
|
||
|
f = self.primary(g)
|
||
|
else:
|
||
|
assert_never(g)
|
||
|
if f.func.name not in self.index:
|
||
|
return None
|
||
|
return self.index[f.func.name]
|
||
|
|
||
|
def native_function_class_name(self) -> Optional[str]:
|
||
|
if self.external:
|
||
|
return f"{str(self.dispatch_key)}NativeFunctions"
|
||
|
else:
|
||
|
# TODO: This discrepancy isn't required; we could also generated
|
||
|
# a class for in-tree kernels. It'll just require carefully
|
||
|
# updating every kernel definition + callsite of every in-tree aten kernel.
|
||
|
return None
|
||
|
|
||
|
|
||
|
# The function schema is undoubtedly the most important data structure
|
||
|
# in all of the codegen, as it defines the type signature for operators,
|
||
|
# and most of the code generation we do is type directed (e.g., look at
|
||
|
# the types, decide what to do. Think about how we code generate
|
||
|
# C++ function stubs!)
|
||
|
#
|
||
|
# We will also see in this class the general structure for how we model
|
||
|
# data in this code generation. A few notable properties to point out
|
||
|
# ahead of time:
|
||
|
#
|
||
|
# - These dataclasses are a *lossless* representation of the strings
|
||
|
# they are parsed from. In fact, we assert that given the
|
||
|
# information stored in the dataclass, we can exactly reconstruct
|
||
|
# the string we parsed from (and assert this inside the parse
|
||
|
# definition). There are a few reasons for this:
|
||
|
#
|
||
|
# - If you find that it is difficult to reconstruct the string
|
||
|
# given a dataclass, that is a clue that you are data
|
||
|
# representation is wrong.
|
||
|
#
|
||
|
# - It helps ensure that all relevant information is present
|
||
|
# in the dataclass, so that downstream users aren't tempted
|
||
|
# to reparse the original string to get some information
|
||
|
# that was omitted.
|
||
|
#
|
||
|
# - It forces you to represent the data in-memory in the same way
|
||
|
# it is recorded textually, which makes the dataclasses easier
|
||
|
# to understand for someone who is familiar with the
|
||
|
# textual format. (As a tradeoff, it means you have to model
|
||
|
# the syntax, even when it is inconvenient. But maybe that means
|
||
|
# the syntax is bad!) If you don't understand the internal
|
||
|
# representation, go look at the printing code to see how
|
||
|
# it maps onto the surface syntax!
|
||
|
#
|
||
|
# - It makes it easy to test the parsing code, as parsing code
|
||
|
# that is inconsistent with the string code will fail early
|
||
|
# and loudly. (As a tradeoff, it makes the parsing code a bit
|
||
|
# brittle (in particular, with trivial whitespace changes you
|
||
|
# are likely to trigger an assert error).
|
||
|
#
|
||
|
# In general, try to make the __str__ code as simple as possible
|
||
|
# (even at the cost of more complex parsing logic.) Additionally,
|
||
|
# try to minimize redundancy in data representation. (Precomputed
|
||
|
# fields are OK though: they are defined as a simple function on
|
||
|
# the canonical representation in question.)
|
||
|
#
|
||
|
# - These dataclasses are all frozen; once constructed their
|
||
|
# values never change. This makes it easy to tell where any
|
||
|
# given data came from: just look to the constructor. As a
|
||
|
# tradeoff, you can't easily "decorate" a schema with extra
|
||
|
# information from a post-facto analysis. We impose this
|
||
|
# restriction to make these structures more understandable.
|
||
|
#
|
||
|
@dataclass(frozen=True)
|
||
|
class FunctionSchema:
|
||
|
# The name of the operator this function schema describes.
|
||
|
name: "OperatorName"
|
||
|
|
||
|
arguments: "Arguments"
|
||
|
|
||
|
# TODO: Need to handle collisions with argument names at some point
|
||
|
returns: Tuple["Return", ...]
|
||
|
|
||
|
def schema_order_arguments(self) -> Iterator["Argument"]:
|
||
|
return itertools.chain(
|
||
|
self.arguments.flat_positional,
|
||
|
self.arguments.flat_kwarg_only,
|
||
|
self.arguments.out,
|
||
|
)
|
||
|
|
||
|
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(func: str) -> "FunctionSchema":
|
||
|
# We should probably get a proper parser here
|
||
|
decls = FunctionSchema.decl_re.findall(func)
|
||
|
assert len(decls) == 1, f"Invalid function schema: {func}"
|
||
|
ops, args, return_decl = decls[0]
|
||
|
name = OperatorName.parse(ops)
|
||
|
arguments = Arguments.parse(args)
|
||
|
returns = parse_returns(return_decl)
|
||
|
r = FunctionSchema(name=name, arguments=arguments, returns=returns)
|
||
|
assert str(r) == func, f"{str(r)} != {func}"
|
||
|
return r
|
||
|
|
||
|
def returns_are_aliased(self) -> bool:
|
||
|
# We assert earlier that schemas can't have a mix of aliased and non-aliased returns
|
||
|
return any(
|
||
|
r
|
||
|
for r in self.returns
|
||
|
if r.annotation is not None and r.annotation.is_write
|
||
|
)
|
||
|
|
||
|
def __post_init__(self) -> None:
|
||
|
for arg, ret in zip(self.arguments.out, self.returns):
|
||
|
assert arg.annotation == ret.annotation, (
|
||
|
"Out arguments must have matching return Tensor; furthermore, "
|
||
|
"the ith-argument needs to correspond to the ith return"
|
||
|
)
|
||
|
# We also enforce that if you have any mutable, positional args, then they are not returned.
|
||
|
# This makes it easier to group these functions properly with their functional/out= counterparts.
|
||
|
for a in self.arguments.post_self_positional_mutable:
|
||
|
assert not any(
|
||
|
a.annotation == r.annotation for r in self.returns
|
||
|
), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
|
||
|
# Invariant: we expect out arguments to appear as keyword arguments in the schema.
|
||
|
# This means that all mutable returns should be aliased to a keyword argument
|
||
|
# (except for "self", which we explicitly don't treat as an out argument because of its use in methods)
|
||
|
# See Note [is_out_fn]
|
||
|
out_and_self = list(self.arguments.out) + [
|
||
|
arg for arg in self.arguments.flat_positional if arg.name == "self"
|
||
|
]
|
||
|
mutable_returns = [
|
||
|
ret
|
||
|
for ret in self.returns
|
||
|
if ret.annotation is not None and ret.annotation.is_write
|
||
|
]
|
||
|
immutable_returns = [
|
||
|
ret
|
||
|
for ret in self.returns
|
||
|
if ret.annotation is None or not ret.annotation.is_write
|
||
|
]
|
||
|
# Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)",
|
||
|
# because:
|
||
|
# (1) It's more annoying to handle properly
|
||
|
# (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple.
|
||
|
# Instead, we expect the (a!) argument to not be returned.
|
||
|
assert (
|
||
|
len(mutable_returns) == 0 or len(immutable_returns) == 0
|
||
|
), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
|
||
|
for ret in mutable_returns:
|
||
|
assert any(ret.annotation == arg.annotation for arg in out_and_self), (
|
||
|
'All mutable returns must be aliased either to a keyword argument, or to "self". '
|
||
|
"Did you forget to mark an out argument as keyword-only?"
|
||
|
)
|
||
|
if self.arguments.out:
|
||
|
# out= ops that return their mutable inputs are only really useful for method chaining.
|
||
|
# And method chaining is only really useful if the thing you're returning is a plain Tensor.
|
||
|
# So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor,
|
||
|
# and all other types of out= op schemas should return void.
|
||
|
# There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that.
|
||
|
if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
|
||
|
assert (
|
||
|
len(self.returns) == 0
|
||
|
), "out= ops that accept tensor lists as out arguments "
|
||
|
"are expected to have no return type (since you can't do method chaining on them)"
|
||
|
else:
|
||
|
# mutable keyword arguments whose name has _scratch_ prefix are
|
||
|
# scratch tensors for memory planning and should not be returned
|
||
|
assert len(
|
||
|
[
|
||
|
arg
|
||
|
for arg in self.arguments.out
|
||
|
if not arg.name.startswith("_scratch_")
|
||
|
]
|
||
|
) == len(
|
||
|
self.returns
|
||
|
), "Must return as many arguments as there are out arguments, or no return at all"
|
||
|
|
||
|
if self.name.name.inplace:
|
||
|
self_a = self.arguments.self_arg
|
||
|
assert (
|
||
|
self_a
|
||
|
and self_a.argument.annotation
|
||
|
and self_a.argument.annotation.is_write
|
||
|
)
|
||
|
if self_a.argument.type == BaseType(BaseTy.Tensor):
|
||
|
# All inplace ops with an ordinary `Tensor self` argument should return self,
|
||
|
# to allow for method chaining.
|
||
|
assert (
|
||
|
len(self.returns) == 1
|
||
|
and self.returns[0].annotation == self_a.argument.annotation
|
||
|
)
|
||
|
else:
|
||
|
# You can't method chain on non-tensor self arguments though (like a List[Tensor])
|
||
|
# so in all other cases we expect the return type to be none.
|
||
|
assert len(self.returns) == 0
|
||
|
|
||
|
if self.arguments.tensor_options is not None:
|
||
|
assert self.kind() == SchemaKind.functional, (
|
||
|
"Found an operator that is not functional or out variant, but has tensor options arguments."
|
||
|
"This is not allowed- tensor options arguments are only allowed for factory functions."
|
||
|
f"schema: {str(self)}"
|
||
|
)
|
||
|
if self.is_functional_fn():
|
||
|
assert self.kind() == SchemaKind.functional, (
|
||
|
"Found an operator that is not functional, but its overload contains the string 'functional'."
|
||
|
"This is a special keyword in the codegen, please use a different overload name."
|
||
|
f"schema: {str(self)}"
|
||
|
)
|
||
|
|
||
|
def is_functional_fn(self) -> bool:
|
||
|
return "functional" in self.name.overload_name
|
||
|
|
||
|
def is_out_fn(self) -> bool:
|
||
|
# Note [is_out_fn]
|
||
|
#
|
||
|
# out functions are the variants which take an explicit out= argument
|
||
|
# to populate into. We need to know if a schema corresponds to an
|
||
|
# out function for several reasons:
|
||
|
#
|
||
|
# - They codegen differently in C++ API
|
||
|
# - codegen to at::add_out rather than at::add
|
||
|
# - out argument is moved to front of C++ argument list
|
||
|
#
|
||
|
# out functions are DEFINED to be any function with a keyword-only
|
||
|
# argument that is mutable. In principle, this could lead to a
|
||
|
# false positive if you define a function that mutates a
|
||
|
# kwarg only argument, but this isn't the "true" output of this
|
||
|
# function. A more robust definition that would work in this
|
||
|
# case would also look at:
|
||
|
#
|
||
|
# - The output types. Out functions take in the arguments
|
||
|
# they mutate and then return them again; this is sort
|
||
|
# of "definitionally" what makes something an out function.
|
||
|
# Historically, we DO check this for consistency.
|
||
|
# - Correspondence with pure variant. An out function
|
||
|
# should have a signature equivalent to its pure variant,
|
||
|
# but just with extra kwargs for the output elements. This
|
||
|
# is difficult to actually check for and historically
|
||
|
# we only do this check in tools/
|
||
|
return bool(self.arguments.out)
|
||
|
|
||
|
def kind(self) -> SchemaKind:
|
||
|
"""
|
||
|
What kind of schema is this? A functional schema is one
|
||
|
that returns a newly allocated output; an inplace schema
|
||
|
modifies the self argument inplace; an out schema writes
|
||
|
the result into an explicitly provided out argument.
|
||
|
"""
|
||
|
is_out = bool(self.arguments.out)
|
||
|
is_scratch = bool(
|
||
|
[arg for arg in self.arguments.out if arg.name.startswith("_scratch_")]
|
||
|
)
|
||
|
is_inplace = self.name.name.inplace
|
||
|
is_mutable = any(
|
||
|
a.annotation is not None and a.annotation.is_write
|
||
|
for a in self.arguments.post_self_positional
|
||
|
)
|
||
|
assert not (is_out and is_inplace)
|
||
|
# out= and inplace schemas can also have post_self_positional mutable args,
|
||
|
# but we give precedence to out= and inplace when deciding the schema kind.
|
||
|
# Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops
|
||
|
# to also worry about mutable post_self_positional arguments,
|
||
|
# but it seems like a much bigger lift to classify them has having a new schema kind.
|
||
|
# The number of ops that fit in this strange category is small enough that
|
||
|
# we can probably manually write code for them instead of forcing the codegen to handle them.
|
||
|
if is_inplace:
|
||
|
return SchemaKind.inplace
|
||
|
elif is_scratch:
|
||
|
assert (
|
||
|
is_out
|
||
|
), "invariant: all scratch operators are expected to be out= operators too"
|
||
|
return SchemaKind.scratch
|
||
|
elif is_out:
|
||
|
assert (
|
||
|
not is_scratch
|
||
|
), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!"
|
||
|
return SchemaKind.out
|
||
|
elif is_mutable:
|
||
|
return SchemaKind.mutable
|
||
|
else:
|
||
|
return SchemaKind.functional
|
||
|
|
||
|
# For every return:
|
||
|
# - If the return aliases an input, we return the input name
|
||
|
# - Otherwise, we return None.
|
||
|
# If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
|
||
|
def aliased_return_names(self) -> List[Optional[str]]:
|
||
|
outs: List[Optional[str]] = []
|
||
|
for r in self.returns:
|
||
|
aliased_args = [
|
||
|
a
|
||
|
for a in self.arguments.flat_all
|
||
|
if a.annotation is not None and a.annotation == r.annotation
|
||
|
]
|
||
|
if len(aliased_args) == 0:
|
||
|
outs.append(None)
|
||
|
elif len(aliased_args) == 1:
|
||
|
outs.append(aliased_args[0].name)
|
||
|
else:
|
||
|
aliased_names = ", ".join(a.name for a in aliased_args)
|
||
|
raise AssertionError(
|
||
|
f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})"
|
||
|
)
|
||
|
return outs
|
||
|
|
||
|
def signature(
|
||
|
self,
|
||
|
*,
|
||
|
strip_default: bool = False,
|
||
|
strip_view_copy_name: bool = False,
|
||
|
keep_return_names: bool = False,
|
||
|
) -> "FunctionSchema":
|
||
|
"""
|
||
|
Certain schemas are 'related', in that they are simply
|
||
|
inplace/out/functional versions of the same function. This method
|
||
|
factors these schemas into the "core" functional signature which
|
||
|
is equal across all versions.
|
||
|
|
||
|
Here is what normalization happens to the schema to convert
|
||
|
it to a signature:
|
||
|
- The overload name is stripped (name is retained, since
|
||
|
it expresses semantic content about what the function does)
|
||
|
- Inplace is set False
|
||
|
- Out arguments are stripped
|
||
|
- Mutable post_self_positional args are converted to returns
|
||
|
- Mutability annotations are stripped (this is sound
|
||
|
because you cannot overload on mutability annotation)
|
||
|
- Return names are stripped since they are not overloadable and
|
||
|
some variants have return names but some not
|
||
|
- TensorOptions are dropped
|
||
|
because out= variants of factory functions don't include them
|
||
|
(and we want to be able to pair up factory functions with their out variants)
|
||
|
|
||
|
Finally, we want to be able to pair up related "view" and their
|
||
|
corresponding "view_copy" operators. We do this by optionally
|
||
|
stripping the trailing "_copy" from the base name.
|
||
|
|
||
|
Example of a mutable op before and after:
|
||
|
|
||
|
f.func (Mutable operator):
|
||
|
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
|
||
|
|
||
|
f.func (Corresponding functional operator):
|
||
|
_fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950
|
||
|
|
||
|
f.func.signature() output:
|
||
|
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950
|
||
|
"""
|
||
|
|
||
|
def strip_ret_annotation(r: Return) -> Return:
|
||
|
return Return(
|
||
|
name=r.name if keep_return_names else None,
|
||
|
type=r.type,
|
||
|
annotation=None,
|
||
|
)
|
||
|
|
||
|
base_name = self.name.name.base
|
||
|
if strip_view_copy_name:
|
||
|
if base_name.endswith("_copy"):
|
||
|
base_name = base_name.replace("_copy", "")
|
||
|
elif base_name.endswith("_scatter"):
|
||
|
base_name = base_name.replace("scatter", "inverse")
|
||
|
|
||
|
# find mutable inputs that are not originally returned, and convert them to returns
|
||
|
returns_from_mutable_inputs = tuple(
|
||
|
# When we're grouping functions we strip the return names,
|
||
|
# but when we're generating the actual functional variants then we follow
|
||
|
# a convention for what to name the returns
|
||
|
Return(
|
||
|
name=f"{a.name}_out" if keep_return_names else None,
|
||
|
type=a.type,
|
||
|
annotation=None,
|
||
|
)
|
||
|
for a in itertools.chain(
|
||
|
# Order is important here (otherwise e.g. inplace with mutable args
|
||
|
# and out= with mutable args won't have the same signature)
|
||
|
[self.arguments.self_arg.argument]
|
||
|
if self.arguments.self_arg is not None
|
||
|
else [],
|
||
|
self.arguments.out,
|
||
|
self.arguments.post_self_positional,
|
||
|
)
|
||
|
if a.annotation is not None
|
||
|
and a.annotation.is_write
|
||
|
and not any(a.annotation == r.annotation for r in self.returns)
|
||
|
)
|
||
|
original_returns = tuple(map(strip_ret_annotation, self.returns))
|
||
|
# Ordering is important here. We expect the "mutable input" returns to come last.
|
||
|
returns = original_returns + returns_from_mutable_inputs
|
||
|
|
||
|
args_sig = self.arguments.signature(strip_default=strip_default)
|
||
|
# See Note [bernoulli.p schema]
|
||
|
if str(self.name) == "bernoulli.p":
|
||
|
args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5"))
|
||
|
|
||
|
return FunctionSchema(
|
||
|
name=OperatorName(
|
||
|
name=BaseOperatorName(
|
||
|
base=base_name,
|
||
|
inplace=False,
|
||
|
dunder_method=self.name.name.dunder_method,
|
||
|
),
|
||
|
overload_name="", # stripped
|
||
|
),
|
||
|
arguments=args_sig,
|
||
|
returns=returns,
|
||
|
)
|
||
|
|
||
|
def view_signature(self) -> "FunctionSchema":
|
||
|
return self.signature(strip_view_copy_name=True)
|
||
|
|
||
|
def with_name(self, name: "OperatorName") -> "FunctionSchema":
|
||
|
return FunctionSchema(
|
||
|
name=name,
|
||
|
arguments=self.arguments,
|
||
|
returns=self.returns,
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def modifies_arguments(self) -> bool:
|
||
|
return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
|
||
|
|
||
|
def has_symint(self) -> bool:
|
||
|
return self.arguments.has_symint_arg()
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
all_arguments_str = str(self.arguments)
|
||
|
if len(self.returns) == 1:
|
||
|
returns = str(self.returns[0]) # omit parentheses
|
||
|
else:
|
||
|
returns = "(" + ", ".join(map(str, self.returns)) + ")"
|
||
|
return f"{self.name}({all_arguments_str}) -> {returns}"
|
||
|
|
||
|
|
||
|
# Here is the rest of the data model, described more briefly.
|
||
|
|
||
|
|
||
|
# Simplified version for what actually shows up in built-ins.
|
||
|
# Look at alias_info.h for expanded syntax. If you need the structure,
|
||
|
# you also need to make this structure recursive so it can be lined
|
||
|
# up with the type components too. For primitives this isn't really
|
||
|
# necessary
|
||
|
@dataclass(frozen=True)
|
||
|
class Annotation:
|
||
|
# Typically only has one element. Not actually a set so
|
||
|
# we can conveniently assume it is canonically ordered
|
||
|
alias_set: Tuple[str, ...]
|
||
|
is_write: bool
|
||
|
alias_set_after: Tuple[str, ...]
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(ann: str) -> "Annotation":
|
||
|
# TODO: implement a proper parser if this gets more ugly
|
||
|
# Regex Explanation:
|
||
|
# Example: "a! -> a|b"
|
||
|
# Group #1: alias before optional '|', required. Matches the first
|
||
|
# character 'a' in the example
|
||
|
# Group #2: optional alias set after optional '|', matches empty string
|
||
|
# in the example
|
||
|
# Group #3: optional "is write" flag, matches '!' in the example.
|
||
|
# Group #4: optional section containing arrow, matches " -> a|b" in the
|
||
|
# example.
|
||
|
# Group #5: optional alias after set, supports wildcard, matches "a|b"
|
||
|
# in the example.
|
||
|
# Group #6: optional sub-section of alias after set, matches "|b" in the
|
||
|
# example.
|
||
|
m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann)
|
||
|
|
||
|
assert m is not None, f"unrecognized alias annotation {ann}"
|
||
|
before_alias = m.group(1) + (m.group(2) if m.group(2) else "")
|
||
|
alias_set = tuple(before_alias.split("|"))
|
||
|
is_write = m.group(3) == "!"
|
||
|
assert not (
|
||
|
is_write and len(alias_set) > 1
|
||
|
), f"alias set larger than 1 is not mutable, got {ann} instead."
|
||
|
after_set = tuple(m.group(5).split("|")) if m.group(5) else tuple()
|
||
|
assert not (
|
||
|
len(before_alias) > 1 and len(after_set) > 1
|
||
|
), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead."
|
||
|
r = Annotation(
|
||
|
alias_set=alias_set, is_write=is_write, alias_set_after=after_set
|
||
|
)
|
||
|
assert str(r) == ann, f"{r} != {ann}"
|
||
|
return r
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
alias_set = "|".join(self.alias_set)
|
||
|
if self.is_write:
|
||
|
alias_set = f"{alias_set}!"
|
||
|
alias_set_after = "|".join(self.alias_set_after)
|
||
|
if alias_set_after:
|
||
|
alias_set = f'{alias_set}{" -> "}{alias_set_after}'
|
||
|
return alias_set
|
||
|
|
||
|
|
||
|
# The base class for the type system. This is also loosely modeled
|
||
|
# off of jit_type.h, but we've simplified the hierarchy to focus
|
||
|
# in on the aspects of the type system that matter for code generation
|
||
|
# (for example, there's no SingleElementType subclass anymore).
|
||
|
# You never actually construct a Type; usually it's going to be one
|
||
|
# of the subclasses. If Python had ADTs this would be one!
|
||
|
@dataclass(frozen=True)
|
||
|
class Type:
|
||
|
@staticmethod
|
||
|
def parse(t: str) -> "Type":
|
||
|
r = Type._parse(t)
|
||
|
assert str(r) == t, f"{r} != {t}"
|
||
|
return r
|
||
|
|
||
|
@staticmethod
|
||
|
def _parse(t: str) -> "Type":
|
||
|
m = re.match(r"^(.+)\?$", t)
|
||
|
if m is not None:
|
||
|
return OptionalType(Type.parse(m.group(1)))
|
||
|
m = re.match(r"^(.+)\[([0-9]+)?\]$", t)
|
||
|
if m is not None:
|
||
|
size = int(m.group(2)) if m.group(2) is not None else None
|
||
|
return ListType(elem=Type.parse(m.group(1)), size=size)
|
||
|
|
||
|
# '__torch__.torch.classes.' is the prefix for custom class
|
||
|
m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t)
|
||
|
if m is not None:
|
||
|
return CustomClassType(m.group(1))
|
||
|
try:
|
||
|
return BaseType(BaseTy[t])
|
||
|
except KeyError as e:
|
||
|
raise RuntimeError(f"unrecognized type {t}") from e
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
# WARNING: These concepts are not very well-defined. For example,
|
||
|
# is "int?" nullable? How about "int?[]". They are defined
|
||
|
# so we can conveniently generate legacy Declarations.yaml but
|
||
|
# really we should probably just remove these at some point
|
||
|
|
||
|
def is_base_ty_like(self, base_ty: "BaseTy") -> bool:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def is_tensor_like(self) -> bool:
|
||
|
return self.is_base_ty_like(BaseTy.Tensor)
|
||
|
|
||
|
def is_generator_like(self) -> bool:
|
||
|
return self.is_base_ty_like(BaseTy.Generator)
|
||
|
|
||
|
def is_symint_like(self) -> bool:
|
||
|
return self.is_base_ty_like(BaseTy.SymInt)
|
||
|
|
||
|
def is_nullable(self) -> bool:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def is_list_like(self) -> Optional["ListType"]:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
# Base types are simple, atomic types with no further structure
|
||
|
class BaseTy(Enum):
|
||
|
Generator = auto()
|
||
|
ScalarType = auto()
|
||
|
Tensor = auto()
|
||
|
int = auto()
|
||
|
Dimname = auto()
|
||
|
DimVector = auto()
|
||
|
float = auto()
|
||
|
str = auto()
|
||
|
bool = auto()
|
||
|
Layout = auto()
|
||
|
Device = auto()
|
||
|
DeviceIndex = auto()
|
||
|
Scalar = auto()
|
||
|
MemoryFormat = auto()
|
||
|
QScheme = auto()
|
||
|
Storage = auto()
|
||
|
Stream = auto()
|
||
|
SymInt = auto()
|
||
|
ConstQuantizerPtr = auto() # TODO: rename
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class BaseType(Type):
|
||
|
name: BaseTy
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return f"{self.name.name}"
|
||
|
|
||
|
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
||
|
return self.name == base_ty
|
||
|
|
||
|
def is_nullable(self) -> bool:
|
||
|
return False
|
||
|
|
||
|
def is_list_like(self) -> Optional["ListType"]:
|
||
|
return None
|
||
|
|
||
|
def is_symint_like(self) -> bool:
|
||
|
return self.name == BaseTy.SymInt
|
||
|
|
||
|
|
||
|
# Optional types may be specified, or may also be validly given None
|
||
|
@dataclass(frozen=True)
|
||
|
class OptionalType(Type):
|
||
|
elem: Type
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return f"{self.elem}?"
|
||
|
|
||
|
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
||
|
return self.elem.is_base_ty_like(base_ty)
|
||
|
|
||
|
def is_symint_like(self) -> bool:
|
||
|
return self.elem.is_symint_like()
|
||
|
|
||
|
def is_nullable(self) -> bool:
|
||
|
return True
|
||
|
|
||
|
def is_list_like(self) -> Optional["ListType"]:
|
||
|
return self.elem.is_list_like()
|
||
|
|
||
|
|
||
|
# A type representing a PyTorch custom class
|
||
|
@dataclass(frozen=True)
|
||
|
class CustomClassType(Type):
|
||
|
class_name: str
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
"""
|
||
|
Return the class name will prefix __torch__.torch.classes
|
||
|
"""
|
||
|
return f"__torch__.torch.classes.{self.class_name}"
|
||
|
|
||
|
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
||
|
return False
|
||
|
|
||
|
def is_symint_like(self) -> bool:
|
||
|
return False
|
||
|
|
||
|
def is_nullable(self) -> bool:
|
||
|
"""
|
||
|
Assume a custom class is not nullable.
|
||
|
"""
|
||
|
return False
|
||
|
|
||
|
def is_list_like(self) -> Optional["ListType"]:
|
||
|
return None
|
||
|
|
||
|
|
||
|
# List types specify that we may have multiples of an element. We
|
||
|
# also support explicit sizes on list types, but these have
|
||
|
# some nontrivial semantics! (However, for C++ API purposes, explicit
|
||
|
# sizes are mostly erased from the type system.)
|
||
|
#
|
||
|
# DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g.,
|
||
|
# int[] elaborates differently than bool[3]!
|
||
|
@dataclass(frozen=True)
|
||
|
class ListType(Type):
|
||
|
elem: Type
|
||
|
size: Optional[int]
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
size = f"{self.size}" if self.size else ""
|
||
|
return f"{self.elem}[{size}]"
|
||
|
|
||
|
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
||
|
return self.elem.is_base_ty_like(base_ty)
|
||
|
|
||
|
def is_symint_like(self) -> bool:
|
||
|
return self.elem.is_symint_like()
|
||
|
|
||
|
def is_nullable(self) -> bool:
|
||
|
return self.elem.is_nullable()
|
||
|
|
||
|
def is_list_like(self) -> Optional["ListType"]:
|
||
|
return self
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class Argument:
|
||
|
# NB: I didn't put kwarg_only as a boolean field here, unlike
|
||
|
# c10::Argument, so that printing works correctly
|
||
|
|
||
|
name: str
|
||
|
type: Type
|
||
|
default: Optional[str]
|
||
|
|
||
|
# The semantics of the annotation field are a little strange.
|
||
|
#
|
||
|
# Alias annotations parametrize Tensors (since Tensors are the only things
|
||
|
# that can alias.) This motivates why I write Tensor(a!)? (and not, for
|
||
|
# example, Tensor?(a!)), because the (a!) describes aliasing on the tensor,
|
||
|
# which may be optional (i.e., the alias annotation should bind first to
|
||
|
# Tensor, before the optional postfix annotation).
|
||
|
#
|
||
|
# However, despite being a property of Tensor, we (and c10::Argument)
|
||
|
# store the annotation at the top level of the Argument, rather than
|
||
|
# inside the embedded Tensor type. In the C++ version of this
|
||
|
# class, we then go through great lengths to mimic the type
|
||
|
# structure in the annotation structure so we can correlate
|
||
|
# annotations with types.
|
||
|
#
|
||
|
# Now, it turns out, in all applications in code generation, the
|
||
|
# structure of annotated types is very simple. So we just hard
|
||
|
# code it here. But if we ever do get anything more complex, this
|
||
|
# model will have to change!
|
||
|
annotation: Optional[Annotation]
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(arg: str) -> "Argument":
|
||
|
name: str
|
||
|
default: Optional[str]
|
||
|
type_and_annot, name_and_default = arg.rsplit(" ", 1)
|
||
|
if "=" in name_and_default:
|
||
|
name, default = name_and_default.split("=")
|
||
|
else:
|
||
|
name = name_and_default
|
||
|
default = None
|
||
|
# TODO: deduplicate annotation matching with Return
|
||
|
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
|
||
|
annotation: Optional[Annotation]
|
||
|
if match:
|
||
|
# If you update this, make sure the __str__ still works too
|
||
|
assert match.group(2) in [
|
||
|
"",
|
||
|
"?",
|
||
|
"[]",
|
||
|
], "unrecognized alias analysis form with Tensor"
|
||
|
type_s = "Tensor" + match.group(2)
|
||
|
annotation = Annotation.parse(match.group(1))
|
||
|
else:
|
||
|
type_s = type_and_annot
|
||
|
annotation = None
|
||
|
type = Type.parse(type_s)
|
||
|
r = Argument(
|
||
|
name=name,
|
||
|
type=type,
|
||
|
default=default,
|
||
|
annotation=annotation,
|
||
|
)
|
||
|
assert str(r) == arg, f"{str(r)} != {arg}"
|
||
|
return r
|
||
|
|
||
|
@property
|
||
|
def is_write(self) -> bool:
|
||
|
return self.annotation is not None and self.annotation.is_write
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
type = f"{self.type}"
|
||
|
if self.annotation:
|
||
|
assert type in ["Tensor", "Tensor?", "Tensor[]"]
|
||
|
type = type.replace("Tensor", f"Tensor({self.annotation})")
|
||
|
if self.name is None:
|
||
|
return type
|
||
|
else:
|
||
|
mb_default = ""
|
||
|
if self.default:
|
||
|
mb_default = f"={self.default}"
|
||
|
return f"{type} {self.name}{mb_default}"
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class Return:
|
||
|
name: Optional[str]
|
||
|
type: Type
|
||
|
annotation: Optional[Annotation]
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(arg: str) -> "Return":
|
||
|
name: Optional[str]
|
||
|
if " " in arg:
|
||
|
type_and_annot, name = arg.rsplit(" ", 1)
|
||
|
else:
|
||
|
type_and_annot = arg
|
||
|
name = None
|
||
|
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
|
||
|
annotation: Optional[Annotation]
|
||
|
if match:
|
||
|
# If you update this, make sure the __str__ still works too
|
||
|
assert match.group(2) in [
|
||
|
"",
|
||
|
"?",
|
||
|
"[]",
|
||
|
], "unrecognized alias analysis form with Tensor"
|
||
|
type_s = "Tensor" + match.group(2)
|
||
|
annotation = Annotation.parse(match.group(1))
|
||
|
else:
|
||
|
type_s = type_and_annot
|
||
|
annotation = None
|
||
|
type = Type.parse(type_s)
|
||
|
r = Return(
|
||
|
name=name,
|
||
|
type=type,
|
||
|
annotation=annotation,
|
||
|
)
|
||
|
assert str(r) == arg, f"{str(r)} != {arg}"
|
||
|
return r
|
||
|
|
||
|
@property
|
||
|
def is_write(self) -> bool:
|
||
|
return self.annotation is not None and self.annotation.is_write
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
type = f"{self.type}"
|
||
|
if self.annotation:
|
||
|
assert type in ["Tensor", "Tensor?", "Tensor[]"]
|
||
|
type = type.replace("Tensor", f"Tensor({self.annotation})")
|
||
|
if self.name is None:
|
||
|
return type
|
||
|
else:
|
||
|
return f"{type} {self.name}"
|
||
|
|
||
|
|
||
|
# Represents the self argument for functions that may be methods
|
||
|
@dataclass(frozen=True)
|
||
|
class SelfArgument:
|
||
|
argument: Argument
|
||
|
|
||
|
|
||
|
# Bundle of arguments that represent a TensorOptions. This is mostly
|
||
|
# relevant for the public C++ API but we bake it into the core data
|
||
|
# model because other APIs often have to interact with it
|
||
|
@dataclass(frozen=True)
|
||
|
class TensorOptionsArguments:
|
||
|
dtype: Argument
|
||
|
layout: Argument
|
||
|
device: Argument
|
||
|
pin_memory: Argument
|
||
|
|
||
|
def all(self) -> Sequence[Argument]:
|
||
|
return [self.dtype, self.layout, self.device, self.pin_memory]
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class Arguments:
|
||
|
# pre_self_positional is usually empty, but is notably non-empty
|
||
|
# for where.self, where the condition argument comes before the
|
||
|
# self argument
|
||
|
pre_self_positional: Tuple[Argument, ...]
|
||
|
self_arg: Optional[SelfArgument]
|
||
|
post_self_positional: Tuple[Argument, ...]
|
||
|
|
||
|
pre_tensor_options_kwarg_only: Tuple[Argument, ...]
|
||
|
tensor_options: Optional[TensorOptionsArguments]
|
||
|
# post_tensor_options is typically memory format, which should be
|
||
|
# part of tensor options but isn't right now, and is usually
|
||
|
# placed after the tensor options arguments
|
||
|
post_tensor_options_kwarg_only: Tuple[Argument, ...]
|
||
|
|
||
|
# Unlike in the previous codegen, we have factored out 'out' arguments
|
||
|
# in the canonical representation, removing them from kwarg
|
||
|
# arguments. This choice is justified by numerous downstream
|
||
|
# transformations which treat out arguments specially; additionally,
|
||
|
# you can see that canonicity is not violated!
|
||
|
out: Tuple[Argument, ...] # these are also kwarg-only
|
||
|
|
||
|
@property
|
||
|
def flat_non_out(self) -> Sequence[Argument]:
|
||
|
ret: List[Argument] = []
|
||
|
ret.extend(self.flat_positional)
|
||
|
ret.extend(self.flat_kwarg_only)
|
||
|
return ret
|
||
|
|
||
|
@property
|
||
|
def flat_positional(self) -> Sequence[Argument]:
|
||
|
ret: List[Argument] = []
|
||
|
ret.extend(self.pre_self_positional)
|
||
|
if self.self_arg is not None:
|
||
|
ret.append(self.self_arg.argument)
|
||
|
ret.extend(self.post_self_positional)
|
||
|
return ret
|
||
|
|
||
|
@property
|
||
|
def post_self_positional_mutable(self) -> Sequence[Argument]:
|
||
|
return [a for a in self.post_self_positional if a.is_write]
|
||
|
|
||
|
# NB: doesn't contain out arguments
|
||
|
@property
|
||
|
def flat_kwarg_only(self) -> Sequence[Argument]:
|
||
|
ret: List[Argument] = []
|
||
|
ret.extend(self.pre_tensor_options_kwarg_only)
|
||
|
if self.tensor_options is not None:
|
||
|
ret.extend(self.tensor_options.all())
|
||
|
ret.extend(self.post_tensor_options_kwarg_only)
|
||
|
return ret
|
||
|
|
||
|
@property
|
||
|
def flat_all(self) -> Sequence[Argument]:
|
||
|
ret: List[Argument] = []
|
||
|
ret.extend(self.flat_positional)
|
||
|
ret.extend(self.flat_kwarg_only)
|
||
|
ret.extend(self.out)
|
||
|
return ret
|
||
|
|
||
|
@property
|
||
|
def non_out(
|
||
|
self,
|
||
|
) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
|
||
|
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
|
||
|
ret.extend(self.positional)
|
||
|
ret.extend(self.kwarg_only)
|
||
|
return ret
|
||
|
|
||
|
@property
|
||
|
def positional(self) -> Sequence[Union[Argument, SelfArgument]]:
|
||
|
ret: List[Union[Argument, SelfArgument]] = []
|
||
|
ret.extend(self.pre_self_positional)
|
||
|
if self.self_arg is not None:
|
||
|
ret.append(self.self_arg)
|
||
|
ret.extend(self.post_self_positional)
|
||
|
return ret
|
||
|
|
||
|
@property
|
||
|
def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]:
|
||
|
ret: List[Union[Argument, TensorOptionsArguments]] = []
|
||
|
ret.extend(self.pre_tensor_options_kwarg_only)
|
||
|
if self.tensor_options is not None:
|
||
|
ret.append(self.tensor_options)
|
||
|
ret.extend(self.post_tensor_options_kwarg_only)
|
||
|
return ret
|
||
|
|
||
|
@property
|
||
|
def all(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
|
||
|
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
|
||
|
ret.extend(self.positional)
|
||
|
ret.extend(self.kwarg_only)
|
||
|
ret.extend(self.out)
|
||
|
return ret
|
||
|
|
||
|
def mutable_arg_names(self) -> List[str]:
|
||
|
return [
|
||
|
a.name
|
||
|
for a in self.flat_all
|
||
|
if a.annotation is not None and a.annotation.is_write
|
||
|
]
|
||
|
|
||
|
def has_tensor_arg(self) -> bool:
|
||
|
return any(a.type.is_tensor_like() for a in self.flat_non_out)
|
||
|
|
||
|
def has_symint_arg(self) -> bool:
|
||
|
return any(a.type.is_symint_like() for a in self.flat_non_out)
|
||
|
|
||
|
def has_generator_arg(self) -> bool:
|
||
|
return any(a.type.is_generator_like() for a in self.flat_non_out)
|
||
|
|
||
|
def signature(self, *, strip_default: bool = False) -> "Arguments":
|
||
|
# dataclasses.replace could be used here, but it is less
|
||
|
# type safe so for now I've opted to type everything out
|
||
|
def strip_arg_annotation(a: Argument) -> Argument:
|
||
|
return Argument(
|
||
|
name=a.name,
|
||
|
type=a.type,
|
||
|
default=a.default if not strip_default else None,
|
||
|
annotation=None,
|
||
|
)
|
||
|
|
||
|
return Arguments(
|
||
|
pre_self_positional=tuple(
|
||
|
map(strip_arg_annotation, self.pre_self_positional)
|
||
|
),
|
||
|
self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument))
|
||
|
if self.self_arg is not None
|
||
|
else None,
|
||
|
post_self_positional=tuple(
|
||
|
map(strip_arg_annotation, self.post_self_positional)
|
||
|
),
|
||
|
# Since TensorOptions are dropped, the post_tensor_options_kwargs are
|
||
|
# converted to pre_tensor_options_kwargs
|
||
|
pre_tensor_options_kwarg_only=tuple(
|
||
|
map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)
|
||
|
)
|
||
|
+ tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
|
||
|
# TensorOptions are dropped in signature,
|
||
|
# so we can pair factory functions with their out= variants.
|
||
|
tensor_options=None,
|
||
|
post_tensor_options_kwarg_only=tuple(),
|
||
|
# out arguments are dropped in signature
|
||
|
out=(),
|
||
|
)
|
||
|
|
||
|
def remove_self_annotation(self) -> "Arguments":
|
||
|
assert self.self_arg is not None
|
||
|
return dataclasses.replace(
|
||
|
self,
|
||
|
self_arg=SelfArgument(
|
||
|
dataclasses.replace(self.self_arg.argument, annotation=None)
|
||
|
),
|
||
|
)
|
||
|
|
||
|
def with_out_args(self, outs: List[Argument]) -> "Arguments":
|
||
|
assert len(self.out) == 0
|
||
|
return dataclasses.replace(
|
||
|
self,
|
||
|
out=tuple(outs),
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]:
|
||
|
positional: List[Argument] = []
|
||
|
kwarg_only: List[Argument] = []
|
||
|
out: List[Argument] = []
|
||
|
arguments_acc = positional
|
||
|
|
||
|
# TODO: Use a real parser here; this will get bamboozled
|
||
|
# by signatures that contain things like std::array<bool, 2> (note the space)
|
||
|
for arg in args.split(", "):
|
||
|
if not arg:
|
||
|
continue
|
||
|
if arg == "*":
|
||
|
assert (
|
||
|
arguments_acc is positional
|
||
|
), "invalid syntax: kwarg-only specifier * can only occur once"
|
||
|
arguments_acc = kwarg_only
|
||
|
continue
|
||
|
parg = Argument.parse(arg)
|
||
|
# Currently, we rely directly on the invariant that there are NO
|
||
|
# kwarg-only mutating arguments. If you want to relax this,
|
||
|
# we will need a more semantic way of matching that takes
|
||
|
# into account return arguments. In that case, you will have
|
||
|
# to manage out computation a level up, in FunctionSchema. See Note
|
||
|
# [is_out_fn]
|
||
|
if parg.annotation is not None and parg.annotation.is_write:
|
||
|
if arguments_acc is positional:
|
||
|
pass # do nothing
|
||
|
elif arguments_acc is kwarg_only:
|
||
|
arguments_acc = out
|
||
|
else:
|
||
|
assert arguments_acc is not out
|
||
|
arguments_acc.append(parg)
|
||
|
|
||
|
return positional, kwarg_only, out
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(args: str) -> "Arguments":
|
||
|
"""
|
||
|
Input: 'int x, int y, int z'
|
||
|
"""
|
||
|
|
||
|
# We do this in two phases. First we parse into three
|
||
|
# main categories: positional, kwarg_only, out.
|
||
|
# Then, we reparse positional and kwarg_only to separate
|
||
|
# out the self argument and tensor options arguments.
|
||
|
|
||
|
positional, kwarg_only, out = Arguments._preparse(args)
|
||
|
|
||
|
# Split self argument
|
||
|
self_ix = None
|
||
|
for i, a in enumerate(positional):
|
||
|
if a.name == "self":
|
||
|
self_ix = i
|
||
|
break
|
||
|
pre_self_positional: List[Argument]
|
||
|
self_arg: Optional[SelfArgument]
|
||
|
post_self_positional: List[Argument]
|
||
|
if self_ix is not None:
|
||
|
pre_self_positional = positional[:self_ix]
|
||
|
self_arg = SelfArgument(positional[self_ix])
|
||
|
post_self_positional = positional[self_ix + 1 :]
|
||
|
else:
|
||
|
pre_self_positional = []
|
||
|
self_arg = None
|
||
|
post_self_positional = positional
|
||
|
|
||
|
# Group tensor options arguments
|
||
|
pre_tensor_options_kwarg_only: List[Argument] = []
|
||
|
tensor_options: Optional[TensorOptionsArguments] = None
|
||
|
post_tensor_options_kwarg_only: List[Argument] = []
|
||
|
kwarg_only_acc = pre_tensor_options_kwarg_only
|
||
|
|
||
|
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
|
||
|
return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
|
||
|
|
||
|
predicates = [ # order matters
|
||
|
pred("dtype", Type.parse("ScalarType")),
|
||
|
pred("layout", Type.parse("Layout")),
|
||
|
pred("device", Type.parse("Device")),
|
||
|
pred("pin_memory", Type.parse("bool")),
|
||
|
]
|
||
|
|
||
|
i = 0
|
||
|
while i < len(kwarg_only):
|
||
|
# If there is enough space...
|
||
|
if i <= len(kwarg_only) - len(predicates):
|
||
|
# And the next len(predicates) arguments look like TensorOptions arguments
|
||
|
if all(
|
||
|
p(a)
|
||
|
for p, a in zip(predicates, kwarg_only[i : i + len(predicates)])
|
||
|
):
|
||
|
assert kwarg_only_acc is pre_tensor_options_kwarg_only
|
||
|
# Group them together as one argument
|
||
|
tensor_options = TensorOptionsArguments(
|
||
|
dtype=kwarg_only[i],
|
||
|
layout=kwarg_only[i + 1],
|
||
|
device=kwarg_only[i + 2],
|
||
|
pin_memory=kwarg_only[i + 3],
|
||
|
)
|
||
|
i += len(predicates)
|
||
|
kwarg_only_acc = post_tensor_options_kwarg_only
|
||
|
continue
|
||
|
kwarg_only_acc.append(kwarg_only[i])
|
||
|
i += 1
|
||
|
|
||
|
return Arguments(
|
||
|
pre_self_positional=tuple(pre_self_positional),
|
||
|
self_arg=self_arg,
|
||
|
post_self_positional=tuple(post_self_positional),
|
||
|
pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only),
|
||
|
tensor_options=tensor_options,
|
||
|
post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only),
|
||
|
out=tuple(out),
|
||
|
)
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
all_arguments: List[str] = []
|
||
|
all_arguments.extend(map(str, self.flat_positional))
|
||
|
if self.flat_kwarg_only or self.out:
|
||
|
all_arguments.append("*")
|
||
|
all_arguments.extend(map(str, self.flat_kwarg_only))
|
||
|
all_arguments.extend(map(str, self.out))
|
||
|
return ", ".join(all_arguments)
|
||
|
|
||
|
def __post_init__(self) -> None:
|
||
|
# TODO: These invariants are weirdly asymmetric?
|
||
|
# TODO: Fancier types?
|
||
|
if self.self_arg is None:
|
||
|
assert not self.pre_self_positional
|
||
|
if self.tensor_options is None:
|
||
|
assert not self.post_tensor_options_kwarg_only
|
||
|
|
||
|
# We don't allow any of the following to have argument annotations,
|
||
|
# to keep things simple.
|
||
|
mutable_pre_self_positionals = [
|
||
|
a
|
||
|
for a in self.pre_self_positional
|
||
|
if a.annotation is not None and a.annotation.is_write
|
||
|
]
|
||
|
assert (
|
||
|
len(mutable_pre_self_positionals) == 0
|
||
|
), "mutable pre_self_positional arguments are not currently supported in the schema"
|
||
|
|
||
|
|
||
|
# Names that validly are __iXXX__ indicating inplace operations.
|
||
|
# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods
|
||
|
# NB: PyTorch hasn't actually implemented all of these
|
||
|
AUGMENTED_ASSIGNMENT_NAMES = [
|
||
|
"add",
|
||
|
"sub",
|
||
|
"mul",
|
||
|
"div",
|
||
|
"mod",
|
||
|
"pow",
|
||
|
"lshift",
|
||
|
"rshift",
|
||
|
"and",
|
||
|
"xor",
|
||
|
"or",
|
||
|
]
|
||
|
|
||
|
|
||
|
# A BaseOperatorName is what we think of the operator name, without
|
||
|
# the overload name. Unusually, we don't represent this as just a
|
||
|
# string; instead, we directly represent a few important semantic
|
||
|
# bits of information we derive from the string: namely whether
|
||
|
# or not it's inplace (add_) and whether or not it's a double-underscore
|
||
|
# method (__add__)
|
||
|
@dataclass(frozen=True)
|
||
|
class BaseOperatorName:
|
||
|
base: str
|
||
|
inplace: bool
|
||
|
dunder_method: bool
|
||
|
# Note [Overload Ambiguity With Functional Variants]
|
||
|
# A handful of operators have both a "mutable" and a "functional" variant.
|
||
|
# (native_batch_norm is a good example, although this isn't the case today).
|
||
|
# For those operators, the mutable and functional variant take in the same set of
|
||
|
# arguments, but have different alias annotations.
|
||
|
# this makes it ambiguous when you try to resolve an OverloadPacket into an overload,
|
||
|
# given a set of input arguments.
|
||
|
#
|
||
|
# So instead of making the "functional" variant in this case a real overload, e.g:
|
||
|
# native_batch_norm (mutable variant)
|
||
|
# native_batch_norm.functional (functional variant)
|
||
|
# we make it a new base operator,
|
||
|
# native_batch_norm_functional (functional variant)
|
||
|
#
|
||
|
# In an ideal world, we would probably invert this so the operators were:
|
||
|
# native_batch_norm.mutable (mutable variant)
|
||
|
# native_batch_norm (functional variant)
|
||
|
#
|
||
|
# Doing that is BC-breaking though, so we're stuck with the above modeling.
|
||
|
functional_overload: bool = False
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(op: str) -> "BaseOperatorName":
|
||
|
assert op != ""
|
||
|
assert not op.endswith("_out"), (
|
||
|
"_out suffix is reserved and not permitted for operator names; "
|
||
|
"did you mean to specify an out overload name instead?"
|
||
|
)
|
||
|
m = re.match(r"^__([^_]+)__$", op)
|
||
|
if m is not None:
|
||
|
dunder_method = True
|
||
|
base = m.group(1)
|
||
|
if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES):
|
||
|
inplace = True
|
||
|
base = base[1:]
|
||
|
else:
|
||
|
inplace = False
|
||
|
# temporary, this is not intrinsically true but
|
||
|
# has been historically true for dunder methods
|
||
|
# we support (but, if we ever got, say, __int__, this would
|
||
|
# be wrong!)
|
||
|
assert base[0] != "i"
|
||
|
else:
|
||
|
dunder_method = False
|
||
|
base = op
|
||
|
if base[-1] == "_":
|
||
|
inplace = True
|
||
|
base = base[:-1]
|
||
|
else:
|
||
|
inplace = False
|
||
|
|
||
|
# See Note [Overload Ambiguity With Functional Variants]
|
||
|
functional_suffix = "_functional"
|
||
|
if base.endswith(functional_suffix):
|
||
|
functional_overload = True
|
||
|
base = base[: -len(functional_suffix)]
|
||
|
# This seems complicated and unnecessary, so banning dunder methods
|
||
|
# for now on ops that have a functional + mutable variant (like native_batch_norm).
|
||
|
assert not dunder_method and not inplace
|
||
|
else:
|
||
|
functional_overload = False
|
||
|
|
||
|
r = BaseOperatorName(
|
||
|
base=base,
|
||
|
inplace=inplace,
|
||
|
dunder_method=dunder_method,
|
||
|
functional_overload=functional_overload,
|
||
|
)
|
||
|
assert str(r) == op, f"{str(r)} != {op}"
|
||
|
return r
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
if self.dunder_method:
|
||
|
i = "i" if self.inplace else ""
|
||
|
return f"__{i}{self.base}__"
|
||
|
else:
|
||
|
i = (
|
||
|
"_"
|
||
|
if self.inplace
|
||
|
else "_functional"
|
||
|
if self.functional_overload
|
||
|
else ""
|
||
|
)
|
||
|
return f"{self.base}{i}"
|
||
|
|
||
|
|
||
|
# Operator name is the base operator name along with the (typically not
|
||
|
# user visible) overload string.
|
||
|
@dataclass(frozen=True)
|
||
|
class OperatorName:
|
||
|
name: BaseOperatorName
|
||
|
overload_name: str
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(op_name: str) -> "OperatorName":
|
||
|
if "." in op_name:
|
||
|
name, overload_name = op_name.split(".", 1)
|
||
|
else:
|
||
|
name = op_name
|
||
|
overload_name = ""
|
||
|
r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name)
|
||
|
assert str(r) == op_name, f"{str(r)} != {op_name}"
|
||
|
return r
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
if self.overload_name:
|
||
|
return f"{self.name}.{self.overload_name}"
|
||
|
else:
|
||
|
return f"{self.name}"
|
||
|
|
||
|
# NB: This must be synchronized with the naming scheme in
|
||
|
# aten/src/ATen/templates/Operators.h
|
||
|
# Given a function schema "aten::op.overload(...)",
|
||
|
# If there is no overload name, this returns f"{op}"
|
||
|
# If there is an overload name, this returns f"{op}_{overload}"
|
||
|
def unambiguous_name(self) -> str:
|
||
|
if self.overload_name:
|
||
|
return f"{self.name}_{self.overload_name}"
|
||
|
else:
|
||
|
return f"{self.name}"
|
||
|
|
||
|
def remove_inplace(self) -> "OperatorName":
|
||
|
return OperatorName(
|
||
|
name=BaseOperatorName(
|
||
|
base=self.name.base,
|
||
|
inplace=False,
|
||
|
dunder_method=self.name.dunder_method,
|
||
|
),
|
||
|
overload_name=self.overload_name,
|
||
|
)
|
||
|
|
||
|
def with_overload(self, overload: str) -> "OperatorName":
|
||
|
return OperatorName(
|
||
|
name=BaseOperatorName(
|
||
|
base=self.name.base,
|
||
|
inplace=False,
|
||
|
dunder_method=self.name.dunder_method,
|
||
|
),
|
||
|
overload_name=overload,
|
||
|
)
|
||
|
|
||
|
|
||
|
def gets_generated_out_inplace_wrapper(
|
||
|
f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex
|
||
|
) -> bool:
|
||
|
return (
|
||
|
f.func.kind() is not SchemaKind.functional
|
||
|
and not b.has_kernel(f)
|
||
|
and b.has_kernel(g.functional)
|
||
|
)
|
||
|
|
||
|
|
||
|
# NativeFunction objects that are views (f.is_view_op returns True)
|
||
|
# are added into a `NativeFunctionsViewGroup`, which we can use to
|
||
|
# easily access the generated (optional) view_copy NativeFunction.
|
||
|
# It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup.
|
||
|
# See Note [Codegen'd {view}_copy Operators]
|
||
|
#
|
||
|
# One property of this representation is that in order for a view-like op to be part of
|
||
|
# a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist.
|
||
|
# There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op,
|
||
|
# but don't have corresponding aliasing `narrow.out` op.
|
||
|
# This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup.
|
||
|
@dataclass(frozen=True)
|
||
|
class NativeFunctionsViewGroup:
|
||
|
view: NativeFunction
|
||
|
# Note: the {view}_copy operator is optional because we currently don't generate copy variants
|
||
|
# for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
|
||
|
# (we already get them "for free" through decomposition)
|
||
|
view_copy: Optional[NativeFunction]
|
||
|
# view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
|
||
|
view_inplace: Optional[NativeFunction]
|
||
|
|
||
|
def __post_init__(self) -> None:
|
||
|
assert self.view.is_view_op
|
||
|
if self.view_copy is None:
|
||
|
assert not gets_generated_view_copy(self.view), (
|
||
|
f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs."
|
||
|
" The codegen expects you to add a corresponding operator to native_functions.yaml:"
|
||
|
f" {get_view_copy_name(self.view)!s}."
|
||
|
" See Note [view_copy NativeFunctions] for details."
|
||
|
)
|
||
|
else:
|
||
|
assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter"))
|
||
|
assert self.view.func.signature() == self.view_copy.func.signature(
|
||
|
strip_view_copy_name=True,
|
||
|
)
|
||
|
assert "view_copy" in self.view_copy.tags, (
|
||
|
f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects"
|
||
|
" view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml."
|
||
|
" See Note [view_copy NativeFunction] for details."
|
||
|
)
|
||
|
if self.view_inplace is not None:
|
||
|
assert self.view.func.signature() == self.view_inplace.func.signature()
|
||
|
|
||
|
if self.view.has_composite_implicit_autograd_kernel:
|
||
|
if self.view_inplace is not None:
|
||
|
assert self.view_inplace.has_composite_implicit_autograd_kernel, (
|
||
|
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
|
||
|
" both have CompositeImplicitAutograd kernels, or both not have composite kernels."
|
||
|
)
|
||
|
if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
|
||
|
if self.view_inplace is not None:
|
||
|
assert (
|
||
|
self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel
|
||
|
), (
|
||
|
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
|
||
|
" both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
|
||
|
)
|
||
|
|
||
|
def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
|
||
|
yield self.view
|
||
|
if self.view_inplace is not None:
|
||
|
yield self.view_inplace
|
||
|
if self.view_copy is not None and include_copy:
|
||
|
yield self.view_copy
|
||
|
|
||
|
@property
|
||
|
def root_name(self) -> str:
|
||
|
return self.view.root_name
|
||
|
|
||
|
@property
|
||
|
def composite(self) -> bool:
|
||
|
# We currently assert that the "group" is consistent.
|
||
|
# If the view op is composite, then its view_inplace op is too.
|
||
|
return self.view.has_composite_implicit_autograd_kernel
|
||
|
|
||
|
|
||
|
def gets_generated_view_copy(f: NativeFunction) -> bool:
|
||
|
# Only aliasing (view) operators get a copy variant.
|
||
|
if not f.is_view_op:
|
||
|
return False
|
||
|
# We don't need to bother generating copy variants for CompositeImplicitAutograd ops,
|
||
|
# because we can let them decompose into base view ops.
|
||
|
if f.has_composite_implicit_autograd_kernel:
|
||
|
return False
|
||
|
# We also don't need to generate copy variants for inplace views.
|
||
|
if "inplace_view" in f.tags:
|
||
|
return False
|
||
|
# Assume ops ending in _inverse have manually-defined copy variants
|
||
|
# (e.g. slice_inverse() has the copy variant slice_scatter()).
|
||
|
# We -could- probably generate these as well, but the codegen will be
|
||
|
# slightly different, and hand-writing these few kernels keeps codegen
|
||
|
# complexity lower.
|
||
|
if f.func.name.name.base.endswith("_inverse"):
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
# Given a NativeFunction that corresponds to a view op,
|
||
|
# returns the OperatorName of the corresponding "copy" variant of the op.
|
||
|
def get_view_copy_name(f: NativeFunction) -> "OperatorName":
|
||
|
# Right now, when asking for a view op's corresponding "view_copy" name
|
||
|
# we assert for sanity that the op is allowed to have a generated view_copy variant.
|
||
|
# (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
|
||
|
# However, narrow_copy() already exists as an op directly in native_functions.yaml.
|
||
|
# I'm hardcoding narrow_copy here for now to maintain the assert,
|
||
|
# But we could also just get rid of the assert.
|
||
|
list_of_ops_with_explicit_view_copy_operators = ["narrow"]
|
||
|
if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators:
|
||
|
assert gets_generated_view_copy(f)
|
||
|
|
||
|
base_name = f"{f.func.name.name.base}_copy"
|
||
|
view_copy_name = OperatorName(
|
||
|
name=BaseOperatorName(
|
||
|
base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method
|
||
|
),
|
||
|
overload_name=f.func.name.overload_name,
|
||
|
)
|
||
|
return view_copy_name
|
||
|
|
||
|
|
||
|
# Helper functions for parsing argument lists (both inputs and returns)
|
||
|
|
||
|
|
||
|
def parse_returns(return_decl: str) -> Tuple[Return, ...]:
|
||
|
"""
|
||
|
Input: '()'
|
||
|
Output: []
|
||
|
"""
|
||
|
if return_decl == "()":
|
||
|
return ()
|
||
|
if return_decl[0] == "(" and return_decl[-1] == ")":
|
||
|
return_decl = return_decl[1:-1]
|
||
|
return tuple(Return.parse(arg) for arg in return_decl.split(", "))
|
||
|
|
||
|
|
||
|
# A Precompute instance consists of a map from kernel argument name
|
||
|
# to the list of Argument instances that should replace that
|
||
|
# kernel argument in the impl function.
|
||
|
@dataclass(frozen=True)
|
||
|
class Precompute:
|
||
|
# A map from kernel argument name -> a list of precomputed
|
||
|
# elements that replaces/supersedes it.
|
||
|
replace: Dict[str, List[Argument]]
|
||
|
# List of precomputed args added without replacement
|
||
|
add: List[Argument]
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(src: object) -> "Precompute":
|
||
|
assert isinstance(src, list)
|
||
|
|
||
|
# src is a list of strings of the format:
|
||
|
# {kernel param name} -> {replacement decl}[, {replacement decl}, ...]
|
||
|
# [{add decl}[, {add decl}, ...]]
|
||
|
# The last line is optional and contains the precomputed parameters that are
|
||
|
# added without replacement.
|
||
|
# The other lines are parsed to get the names of which precomputed elements
|
||
|
# should replace which kernel arguments.
|
||
|
add_args = []
|
||
|
if " -> " not in src[-1]:
|
||
|
add_list = src[-1].split(",")
|
||
|
add_args = [Argument.parse(name.strip()) for name in add_list]
|
||
|
src = src[:-1]
|
||
|
|
||
|
replace = {}
|
||
|
for raw_replace_item in src:
|
||
|
assert isinstance(raw_replace_item, str)
|
||
|
assert " -> " in raw_replace_item, (
|
||
|
"precomputed parameters without replacement"
|
||
|
" are allowed only in the last line"
|
||
|
)
|
||
|
|
||
|
arg, with_list_raw = raw_replace_item.split(" -> ")
|
||
|
with_list = with_list_raw.split(",")
|
||
|
with_list_args = [Argument.parse(name.strip()) for name in with_list]
|
||
|
replace[arg] = with_list_args
|
||
|
|
||
|
r = Precompute(replace=replace, add=add_args)
|
||
|
assert r.to_list() == src, "r.to_list() != src"
|
||
|
return r
|
||
|
|
||
|
def __post_init__(self) -> None:
|
||
|
# the template parameters are upper so if these are the
|
||
|
# same then it is ambiguous
|
||
|
for a in self.add:
|
||
|
assert a.name.upper() != a.name
|
||
|
for args in self.replace.values():
|
||
|
for a in args:
|
||
|
assert a.name.upper() != a.name
|
||
|
|
||
|
def to_list(self) -> List[str]:
|
||
|
replace_list = []
|
||
|
for kernel_param, replacement_params in self.replace.items():
|
||
|
replacements = ", ".join(str(param) for param in replacement_params)
|
||
|
replace_list.append(f"{kernel_param} -> {replacements}")
|
||
|
|
||
|
return replace_list
|