122 lines
4.8 KiB
Python
122 lines
4.8 KiB
Python
|
import ast
|
||
|
import inspect
|
||
|
import textwrap
|
||
|
import copy
|
||
|
import functools
|
||
|
from types import FunctionType
|
||
|
from typing import cast, Union, Callable, Dict, Optional, Any
|
||
|
from torch.fx._symbolic_trace import Tracer
|
||
|
from torch.fx.graph import Graph
|
||
|
from torch._sources import normalize_source_lines
|
||
|
import torch
|
||
|
|
||
|
class AST_Rewriter(ast.NodeTransformer):
|
||
|
"""
|
||
|
Take a FunctionType object representing a `forward` method, then
|
||
|
perform an AST rewrite to swap out nodes that are not symbolically
|
||
|
traceable with a callsite to the FX alternative.
|
||
|
|
||
|
To support swapping out an AST node, define a new `visit` method on
|
||
|
that node. For more details, see:
|
||
|
https://docs.python.org/3/library/ast.html#ast.NodeTransformer
|
||
|
"""
|
||
|
|
||
|
def rewrite(self, fn: FunctionType):
|
||
|
|
||
|
# Normalize the source lines
|
||
|
sourcelines, _ = inspect.getsourcelines(fn)
|
||
|
sourcelines = normalize_source_lines(sourcelines)
|
||
|
source = ''.join(sourcelines)
|
||
|
normalized_str = textwrap.dedent(source)
|
||
|
|
||
|
# Rewrite the original AST
|
||
|
source_ast = ast.parse(normalized_str)
|
||
|
dest_ast = ast.fix_missing_locations(self.visit(source_ast))
|
||
|
|
||
|
# Pull out the compiled function from the newly-created Module
|
||
|
code = compile(dest_ast, "", "exec")
|
||
|
globals_dict = copy.copy(fn.__globals__)
|
||
|
keys_before = set(globals_dict.keys())
|
||
|
exec(code, globals_dict)
|
||
|
new_keys = list(set(globals_dict.keys()) - keys_before)
|
||
|
assert len(new_keys) == 1
|
||
|
fn_compiled = globals_dict[new_keys[0]]
|
||
|
|
||
|
# return the compiled function with the original globals
|
||
|
def change_func_globals(f, globals):
|
||
|
"""Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
|
||
|
# __globals__ is a private member of the function class
|
||
|
# so we have to copy the function, f, all of its member, except f.__globals__
|
||
|
g = FunctionType(
|
||
|
f.__code__,
|
||
|
globals,
|
||
|
name=f.__name__,
|
||
|
argdefs=f.__defaults__,
|
||
|
closure=f.__closure__,
|
||
|
)
|
||
|
g = functools.update_wrapper(g, f)
|
||
|
g.__kwdefaults__ = copy.copy(f.__kwdefaults__)
|
||
|
return g
|
||
|
# Return the correct FunctionType object
|
||
|
return change_func_globals(fn_compiled, globals=fn.__globals__)
|
||
|
|
||
|
def visit_Assert(self, node):
|
||
|
"""
|
||
|
Swap out the Assert node (Python's `assert`) with a callsite to the
|
||
|
symbolically-traceable torch._assert function
|
||
|
"""
|
||
|
# Create the Call node
|
||
|
n = ast.parse('torch._assert()', mode='eval')
|
||
|
assert isinstance(n, ast.Expression)
|
||
|
call_node = n.body
|
||
|
assert isinstance(call_node, ast.Call)
|
||
|
msg = node.msg if node.msg else ast.Constant(value="", kind=None)
|
||
|
call_node.args = [node.test, msg]
|
||
|
|
||
|
# Ensure that the new node conforms to the Python AST grammar
|
||
|
expr_wrapper = ast.Expr(value=call_node)
|
||
|
|
||
|
# Return the new Call node to signify that we want to use it as
|
||
|
# a replacement for the original _assert node
|
||
|
return ast.copy_location(expr_wrapper, node)
|
||
|
|
||
|
def visit_AnnAssign(self, node):
|
||
|
"""
|
||
|
Swap out Python's AnnAssign with an Assign node where the annotation function is called.
|
||
|
Example:
|
||
|
Original:
|
||
|
y: Tensor_Type(1,2,3, Dyn) = f2(x)
|
||
|
Output:
|
||
|
y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
|
||
|
"""
|
||
|
return ast.Assign(targets=[node.target], value=ast.Call(
|
||
|
func=ast.Name(id='annotate', ctx=ast.Load()),
|
||
|
args=[node.value, node.annotation], keywords=[]))
|
||
|
|
||
|
|
||
|
class RewritingTracer(Tracer):
|
||
|
def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
|
||
|
return super().trace(_rewrite(root), concrete_args)
|
||
|
|
||
|
|
||
|
def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
|
||
|
if isinstance(fn, torch.nn.Module):
|
||
|
# Rewrite this module's `forward` as well as the `forward`s of
|
||
|
# all of this module's recursive descendents. Return the new,
|
||
|
# rewritten module hierarchy.
|
||
|
def rewrite_module(m : torch.nn.Module):
|
||
|
class RewrittenModule(torch.nn.Module):
|
||
|
def __init__(self, orig):
|
||
|
super().__init__()
|
||
|
for k, v in orig.__dict__.items():
|
||
|
if isinstance(v, torch.nn.Module):
|
||
|
self.__dict__[k] = copy.copy(rewrite_module(v))
|
||
|
else:
|
||
|
self.__dict__[k] = copy.copy(v)
|
||
|
RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward))
|
||
|
return RewrittenModule(m)
|
||
|
return rewrite_module(fn)
|
||
|
else:
|
||
|
# Rewrite this single free function
|
||
|
return AST_Rewriter().rewrite(cast(FunctionType, fn))
|