# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Reaching definition analysis. This analysis attaches a set of a Definition objects to each symbol, one for each distinct definition that may reach it. The Definition objects are mutable and may be used by subsequent analyses to further annotate data like static type and value information. The analysis also attaches the set of the symbols defined at the entry of control flow statements. Requires activity analysis. """ import weakref import gast from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import cfg from tensorflow.python.autograph.pyct import transformer class Definition(object): """Definition objects describe a unique definition of a variable. Subclasses of this may be used by passing an appropriate factory function to resolve. Attributes: param_of: Optional[ast.AST] directives: Dict, optional definition annotations """ def __init__(self): self.param_of = None self.directives = {} def __repr__(self): return '%s[%d]' % (self.__class__.__name__, id(self)) class _NodeState(object): """Abstraction for the state of the CFG walk for reaching definition analysis. This is a value type. Only implements the strictly necessary operators. Attributes: value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and their possible definitions """ def __init__(self, init_from=None): if init_from: if isinstance(init_from, _NodeState): self.value = { s: set(other_infos) for s, other_infos in init_from.value.items() } elif isinstance(init_from, dict): self.value = {s: set((init_from[s],)) for s in init_from} else: assert False, init_from else: self.value = {} def __eq__(self, other): if frozenset(self.value.keys()) != frozenset(other.value.keys()): return False ret = all(self.value[s] == other.value[s] for s in self.value) return ret def __ne__(self, other): return not self.__eq__(other) def __or__(self, other): assert isinstance(other, _NodeState) result = _NodeState(self) for s, other_infos in other.value.items(): if s in result.value: result.value[s].update(other_infos) else: result.value[s] = set(other_infos) return result def __sub__(self, other): assert isinstance(other, set) result = _NodeState(self) for s in other: result.value.pop(s, None) return result def __repr__(self): return 'NodeState[%s]=%s' % (id(self), repr(self.value)) class Analyzer(cfg.GraphVisitor): """CFG visitor that determines reaching definitions at statement level.""" def __init__(self, graph, definition_factory): self._definition_factory = definition_factory super(Analyzer, self).__init__(graph) self.gen_map = {} def init_state(self, _): return _NodeState() def visit_node(self, node): prev_defs_out = self.out[node] defs_in = _NodeState() for n in node.prev: defs_in |= self.out[n] if anno.hasanno(node.ast_node, anno.Static.SCOPE): node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) # The definition objects created by each node must be singletons because # their ids are used in equality checks. if node not in self.gen_map: node_symbols = {} # Every binding operation (assign, nonlocal, global, etc.) counts as a # definition, with the exception of del, which only deletes without # creating a new variable. newly_defined = ((node_scope.bound | node_scope.globals) - node_scope.deleted) for s in newly_defined: def_ = self._definition_factory() node_symbols[s] = def_ # Every param receives a definition. Params are not necessarily # considered as "modified". for s, p in node_scope.params.items(): def_ = self._definition_factory() def_.param_of = weakref.ref(p) node_symbols[s] = def_ self.gen_map[node] = _NodeState(node_symbols) gen = self.gen_map[node] kill = node_scope.modified | node_scope.deleted defs_out = gen | (defs_in - kill) gen = self.gen_map[node] defs_out = gen | (defs_in - kill) else: assert self.can_ignore(node), (node.ast_node, node) defs_out = defs_in self.in_[node] = defs_in self.out[node] = defs_out return prev_defs_out != defs_out class TreeAnnotator(transformer.Base): """AST visitor that annotates each symbol name with its reaching definitions. Simultaneously, the visitor runs the dataflow analysis on each function node, accounting for the effect of closures. For example: def foo(): bar = 1 def baz(): # bar = 1 reaches here """ def __init__(self, source_info, graphs, definition_factory): super(TreeAnnotator, self).__init__(source_info) self.allow_skips = False self.definition_factory = definition_factory self.graphs = graphs self.current_analyzer = None self.current_cfg_node = None def visit_FunctionDef(self, node): parent_analyzer = self.current_analyzer subgraph = self.graphs[node] analyzer = Analyzer(subgraph, self.definition_factory) analyzer.visit_forward() # Recursively process any remaining subfunctions. self.current_analyzer = analyzer node.args = self.visit(node.args) node.body = self.visit_block(node.body) self.current_analyzer = parent_analyzer return node def visit_Name(self, node): if self.current_analyzer is None: # Names may appear outside function defs - for example in class # definitions. return node analyzer = self.current_analyzer cfg_node = self.current_cfg_node assert cfg_node is not None, ('name node, %s, outside of any statement?' % node.id) qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Load): anno.setanno(node, anno.Static.DEFINITIONS, tuple(analyzer.in_[cfg_node].value.get(qn, ()))) else: anno.setanno(node, anno.Static.DEFINITIONS, tuple(analyzer.out[cfg_node].value.get(qn, ()))) return node def _aggregate_predecessors_defined_in(self, node): preds = self.current_analyzer.graph.stmt_prev[node] node_defined_in = set() for p in preds: node_defined_in |= set(self.current_analyzer.out[p].value.keys()) anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in)) def visit_If(self, node): self._aggregate_predecessors_defined_in(node) return self.generic_visit(node) def visit_For(self, node): self._aggregate_predecessors_defined_in(node) # Manually accounting for the shortcoming described in # cfg.AstToCfg.visit_For. parent = self.current_cfg_node self.current_cfg_node = self.current_analyzer.graph.index[node.iter] node.target = self.visit(node.target) self.current_cfg_node = parent node.iter = self.visit(node.iter) node.body = self.visit_block(node.body) node.orelse = self.visit_block(node.orelse) return node def visit_While(self, node): self._aggregate_predecessors_defined_in(node) return self.generic_visit(node) def visit_Try(self, node): self._aggregate_predecessors_defined_in(node) return self.generic_visit(node) def visit_ExceptHandler(self, node): self._aggregate_predecessors_defined_in(node) # TODO(mdan): Also track the exception type / name symbols. node.body = self.visit_block(node.body) return node def visit(self, node): parent = self.current_cfg_node if (self.current_analyzer is not None and node in self.current_analyzer.graph.index): self.current_cfg_node = self.current_analyzer.graph.index[node] node = super(TreeAnnotator, self).visit(node) self.current_cfg_node = parent return node def resolve(node, source_info, graphs, definition_factory=Definition): """Resolves reaching definitions for each symbol. Args: node: ast.AST source_info: transformer.SourceInfo graphs: Dict[ast.FunctionDef, cfg.Graph] definition_factory: Callable[[], Definition] Returns: ast.AST """ visitor = TreeAnnotator(source_info, graphs, definition_factory) node = visitor.visit(node) return node