345 lines
10 KiB
Python
345 lines
10 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""AST manipulation utilities."""
|
|
|
|
import ast
|
|
|
|
import gast
|
|
|
|
from tensorflow.python.autograph.pyct import anno
|
|
from tensorflow.python.autograph.pyct import parser
|
|
from tensorflow.python.autograph.pyct import qual_names
|
|
|
|
|
|
class CleanCopier(object):
|
|
"""NodeTransformer-like visitor that copies an AST."""
|
|
|
|
def __init__(self, preserve_annos):
|
|
super(CleanCopier, self).__init__()
|
|
self.preserve_annos = preserve_annos
|
|
|
|
def copy(self, node):
|
|
"""Returns a deep copy of node (excluding some fields, see copy_clean)."""
|
|
|
|
if isinstance(node, list):
|
|
return [self.copy(n) for n in node]
|
|
elif isinstance(node, tuple):
|
|
return tuple(self.copy(n) for n in node)
|
|
elif not isinstance(node, (gast.AST, ast.AST)):
|
|
# Assuming everything that's not an AST, list or tuple is a value type
|
|
# and may simply be assigned.
|
|
return node
|
|
|
|
assert isinstance(node, (gast.AST, ast.AST))
|
|
|
|
new_fields = {}
|
|
for f in node._fields:
|
|
if not f.startswith('__') and hasattr(node, f):
|
|
new_fields[f] = self.copy(getattr(node, f))
|
|
new_node = type(node)(**new_fields)
|
|
|
|
if self.preserve_annos:
|
|
for k in self.preserve_annos:
|
|
anno.copyanno(node, new_node, k)
|
|
return new_node
|
|
|
|
|
|
def copy_clean(node, preserve_annos=None):
|
|
"""Creates a deep copy of an AST.
|
|
|
|
The copy will not include fields that are prefixed by '__', with the
|
|
exception of user-specified annotations.
|
|
|
|
Args:
|
|
node: ast.AST
|
|
preserve_annos: Optional[Set[Hashable]], annotation keys to include in the
|
|
copy
|
|
Returns:
|
|
ast.AST
|
|
"""
|
|
return CleanCopier(preserve_annos).copy(node)
|
|
|
|
|
|
class SymbolRenamer(gast.NodeTransformer):
|
|
"""Transformer that can rename symbols to a simple names."""
|
|
|
|
def __init__(self, name_map):
|
|
self.name_map = name_map
|
|
|
|
def _process_name_node(self, node):
|
|
qn = anno.getanno(node, anno.Basic.QN)
|
|
if qn in self.name_map:
|
|
new_node = gast.Name(
|
|
str(self.name_map[qn]),
|
|
ctx=node.ctx,
|
|
annotation=None,
|
|
type_comment=None)
|
|
# All annotations get carried over.
|
|
for k in anno.keys(node):
|
|
anno.copyanno(node, new_node, k)
|
|
return new_node
|
|
return self.generic_visit(node)
|
|
|
|
def _process_list_of_strings(self, names):
|
|
for i in range(len(names)):
|
|
qn = qual_names.QN(names[i])
|
|
if qn in self.name_map:
|
|
names[i] = str(self.name_map[qn])
|
|
return names
|
|
|
|
def visit_Nonlocal(self, node):
|
|
node.names = self._process_list_of_strings(node.names)
|
|
return node
|
|
|
|
def visit_Global(self, node):
|
|
node.names = self._process_list_of_strings(node.names)
|
|
return node
|
|
|
|
def visit_Name(self, node):
|
|
return self._process_name_node(node)
|
|
|
|
def visit_Attribute(self, node):
|
|
if anno.hasanno(node, anno.Basic.QN):
|
|
return self._process_name_node(node)
|
|
# Renaming attributes is not supported.
|
|
return self.generic_visit(node)
|
|
|
|
def visit_FunctionDef(self, node):
|
|
qn = qual_names.QN(node.name)
|
|
if qn in self.name_map:
|
|
node.name = str(self.name_map[qn])
|
|
return self.generic_visit(node)
|
|
|
|
|
|
def rename_symbols(node, name_map):
|
|
"""Renames symbols in an AST. Requires qual_names annotations."""
|
|
renamer = SymbolRenamer(name_map)
|
|
if isinstance(node, list):
|
|
return [renamer.visit(n) for n in node]
|
|
elif isinstance(node, tuple):
|
|
return tuple(renamer.visit(n) for n in node)
|
|
return renamer.visit(node)
|
|
|
|
|
|
def keywords_to_dict(keywords):
|
|
"""Converts a list of ast.keyword objects to a dict."""
|
|
keys = []
|
|
values = []
|
|
for kw in keywords:
|
|
keys.append(gast.Constant(kw.arg, kind=None))
|
|
values.append(kw.value)
|
|
return gast.Dict(keys=keys, values=values)
|
|
|
|
|
|
class PatternMatcher(gast.NodeVisitor):
|
|
"""Matches a node against a pattern represented by a node."""
|
|
|
|
def __init__(self, pattern):
|
|
self.pattern = pattern
|
|
self.pattern_stack = []
|
|
self.matches = True
|
|
|
|
def compare_and_visit(self, node, pattern):
|
|
self.pattern_stack.append(self.pattern)
|
|
self.pattern = pattern
|
|
self.generic_visit(node)
|
|
self.pattern = self.pattern_stack.pop()
|
|
|
|
def no_match(self):
|
|
self.matches = False
|
|
return False
|
|
|
|
def is_wildcard(self, p):
|
|
if isinstance(p, (list, tuple)) and len(p) == 1:
|
|
p, = p
|
|
if isinstance(p, gast.Name) and p.id == '_':
|
|
return True
|
|
if p == '_':
|
|
return True
|
|
return False
|
|
|
|
def generic_visit(self, node):
|
|
if not self.matches:
|
|
return
|
|
|
|
pattern = self.pattern
|
|
for f in node._fields:
|
|
if f.startswith('__'):
|
|
continue
|
|
|
|
if not hasattr(node, f):
|
|
if hasattr(pattern, f) and getattr(pattern, f):
|
|
return self.no_match()
|
|
else:
|
|
continue
|
|
if not hasattr(pattern, f):
|
|
return self.no_match()
|
|
|
|
v = getattr(node, f)
|
|
p = getattr(pattern, f)
|
|
|
|
if self.is_wildcard(p):
|
|
continue
|
|
if isinstance(v, (list, tuple)):
|
|
if not isinstance(p, (list, tuple)) or len(v) != len(p):
|
|
return self.no_match()
|
|
for v_item, p_item in zip(v, p):
|
|
self.compare_and_visit(v_item, p_item)
|
|
elif isinstance(v, (gast.AST, ast.AST)):
|
|
if not isinstance(v, type(p)) and not isinstance(p, type(v)):
|
|
return self.no_match()
|
|
self.compare_and_visit(v, p)
|
|
else:
|
|
# Assume everything else is a value type.
|
|
if v != p:
|
|
return self.no_match()
|
|
|
|
|
|
def matches(node, pattern):
|
|
"""Basic pattern matcher for AST.
|
|
|
|
The pattern may contain wildcards represented by the symbol '_'. A node
|
|
matches a pattern if for every node in the tree, either there is a node of
|
|
the same type in pattern, or a Name node with id='_'.
|
|
|
|
Args:
|
|
node: ast.AST
|
|
pattern: ast.AST
|
|
Returns:
|
|
bool
|
|
"""
|
|
if isinstance(pattern, str):
|
|
pattern = parser.parse_str(pattern)
|
|
|
|
matcher = PatternMatcher(pattern)
|
|
matcher.visit(node)
|
|
return matcher.matches
|
|
|
|
|
|
# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
|
|
def apply_to_single_assignments(targets, values, apply_fn):
|
|
"""Applies a function to each individual assignment.
|
|
|
|
This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
|
|
It tries to break down the unpacking if possible. In effect, it has the same
|
|
effect as passing the assigned values in SSA form to apply_fn.
|
|
|
|
Examples:
|
|
|
|
The following will result in apply_fn(a, c), apply_fn(b, d):
|
|
|
|
a, b = c, d
|
|
|
|
The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
|
|
|
|
a, b = c
|
|
|
|
The following will result in apply_fn(a, (b, c)):
|
|
|
|
a = b, c
|
|
|
|
It uses the visitor pattern to allow subclasses to process single
|
|
assignments individually.
|
|
|
|
Args:
|
|
targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
|
|
used with the targets field of an ast.Assign node
|
|
values: ast.AST
|
|
apply_fn: Callable[[ast.AST, ast.AST], None], called with the
|
|
respective nodes of each single assignment
|
|
"""
|
|
if not isinstance(targets, (list, tuple)):
|
|
targets = (targets,)
|
|
for target in targets:
|
|
if isinstance(target, (gast.Tuple, gast.List)):
|
|
for i in range(len(target.elts)):
|
|
target_el = target.elts[i]
|
|
if isinstance(values, (gast.Tuple, gast.List)):
|
|
value_el = values.elts[i]
|
|
else:
|
|
idx = parser.parse_expression(str(i))
|
|
value_el = gast.Subscript(values, idx, ctx=gast.Load())
|
|
apply_to_single_assignments(target_el, value_el, apply_fn)
|
|
else:
|
|
apply_fn(target, values)
|
|
|
|
|
|
def parallel_walk(node, other):
|
|
"""Walks two ASTs in parallel.
|
|
|
|
The two trees must have identical structure.
|
|
|
|
Args:
|
|
node: Union[ast.AST, Iterable[ast.AST]]
|
|
other: Union[ast.AST, Iterable[ast.AST]]
|
|
Yields:
|
|
Tuple[ast.AST, ast.AST]
|
|
Raises:
|
|
ValueError: if the two trees don't have identical structure.
|
|
"""
|
|
if isinstance(node, (list, tuple)):
|
|
node_stack = list(node)
|
|
else:
|
|
node_stack = [node]
|
|
|
|
if isinstance(other, (list, tuple)):
|
|
other_stack = list(other)
|
|
else:
|
|
other_stack = [other]
|
|
|
|
while node_stack and other_stack:
|
|
assert len(node_stack) == len(other_stack)
|
|
n = node_stack.pop()
|
|
o = other_stack.pop()
|
|
|
|
if ((not isinstance(n, (ast.AST, gast.AST, str)) and n is not None) or
|
|
(not isinstance(o, (ast.AST, gast.AST, str)) and n is not None) or
|
|
n.__class__.__name__ != o.__class__.__name__):
|
|
raise ValueError('inconsistent nodes: {} ({}) and {} ({})'.format(
|
|
n, n.__class__.__name__, o, o.__class__.__name__))
|
|
|
|
yield n, o
|
|
|
|
if isinstance(n, str):
|
|
assert isinstance(o, str), 'The check above should have ensured this'
|
|
continue
|
|
if n is None:
|
|
assert o is None, 'The check above should have ensured this'
|
|
continue
|
|
|
|
for f in n._fields:
|
|
n_child = getattr(n, f, None)
|
|
o_child = getattr(o, f, None)
|
|
if f.startswith('__') or n_child is None or o_child is None:
|
|
continue
|
|
|
|
if isinstance(n_child, (list, tuple)):
|
|
if (not isinstance(o_child, (list, tuple)) or
|
|
len(n_child) != len(o_child)):
|
|
raise ValueError(
|
|
'inconsistent values for field {}: {} and {}'.format(
|
|
f, n_child, o_child))
|
|
node_stack.extend(n_child)
|
|
other_stack.extend(o_child)
|
|
|
|
elif isinstance(n_child, (gast.AST, ast.AST)):
|
|
node_stack.append(n_child)
|
|
other_stack.append(o_child)
|
|
|
|
elif n_child != o_child:
|
|
raise ValueError(
|
|
'inconsistent values for field {}: {} and {}'.format(
|
|
f, n_child, o_child))
|