import ast import inspect import textwrap import copy import functools from types import FunctionType from typing import cast, Union, Callable, Dict, Optional, Any from torch.fx._symbolic_trace import Tracer from torch.fx.graph import Graph from torch._sources import normalize_source_lines import torch class AST_Rewriter(ast.NodeTransformer): """ Take a FunctionType object representing a `forward` method, then perform an AST rewrite to swap out nodes that are not symbolically traceable with a callsite to the FX alternative. To support swapping out an AST node, define a new `visit` method on that node. For more details, see: https://docs.python.org/3/library/ast.html#ast.NodeTransformer """ def rewrite(self, fn: FunctionType): # Normalize the source lines sourcelines, _ = inspect.getsourcelines(fn) sourcelines = normalize_source_lines(sourcelines) source = ''.join(sourcelines) normalized_str = textwrap.dedent(source) # Rewrite the original AST source_ast = ast.parse(normalized_str) dest_ast = ast.fix_missing_locations(self.visit(source_ast)) # Pull out the compiled function from the newly-created Module code = compile(dest_ast, "", "exec") globals_dict = copy.copy(fn.__globals__) keys_before = set(globals_dict.keys()) exec(code, globals_dict) new_keys = list(set(globals_dict.keys()) - keys_before) assert len(new_keys) == 1 fn_compiled = globals_dict[new_keys[0]] # return the compiled function with the original globals def change_func_globals(f, globals): """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" # __globals__ is a private member of the function class # so we have to copy the function, f, all of its member, except f.__globals__ g = FunctionType( f.__code__, globals, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__, ) g = functools.update_wrapper(g, f) g.__kwdefaults__ = copy.copy(f.__kwdefaults__) return g # Return the correct FunctionType object return change_func_globals(fn_compiled, globals=fn.__globals__) def visit_Assert(self, node): """ Swap out the Assert node (Python's `assert`) with a callsite to the symbolically-traceable torch._assert function """ # Create the Call node n = ast.parse('torch._assert()', mode='eval') assert isinstance(n, ast.Expression) call_node = n.body assert isinstance(call_node, ast.Call) msg = node.msg if node.msg else ast.Constant(value="", kind=None) call_node.args = [node.test, msg] # Ensure that the new node conforms to the Python AST grammar expr_wrapper = ast.Expr(value=call_node) # Return the new Call node to signify that we want to use it as # a replacement for the original _assert node return ast.copy_location(expr_wrapper, node) def visit_AnnAssign(self, node): """ Swap out Python's AnnAssign with an Assign node where the annotation function is called. Example: Original: y: Tensor_Type(1,2,3, Dyn) = f2(x) Output: y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) """ return ast.Assign(targets=[node.target], value=ast.Call( func=ast.Name(id='annotate', ctx=ast.Load()), args=[node.value, node.annotation], keywords=[])) class RewritingTracer(Tracer): def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: return super().trace(_rewrite(root), concrete_args) def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: if isinstance(fn, torch.nn.Module): # Rewrite this module's `forward` as well as the `forward`s of # all of this module's recursive descendents. Return the new, # rewritten module hierarchy. def rewrite_module(m : torch.nn.Module): class RewrittenModule(torch.nn.Module): def __init__(self, orig): super().__init__() for k, v in orig.__dict__.items(): if isinstance(v, torch.nn.Module): self.__dict__[k] = copy.copy(rewrite_module(v)) else: self.__dict__[k] = copy.copy(v) RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward)) return RewrittenModule(m) return rewrite_module(fn) else: # Rewrite this single free function return AST_Rewriter().rewrite(cast(FunctionType, fn))