221 lines
7.6 KiB
Python
221 lines
7.6 KiB
Python
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Live variable analysis.
|
||
|
|
||
|
See https://en.wikipedia.org/wiki/Live_variable_analysis for a definition of
|
||
|
the following idioms: live variable, live in, live out, which are used
|
||
|
throughout this file.
|
||
|
|
||
|
This analysis attaches the following:
|
||
|
* symbols that are live at the exit of control flow statements
|
||
|
* symbols that are live at the entry of control flow statements
|
||
|
|
||
|
Requires activity analysis.
|
||
|
"""
|
||
|
|
||
|
import gast
|
||
|
|
||
|
from tensorflow.python.autograph.pyct import anno
|
||
|
from tensorflow.python.autograph.pyct import cfg
|
||
|
from tensorflow.python.autograph.pyct import transformer
|
||
|
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||
|
|
||
|
|
||
|
class Analyzer(cfg.GraphVisitor):
|
||
|
"""CFG visitor that performs liveness analysis at statement level."""
|
||
|
|
||
|
def __init__(self, graph, include_annotations):
|
||
|
super(Analyzer, self).__init__(graph)
|
||
|
self.include_annotations = include_annotations
|
||
|
|
||
|
def init_state(self, _):
|
||
|
return set()
|
||
|
|
||
|
def lamba_check(self, fn_ast_node):
|
||
|
if isinstance(fn_ast_node, gast.Lambda):
|
||
|
# Exception: lambda functions are assumed to be used only in the
|
||
|
# place where they are defined, and not later.
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
def visit_node(self, node):
|
||
|
prev_live_in = self.in_[node]
|
||
|
|
||
|
if anno.hasanno(node.ast_node, anno.Static.SCOPE):
|
||
|
node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
|
||
|
|
||
|
gen = node_scope.read
|
||
|
if not self.include_annotations:
|
||
|
gen -= node_scope.annotations
|
||
|
# TODO(mdan): verify whether composites' parents need to be added.
|
||
|
# E.g. whether x needs to be added if x.y is live. Theoretically the
|
||
|
# activity analysis should have both so that wouldn't be needed.
|
||
|
kill = node_scope.modified | node_scope.deleted
|
||
|
|
||
|
live_out = set()
|
||
|
for n in node.next:
|
||
|
live_out |= self.in_[n]
|
||
|
live_in = gen | (live_out - kill)
|
||
|
|
||
|
reaching_functions = anno.getanno(
|
||
|
node.ast_node, anno.Static.DEFINED_FNS_IN)
|
||
|
for fn_ast_node in reaching_functions:
|
||
|
if self.lamba_check(fn_ast_node):
|
||
|
continue
|
||
|
fn_scope = anno.getanno(fn_ast_node, annos.NodeAnno.ARGS_AND_BODY_SCOPE)
|
||
|
# Any closure of a reaching function definition is conservatively
|
||
|
# considered live.
|
||
|
live_in |= (fn_scope.read - fn_scope.bound)
|
||
|
|
||
|
else:
|
||
|
assert self.can_ignore(node), (node.ast_node, node)
|
||
|
|
||
|
live_out = set()
|
||
|
for n in node.next:
|
||
|
live_out |= self.in_[n]
|
||
|
live_in = live_out
|
||
|
|
||
|
self.in_[node] = live_in
|
||
|
self.out[node] = live_out
|
||
|
|
||
|
# TODO(mdan): Move this to the superclass?
|
||
|
return prev_live_in != live_in
|
||
|
|
||
|
|
||
|
class TreeAnnotator(transformer.Base):
|
||
|
"""Runs liveness analysis on each of the functions defined in the AST.
|
||
|
|
||
|
If a function defined other local functions, those will have separate CFGs.
|
||
|
However, dataflow analysis needs to tie up these CFGs to properly emulate the
|
||
|
effect of closures. In the case of liveness, the parent function's live
|
||
|
variables must account for the variables that are live at the entry of each
|
||
|
subfunction. For example:
|
||
|
|
||
|
def foo():
|
||
|
# baz is live from here on
|
||
|
def bar():
|
||
|
print(baz)
|
||
|
|
||
|
This analyzer runs liveness analysis on each individual function, accounting
|
||
|
for the effect above.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, source_info, graphs, include_annotations):
|
||
|
super(TreeAnnotator, self).__init__(source_info)
|
||
|
self.include_annotations = include_annotations
|
||
|
self.allow_skips = False
|
||
|
self.graphs = graphs
|
||
|
self.current_analyzer = None
|
||
|
|
||
|
def visit(self, node):
|
||
|
node = super(TreeAnnotator, self).visit(node)
|
||
|
if (self.current_analyzer is not None and
|
||
|
isinstance(node, gast.stmt) and
|
||
|
node in self.current_analyzer.graph.index):
|
||
|
cfg_node = self.current_analyzer.graph.index[node]
|
||
|
anno.setanno(node, anno.Static.LIVE_VARS_IN,
|
||
|
frozenset(self.current_analyzer.in_[cfg_node]))
|
||
|
return node
|
||
|
|
||
|
def _analyze_function(self, node, is_lambda):
|
||
|
parent_analyzer = self.current_analyzer
|
||
|
|
||
|
analyzer = Analyzer(self.graphs[node], self.include_annotations)
|
||
|
analyzer.visit_reverse()
|
||
|
self.current_analyzer = analyzer
|
||
|
node = self.generic_visit(node)
|
||
|
|
||
|
self.current_analyzer = parent_analyzer
|
||
|
return node
|
||
|
|
||
|
def visit_Lambda(self, node):
|
||
|
return self._analyze_function(node, is_lambda=True)
|
||
|
|
||
|
def visit_FunctionDef(self, node):
|
||
|
return self._analyze_function(node, is_lambda=False)
|
||
|
|
||
|
def _block_statement_live_out(self, node):
|
||
|
successors = self.current_analyzer.graph.stmt_next[node]
|
||
|
stmt_live_out = set()
|
||
|
for s in successors:
|
||
|
stmt_live_out.update(self.current_analyzer.in_[s])
|
||
|
anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(stmt_live_out))
|
||
|
return node
|
||
|
|
||
|
def _block_statement_live_in(self, node, entry_node):
|
||
|
if entry_node in self.current_analyzer.graph.index:
|
||
|
cfg_node = self.current_analyzer.graph.index[entry_node]
|
||
|
stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node])
|
||
|
else:
|
||
|
assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), (
|
||
|
'If not matching a CFG node, must be a block statement:'
|
||
|
' {}'.format(entry_node))
|
||
|
stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN)
|
||
|
anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in)
|
||
|
return node
|
||
|
|
||
|
def visit_If(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
node = self._block_statement_live_out(node)
|
||
|
return self._block_statement_live_in(node, node.test)
|
||
|
|
||
|
def visit_For(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
node = self._block_statement_live_out(node)
|
||
|
return self._block_statement_live_in(node, node.iter)
|
||
|
|
||
|
def visit_While(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
node = self._block_statement_live_out(node)
|
||
|
return self._block_statement_live_in(node, node.test)
|
||
|
|
||
|
def visit_Try(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
node = self._block_statement_live_out(node)
|
||
|
return self._block_statement_live_in(node, node.body[0])
|
||
|
|
||
|
def visit_ExceptHandler(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
node = self._block_statement_live_out(node)
|
||
|
return self._block_statement_live_in(node, node.body[0])
|
||
|
|
||
|
def visit_With(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
return self._block_statement_live_in(node, node.items[0])
|
||
|
|
||
|
def visit_Expr(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
cfg_node = self.current_analyzer.graph.index[node]
|
||
|
anno.setanno(node, anno.Static.LIVE_VARS_OUT,
|
||
|
frozenset(self.current_analyzer.out[cfg_node]))
|
||
|
return node
|
||
|
|
||
|
|
||
|
# TODO(mdan): Investigate the possibility of removing include_annotations.
|
||
|
def resolve(node, source_info, graphs, include_annotations=True):
|
||
|
"""Resolves the live symbols at the exit of control flow statements.
|
||
|
|
||
|
Args:
|
||
|
node: ast.AST
|
||
|
source_info: transformer.SourceInfo
|
||
|
graphs: Dict[ast.FunctionDef, cfg.Graph]
|
||
|
include_annotations: Bool, whether type annotations should be included in
|
||
|
the analysis.
|
||
|
Returns:
|
||
|
ast.AST
|
||
|
"""
|
||
|
node = TreeAnnotator(source_info, graphs, include_annotations).visit(node)
|
||
|
return node
|