Traktor/myenv/Lib/site-packages/torch/fx/experimental/rewriter.py

122 lines
4.8 KiB
Python
Raw Permalink Normal View History

2024-05-26 05:12:46 +02:00
import ast
import inspect
import textwrap
import copy
import functools
from types import FunctionType
from typing import cast, Union, Callable, Dict, Optional, Any
from torch.fx._symbolic_trace import Tracer
from torch.fx.graph import Graph
from torch._sources import normalize_source_lines
import torch
class AST_Rewriter(ast.NodeTransformer):
"""
Take a FunctionType object representing a `forward` method, then
perform an AST rewrite to swap out nodes that are not symbolically
traceable with a callsite to the FX alternative.
To support swapping out an AST node, define a new `visit` method on
that node. For more details, see:
https://docs.python.org/3/library/ast.html#ast.NodeTransformer
"""
def rewrite(self, fn: FunctionType):
# Normalize the source lines
sourcelines, _ = inspect.getsourcelines(fn)
sourcelines = normalize_source_lines(sourcelines)
source = ''.join(sourcelines)
normalized_str = textwrap.dedent(source)
# Rewrite the original AST
source_ast = ast.parse(normalized_str)
dest_ast = ast.fix_missing_locations(self.visit(source_ast))
# Pull out the compiled function from the newly-created Module
code = compile(dest_ast, "", "exec")
globals_dict = copy.copy(fn.__globals__)
keys_before = set(globals_dict.keys())
exec(code, globals_dict)
new_keys = list(set(globals_dict.keys()) - keys_before)
assert len(new_keys) == 1
fn_compiled = globals_dict[new_keys[0]]
# return the compiled function with the original globals
def change_func_globals(f, globals):
"""Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
# __globals__ is a private member of the function class
# so we have to copy the function, f, all of its member, except f.__globals__
g = FunctionType(
f.__code__,
globals,
name=f.__name__,
argdefs=f.__defaults__,
closure=f.__closure__,
)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = copy.copy(f.__kwdefaults__)
return g
# Return the correct FunctionType object
return change_func_globals(fn_compiled, globals=fn.__globals__)
def visit_Assert(self, node):
"""
Swap out the Assert node (Python's `assert`) with a callsite to the
symbolically-traceable torch._assert function
"""
# Create the Call node
n = ast.parse('torch._assert()', mode='eval')
assert isinstance(n, ast.Expression)
call_node = n.body
assert isinstance(call_node, ast.Call)
msg = node.msg if node.msg else ast.Constant(value="", kind=None)
call_node.args = [node.test, msg]
# Ensure that the new node conforms to the Python AST grammar
expr_wrapper = ast.Expr(value=call_node)
# Return the new Call node to signify that we want to use it as
# a replacement for the original _assert node
return ast.copy_location(expr_wrapper, node)
def visit_AnnAssign(self, node):
"""
Swap out Python's AnnAssign with an Assign node where the annotation function is called.
Example:
Original:
y: Tensor_Type(1,2,3, Dyn) = f2(x)
Output:
y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
"""
return ast.Assign(targets=[node.target], value=ast.Call(
func=ast.Name(id='annotate', ctx=ast.Load()),
args=[node.value, node.annotation], keywords=[]))
class RewritingTracer(Tracer):
def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
return super().trace(_rewrite(root), concrete_args)
def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
if isinstance(fn, torch.nn.Module):
# Rewrite this module's `forward` as well as the `forward`s of
# all of this module's recursive descendents. Return the new,
# rewritten module hierarchy.
def rewrite_module(m : torch.nn.Module):
class RewrittenModule(torch.nn.Module):
def __init__(self, orig):
super().__init__()
for k, v in orig.__dict__.items():
if isinstance(v, torch.nn.Module):
self.__dict__[k] = copy.copy(rewrite_module(v))
else:
self.__dict__[k] = copy.copy(v)
RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward))
return RewrittenModule(m)
return rewrite_module(fn)
else:
# Rewrite this single free function
return AST_Rewriter().rewrite(cast(FunctionType, fn))