551 lines
17 KiB
Python
551 lines
17 KiB
Python
|
import ast
|
||
|
import builtins
|
||
|
import dis
|
||
|
import enum
|
||
|
import inspect
|
||
|
import re
|
||
|
import typing
|
||
|
import warnings
|
||
|
|
||
|
from textwrap import dedent
|
||
|
from typing import Type
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from torch._C import (
|
||
|
_GeneratorType,
|
||
|
AnyType,
|
||
|
AwaitType,
|
||
|
BoolType,
|
||
|
ComplexType,
|
||
|
DeviceObjType,
|
||
|
DictType,
|
||
|
EnumType,
|
||
|
FloatType,
|
||
|
FutureType,
|
||
|
InterfaceType,
|
||
|
IntType,
|
||
|
ListType,
|
||
|
NoneType,
|
||
|
NumberType,
|
||
|
OptionalType,
|
||
|
StreamObjType,
|
||
|
StringType,
|
||
|
TensorType,
|
||
|
TupleType,
|
||
|
UnionType,
|
||
|
)
|
||
|
from torch._sources import get_source_lines_and_file
|
||
|
from .._jit_internal import ( # type: ignore[attr-defined]
|
||
|
_Await,
|
||
|
_qualified_name,
|
||
|
Any,
|
||
|
BroadcastingList1,
|
||
|
BroadcastingList2,
|
||
|
BroadcastingList3,
|
||
|
Dict,
|
||
|
Future,
|
||
|
is_await,
|
||
|
is_dict,
|
||
|
is_future,
|
||
|
is_ignored_fn,
|
||
|
is_list,
|
||
|
is_optional,
|
||
|
is_tuple,
|
||
|
is_union,
|
||
|
List,
|
||
|
Optional,
|
||
|
Tuple,
|
||
|
Union,
|
||
|
)
|
||
|
from ._state import _get_script_class
|
||
|
|
||
|
if torch.distributed.rpc.is_available():
|
||
|
from torch._C import RRefType
|
||
|
from .._jit_internal import is_rref, RRef
|
||
|
|
||
|
from torch._ops import OpOverloadPacket
|
||
|
|
||
|
|
||
|
class Module:
|
||
|
def __init__(self, name, members):
|
||
|
self.name = name
|
||
|
self.members = members
|
||
|
|
||
|
def __getattr__(self, name):
|
||
|
try:
|
||
|
return self.members[name]
|
||
|
except KeyError:
|
||
|
raise RuntimeError(
|
||
|
f"Module {self.name} has no member called {name}"
|
||
|
) from None
|
||
|
|
||
|
|
||
|
class EvalEnv:
|
||
|
env = {
|
||
|
"torch": Module("torch", {"Tensor": torch.Tensor}),
|
||
|
"Tensor": torch.Tensor,
|
||
|
"typing": Module("typing", {"Tuple": Tuple}),
|
||
|
"Tuple": Tuple,
|
||
|
"List": List,
|
||
|
"Dict": Dict,
|
||
|
"Optional": Optional,
|
||
|
"Union": Union,
|
||
|
"Future": Future,
|
||
|
"Await": _Await,
|
||
|
}
|
||
|
|
||
|
def __init__(self, rcb):
|
||
|
self.rcb = rcb
|
||
|
if torch.distributed.rpc.is_available():
|
||
|
self.env["RRef"] = RRef
|
||
|
|
||
|
def __getitem__(self, name):
|
||
|
if name in self.env:
|
||
|
return self.env[name]
|
||
|
if self.rcb is not None:
|
||
|
return self.rcb(name)
|
||
|
return getattr(builtins, name, None)
|
||
|
|
||
|
|
||
|
def get_signature(fn, rcb, loc, is_method):
|
||
|
if isinstance(fn, OpOverloadPacket):
|
||
|
signature = try_real_annotations(fn.op, loc)
|
||
|
else:
|
||
|
signature = try_real_annotations(fn, loc)
|
||
|
if signature is not None and is_method:
|
||
|
# If this is a method, then the signature will include a type for
|
||
|
# `self`, but type comments do not contain a `self`. So strip it
|
||
|
# away here so everything is consistent (`inspect.ismethod` does
|
||
|
# not work here since `fn` is unbound at this point)
|
||
|
param_types, return_type = signature
|
||
|
param_types = param_types[1:]
|
||
|
signature = (param_types, return_type)
|
||
|
|
||
|
if signature is None:
|
||
|
type_line, source = None, None
|
||
|
try:
|
||
|
source = dedent("".join(get_source_lines_and_file(fn)[0]))
|
||
|
type_line = get_type_line(source)
|
||
|
except TypeError:
|
||
|
pass
|
||
|
# This might happen both because we failed to get the source of fn, or
|
||
|
# because it didn't have any annotations.
|
||
|
if type_line is not None:
|
||
|
signature = parse_type_line(type_line, rcb, loc)
|
||
|
|
||
|
return signature
|
||
|
|
||
|
|
||
|
def is_function_or_method(the_callable):
|
||
|
# A stricter version of `inspect.isroutine` that does not pass for built-in
|
||
|
# functions
|
||
|
return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)
|
||
|
|
||
|
|
||
|
def is_vararg(the_callable):
|
||
|
if not is_function_or_method(the_callable) and callable(the_callable): # noqa: B004
|
||
|
# If `the_callable` is a class, de-sugar the call so we can still get
|
||
|
# the signature
|
||
|
the_callable = the_callable.__call__
|
||
|
|
||
|
if is_function_or_method(the_callable):
|
||
|
return inspect.getfullargspec(the_callable).varargs is not None
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
|
||
|
def get_param_names(fn, n_args):
|
||
|
if isinstance(fn, OpOverloadPacket):
|
||
|
fn = fn.op
|
||
|
|
||
|
if (
|
||
|
not is_function_or_method(fn)
|
||
|
and callable(fn)
|
||
|
and is_function_or_method(fn.__call__)
|
||
|
): # noqa: B004
|
||
|
# De-sugar calls to classes
|
||
|
fn = fn.__call__
|
||
|
|
||
|
if is_function_or_method(fn):
|
||
|
if is_ignored_fn(fn):
|
||
|
fn = inspect.unwrap(fn)
|
||
|
return inspect.getfullargspec(fn).args
|
||
|
else:
|
||
|
# The `fn` was not a method or function (maybe a class with a __call__
|
||
|
# method, so use a default param name list)
|
||
|
return [str(i) for i in range(n_args)]
|
||
|
|
||
|
|
||
|
def check_fn(fn, loc):
|
||
|
# Make sure the function definition is not a class instantiation
|
||
|
try:
|
||
|
source = dedent("".join(get_source_lines_and_file(fn)[0]))
|
||
|
except (OSError, TypeError):
|
||
|
return
|
||
|
if source is None:
|
||
|
return
|
||
|
|
||
|
py_ast = ast.parse(source)
|
||
|
if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
|
||
|
raise torch.jit.frontend.FrontendError(
|
||
|
loc,
|
||
|
f"Cannot instantiate class '{py_ast.body[0].name}' in a script function",
|
||
|
)
|
||
|
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
|
||
|
raise torch.jit.frontend.FrontendError(
|
||
|
loc, "Expected a single top-level function"
|
||
|
)
|
||
|
|
||
|
|
||
|
def _eval_no_call(stmt, glob, loc):
|
||
|
"""Evaluate statement as long as it does not contain any method/function calls."""
|
||
|
bytecode = compile(stmt, "", mode="eval")
|
||
|
for insn in dis.get_instructions(bytecode):
|
||
|
if "CALL" in insn.opname:
|
||
|
raise RuntimeError(
|
||
|
f"Type annotation should not contain calls, but '{stmt}' does"
|
||
|
)
|
||
|
return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204
|
||
|
|
||
|
|
||
|
def parse_type_line(type_line, rcb, loc):
|
||
|
"""Parse a type annotation specified as a comment.
|
||
|
|
||
|
Example inputs:
|
||
|
# type: (Tensor, torch.Tensor) -> Tuple[Tensor]
|
||
|
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
|
||
|
"""
|
||
|
arg_ann_str, ret_ann_str = split_type_line(type_line)
|
||
|
|
||
|
try:
|
||
|
arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb))
|
||
|
except (NameError, SyntaxError) as e:
|
||
|
raise RuntimeError(
|
||
|
"Failed to parse the argument list of a type annotation"
|
||
|
) from e
|
||
|
|
||
|
if not isinstance(arg_ann, tuple):
|
||
|
arg_ann = (arg_ann,)
|
||
|
|
||
|
try:
|
||
|
ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb))
|
||
|
except (NameError, SyntaxError) as e:
|
||
|
raise RuntimeError(
|
||
|
"Failed to parse the return type of a type annotation"
|
||
|
) from e
|
||
|
|
||
|
arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
|
||
|
return arg_types, ann_to_type(ret_ann, loc)
|
||
|
|
||
|
|
||
|
def get_type_line(source):
|
||
|
"""Try to find the line containing a comment with the type annotation."""
|
||
|
type_comment = "# type:"
|
||
|
|
||
|
lines = source.split("\n")
|
||
|
lines = list(enumerate(lines))
|
||
|
type_lines = list(filter(lambda line: type_comment in line[1], lines))
|
||
|
# `type: ignore` comments may be needed in JIT'ed functions for mypy, due
|
||
|
# to the hack in torch/_VF.py.
|
||
|
|
||
|
# An ignore type comment can be of following format:
|
||
|
# 1) type: ignore
|
||
|
# 2) type: ignore[rule-code]
|
||
|
# This ignore statement must be at the end of the line
|
||
|
|
||
|
# adding an extra backslash before the space, to avoid triggering
|
||
|
# one of the checks in .github/workflows/lint.yml
|
||
|
type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$")
|
||
|
type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines))
|
||
|
|
||
|
if len(type_lines) == 0:
|
||
|
# Catch common typo patterns like extra spaces, typo in 'ignore', etc.
|
||
|
wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):")
|
||
|
wrong_type_lines = list(
|
||
|
filter(lambda line: wrong_type_pattern.search(line[1]), lines)
|
||
|
)
|
||
|
if len(wrong_type_lines) > 0:
|
||
|
raise RuntimeError(
|
||
|
"The annotation prefix in line "
|
||
|
+ str(wrong_type_lines[0][0])
|
||
|
+ " is probably invalid.\nIt must be '# type:'"
|
||
|
+ "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950
|
||
|
+ "\nfor examples"
|
||
|
)
|
||
|
return None
|
||
|
elif len(type_lines) == 1:
|
||
|
# Only 1 type line, quit now
|
||
|
return type_lines[0][1].strip()
|
||
|
|
||
|
# Parse split up argument types according to PEP 484
|
||
|
# https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code
|
||
|
return_line = None
|
||
|
parameter_type_lines = []
|
||
|
for line_num, line in type_lines:
|
||
|
if "# type: (...) -> " in line:
|
||
|
return_line = (line_num, line)
|
||
|
break
|
||
|
elif type_comment in line:
|
||
|
parameter_type_lines.append(line)
|
||
|
if return_line is None:
|
||
|
raise RuntimeError(
|
||
|
"Return type line '# type: (...) -> ...' not found on multiline "
|
||
|
"type annotation\nfor type lines:\n"
|
||
|
+ "\n".join([line[1] for line in type_lines])
|
||
|
+ "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)"
|
||
|
)
|
||
|
|
||
|
def get_parameter_type(line):
|
||
|
item_type = line[line.find(type_comment) + len(type_comment) :]
|
||
|
return item_type.strip()
|
||
|
|
||
|
types = map(get_parameter_type, parameter_type_lines)
|
||
|
parameter_types = ", ".join(types)
|
||
|
|
||
|
return return_line[1].replace("...", parameter_types)
|
||
|
|
||
|
|
||
|
def split_type_line(type_line):
|
||
|
"""Split the comment with the type annotation into parts for argument and return types.
|
||
|
|
||
|
For example, for an input of:
|
||
|
# type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]
|
||
|
|
||
|
This function will return:
|
||
|
("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")
|
||
|
|
||
|
"""
|
||
|
start_offset = len("# type:")
|
||
|
try:
|
||
|
arrow_pos = type_line.index("->")
|
||
|
except ValueError:
|
||
|
raise RuntimeError(
|
||
|
"Syntax error in type annotation (cound't find `->`)"
|
||
|
) from None
|
||
|
return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip()
|
||
|
|
||
|
|
||
|
def try_real_annotations(fn, loc):
|
||
|
"""Try to use the Py3.5+ annotation syntax to get the type."""
|
||
|
try:
|
||
|
# Note: anything annotated as `Optional[T]` will automatically
|
||
|
# be returned as `Union[T, None]` per
|
||
|
# https://github.com/python/typing/blob/master/src/typing.py#L850
|
||
|
sig = inspect.signature(fn)
|
||
|
except ValueError:
|
||
|
return None
|
||
|
|
||
|
all_annots = [sig.return_annotation] + [
|
||
|
p.annotation for p in sig.parameters.values()
|
||
|
]
|
||
|
if all(ann is sig.empty for ann in all_annots):
|
||
|
return None
|
||
|
|
||
|
arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()]
|
||
|
return_type = ann_to_type(sig.return_annotation, loc)
|
||
|
return arg_types, return_type
|
||
|
|
||
|
|
||
|
# Finds common type for enum values belonging to an Enum class. If not all
|
||
|
# values have the same type, AnyType is returned.
|
||
|
def get_enum_value_type(e: Type[enum.Enum], loc):
|
||
|
enum_values: List[enum.Enum] = list(e)
|
||
|
if not enum_values:
|
||
|
raise ValueError(f"No enum values defined for: '{e.__class__}'")
|
||
|
|
||
|
types = {type(v.value) for v in enum_values}
|
||
|
ir_types = [try_ann_to_type(t, loc) for t in types]
|
||
|
|
||
|
# If Enum values are of different types, an exception will be raised here.
|
||
|
# Even though Python supports this case, we chose to not implement it to
|
||
|
# avoid overcomplicate logic here for a rare use case. Please report a
|
||
|
# feature request if you find it necessary.
|
||
|
res = torch._C.unify_type_list(ir_types)
|
||
|
if not res:
|
||
|
return AnyType.get()
|
||
|
return res
|
||
|
|
||
|
|
||
|
def is_tensor(ann):
|
||
|
if issubclass(ann, torch.Tensor):
|
||
|
return True
|
||
|
|
||
|
if issubclass(
|
||
|
ann,
|
||
|
(
|
||
|
torch.LongTensor,
|
||
|
torch.DoubleTensor,
|
||
|
torch.FloatTensor,
|
||
|
torch.IntTensor,
|
||
|
torch.ShortTensor,
|
||
|
torch.HalfTensor,
|
||
|
torch.CharTensor,
|
||
|
torch.ByteTensor,
|
||
|
torch.BoolTensor,
|
||
|
),
|
||
|
):
|
||
|
warnings.warn(
|
||
|
"TorchScript will treat type annotations of Tensor "
|
||
|
"dtype-specific subtypes as if they are normal Tensors. "
|
||
|
"dtype constraints are not enforced in compilation either."
|
||
|
)
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
|
||
|
def _fake_rcb(inp):
|
||
|
return None
|
||
|
|
||
|
|
||
|
def try_ann_to_type(ann, loc, rcb=None):
|
||
|
ann_args = typing.get_args(ann) # always returns a tuple!
|
||
|
|
||
|
if ann is inspect.Signature.empty:
|
||
|
return TensorType.getInferred()
|
||
|
if ann is None:
|
||
|
return NoneType.get()
|
||
|
if inspect.isclass(ann) and is_tensor(ann):
|
||
|
return TensorType.get()
|
||
|
if is_tuple(ann):
|
||
|
# Special case for the empty Tuple type annotation `Tuple[()]`
|
||
|
if len(ann_args) == 1 and ann_args[0] == ():
|
||
|
return TupleType([])
|
||
|
return TupleType([try_ann_to_type(a, loc) for a in ann_args])
|
||
|
if is_list(ann):
|
||
|
elem_type = try_ann_to_type(ann_args[0], loc)
|
||
|
if elem_type:
|
||
|
return ListType(elem_type)
|
||
|
if is_dict(ann):
|
||
|
key = try_ann_to_type(ann_args[0], loc)
|
||
|
value = try_ann_to_type(ann_args[1], loc)
|
||
|
# Raise error if key or value is None
|
||
|
if key is None:
|
||
|
raise ValueError(
|
||
|
f"Unknown type annotation: '{ann_args[0]}' at {loc.highlight()}"
|
||
|
)
|
||
|
if value is None:
|
||
|
raise ValueError(
|
||
|
f"Unknown type annotation: '{ann_args[1]}' at {loc.highlight()}"
|
||
|
)
|
||
|
return DictType(key, value)
|
||
|
if is_optional(ann):
|
||
|
if issubclass(ann_args[1], type(None)):
|
||
|
contained = ann_args[0]
|
||
|
else:
|
||
|
contained = ann_args[1]
|
||
|
valid_type = try_ann_to_type(contained, loc)
|
||
|
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}"
|
||
|
assert valid_type, msg.format(repr(ann), repr(contained), repr(loc))
|
||
|
return OptionalType(valid_type)
|
||
|
if is_union(ann):
|
||
|
# TODO: this is hack to recognize NumberType
|
||
|
if set(ann_args) == {int, float, complex}:
|
||
|
return NumberType.get()
|
||
|
inner: List = []
|
||
|
# We need these extra checks because both `None` and invalid
|
||
|
# values will return `None`
|
||
|
# TODO: Determine if the other cases need to be fixed as well
|
||
|
for a in typing.get_args(ann):
|
||
|
if a is None:
|
||
|
inner.append(NoneType.get())
|
||
|
maybe_type = try_ann_to_type(a, loc)
|
||
|
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}"
|
||
|
assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc))
|
||
|
inner.append(maybe_type)
|
||
|
return UnionType(inner) # type: ignore[arg-type]
|
||
|
if torch.distributed.rpc.is_available() and is_rref(ann):
|
||
|
return RRefType(try_ann_to_type(ann_args[0], loc))
|
||
|
if is_future(ann):
|
||
|
return FutureType(try_ann_to_type(ann_args[0], loc))
|
||
|
if is_await(ann):
|
||
|
elementType = try_ann_to_type(ann_args[0], loc) if ann_args else AnyType.get()
|
||
|
return AwaitType(elementType)
|
||
|
if ann is float:
|
||
|
return FloatType.get()
|
||
|
if ann is complex:
|
||
|
return ComplexType.get()
|
||
|
if ann is int or ann is torch.SymInt:
|
||
|
return IntType.get()
|
||
|
if ann is str:
|
||
|
return StringType.get()
|
||
|
if ann is bool:
|
||
|
return BoolType.get()
|
||
|
if ann is Any:
|
||
|
return AnyType.get()
|
||
|
if ann is type(None):
|
||
|
return NoneType.get()
|
||
|
if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
|
||
|
return InterfaceType(ann.__torch_script_interface__)
|
||
|
if ann is torch.device:
|
||
|
return DeviceObjType.get()
|
||
|
if ann is torch.Generator:
|
||
|
return _GeneratorType.get()
|
||
|
if ann is torch.Stream:
|
||
|
return StreamObjType.get()
|
||
|
if ann is torch.dtype:
|
||
|
return IntType.get() # dtype not yet bound in as its own type
|
||
|
if inspect.isclass(ann) and issubclass(ann, enum.Enum):
|
||
|
if _get_script_class(ann) is None:
|
||
|
scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
|
||
|
name = scripted_class.qualified_name()
|
||
|
else:
|
||
|
name = _qualified_name(ann)
|
||
|
return EnumType(name, get_enum_value_type(ann, loc), list(ann))
|
||
|
if inspect.isclass(ann):
|
||
|
maybe_script_class = _get_script_class(ann)
|
||
|
if maybe_script_class is not None:
|
||
|
return maybe_script_class
|
||
|
if torch._jit_internal.can_compile_class(ann):
|
||
|
return torch.jit._script._recursive_compile_class(ann, loc)
|
||
|
|
||
|
# Maybe resolve a NamedTuple to a Tuple Type
|
||
|
if rcb is None:
|
||
|
rcb = _fake_rcb
|
||
|
return torch._C._resolve_type_from_object(ann, loc, rcb)
|
||
|
|
||
|
|
||
|
def ann_to_type(ann, loc, rcb=None):
|
||
|
the_type = try_ann_to_type(ann, loc, rcb)
|
||
|
if the_type is not None:
|
||
|
return the_type
|
||
|
raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"Any",
|
||
|
"List",
|
||
|
"BroadcastingList1",
|
||
|
"BroadcastingList2",
|
||
|
"BroadcastingList3",
|
||
|
"Tuple",
|
||
|
"is_tuple",
|
||
|
"is_list",
|
||
|
"Dict",
|
||
|
"is_dict",
|
||
|
"is_optional",
|
||
|
"is_union",
|
||
|
"TensorType",
|
||
|
"TupleType",
|
||
|
"FloatType",
|
||
|
"ComplexType",
|
||
|
"IntType",
|
||
|
"ListType",
|
||
|
"StringType",
|
||
|
"DictType",
|
||
|
"AnyType",
|
||
|
"Module",
|
||
|
# TODO: Consider not exporting these during wildcard import (reserve
|
||
|
# that for the types; for idiomatic typing code.)
|
||
|
"get_signature",
|
||
|
"check_fn",
|
||
|
"get_param_names",
|
||
|
"parse_type_line",
|
||
|
"get_type_line",
|
||
|
"split_type_line",
|
||
|
"try_real_annotations",
|
||
|
"try_ann_to_type",
|
||
|
"ann_to_type",
|
||
|
]
|