676 lines
32 KiB
Python
676 lines
32 KiB
Python
import torch
|
|
from torch.fx import Node
|
|
from torch.fx._compatibility import compatibility
|
|
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
|
|
from torch.utils._pytree import tree_map_only
|
|
from torch.utils import _pytree as pytree
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
|
|
import _operator
|
|
from enum import Enum
|
|
import itertools
|
|
from typing import Set, Dict
|
|
from collections import defaultdict
|
|
|
|
__all__ = ['reinplace']
|
|
|
|
class _ViewType(Enum):
|
|
NonView = 0
|
|
SingleOutputView = 1
|
|
MultiOutputView = 2
|
|
|
|
def _is_view_op(tgt):
|
|
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
|
|
schema = tgt._schema
|
|
if len(schema.arguments) > 0:
|
|
first_arg = schema.arguments[0]
|
|
# check if op is a view
|
|
return first_arg.alias_info is not None and not first_arg.alias_info.is_write
|
|
|
|
def _get_view_type(tgt) -> _ViewType:
|
|
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
|
|
schema = tgt._schema
|
|
if len(schema.arguments) > 0:
|
|
first_arg = schema.arguments[0]
|
|
# check if op is a view
|
|
if first_arg.alias_info is not None and not first_arg.alias_info.is_write:
|
|
# check if op is a multi-output view
|
|
if '*' in first_arg.alias_info.after_set:
|
|
return _ViewType.MultiOutputView
|
|
else:
|
|
return _ViewType.SingleOutputView
|
|
return _ViewType.NonView
|
|
|
|
|
|
# Stores a bunch of metadata related to functionalization each node.
|
|
# Relevant metadata:
|
|
# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors)
|
|
# The fake tensor output from running the current node
|
|
# n.meta['view_of']: Node
|
|
# If the current node n is a view of some base tensor, the 'view_of' field tells us which
|
|
# view node was used to generate the current node (a view tensor).
|
|
# This information actually makes `fake_result` redundant, but we can use `fake_result`
|
|
# to sanity check that our aliasing information is correct.
|
|
@compatibility(is_backward_compatible=False)
|
|
class _FunctionalizationMetadataProp(torch.fx.Interpreter):
|
|
|
|
def run_node(self, node: Node):
|
|
self.node_counter += 1
|
|
result = super().run_node(node)
|
|
node.meta['fake_result'] = result
|
|
node.meta['node_idx'] = self.node_counter
|
|
|
|
# (1) Update metadata with the list of nodes that are used by this node
|
|
# copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
|
|
# We don't want to treat it as "being used as an input".
|
|
node_args = node.args
|
|
if node.target is torch.ops.aten.copy_.default:
|
|
node_args = node_args[1:]
|
|
|
|
# (2) Update metadata to track aliasing information about view tensor nodes.
|
|
if node.op == 'call_function':
|
|
view_type = _get_view_type(node.target)
|
|
if view_type == _ViewType.SingleOutputView:
|
|
assert isinstance(node.args[0], Node)
|
|
node.meta['view_of'] = node.args[0]
|
|
elif view_type == _ViewType.MultiOutputView:
|
|
self.multi_output_view_nodes[node] = node.args[0]
|
|
|
|
# Check if we returned a multi-output view,
|
|
# and we're now grabbing the individual views from the output.
|
|
#
|
|
# For multi-output views, we want to map each output view to the base,
|
|
# but this mapping involves two separate nodes in FX IR.
|
|
# e.g. "a, b = x_1.split(...)" becomes:
|
|
# %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
|
|
# %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
|
|
# %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
|
|
# And we'd like to set:
|
|
# getitem1.meta['view_of'] = x_1
|
|
elif node.target is _operator.getitem:
|
|
list_arg = node.args[0]
|
|
maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None)
|
|
if maybe_base_of_view is not None:
|
|
# Note: we could also track indexing info here for multi-output views.
|
|
# I don't think this metadata is strictly needed for de-functionalization.
|
|
assert isinstance(maybe_base_of_view, Node)
|
|
node.meta['view_of'] = maybe_base_of_view
|
|
|
|
if 'view_of' in node.meta:
|
|
# We're linking the current node with its first argument as views.
|
|
# Assert here that this is actually the case, and their storages are the same.
|
|
assert isinstance(node.meta['fake_result'], FakeTensor)
|
|
assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
|
|
view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
|
|
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage())
|
|
assert view_storage == base_storage
|
|
return result
|
|
|
|
|
|
|
|
def propagate(self, *args):
|
|
self.multi_output_view_nodes = {}
|
|
self.node_counter = -1
|
|
|
|
with FakeTensorMode() as mode:
|
|
fake_args = [mode.from_tensor(a) for a in args]
|
|
return super().run(*fake_args)
|
|
|
|
def _schemas_match(functional_schema, inplace_schema):
|
|
names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name
|
|
arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all(
|
|
a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments))
|
|
# for the inplace op, its first argument should be mutable
|
|
assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write
|
|
# and its remaining arguments shouldn't be.
|
|
assert all(a.alias_info is None for a in inplace_schema.arguments[1:])
|
|
return names_match and arg_types_match
|
|
|
|
# TODO: this should be beefed up to be able to properly re-inplace with:
|
|
# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper)
|
|
# - out= ops (e.g. angle -> angle.out)
|
|
# TODO: we should also figure this info out using torchgen.
|
|
def _maybe_get_inplace_op(op):
|
|
# __module__ seems broken; it returns torch._ops.aten which doesn't exist
|
|
if not isinstance(op, torch._ops.OpOverload):
|
|
return None
|
|
# Some view ops have inplace variants (as_strided_, etc),
|
|
# but we do NOT want the reinplacing pass to directly add these into the program.
|
|
# (they'll require extra special handling, aren't aren't really useful for perf anyway)
|
|
if _is_view_op(op):
|
|
return None
|
|
op_namespace = op.__module__.split(".")[-1]
|
|
op_base_name = op.overloadpacket.__name__
|
|
maybe_namespace_module = getattr(torch.ops, op_namespace)
|
|
maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None)
|
|
if maybe_inplace_op is None:
|
|
return None
|
|
|
|
inplace_overloads = [
|
|
getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads()
|
|
]
|
|
inplace_overloads_with_matching_schemas = [
|
|
f
|
|
for f in inplace_overloads
|
|
if _schemas_match(op._schema, f._schema)
|
|
]
|
|
# Just because foo() and foo_() are both existing operators,
|
|
# They aren't guaranteed to have compatible schemas.
|
|
# For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant,
|
|
# Even though several overloads of pow_ exist.
|
|
if len(inplace_overloads_with_matching_schemas) == 0:
|
|
return None
|
|
assert len(inplace_overloads_with_matching_schemas) == 1
|
|
inplace_op = inplace_overloads_with_matching_schemas[0]
|
|
return inplace_op
|
|
|
|
_VIEW_INVERSE_MAP = {
|
|
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
|
|
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
|
|
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
|
|
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
|
|
}
|
|
|
|
# This function, given a set of set of (aliased) tensor nodes,
|
|
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
|
|
# in the node ordering.
|
|
def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
|
|
def _add_if_tensor(x, set_):
|
|
if isinstance(x, FakeTensor):
|
|
set_.add(StorageWeakRef(x._typed_storage()))
|
|
|
|
nodes_used_after = set()
|
|
for t in tensor_aliases:
|
|
# get all nodes that use the current alias
|
|
usage_nodes = t.users
|
|
for n in usage_nodes:
|
|
# We only care about usages after the current node
|
|
if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index:
|
|
continue
|
|
# We also don't care about intermediate view ops.
|
|
# They only matter if their output is then used elsewhere
|
|
# (either in an out-of-place op, or as an output to the function).
|
|
if n in tensor_aliases:
|
|
if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem:
|
|
continue
|
|
nodes_used_after.add(n)
|
|
return nodes_used_after
|
|
|
|
# Given an op that we're trying to re-inplace, "b = foo(a)",
|
|
# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)"
|
|
# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF:
|
|
# If there are any aliases in the alias_set(a) that satisfy:
|
|
# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base"
|
|
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
|
|
# as "alias"
|
|
def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]:
|
|
def matching_view_metadata(a, b):
|
|
return a.size() == b.size() and \
|
|
a.stride() == b.stride() and \
|
|
a.storage_offset() == b.storage_offset()
|
|
|
|
view_inverse_nodes = set()
|
|
# Go through them in node order, so we can see chains of view_scatter ops.
|
|
for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']):
|
|
if n.target not in _VIEW_INVERSE_MAP:
|
|
continue
|
|
base = n.args[0]
|
|
mutated_view = n.args[1]
|
|
assert isinstance(base, Node)
|
|
assert isinstance(base.meta['fake_result'], FakeTensor)
|
|
assert isinstance(mutated_view, Node)
|
|
assert isinstance(mutated_view.meta['fake_result'], FakeTensor)
|
|
# Check that this view_inverse op actually corresponds to taking doing the inverse
|
|
# of one of our existing self_alias nodes.
|
|
original_view = _VIEW_INVERSE_MAP[n.target]
|
|
for self_alias in self_aliases:
|
|
# We're looking for some alias of the self arg, "alias",
|
|
# that was created from some op `alias = foo(base, args...)`
|
|
# such that the current _scatter op "inverts" that foo call.
|
|
# We can check that by running the original op again, and checking that the strides match.
|
|
if 'view_of' not in self_alias.meta:
|
|
continue
|
|
self_alias_base = self_alias.meta['view_of']
|
|
try:
|
|
# The we're trying to re-use the args from the view_scatter call inside of the corresponding
|
|
# view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
|
|
# of the current alias we're looking at.
|
|
view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs)
|
|
expected_metadata = self_alias.meta['fake_result']
|
|
# If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace.
|
|
if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \
|
|
matching_view_metadata(view_replay_metadata, expected_metadata):
|
|
view_inverse_nodes.add(n)
|
|
except Exception:
|
|
continue
|
|
|
|
return view_inverse_nodes
|
|
|
|
|
|
@compatibility(is_backward_compatible=True)
|
|
def reinplace(gm, *sample_args):
|
|
"""
|
|
Given an fx.GraphModule, modifies it to perform "reinplacing",
|
|
mutating the nodes of the graph.
|
|
We look for out-of-place op call sites like `b = a.add(...)`,
|
|
and convert them to be inplace (`b = a.add_(...)`),
|
|
as long as the input to the current operator ("a") isn't re-used
|
|
anywhere later in the graph.
|
|
|
|
This pass currently expects to operate on a **functional, ATen** graph.
|
|
This can be obtained by running `make_fx(functionalize(f))`.
|
|
|
|
Sample inputs are needed to determine aliasing relationships of the inputs.
|
|
In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the
|
|
inputs to the program.
|
|
|
|
Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows:
|
|
|
|
(1) Perform some initial checks on the metadata of "a" and "args..."
|
|
that can disqualify them from being reinplaced.
|
|
|
|
(1a) Check that the self argument we're attempting to reinplace
|
|
has acceptable dtype/size metadata to reinplace with.
|
|
|
|
For example, if we have:
|
|
a = torch.ones(1)
|
|
b = torch.ones(10)
|
|
out = torch.add(a, b)
|
|
We can't turn that into
|
|
a.add_(b)
|
|
Because that would require resizing "a".
|
|
|
|
Similarly, we can't convert torch.ge(a, b) into a.ge_(b),
|
|
because that would require changing a's dtype (from e.g. float32 to bool).
|
|
Note that in this specific example, we could technically do better..
|
|
|
|
If we see the pattern:
|
|
a_1 = a.ge(b)
|
|
a_2 = aten._to_copy(a_1, a.dtype)
|
|
Then we this should be valid to completely re-inplace
|
|
(this is exactly what functionalization will emit when it sees a.ge_(b)).
|
|
|
|
This optimization is only really important for user programs
|
|
that directly use inplace comparison ops though.
|
|
|
|
We also cannot re-inplace on tensors that have overlapping memory,
|
|
e.g. torch.ones(1).expand(4, 4).add_(1)
|
|
|
|
(1b) Check if "a" is an alias of any of the program inputs.
|
|
|
|
If it is, skip and move to the next node.
|
|
Inplace'ing an op that would cause it to mutate a program is not sound,
|
|
because that would be a side effect visible to the user.
|
|
|
|
NOTE: there's a future optimization that we should make:
|
|
if "a" is a (alias of a) program input, but later in the program
|
|
there is a node that looks like "a.copy_(...)",
|
|
Then re-inplacing is ok to do - we are temporarily re-using a's buffer,
|
|
which will later be overwritten by the copy_() call.
|
|
|
|
This will be an important optimization to have for programs that mutate
|
|
their inputs. It currently isn't implemented though.
|
|
|
|
(1c) Check if "a" and "args..." alias
|
|
|
|
For example, re-inplacing to create code like the below
|
|
isn't guaranteed to be sound:
|
|
|
|
aten.mul_(a, a)
|
|
|
|
(2) Check that "a" and all of its outstanding aliases are not used anywhere
|
|
later in the graph. If this is the case, then it's safe to re-inplace
|
|
to "b = foo_(a)".
|
|
|
|
There are a few caveats to this, explained in more detail below:
|
|
(a) If "a" is used later as an argument to a view op, that is okay.
|
|
It's only a problem if "a" (or that view) is later passed
|
|
into a normal operator, or if it is returned as the program output.
|
|
(b) If "a" is a repeat argument in `foo()`, then don't reinplace.
|
|
Most ATen kernels don't make any guarantees that this is sound,
|
|
e.g. if you do aten.mul_(a, a).
|
|
So we'll just ban re-inplacing in this case.
|
|
It's only a problem if "a" (or that view) is later passed
|
|
(c) If "a" is used as an input into a view "inverse" / "scatter"
|
|
operator, it is potentially fine to re-inplace
|
|
(and remove that scatter operator from the graph).
|
|
See below for a more detailed example.
|
|
|
|
NOTE: there is an optimization in this step that is crucial
|
|
to fully recovering performance from functionalization.
|
|
|
|
Given this program:
|
|
def f(x):
|
|
a = torch.ops.aten.add(x, x)
|
|
b = torch.ops.aten.diagonal(a)
|
|
torch.ops.aten.fill_(b, 0)
|
|
return d
|
|
|
|
Functionalization will emit the following:
|
|
def f(x):
|
|
a = torch.ops.aten.add(x, x)
|
|
b = torch.ops.aten.diagonal(a, 0, 1)
|
|
b_updated = torch.ops.aten.fill(b, 0)
|
|
a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1)
|
|
return a_updated
|
|
|
|
Ordinarily, we would not be able to reinplace the fill,
|
|
because "b" aliases with "a" which is used by the diagonal_scatter call.
|
|
|
|
"re-inplacing" is on the hook for figuring out that it is ok to
|
|
completely, the expensive diagonal_scatter call, if we re-inplace the add().
|
|
|
|
So, for every `alias in alias_set(a)`, instead of checking
|
|
that "alias" is not used anywhere later in the graph,
|
|
we check that
|
|
EITHER:
|
|
(a) alias is not used anywhere later in the graph
|
|
OR:
|
|
(b) alias is used exactly once later on in the graph,
|
|
in the following op:
|
|
|
|
out = foo_scatter(alias, x, args...)
|
|
|
|
where the following must hold:
|
|
(i) "foo_scatter" is the "inverse" operator for foo.
|
|
This only applies to "foo" ops that are view operators,
|
|
which view into a subset of the original tensor's memory.
|
|
In practice, there are ~4 operators where this applies:
|
|
diagonal -> diagonal_scatter
|
|
slice -> slice_scatter
|
|
select -> select_scatter
|
|
as_strided -> as_strided_scatter
|
|
(ii) "args..." are the same between the foo() and foo_scatter() calls.
|
|
|
|
(3) Perform the actual re-inplacing on foo!
|
|
|
|
(3b) is the common case, but special care is needed for {view}_scatter (3a)
|
|
|
|
(3a) {view}_scatter ops.
|
|
|
|
Consider this program:
|
|
a = torch.zeros(2, 2)
|
|
b = torch.ones(2)
|
|
a[0] = b
|
|
|
|
Post functionalization, that will look like:
|
|
a = torch.zeros(2)
|
|
b = torch.ones(1)
|
|
a_updated = torch.select_scatter(a, b, 0, 0)
|
|
|
|
In this case though, there is no "functional" op to re-inplace!
|
|
Instead, we'd like to directly remove toe select_scatter call.
|
|
We already know from (3) that this is valid,
|
|
because "a" has no later usages in the graph.
|
|
|
|
We perform the re-inplacing on the {view}_scatter op like so
|
|
Before:
|
|
a_updated = torch.select_scatter(a, b, args...)
|
|
After:
|
|
a_slice = a.select(a, args...)
|
|
a_slice.copy_(b)
|
|
|
|
(3b) Otherwise, replace the functional op with its inplace variant.
|
|
Before:
|
|
b = foo(a, args...)
|
|
After:
|
|
a.foo_(args...)
|
|
|
|
(4) Finally, after converting either:
|
|
Before:
|
|
b = foo(a)
|
|
After:
|
|
foo_(a)
|
|
or
|
|
Before:
|
|
b = {slice}_scatter(a, mutated_slice, args...)
|
|
After:
|
|
slice = {slice}(a, args...)
|
|
slice.copy_(mutated_slice)
|
|
|
|
We now need to find all later nodes that use "b" as an argument
|
|
and update them to take in "a" instead.
|
|
|
|
Note that for the majority of inplace ops, this isn't actually necessary
|
|
(because most inplace ops return "self" as their output).
|
|
This isn't generally true for all mutable ops though, which is why
|
|
we need to actually replace all of the arguments.
|
|
|
|
We also need to update our metadata of Dict[StorageWeakRef, Set[Node]],
|
|
That maps a given tensor storage to the set of all nodes that take in that storage
|
|
as an input.
|
|
Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused
|
|
together.
|
|
|
|
(5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them"
|
|
during step (3) get manually deleted from the graph.
|
|
Their outputs are no longer used, so technically standard DCE would be able
|
|
to do this, but we can no longer run FX's DCE pass now that we have mutable
|
|
ops in the graph.
|
|
"""
|
|
_FunctionalizationMetadataProp(gm).propagate(*sample_args)
|
|
|
|
# Useful debug printing
|
|
# def _print(x):
|
|
# if isinstance(x, FakeTensor):
|
|
# print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}')
|
|
|
|
# for n in gm.graph.nodes:
|
|
# print(n.format_node())
|
|
# if hasattr(n, 'meta'):
|
|
# print(f'node_idx: {n.meta["node_idx"]}')
|
|
# if 'fake_result' in n.meta:
|
|
# tree_map(_print, n.meta['fake_result'])
|
|
# if 'view_of' in n.meta:
|
|
# print(f'view_of: {str(n.meta["view_of"])}')
|
|
# print()
|
|
|
|
# We need to know which nodes correspond to inputs (or their aliases)
|
|
# so we know not to re-inplace them.
|
|
# NOTE: later, we'll need to add an optimization for fully recovering performance
|
|
# on programs that mutate inputs.
|
|
input_storages = {
|
|
StorageWeakRef(
|
|
node.meta['fake_result']._typed_storage()
|
|
) for node in gm.graph.nodes if node.op == 'placeholder'}
|
|
|
|
|
|
# We also need to know for a given node, what are all of its aliasing nodes.
|
|
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
|
|
for n in gm.graph.nodes:
|
|
if 'fake_result' in n.meta:
|
|
# Tree-mapping because some ops can return lists of tensors.
|
|
def _add_to_map(x):
|
|
if isinstance(x, FakeTensor):
|
|
storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
|
|
pytree.tree_map_(_add_to_map, n.meta['fake_result'])
|
|
|
|
# inplace-ify functional ops, subject to the constraints written below.
|
|
all_later_view_inverse_nodes_to_delete = set()
|
|
for idx, node in enumerate(gm.graph.nodes):
|
|
if node.op == 'call_function':
|
|
|
|
# Today, the re-inplace pass on directly acts on:
|
|
# - functional ops with an inplace variant
|
|
# - {view}_scatter ops that can be potentially removed from the graph.
|
|
# Both of these ops take in tensor first args, so filtering on this condition
|
|
# makes the later code simpler.
|
|
# We should revisit this at some point though, particularly when we also want
|
|
# the reinplacer to be able to handle out= and mutable operators
|
|
# and tensorlist first args (like `_foreach_` ops).
|
|
if not isinstance(node.target, torch._ops.OpOverload):
|
|
continue
|
|
if len(node.target._schema.arguments) < 1:
|
|
continue
|
|
if type(node.target._schema.arguments[0].type) != torch.TensorType:
|
|
continue
|
|
|
|
# Step 1a: Check that the self argument we're attempting to reinplace
|
|
# has the same size/stride as the output.
|
|
# For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor)
|
|
# As it would require resizing scalar_tensor.
|
|
# (We could potentially swizzle this into larger_tensor.add_(scalar_tensor),
|
|
# this is probably an optimization to revisit later).
|
|
self_arg = node.args[0]
|
|
self_flattened = pytree.tree_leaves(self_arg.meta['fake_result'])
|
|
node_flattened = pytree.tree_leaves(node.meta['fake_result'])
|
|
self_has_wrong_metadata = False
|
|
if len(self_flattened) == len(node_flattened):
|
|
for self_meta, node_meta in zip(self_flattened, node_flattened):
|
|
if self_meta.numel() != node_meta.numel():
|
|
self_has_wrong_metadata = True
|
|
if self_meta.dtype != node_meta.dtype:
|
|
self_has_wrong_metadata = True
|
|
# We also cannot re-inplace on tensors that have internal memory overlap.
|
|
# e.g. torch.ones(1).expand(4, 4).add_(1)
|
|
if torch._debug_has_internal_overlap(self_meta) == 1:
|
|
self_has_wrong_metadata = True
|
|
# Here, we (optimistically) assume that a.resize(b) is valid to re-inplace,
|
|
# Since users should never really be calling the functional "torch.ops.aten.resize"
|
|
# op directly in their programs.
|
|
if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default:
|
|
continue
|
|
|
|
# Step 1b: ensure that the op we're trying to re-inplace isn't a program input
|
|
self_arg_name = self_arg.name
|
|
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
|
|
if self_arg_storage in input_storages:
|
|
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
|
|
continue
|
|
if len([x for x in node.args if x is self_arg]) > 1:
|
|
# Step 1c:
|
|
# Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound,
|
|
# so we prevent re-inplacing in this case.
|
|
continue
|
|
|
|
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
|
|
self_aliases = storage_to_nodes[self_arg_storage]
|
|
|
|
# First, we find all later usages of any of the aliases of self_arg.
|
|
later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx'])
|
|
# Then, we check if any of those later usages are actually view_scatter ops
|
|
# that are safe to fully remove.
|
|
later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases)
|
|
|
|
# Step 2: Check to see if the input to the op is re-used later in the graph.
|
|
# If not (same goes for its aliases), then this op is safe to re-in place.
|
|
# This is a slightly roundabout way to check that there are no later usages of the current self argument.
|
|
# (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete)
|
|
can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0
|
|
if not can_reinplace:
|
|
continue
|
|
|
|
# Step 3a: Special handling for when we see *_scatter operators.
|
|
# When we see an operator like `b = torch.slice_scatter(a, ...)`,
|
|
# instead of trying to "inplace" it into a.slice_scatter_(..._),
|
|
# we would prefer to remove it from the graph entirely,
|
|
# and instead copy_() the slice directly into the larger tensor.
|
|
# See the description of the algorithm for a full example.
|
|
if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete:
|
|
view_op = _VIEW_INVERSE_MAP[node.target]
|
|
# Before:
|
|
# base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...)
|
|
# After:
|
|
# slice = torch.ops.aten.slice.default(base, args...)
|
|
# slice.copy_(mutated_slice)
|
|
with gm.graph.inserting_before(node):
|
|
mutated_slice_node = node.args[1]
|
|
remaining_slice_args = node.args[2:]
|
|
slice_node = gm.graph.create_node(
|
|
'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs)
|
|
copy_node = gm.graph.create_node(
|
|
'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {})
|
|
# Add the slice_scatter node to our "nodes to delete" list.
|
|
all_later_view_inverse_nodes_to_delete.add(node)
|
|
|
|
|
|
else:
|
|
# Step 3b: Check to see if this operator has an inplace variant.
|
|
maybe_inplace_op = _maybe_get_inplace_op(node.target)
|
|
if maybe_inplace_op is None:
|
|
continue
|
|
# And if so, replace it with its inplace variant.
|
|
node.target = maybe_inplace_op
|
|
|
|
# At this point, 'storage_to_nodes' will be stale.
|
|
# Now that we're inplacing `b = foo(a)`, we need to effectively
|
|
# union together the dict values for b and a's storage.
|
|
# Hmm... morally I think we also want to keep the `fake_result` metadata
|
|
# up to date here, but I'm not sure how easy it is to do.
|
|
# Maybe it's fine to wait until the end of the pass to update it.
|
|
curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
|
|
storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
|
|
storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
|
|
|
|
# Need to remember the view_scatter view nodes we found so we can remove them alter.
|
|
all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages)
|
|
|
|
# Step 4:
|
|
# Now that we've replaced b = a.foo() with a.foo_(),
|
|
# We need to replace any later usages of "b" with "a"
|
|
for old in itertools.chain([node], later_view_inverse_node_usages):
|
|
new = old.args[0]
|
|
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
|
|
for node_to_update in nodes_to_update:
|
|
new_args = []
|
|
args = node_to_update.args
|
|
|
|
def replace_arg(a):
|
|
if a == old:
|
|
return new
|
|
return a
|
|
|
|
# First, replace usages of "b" with "a"
|
|
node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args)
|
|
node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs)
|
|
|
|
# Second, update our storage_to_nodes data structure.
|
|
old_flattened_res = pytree.tree_leaves(old.meta['fake_result'])
|
|
node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result'])
|
|
|
|
old_res_storage = {
|
|
StorageWeakRef(
|
|
x._typed_storage()
|
|
) for x in old_flattened_res if isinstance(x, FakeTensor)}
|
|
node_res_storage = {
|
|
StorageWeakRef(
|
|
x._typed_storage()
|
|
) for x in node_flattened_res if isinstance(x, FakeTensor)}
|
|
|
|
# This will happen if we're updating a view op, e.g.
|
|
# e.g. replacing
|
|
# x = view(old)
|
|
# x = view(new)
|
|
# When that happens, we need to make sure to keep our
|
|
# storage mapping up to date.
|
|
#
|
|
# We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor,
|
|
# or multiple tensors that all share the same storage.
|
|
# We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
|
|
if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
|
|
new_flattened_res = pytree.tree_leaves(new.meta['fake_result'])
|
|
new_res_storage = {
|
|
StorageWeakRef(
|
|
x._typed_storage()
|
|
) for x in new_flattened_res if isinstance(x, FakeTensor)}
|
|
assert len(new_res_storage) == 1
|
|
(old_ref,) = old_res_storage
|
|
(new_ref,) = new_res_storage
|
|
(node_ref,) = node_res_storage
|
|
# Technically, "old_ref" and all its aliases will remain
|
|
# in our mapping.
|
|
# That should be fine though, since we deleted "old"
|
|
# from the graph at this point.
|
|
storage_to_nodes[node_ref].update(storage_to_nodes[new_ref])
|
|
storage_to_nodes[new_ref].update(storage_to_nodes[node_ref])
|
|
|
|
# Step 4: delete any _scatter nodes that we de-functionalized
|
|
# Need to take care not to delete any of these nodes until after *all* modifications
|
|
# to the graph are finished.
|
|
for to_delete in all_later_view_inverse_nodes_to_delete:
|
|
gm.graph.erase_node(to_delete)
|
|
|
|
|
|
gm.recompile()
|
|
return gm
|