885 lines
34 KiB
Python
885 lines
34 KiB
Python
|
import contextlib
|
||
|
import copy
|
||
|
import itertools
|
||
|
import linecache
|
||
|
import os
|
||
|
import sys
|
||
|
import traceback
|
||
|
import warnings
|
||
|
from pathlib import Path
|
||
|
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.overrides
|
||
|
from torch.nn.modules.module import _addindent
|
||
|
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
|
||
|
|
||
|
from ._compatibility import compatibility
|
||
|
from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode
|
||
|
|
||
|
__all__ = [
|
||
|
"reduce_graph_module",
|
||
|
"reduce_package_graph_module",
|
||
|
"reduce_deploy_graph_module",
|
||
|
"GraphModule",
|
||
|
]
|
||
|
|
||
|
_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
|
||
|
|
||
|
# Normal exec loses the source code, however we can work with
|
||
|
# the linecache module to recover it.
|
||
|
# Using _exec_with_source will add it to our local cache
|
||
|
# and then tools like TorchScript will be able to get source info.
|
||
|
class _EvalCacheLoader:
|
||
|
def __init__(self):
|
||
|
self.eval_cache = {}
|
||
|
self.next_id = 0
|
||
|
|
||
|
def cache(self, src: str, globals: Dict[str, Any], co_fields=None):
|
||
|
"""Store the source in a private cache, and add a lazy entry in linecache
|
||
|
that allows the source to be retrieved by 'filename'.
|
||
|
|
||
|
Args:
|
||
|
src (str): The module source to cache
|
||
|
globals (dict): The module globals
|
||
|
|
||
|
Returns:
|
||
|
str: The cache key (and dummy filename) generated for src.
|
||
|
"""
|
||
|
|
||
|
key = self._get_key()
|
||
|
if co_fields:
|
||
|
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
|
||
|
self.eval_cache[key] = src
|
||
|
|
||
|
# Don't mutate globals so that this loader is only used
|
||
|
# to populate linecache, and doesn't interact with other modules
|
||
|
# that might check `__loader__`
|
||
|
globals_copy = globals.copy()
|
||
|
globals_copy["__file__"] = key
|
||
|
globals_copy["__name__"] = key
|
||
|
globals_copy["__loader__"] = self
|
||
|
linecache.lazycache(key, globals_copy)
|
||
|
|
||
|
return key
|
||
|
|
||
|
# Part of the loader protocol (PEP 302)
|
||
|
# linecache will use this method when trying to find source code
|
||
|
def get_source(self, module_name) -> Optional[str]:
|
||
|
if module_name in self.eval_cache:
|
||
|
return self.eval_cache[module_name]
|
||
|
return None
|
||
|
|
||
|
def _get_key(self):
|
||
|
key = f"<eval_with_key>.{self.next_id}"
|
||
|
self.next_id += 1
|
||
|
return key
|
||
|
|
||
|
|
||
|
_loader = _EvalCacheLoader()
|
||
|
|
||
|
|
||
|
def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None):
|
||
|
key = _loader.cache(src, globals, co_fields)
|
||
|
exec(compile(src, key, "exec"), globals)
|
||
|
|
||
|
|
||
|
def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None):
|
||
|
return _method_from_src(
|
||
|
method_name="forward", src=src, globals=globals, co_fields=co_fields
|
||
|
)
|
||
|
|
||
|
|
||
|
def _method_from_src(
|
||
|
method_name: str, src: str, globals: Dict[str, Any], co_fields=None
|
||
|
) -> Callable:
|
||
|
# avoid mutating the passed in dict
|
||
|
globals_copy = globals.copy()
|
||
|
_exec_with_source(src, globals_copy, co_fields)
|
||
|
fn = globals_copy[method_name]
|
||
|
del globals_copy[method_name]
|
||
|
return fn
|
||
|
|
||
|
|
||
|
def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
|
||
|
if name in _custom_builtins:
|
||
|
return _custom_builtins[name].import_str
|
||
|
if _is_from_torch(name):
|
||
|
return "import torch"
|
||
|
module_name, attr_name = importer.get_name(obj)
|
||
|
return f"from {module_name} import {attr_name} as {name}"
|
||
|
|
||
|
|
||
|
def _format_import_block(globals: Dict[str, Any], importer: Importer):
|
||
|
import_strs: Set[str] = set()
|
||
|
for name, obj in globals.items():
|
||
|
import_strs.add(_format_import_statement(name, obj, importer))
|
||
|
# Sort the imports so we have a stable import block that allows us to
|
||
|
# hash the graph module and get a consistent key for use in a cache.
|
||
|
return "\n".join(sorted(import_strs))
|
||
|
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
|
||
|
# BC: attribute name was changed from `code` to `_code` to facilitate
|
||
|
# making `code` into a property and adding a docstring to it
|
||
|
fn_src = body.get("_code") or body["code"]
|
||
|
forward = _forward_from_src(import_block + fn_src, {})
|
||
|
return _deserialize_graph_module(forward, body)
|
||
|
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def reduce_package_graph_module(
|
||
|
importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
|
||
|
) -> torch.nn.Module:
|
||
|
forward = importer.import_module(generated_module_name).forward
|
||
|
return _deserialize_graph_module(forward, body)
|
||
|
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def reduce_deploy_graph_module(
|
||
|
importer: PackageImporter, body: Dict[Any, Any], import_block: str
|
||
|
) -> torch.nn.Module:
|
||
|
ns = {}
|
||
|
ns["__builtins__"] = importer.patched_builtins
|
||
|
fn_src = body.get("_code")
|
||
|
assert fn_src is not None
|
||
|
forward = _forward_from_src(import_block + fn_src, ns)
|
||
|
return _deserialize_graph_module(forward, body)
|
||
|
|
||
|
|
||
|
# We create a dummy class here because symbolic_trace pulls the forward()
|
||
|
# function off of the class, rather than the instance. This class is used
|
||
|
# in _deserialize_graph_module() below.
|
||
|
class _CodeOnlyModule(torch.nn.Module):
|
||
|
def __init__(self, body):
|
||
|
super().__init__()
|
||
|
self.__dict__ = body
|
||
|
|
||
|
|
||
|
def _deserialize_graph_module(forward, body: Dict[Any, Any], graph_module_cls=None) -> torch.nn.Module:
|
||
|
"""
|
||
|
Deserialize a GraphModule given the dictionary of the original module,
|
||
|
using the code to reconstruct the graph. We delete the actual graph before
|
||
|
saving the dictionary so that changes to the in-memory graph format do not
|
||
|
get serialized.
|
||
|
"""
|
||
|
|
||
|
# Try to retrieve the forward source in a backward-compatible way
|
||
|
_CodeOnlyModule.forward = forward
|
||
|
|
||
|
tracer_cls = body.get("_tracer_cls")
|
||
|
if tracer_cls is None:
|
||
|
from ._symbolic_trace import Tracer
|
||
|
|
||
|
tracer_cls = Tracer
|
||
|
|
||
|
graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule")
|
||
|
|
||
|
# This is a workaround for a mypy linter issue related to
|
||
|
# passing base class as an argument - https://github.com/python/mypy/issues/5865.
|
||
|
cls_tracer: Any = tracer_cls
|
||
|
|
||
|
class KeepModules(cls_tracer):
|
||
|
# we shouldn't trace into any of the submodules,
|
||
|
# because they were not traced in the original GraphModule
|
||
|
def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
|
||
|
return True
|
||
|
|
||
|
com = _CodeOnlyModule(body)
|
||
|
|
||
|
tracer_extras = body.get("_tracer_extras", {})
|
||
|
graph = KeepModules().trace(com, **tracer_extras)
|
||
|
|
||
|
# Manually set Tracer class on the reconstructed Graph, to avoid
|
||
|
# referencing the private local subclass KeepModules.
|
||
|
graph._tracer_cls = tracer_cls
|
||
|
from ._lazy_graph_module import _make_graph_module
|
||
|
gm = _make_graph_module(com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls)
|
||
|
|
||
|
# The GraphModule constructor only retains attributes referenced by the graph.
|
||
|
# In this case, our goal is return a GraphModule as close to identical as the one
|
||
|
# put into the package. If any additional attributes were present in body,
|
||
|
# we should keep them.
|
||
|
for k, v in body.items():
|
||
|
if not hasattr(gm, k):
|
||
|
setattr(gm, k, v)
|
||
|
return gm
|
||
|
|
||
|
|
||
|
# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
|
||
|
# This installs empty Modules where none exist yet if they are subpaths of target
|
||
|
def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
|
||
|
*prefix, field = target.split(".")
|
||
|
for item in prefix:
|
||
|
f = getattr(from_module, item)
|
||
|
t = getattr(to_module, item, None)
|
||
|
if f is t:
|
||
|
# we have already installed one of its parents
|
||
|
# (e.g. target = root.linear.weight, but we have already installed root.linear)
|
||
|
# once we install a parent, we no longer need to copy the children
|
||
|
# since all the needed properties will already be present
|
||
|
return
|
||
|
|
||
|
if t is None:
|
||
|
t = torch.nn.Module()
|
||
|
setattr(to_module, item, t)
|
||
|
from_module, to_module = f, t
|
||
|
|
||
|
orig = getattr(from_module, field)
|
||
|
# If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
|
||
|
# So, we register it as a named buffer in the target module.
|
||
|
if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
|
||
|
to_module.register_buffer(field, orig)
|
||
|
else:
|
||
|
setattr(to_module, field, orig)
|
||
|
|
||
|
|
||
|
# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
|
||
|
# This installs empty Modules where none exist yet if they are subpaths of target
|
||
|
def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
|
||
|
*prefix, field = target.split(".")
|
||
|
for item in prefix:
|
||
|
t = getattr(to_module, item, None)
|
||
|
|
||
|
if t is None:
|
||
|
t = torch.nn.Module()
|
||
|
setattr(to_module, item, t)
|
||
|
to_module = t
|
||
|
|
||
|
# If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
|
||
|
# So, we register it as a named buffer in the target module.
|
||
|
if isinstance(from_obj, torch.Tensor) and not isinstance(
|
||
|
from_obj, torch.nn.Parameter
|
||
|
):
|
||
|
to_module.register_buffer(field, from_obj)
|
||
|
else:
|
||
|
setattr(to_module, field, from_obj)
|
||
|
|
||
|
|
||
|
class _WrappedCall:
|
||
|
def __init__(self, cls, cls_call):
|
||
|
self.cls = cls
|
||
|
self.cls_call = cls_call
|
||
|
|
||
|
# Previously, if an error occurred when valid
|
||
|
# symbolically-traced code was run with an invalid input, the
|
||
|
# user would see the source of the error as coming from
|
||
|
# `File "<eval_with_key_N">`, where N is some number. We use
|
||
|
# this function to generate a more informative error message. We
|
||
|
# return the traceback itself, a message explaining that the
|
||
|
# error occurred in a traced Module's generated forward
|
||
|
# function, and five lines of context surrounding the faulty
|
||
|
# line
|
||
|
@staticmethod
|
||
|
def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
|
||
|
# auxiliary variables (for readability)
|
||
|
err_lineno = frame_summary.lineno
|
||
|
assert err_lineno is not None
|
||
|
line = frame_summary.line
|
||
|
assert line is not None
|
||
|
err_line_len = len(line)
|
||
|
all_src_lines = linecache.getlines(frame_summary.filename)
|
||
|
|
||
|
# constituent substrings of the error message
|
||
|
tb_repr = traceback.format_exc()
|
||
|
custom_msg = (
|
||
|
"Call using an FX-traced Module, "
|
||
|
f"line {err_lineno} of the traced Module's "
|
||
|
"generated forward function:"
|
||
|
)
|
||
|
before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
|
||
|
marker = "~" * err_line_len + "~~~ <--- HERE"
|
||
|
err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
|
||
|
|
||
|
# joined message
|
||
|
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
|
||
|
|
||
|
def __call__(self, obj, *args, **kwargs):
|
||
|
try:
|
||
|
if self.cls_call is not None:
|
||
|
return self.cls_call(obj, *args, **kwargs)
|
||
|
else:
|
||
|
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
|
||
|
except Exception as e:
|
||
|
assert e.__traceback__
|
||
|
topmost_framesummary: traceback.FrameSummary = (
|
||
|
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1]
|
||
|
) # type: ignore[arg-type]
|
||
|
if "eval_with_key" in topmost_framesummary.filename:
|
||
|
print(
|
||
|
_WrappedCall._generate_error_message(topmost_framesummary),
|
||
|
file=sys.stderr,
|
||
|
)
|
||
|
raise e.with_traceback(None) # noqa: TRY200
|
||
|
else:
|
||
|
raise e
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
class GraphModule(torch.nn.Module):
|
||
|
"""
|
||
|
GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
|
||
|
``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
|
||
|
from that ``graph``.
|
||
|
|
||
|
.. warning::
|
||
|
|
||
|
When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
|
||
|
regenerated. However, if you edit the contents of the ``graph`` without reassigning
|
||
|
the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
|
||
|
code.
|
||
|
"""
|
||
|
|
||
|
def __new__(cls: "Type[GraphModule]", *args, **kwargs):
|
||
|
# each instance of a graph module needs its own forward method
|
||
|
# so create a new singleton class for each instance.
|
||
|
# it is a subclass of the user-defined class, the only difference
|
||
|
# is an extra layer to install the forward method
|
||
|
|
||
|
# address issue described at https://github.com/pytorch/pytorch/issues/63883
|
||
|
# in other words, traverse class hierarchy to fix the redundant class definition problem
|
||
|
for t in cls.__mro__:
|
||
|
c = t.__qualname__.split(".")[-1]
|
||
|
if c != "GraphModuleImpl":
|
||
|
cls = t
|
||
|
break
|
||
|
|
||
|
class GraphModuleImpl(cls): # type: ignore[misc, valid-type]
|
||
|
pass
|
||
|
|
||
|
return super().__new__(GraphModuleImpl)
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def __init__(
|
||
|
self,
|
||
|
root: Union[torch.nn.Module, Dict[str, Any]],
|
||
|
graph: Graph,
|
||
|
class_name: str = "GraphModule",
|
||
|
):
|
||
|
"""
|
||
|
Construct a GraphModule.
|
||
|
|
||
|
Args:
|
||
|
|
||
|
root (Union[torch.nn.Module, Dict[str, Any]):
|
||
|
``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
|
||
|
In the case that ``root`` is a Module, any references to Module-based objects (via qualified
|
||
|
name) in the Graph's Nodes' ``target`` field will be copied over from the respective place
|
||
|
within ``root``'s Module hierarchy into the GraphModule's module hierarchy.
|
||
|
In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be
|
||
|
looked up directly in the dict's keys. The object mapped to by the Dict will be copied
|
||
|
over into the appropriate place within the GraphModule's module hierarchy.
|
||
|
|
||
|
graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation
|
||
|
|
||
|
class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
|
||
|
error messages will report as originating from ``GraphModule``. It may be helpful to set this
|
||
|
to ``root``'s original name or a name that makes sense within the context of your transform.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.__class__.__name__ = class_name
|
||
|
if isinstance(root, torch.nn.Module):
|
||
|
if hasattr(root, "training"):
|
||
|
self.training = root.training
|
||
|
|
||
|
# When we pickle/unpickle graph module, we don't want to drop any module or attributes.
|
||
|
if isinstance(root, _CodeOnlyModule):
|
||
|
for k, _ in root.named_children():
|
||
|
_copy_attr(root, self, k)
|
||
|
|
||
|
for k, _ in root.named_buffers():
|
||
|
_copy_attr(root, self, k)
|
||
|
|
||
|
for k, _ in root.named_parameters():
|
||
|
_copy_attr(root, self, k)
|
||
|
|
||
|
for node in graph.nodes:
|
||
|
if node.op in ["get_attr", "call_module"]:
|
||
|
assert isinstance(node.target, str)
|
||
|
_copy_attr(root, self, node.target)
|
||
|
elif isinstance(root, dict):
|
||
|
targets_to_copy = []
|
||
|
for node in graph.nodes:
|
||
|
if node.op in ["get_attr", "call_module"]:
|
||
|
assert isinstance(node.target, str)
|
||
|
if node.target not in root:
|
||
|
raise RuntimeError(
|
||
|
"Node "
|
||
|
+ str(node)
|
||
|
+ " referenced target "
|
||
|
+ node.target
|
||
|
+ " but that target was not provided in ``root``!"
|
||
|
)
|
||
|
targets_to_copy.append(node.target)
|
||
|
# Sort targets in ascending order of the # of atoms.
|
||
|
# This will ensure that less deeply nested attributes are assigned
|
||
|
# before more deeply nested attributes. For example, foo.bar
|
||
|
# will be assigned before foo.bar.baz. Otherwise, we might assign
|
||
|
# the user-provided ``foo.bar`` and wipe out the previously-assigned
|
||
|
# ``foo.bar.baz``
|
||
|
targets_to_copy.sort(key=lambda t: t.count("."))
|
||
|
for target_to_copy in targets_to_copy:
|
||
|
_assign_attr(root[target_to_copy], self, target_to_copy)
|
||
|
else:
|
||
|
raise RuntimeError("Unsupported type " + str(root) + " passed for root!")
|
||
|
|
||
|
self.graph = graph
|
||
|
|
||
|
# Store the Tracer class responsible for creating a Graph separately as part of the
|
||
|
# GraphModule state, except when the Tracer is defined in a local namespace.
|
||
|
# Locally defined Tracers are not pickleable. This is needed because torch.package will
|
||
|
# serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
|
||
|
# to re-create the Graph during deserialization.
|
||
|
self._tracer_cls = None
|
||
|
if (
|
||
|
self.graph._tracer_cls
|
||
|
and "<locals>" not in self.graph._tracer_cls.__qualname__
|
||
|
):
|
||
|
self._tracer_cls = self.graph._tracer_cls
|
||
|
|
||
|
self._tracer_extras = {}
|
||
|
if self.graph._tracer_extras:
|
||
|
self._tracer_extras = self.graph._tracer_extras
|
||
|
|
||
|
# Dictionary to store metadata
|
||
|
self.meta: Dict[str, Any] = {}
|
||
|
self._replace_hook = None
|
||
|
|
||
|
# TorchScript breaks trying to compile the graph setter because of the
|
||
|
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
|
||
|
#
|
||
|
# Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
|
||
|
__jit_unused_properties__ = ["graph"]
|
||
|
|
||
|
@property
|
||
|
def graph(self) -> Graph:
|
||
|
"""
|
||
|
Return the ``Graph`` underlying this ``GraphModule``
|
||
|
"""
|
||
|
return self._graph
|
||
|
|
||
|
@graph.setter
|
||
|
def graph(self, g: Graph) -> None:
|
||
|
"""
|
||
|
Set the underlying ``Graph`` for this ``GraphModule``. This will internally
|
||
|
recompile the ``GraphModule`` so that the generated ``forward()`` function
|
||
|
corresponds to ``g``
|
||
|
"""
|
||
|
assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}"
|
||
|
self._graph = g
|
||
|
g.owning_module = self
|
||
|
self.recompile()
|
||
|
|
||
|
@compatibility(is_backward_compatible=False)
|
||
|
def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
|
||
|
"""Dumps out module to ``folder`` with ``module_name`` so that it can be
|
||
|
imported with ``from <folder> import <module_name>``
|
||
|
|
||
|
Args:
|
||
|
|
||
|
folder (Union[str, os.PathLike]): The folder to write the code out to
|
||
|
|
||
|
module_name (str): Top-level name to use for the ``Module`` while
|
||
|
writing out the code
|
||
|
"""
|
||
|
folder = Path(folder)
|
||
|
Path(folder).mkdir(exist_ok=True)
|
||
|
torch.save(self.state_dict(), folder / "state_dict.pt")
|
||
|
tab = " " * 4
|
||
|
custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()])
|
||
|
model_str = f"""
|
||
|
import torch
|
||
|
{custom_builtins}
|
||
|
|
||
|
from torch.nn import *
|
||
|
class {module_name}(torch.nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
"""
|
||
|
|
||
|
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
|
||
|
safe_reprs = [
|
||
|
nn.Linear,
|
||
|
nn.Conv1d,
|
||
|
nn.Conv2d,
|
||
|
nn.Conv3d,
|
||
|
nn.BatchNorm1d,
|
||
|
nn.BatchNorm2d,
|
||
|
nn.BatchNorm3d,
|
||
|
]
|
||
|
if type(module) in safe_reprs:
|
||
|
return f"{module.__repr__()}"
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
blobified_modules = []
|
||
|
for module_name, module in self.named_children():
|
||
|
module_str = _gen_model_repr(module_name, module)
|
||
|
if module_str is None:
|
||
|
module_file = folder / f"{module_name}.pt"
|
||
|
torch.save(module, module_file)
|
||
|
blobified_modules.append(module_name)
|
||
|
module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
|
||
|
module_str = f"torch.load(r'{module_file}') # {module_repr}"
|
||
|
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
|
||
|
|
||
|
for buffer_name, buffer in self._buffers.items():
|
||
|
if buffer is None:
|
||
|
continue
|
||
|
model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
|
||
|
|
||
|
for param_name, param in self._parameters.items():
|
||
|
if param is None:
|
||
|
continue
|
||
|
model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
|
||
|
|
||
|
model_str += (
|
||
|
f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
|
||
|
)
|
||
|
model_str += f"{_addindent(self.code, 4)}\n"
|
||
|
|
||
|
module_file = folder / "module.py"
|
||
|
module_file.write_text(model_str)
|
||
|
|
||
|
init_file = folder / "__init__.py"
|
||
|
init_file.write_text("from .module import *")
|
||
|
|
||
|
if len(blobified_modules) > 0:
|
||
|
warnings.warn(
|
||
|
"Was not able to save the following children modules as reprs -"
|
||
|
f"saved as pickled files instead: {blobified_modules}"
|
||
|
)
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
|
||
|
"""
|
||
|
Adds the given submodule to ``self``.
|
||
|
|
||
|
This installs empty Modules where none exist yet if they are
|
||
|
subpaths of ``target``.
|
||
|
|
||
|
Args:
|
||
|
target: The fully-qualified string name of the new submodule
|
||
|
(See example in ``nn.Module.get_submodule`` for how to
|
||
|
specify a fully-qualified string.)
|
||
|
m: The submodule itself; the actual object we want to
|
||
|
install in the current Module
|
||
|
|
||
|
Return:
|
||
|
bool: Whether or not the submodule could be inserted. For
|
||
|
this method to return True, each object in the chain
|
||
|
denoted by ``target`` must either a) not exist yet,
|
||
|
or b) reference an ``nn.Module`` (not a parameter or
|
||
|
other attribute)
|
||
|
"""
|
||
|
*prefix, field = target.split(".")
|
||
|
mod: torch.nn.Module = self
|
||
|
|
||
|
for item in prefix:
|
||
|
|
||
|
submod = getattr(mod, item, None)
|
||
|
|
||
|
if submod is None:
|
||
|
submod = torch.nn.Module()
|
||
|
setattr(mod, item, submod)
|
||
|
|
||
|
if not isinstance(submod, torch.nn.Module):
|
||
|
return False
|
||
|
|
||
|
mod = submod
|
||
|
|
||
|
mod.add_module(field, m)
|
||
|
return True
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def delete_submodule(self, target: str) -> bool:
|
||
|
"""
|
||
|
Deletes the given submodule from ``self``.
|
||
|
|
||
|
The module will not be deleted if ``target`` is not a valid
|
||
|
target.
|
||
|
|
||
|
Args:
|
||
|
target: The fully-qualified string name of the new submodule
|
||
|
(See example in ``nn.Module.get_submodule`` for how to
|
||
|
specify a fully-qualified string.)
|
||
|
|
||
|
Returns:
|
||
|
bool: Whether or not the target string referenced a
|
||
|
submodule we want to delete. A return value of ``False``
|
||
|
means that the ``target`` was not a valid reference to
|
||
|
a submodule.
|
||
|
"""
|
||
|
atoms = target.split(".")
|
||
|
path, target_submod = atoms[:-1], atoms[-1]
|
||
|
mod: torch.nn.Module = self
|
||
|
|
||
|
# Get the parent module
|
||
|
for item in path:
|
||
|
|
||
|
if not hasattr(mod, item):
|
||
|
return False
|
||
|
|
||
|
mod = getattr(mod, item)
|
||
|
|
||
|
if not isinstance(mod, torch.nn.Module):
|
||
|
return False
|
||
|
|
||
|
if not hasattr(mod, target_submod):
|
||
|
return False
|
||
|
|
||
|
if not isinstance(getattr(mod, target_submod), torch.nn.Module):
|
||
|
return False
|
||
|
|
||
|
delattr(mod, target_submod)
|
||
|
return True
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def delete_all_unused_submodules(self) -> None:
|
||
|
"""
|
||
|
Deletes all unused submodules from ``self``.
|
||
|
|
||
|
A Module is considered "used" if any one of the following is
|
||
|
true:
|
||
|
1. It has children that are used
|
||
|
2. Its forward is called directly via a ``call_module`` node
|
||
|
3. It has a non-Module attribute that is used from a
|
||
|
``get_attr`` node
|
||
|
|
||
|
This method can be called to clean up an ``nn.Module`` without
|
||
|
manually calling ``delete_submodule`` on each unused submodule.
|
||
|
"""
|
||
|
used: List[str] = []
|
||
|
|
||
|
for node in self.graph.nodes:
|
||
|
|
||
|
if node.op == "call_module" or node.op == "get_attr":
|
||
|
|
||
|
# A list of strings representing the different parts
|
||
|
# of the path. For example, `foo.bar.baz` gives us
|
||
|
# ["foo", "bar", "baz"]
|
||
|
fullpath = node.target.split(".")
|
||
|
|
||
|
# If we're looking at multiple parts of a path, join
|
||
|
# join them with a dot. Otherwise, return that single
|
||
|
# element without doing anything to it.
|
||
|
def join_fn(x: str, y: str) -> str:
|
||
|
return ".".join([x, y] if y else [x])
|
||
|
|
||
|
# Progressively collect all the names of intermediate
|
||
|
# modules. For example, if we have the target
|
||
|
# `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
|
||
|
# `foo.bar.baz` to the list.
|
||
|
used.extend(itertools.accumulate(fullpath, join_fn))
|
||
|
|
||
|
# For a `call_module` node, also register all recursive submodules
|
||
|
# as used
|
||
|
if node.op == "call_module":
|
||
|
try:
|
||
|
submod = self.get_submodule(node.target)
|
||
|
|
||
|
for submod_name, _ in submod.named_modules():
|
||
|
if submod_name != "":
|
||
|
used.append(".".join([node.target, submod_name]))
|
||
|
except AttributeError:
|
||
|
# Node referenced nonexistent submodule, don't need to
|
||
|
# worry about GCing anything
|
||
|
pass
|
||
|
|
||
|
to_delete = [name for name, _ in self.named_modules() if name not in used]
|
||
|
|
||
|
for name in to_delete:
|
||
|
self.delete_submodule(name)
|
||
|
|
||
|
@property
|
||
|
def code(self) -> str:
|
||
|
"""
|
||
|
Return the Python code generated from the ``Graph`` underlying this
|
||
|
``GraphModule``.
|
||
|
"""
|
||
|
if not hasattr(self, "_code"):
|
||
|
raise RuntimeError(
|
||
|
"Code has not been generated! Please report a bug to PyTorch"
|
||
|
)
|
||
|
return self._code
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def recompile(self) -> PythonCode:
|
||
|
"""
|
||
|
Recompile this GraphModule from its ``graph`` attribute. This should be
|
||
|
called after editing the contained ``graph``, otherwise the generated
|
||
|
code of this ``GraphModule`` will be out of date.
|
||
|
"""
|
||
|
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||
|
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||
|
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||
|
python_code = self._graph.python_code(root_module="self")
|
||
|
self._code = python_code.src
|
||
|
self._lineno_map = python_code._lineno_map
|
||
|
|
||
|
cls = type(self)
|
||
|
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||
|
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
|
||
|
|
||
|
# Determine whether this class explicitly defines a __call__ implementation
|
||
|
# to wrap. If it does, save it in order to have wrapped_call invoke it.
|
||
|
# If it does not, wrapped_call can use a dynamic call to super() instead.
|
||
|
# In most cases, super().__call__ should be torch.nn.Module.__call__.
|
||
|
# We do not want to hold a reference to Module.__call__ here; doing so will
|
||
|
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
|
||
|
cls_call = cls.__call__ if "__call__" in vars(cls) else None
|
||
|
|
||
|
if "_wrapped_call" not in vars(cls):
|
||
|
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
|
||
|
|
||
|
def call_wrapped(self, *args, **kwargs):
|
||
|
return self._wrapped_call(self, *args, **kwargs)
|
||
|
|
||
|
cls.__call__ = call_wrapped # type: ignore[method-assign]
|
||
|
|
||
|
return python_code
|
||
|
|
||
|
# Passing Tracer as argument allows subclasses extending fx.GraphModule
|
||
|
# define their own Tracer (extending fx.Tracer).
|
||
|
def __reduce_deploy__(self, importer: Importer):
|
||
|
dict_without_graph = self.__dict__.copy()
|
||
|
dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
|
||
|
del dict_without_graph["_graph"]
|
||
|
|
||
|
python_code = self.recompile()
|
||
|
import_block = _format_import_block(python_code.globals, importer)
|
||
|
return (reduce_deploy_graph_module, (dict_without_graph, import_block))
|
||
|
|
||
|
def __reduce_package__(self, exporter: PackageExporter):
|
||
|
dict_without_graph = self.__dict__.copy()
|
||
|
dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
|
||
|
del dict_without_graph["_graph"]
|
||
|
|
||
|
generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
|
||
|
python_code = self.recompile()
|
||
|
import_block = _format_import_block(python_code.globals, exporter.importer)
|
||
|
module_code = import_block + self.code
|
||
|
exporter.save_source_string(generated_module_name, module_code)
|
||
|
return (
|
||
|
reduce_package_graph_module,
|
||
|
(dict_without_graph, generated_module_name),
|
||
|
)
|
||
|
|
||
|
def __reduce__(self):
|
||
|
"""
|
||
|
Serialization of GraphModule. We serialize only the generated code, not
|
||
|
the underlying ``Graph``. This is because ``Graph`` does not have on-disk
|
||
|
backward-compatibility guarantees, whereas Python source code does.
|
||
|
On the deserialization side, we symbolically trace through the generated
|
||
|
code to regenerate the underlying ``Graph``
|
||
|
"""
|
||
|
dict_without_graph = self.__dict__.copy()
|
||
|
|
||
|
python_code = self.recompile()
|
||
|
import_block = _format_import_block(python_code.globals, sys_importer)
|
||
|
del dict_without_graph["_graph"]
|
||
|
return (reduce_graph_module, (dict_without_graph, import_block))
|
||
|
|
||
|
def _deepcopy_init(self):
|
||
|
return GraphModule.__init__
|
||
|
|
||
|
# because __reduce__ is defined for serialization,
|
||
|
# we need to define deepcopy otherwise it will call __reduce__
|
||
|
# and cause symbolic tracing to occur every time we try to copy the object
|
||
|
def __deepcopy__(self, memo):
|
||
|
res = type(self).__new__(type(self))
|
||
|
memo[id(self)] = res
|
||
|
fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
|
||
|
self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"])
|
||
|
# hooks are lost during `GraphModule.__init__`, so we need to copy over
|
||
|
# them explicitly, note right now we are only copying state_dict related
|
||
|
# hooks, to reduce bc-related issues, we can copy forward/backward related
|
||
|
# hooks in the future as well if needed
|
||
|
extra_preserved_attrs = [
|
||
|
"_state_dict_hooks",
|
||
|
"_load_state_dict_pre_hooks",
|
||
|
"_load_state_dict_post_hooks",
|
||
|
"_replace_hook",
|
||
|
]
|
||
|
for attr in extra_preserved_attrs:
|
||
|
if attr in self.__dict__:
|
||
|
setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
|
||
|
res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
|
||
|
if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
|
||
|
for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
|
||
|
setattr(res, attr_name, attr)
|
||
|
return res
|
||
|
|
||
|
def __copy__(self):
|
||
|
from ._lazy_graph_module import _make_graph_module
|
||
|
res = _make_graph_module(self, self.graph)
|
||
|
res.meta = getattr(self, "meta", {})
|
||
|
return res
|
||
|
|
||
|
@compatibility(is_backward_compatible=False)
|
||
|
def print_readable(self, print_output=True):
|
||
|
"""
|
||
|
Return the Python code generated for current GraphModule and its children GraphModules
|
||
|
"""
|
||
|
verbose_python_code = self._graph.python_code(root_module="self", verbose=True)
|
||
|
module_code = verbose_python_code.src
|
||
|
module_code = module_code.lstrip("\n")
|
||
|
module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
|
||
|
module_code = _addindent(module_code, 4)
|
||
|
|
||
|
submodule_code_list = [""]
|
||
|
for submodule in self.children():
|
||
|
if isinstance(submodule, GraphModule):
|
||
|
submodule_code_list.append(submodule.print_readable(print_output=False))
|
||
|
submodule_code = "\n".join(submodule_code_list)
|
||
|
submodule_code = _addindent(submodule_code, 4)
|
||
|
|
||
|
output = module_code + submodule_code
|
||
|
if print_output:
|
||
|
print(module_code + submodule_code)
|
||
|
return output
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
orig_str = super().__str__()
|
||
|
print_readable_reminder = (
|
||
|
"# To see more debug info, please use `graph_module.print_readable()`"
|
||
|
)
|
||
|
return "\n".join([orig_str, self._code, print_readable_reminder])
|
||
|
|
||
|
def _replicate_for_data_parallel(self):
|
||
|
new_gm = self.__copy__()
|
||
|
new_gm._is_replica = True
|
||
|
return new_gm
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _set_replace_hook(self, f):
|
||
|
"""
|
||
|
Takes a callable which will be called everytime when we replace a node
|
||
|
to a new node, or change the node's name. Callable takes three arguments:
|
||
|
the old node we're changing, and NAME of the new node, followed by the
|
||
|
user node which consumes the old node to be replaced.
|
||
|
"""
|
||
|
assert callable(f), "Replace hook must be a callable."
|
||
|
prev, self._replace_hook = self._replace_hook, f
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
self._replace_hook = prev
|
||
|
|
||
|
|
||
|
# workarounds for issues in __torch_function__
|
||
|
|
||
|
# WAR for __torch_function__ not handling tensor lists,
|
||
|
# fix is in https://github.com/pytorch/pytorch/pull/34725
|
||
|
# orig_cat = torch.cat
|
||
|
# def patched_cat(*args, **kwargs):
|
||
|
# tensors = args[0]
|
||
|
# for t in tensors:
|
||
|
# if isinstance(t, Proxy):
|
||
|
# return t.__torch_function__(patched_cat, (), args, kwargs)
|
||
|
# return orig_cat(*args, **kwargs)
|
||
|
# patched_cat.__module__ = 'torch'
|
||
|
# patched_cat.__name__ = 'cat'
|
||
|
# torch.cat = patched_cat
|