179 lines
5.2 KiB
Python
179 lines
5.2 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""An analysis that determines the reach of a function definition.
|
|
|
|
A function definition is said to reach a statement if that function may exist
|
|
(and therefore may be called) when that statement executes.
|
|
"""
|
|
|
|
import gast
|
|
|
|
from tensorflow.python.autograph.pyct import anno
|
|
from tensorflow.python.autograph.pyct import cfg
|
|
from tensorflow.python.autograph.pyct import transformer
|
|
|
|
|
|
class Definition(object):
|
|
"""Definition objects describe a unique definition of a function."""
|
|
|
|
def __init__(self, def_node):
|
|
self.def_node = def_node
|
|
|
|
|
|
class _NodeState(object):
|
|
"""Abstraction for the state of the CFG walk for reaching definition analysis.
|
|
|
|
This is a value type. Only implements the strictly necessary operators.
|
|
|
|
Attributes:
|
|
value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
|
|
their possible definitions
|
|
"""
|
|
|
|
def __init__(self, init_from=None):
|
|
if init_from:
|
|
self.value = set(init_from)
|
|
else:
|
|
self.value = set()
|
|
|
|
def __eq__(self, other):
|
|
return self.value == other.value
|
|
|
|
def __ne__(self, other):
|
|
return self.value != other.value
|
|
|
|
def __or__(self, other):
|
|
assert isinstance(other, _NodeState)
|
|
result = _NodeState(self.value)
|
|
result.value.update(other.value)
|
|
return result
|
|
|
|
def __add__(self, value):
|
|
result = _NodeState(self.value)
|
|
result.value.add(value)
|
|
return result
|
|
|
|
def __repr__(self):
|
|
return 'NodeState[%s]=%s' % (id(self), repr(self.value))
|
|
|
|
|
|
class Analyzer(cfg.GraphVisitor):
|
|
"""CFG visitor that determines reaching definitions at statement level."""
|
|
|
|
def __init__(self, graph, external_defs):
|
|
super(Analyzer, self).__init__(graph)
|
|
# This allows communicating that nodes have extra reaching definitions,
|
|
# e.g. those that a function closes over.
|
|
self.external_defs = external_defs
|
|
|
|
def init_state(self, _):
|
|
return _NodeState()
|
|
|
|
def visit_node(self, node):
|
|
prev_defs_out = self.out[node]
|
|
|
|
if node is self.graph.entry:
|
|
defs_in = _NodeState(self.external_defs)
|
|
else:
|
|
defs_in = prev_defs_out
|
|
|
|
for n in node.prev:
|
|
defs_in |= self.out[n]
|
|
|
|
defs_out = defs_in
|
|
if isinstance(node.ast_node, (gast.Lambda, gast.FunctionDef)):
|
|
defs_out += node.ast_node
|
|
|
|
self.in_[node] = defs_in
|
|
self.out[node] = defs_out
|
|
|
|
return prev_defs_out != defs_out
|
|
|
|
|
|
class TreeAnnotator(transformer.Base):
|
|
"""AST visitor that annotates each symbol name with its reaching definitions.
|
|
|
|
Simultaneously, the visitor runs the dataflow analysis on each function node,
|
|
accounting for the effect of closures. For example:
|
|
|
|
def foo():
|
|
def f():
|
|
pass
|
|
def g():
|
|
# `def f` reaches here
|
|
"""
|
|
|
|
def __init__(self, source_info, graphs):
|
|
super(TreeAnnotator, self).__init__(source_info)
|
|
self.graphs = graphs
|
|
self.allow_skips = False
|
|
self.current_analyzer = None
|
|
|
|
def _proces_function(self, node):
|
|
parent_analyzer = self.current_analyzer
|
|
subgraph = self.graphs[node]
|
|
|
|
if (self.current_analyzer is not None
|
|
and node in self.current_analyzer.graph.index):
|
|
cfg_node = self.current_analyzer.graph.index[node]
|
|
defined_in = self.current_analyzer.in_[cfg_node].value
|
|
else:
|
|
defined_in = ()
|
|
|
|
analyzer = Analyzer(subgraph, defined_in)
|
|
analyzer.visit_forward()
|
|
|
|
self.current_analyzer = analyzer
|
|
node = self.generic_visit(node)
|
|
self.current_analyzer = parent_analyzer
|
|
return node
|
|
|
|
def visit_FunctionDef(self, node):
|
|
return self._proces_function(node)
|
|
|
|
def visit_Lambda(self, node):
|
|
return self._proces_function(node)
|
|
|
|
def visit(self, node):
|
|
# This can happen before entering the top level function
|
|
if (self.current_analyzer is not None
|
|
and node in self.current_analyzer.graph.index):
|
|
cfg_node = self.current_analyzer.graph.index[node]
|
|
anno.setanno(node, anno.Static.DEFINED_FNS_IN,
|
|
self.current_analyzer.in_[cfg_node].value)
|
|
|
|
extra_node = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
|
|
if extra_node is not None:
|
|
cfg_node = self.current_analyzer.graph.index[extra_node]
|
|
anno.setanno(extra_node, anno.Static.DEFINED_FNS_IN,
|
|
self.current_analyzer.in_[cfg_node].value)
|
|
|
|
return super(TreeAnnotator, self).visit(node)
|
|
|
|
|
|
def resolve(node, source_info, graphs):
|
|
"""Resolves reaching definitions for each symbol.
|
|
|
|
Args:
|
|
node: ast.AST
|
|
source_info: transformer.SourceInfo
|
|
graphs: Dict[ast.FunctionDef, cfg.Graph]
|
|
Returns:
|
|
ast.AST
|
|
"""
|
|
visitor = TreeAnnotator(source_info, graphs)
|
|
node = visitor.visit(node)
|
|
return node
|