2023-06-19 00:49:18 +02:00

289 lines
8.9 KiB

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reaching definition analysis.
This analysis attaches a set of a Definition objects to each symbol, one
for each distinct definition that may reach it. The Definition objects are
mutable and may be used by subsequent analyses to further annotate data like
static type and value information.
The analysis also attaches the set of the symbols defined at the entry of
control flow statements.
Requires activity analysis.
import weakref
import gast
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import transformer
class Definition(object):
"""Definition objects describe a unique definition of a variable.
Subclasses of this may be used by passing an appropriate factory function to
param_of: Optional[ast.AST]
directives: Dict, optional definition annotations
def __init__(self):
self.param_of = None
self.directives = {}
def __repr__(self):
return '%s[%d]' % (self.__class__.__name__, id(self))
class _NodeState(object):
"""Abstraction for the state of the CFG walk for reaching definition analysis.
This is a value type. Only implements the strictly necessary operators.
value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
their possible definitions
def __init__(self, init_from=None):
if init_from:
if isinstance(init_from, _NodeState):
self.value = {
s: set(other_infos) for s, other_infos in init_from.value.items()
elif isinstance(init_from, dict):
self.value = {s: set((init_from[s],)) for s in init_from}
assert False, init_from
self.value = {}
def __eq__(self, other):
if frozenset(self.value.keys()) != frozenset(other.value.keys()):
return False
ret = all(self.value[s] == other.value[s] for s in self.value)
return ret
def __ne__(self, other):
return not self.__eq__(other)
def __or__(self, other):
assert isinstance(other, _NodeState)
result = _NodeState(self)
for s, other_infos in other.value.items():
if s in result.value:
result.value[s] = set(other_infos)
return result
def __sub__(self, other):
assert isinstance(other, set)
result = _NodeState(self)
for s in other:
result.value.pop(s, None)
return result
def __repr__(self):
return 'NodeState[%s]=%s' % (id(self), repr(self.value))
class Analyzer(cfg.GraphVisitor):
"""CFG visitor that determines reaching definitions at statement level."""
def __init__(self, graph, definition_factory):
self._definition_factory = definition_factory
super(Analyzer, self).__init__(graph)
self.gen_map = {}
def init_state(self, _):
return _NodeState()
def visit_node(self, node):
prev_defs_out = self.out[node]
defs_in = _NodeState()
for n in node.prev:
defs_in |= self.out[n]
if anno.hasanno(node.ast_node, anno.Static.SCOPE):
node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
# The definition objects created by each node must be singletons because
# their ids are used in equality checks.
if node not in self.gen_map:
node_symbols = {}
# Every binding operation (assign, nonlocal, global, etc.) counts as a
# definition, with the exception of del, which only deletes without
# creating a new variable.
newly_defined = ((node_scope.bound | node_scope.globals) -
for s in newly_defined:
def_ = self._definition_factory()
node_symbols[s] = def_
# Every param receives a definition. Params are not necessarily
# considered as "modified".
for s, p in node_scope.params.items():
def_ = self._definition_factory()
def_.param_of = weakref.ref(p)
node_symbols[s] = def_
self.gen_map[node] = _NodeState(node_symbols)
gen = self.gen_map[node]
kill = node_scope.modified | node_scope.deleted
defs_out = gen | (defs_in - kill)
gen = self.gen_map[node]
defs_out = gen | (defs_in - kill)
assert self.can_ignore(node), (node.ast_node, node)
defs_out = defs_in
self.in_[node] = defs_in
self.out[node] = defs_out
return prev_defs_out != defs_out
class TreeAnnotator(transformer.Base):
"""AST visitor that annotates each symbol name with its reaching definitions.
Simultaneously, the visitor runs the dataflow analysis on each function node,
accounting for the effect of closures. For example:
def foo():
bar = 1
def baz():
# bar = 1 reaches here
def __init__(self, source_info, graphs, definition_factory):
super(TreeAnnotator, self).__init__(source_info)
self.allow_skips = False
self.definition_factory = definition_factory
self.graphs = graphs
self.current_analyzer = None
self.current_cfg_node = None
def visit_FunctionDef(self, node):
parent_analyzer = self.current_analyzer
subgraph = self.graphs[node]
analyzer = Analyzer(subgraph, self.definition_factory)
# Recursively process any remaining subfunctions.
self.current_analyzer = analyzer
node.args = self.visit(node.args)
node.body = self.visit_block(node.body)
self.current_analyzer = parent_analyzer
return node
def visit_Name(self, node):
if self.current_analyzer is None:
# Names may appear outside function defs - for example in class
# definitions.
return node
analyzer = self.current_analyzer
cfg_node = self.current_cfg_node
assert cfg_node is not None, ('name node, %s, outside of any statement?'
qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Load):
anno.setanno(node, anno.Static.DEFINITIONS,
tuple(analyzer.in_[cfg_node].value.get(qn, ())))
anno.setanno(node, anno.Static.DEFINITIONS,
tuple(analyzer.out[cfg_node].value.get(qn, ())))
return node
def _aggregate_predecessors_defined_in(self, node):
preds = self.current_analyzer.graph.stmt_prev[node]
node_defined_in = set()
for p in preds:
node_defined_in |= set(self.current_analyzer.out[p].value.keys())
anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))
def visit_If(self, node):
return self.generic_visit(node)
def visit_For(self, node):
# Manually accounting for the shortcoming described in
# cfg.AstToCfg.visit_For.
parent = self.current_cfg_node
self.current_cfg_node = self.current_analyzer.graph.index[node.iter] = self.visit(
self.current_cfg_node = parent
node.iter = self.visit(node.iter)
node.body = self.visit_block(node.body)
node.orelse = self.visit_block(node.orelse)
return node
def visit_While(self, node):
return self.generic_visit(node)
def visit_Try(self, node):
return self.generic_visit(node)
def visit_ExceptHandler(self, node):
# TODO(mdan): Also track the exception type / name symbols.
node.body = self.visit_block(node.body)
return node
def visit(self, node):
parent = self.current_cfg_node
if (self.current_analyzer is not None and
node in self.current_analyzer.graph.index):
self.current_cfg_node = self.current_analyzer.graph.index[node]
node = super(TreeAnnotator, self).visit(node)
self.current_cfg_node = parent
return node
def resolve(node, source_info, graphs, definition_factory=Definition):
"""Resolves reaching definitions for each symbol.
node: ast.AST
source_info: transformer.SourceInfo
graphs: Dict[ast.FunctionDef, cfg.Graph]
definition_factory: Callable[[], Definition]
visitor = TreeAnnotator(source_info, graphs, definition_factory)
node = visitor.visit(node)
return node