249 lines
9.1 KiB
Python
249 lines
9.1 KiB
Python
|
import ast
|
||
|
import inspect
|
||
|
import textwrap
|
||
|
import warnings
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
|
||
|
"""Check the ``__init__`` method of a given ``nn.Module``.
|
||
|
|
||
|
It ensures that all instance-level attributes can be properly initialized.
|
||
|
|
||
|
Specifically, we do type inference based on attribute values...even
|
||
|
if the attribute in question has already been typed using
|
||
|
Python3-style annotations or ``torch.jit.annotate``. This means that
|
||
|
setting an instance-level attribute to ``[]`` (for ``List``),
|
||
|
``{}`` for ``Dict``), or ``None`` (for ``Optional``) isn't enough
|
||
|
information for us to properly initialize that attribute.
|
||
|
|
||
|
An object of this class can walk a given ``nn.Module``'s AST and
|
||
|
determine if it meets our requirements or not.
|
||
|
|
||
|
Known limitations
|
||
|
1. We can only check the AST nodes for certain constructs; we can't
|
||
|
``eval`` arbitrary expressions. This means that function calls,
|
||
|
class instantiations, and complex expressions that resolve to one of
|
||
|
the "empty" values specified above will NOT be flagged as
|
||
|
problematic.
|
||
|
2. We match on string literals, so if the user decides to use a
|
||
|
non-standard import (e.g. `from typing import List as foo`), we
|
||
|
won't catch it.
|
||
|
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
|
class M(torch.nn.Module):
|
||
|
def fn(self):
|
||
|
return []
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.x: List[int] = []
|
||
|
|
||
|
def forward(self, x: List[int]):
|
||
|
self.x = x
|
||
|
return 1
|
||
|
|
||
|
The above code will pass the ``AttributeTypeIsSupportedChecker``
|
||
|
check since we have a function call in ``__init__``. However,
|
||
|
it will still fail later with the ``RuntimeError`` "Tried to set
|
||
|
nonexistent attribute: x. Did you forget to initialize it in
|
||
|
__init__()?".
|
||
|
|
||
|
Args:
|
||
|
nn_module - The instance of ``torch.nn.Module`` whose
|
||
|
``__init__`` method we wish to check
|
||
|
"""
|
||
|
|
||
|
def check(self, nn_module: torch.nn.Module) -> None:
|
||
|
source_lines = inspect.getsource(nn_module.__class__.__init__)
|
||
|
|
||
|
# Ignore comments no matter the indentation
|
||
|
def is_useless_comment(line):
|
||
|
line = line.strip()
|
||
|
return line.startswith("#") and not line.startswith("# type:")
|
||
|
|
||
|
source_lines = "\n".join(
|
||
|
[l for l in source_lines.split("\n") if not is_useless_comment(l)]
|
||
|
)
|
||
|
|
||
|
# This AST only contains the `__init__` method of the nn.Module
|
||
|
init_ast = ast.parse(textwrap.dedent(source_lines))
|
||
|
|
||
|
# Get items annotated in the class body
|
||
|
self.class_level_annotations = list(nn_module.__annotations__.keys())
|
||
|
|
||
|
# Flag for later
|
||
|
self.visiting_class_level_ann = False
|
||
|
|
||
|
self.visit(init_ast)
|
||
|
|
||
|
def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool:
|
||
|
if ann_type == "List":
|
||
|
# Assigning `[]` to a `List` type gives you a Node where
|
||
|
# value=List(elts=[], ctx=Load())
|
||
|
if not isinstance(node, ast.List):
|
||
|
return False
|
||
|
if node.elts:
|
||
|
return False
|
||
|
elif ann_type == "Dict":
|
||
|
# Assigning `{}` to a `Dict` type gives you a Node where
|
||
|
# value=Dict(keys=[], values=[])
|
||
|
if not isinstance(node, ast.Dict):
|
||
|
return False
|
||
|
if node.keys:
|
||
|
return False
|
||
|
elif ann_type == "Optional":
|
||
|
# Assigning `None` to an `Optional` type gives you a
|
||
|
# Node where value=Constant(value=None, kind=None)
|
||
|
if not isinstance(node, ast.Constant):
|
||
|
return False
|
||
|
if node.value: # type: ignore[attr-defined]
|
||
|
return False
|
||
|
|
||
|
return True
|
||
|
|
||
|
def visit_Assign(self, node):
|
||
|
"""Store assignment state when assigning to a Call Node.
|
||
|
|
||
|
If we're visiting a Call Node (the right-hand side of an
|
||
|
assignment statement), we won't be able to check the variable
|
||
|
that we're assigning to (the left-hand side of an assignment).
|
||
|
Because of this, we need to store this state in visitAssign.
|
||
|
(Luckily, we only have to do this if we're assigning to a Call
|
||
|
Node, i.e. ``torch.jit.annotate``. If we're using normal Python
|
||
|
annotations, we'll be visiting an AnnAssign Node, which has its
|
||
|
target built in.)
|
||
|
"""
|
||
|
try:
|
||
|
if (
|
||
|
isinstance(node.value, ast.Call)
|
||
|
and node.targets[0].attr in self.class_level_annotations
|
||
|
):
|
||
|
self.visiting_class_level_ann = True
|
||
|
except AttributeError:
|
||
|
return
|
||
|
self.generic_visit(node)
|
||
|
self.visiting_class_level_ann = False
|
||
|
|
||
|
def visit_AnnAssign(self, node):
|
||
|
"""Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` method.
|
||
|
|
||
|
It checks if it conforms to our attribute annotation rules."""
|
||
|
# If we have a local variable
|
||
|
try:
|
||
|
if node.target.value.id != "self":
|
||
|
return
|
||
|
except AttributeError:
|
||
|
return
|
||
|
|
||
|
# If we have an attribute that's already been annotated at the
|
||
|
# class level
|
||
|
if node.target.attr in self.class_level_annotations:
|
||
|
return
|
||
|
|
||
|
# TODO @ansley: add `Union` once landed
|
||
|
|
||
|
# NB: Even though `Tuple` is a "container", we don't want to
|
||
|
# check for it here. `Tuple` functions as an type with an
|
||
|
# "infinite" number of subtypes, in the sense that you can have
|
||
|
# `Tuple[())]`, `Tuple[T1]`, `Tuple[T2]`, `Tuple[T1, T2]`,
|
||
|
# `Tuple[T2, T1]` and so on, and none of these subtypes can be
|
||
|
# used in place of the other. Therefore, assigning an empty
|
||
|
# tuple in `__init__` CORRECTLY means that that variable
|
||
|
# cannot be reassigned later to a non-empty tuple. Same
|
||
|
# deal with `NamedTuple`
|
||
|
|
||
|
containers = {"List", "Dict", "Optional"}
|
||
|
|
||
|
# If we're not evaluating one of the specified problem types
|
||
|
try:
|
||
|
if node.annotation.value.id not in containers:
|
||
|
return
|
||
|
except AttributeError:
|
||
|
# To evaluate a base type (`str`, `int`, etc.), we would
|
||
|
# have needed to get the name through `node.annotation.id`
|
||
|
# instead of `node.annotation.value.id`. Seems that we're
|
||
|
# not evaluating one of our "containers"
|
||
|
return
|
||
|
|
||
|
# Check if the assigned variable is empty
|
||
|
ann_type = node.annotation.value.id
|
||
|
if not self._is_empty_container(node.value, ann_type):
|
||
|
return
|
||
|
|
||
|
warnings.warn(
|
||
|
"The TorchScript type system doesn't support "
|
||
|
"instance-level annotations on empty non-base "
|
||
|
"types in `__init__`. Instead, either 1) use a "
|
||
|
"type annotation in the class body, or 2) wrap "
|
||
|
"the type in `torch.jit.Attribute`."
|
||
|
)
|
||
|
|
||
|
def visit_Call(self, node):
|
||
|
"""Determine if a Call node is 'torch.jit.annotate' in __init__.
|
||
|
|
||
|
Visit a Call node in an ``nn.Module``'s ``__init__``
|
||
|
method and determine if it's ``torch.jit.annotate``. If so,
|
||
|
see if it conforms to our attribute annotation rules.
|
||
|
"""
|
||
|
# If we have an attribute that's already been annotated at the
|
||
|
# class level
|
||
|
if self.visiting_class_level_ann:
|
||
|
return
|
||
|
|
||
|
# If this isn't a call to `torch.jit.annotate`
|
||
|
try:
|
||
|
if (
|
||
|
node.func.value.value.id != "torch"
|
||
|
or node.func.value.attr != "jit"
|
||
|
or node.func.attr != "annotate"
|
||
|
):
|
||
|
self.generic_visit(node)
|
||
|
elif (
|
||
|
node.func.value.value.id != "jit" or node.func.value.attr != "annotate"
|
||
|
):
|
||
|
self.generic_visit(node)
|
||
|
except AttributeError:
|
||
|
# Looks like we didn't even have the right node structure
|
||
|
# to check for `torch.jit.annotate` in the first place
|
||
|
self.generic_visit(node)
|
||
|
|
||
|
# Invariant: we have a `torch.jit.annotate` or a
|
||
|
# `torch.annotate` call
|
||
|
|
||
|
# A Call Node for `torch.jit.annotate` should have an `args`
|
||
|
# list of length 2 where args[0] represents the annotation and
|
||
|
# args[1] represents the actual value
|
||
|
if len(node.args) != 2:
|
||
|
return
|
||
|
|
||
|
if not isinstance(node.args[0], ast.Subscript):
|
||
|
return
|
||
|
|
||
|
# See notes in `visit_AnnAssign` r.e. containers
|
||
|
|
||
|
containers = {"List", "Dict", "Optional"}
|
||
|
|
||
|
try:
|
||
|
ann_type = node.args[0].value.id # type: ignore[attr-defined]
|
||
|
except AttributeError:
|
||
|
return
|
||
|
|
||
|
if ann_type not in containers:
|
||
|
return
|
||
|
|
||
|
# Check if the assigned variable is empty
|
||
|
if not self._is_empty_container(node.args[1], ann_type):
|
||
|
return
|
||
|
|
||
|
warnings.warn(
|
||
|
"The TorchScript type system doesn't support "
|
||
|
"instance-level annotations on empty non-base "
|
||
|
"types in `__init__`. Instead, either 1) use a "
|
||
|
"type annotation in the class body, or 2) wrap "
|
||
|
"the type in `torch.jit.Attribute`."
|
||
|
)
|