3RNN/Lib/site-packages/tensorflow/python/autograph/pyct/ast_util.py
2024-05-26 19:49:15 +02:00

345 lines
10 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""AST manipulation utilities."""
import ast
import gast
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
class CleanCopier(object):
"""NodeTransformer-like visitor that copies an AST."""
def __init__(self, preserve_annos):
super(CleanCopier, self).__init__()
self.preserve_annos = preserve_annos
def copy(self, node):
"""Returns a deep copy of node (excluding some fields, see copy_clean)."""
if isinstance(node, list):
return [self.copy(n) for n in node]
elif isinstance(node, tuple):
return tuple(self.copy(n) for n in node)
elif not isinstance(node, (gast.AST, ast.AST)):
# Assuming everything that's not an AST, list or tuple is a value type
# and may simply be assigned.
return node
assert isinstance(node, (gast.AST, ast.AST))
new_fields = {}
for f in node._fields:
if not f.startswith('__') and hasattr(node, f):
new_fields[f] = self.copy(getattr(node, f))
new_node = type(node)(**new_fields)
if self.preserve_annos:
for k in self.preserve_annos:
anno.copyanno(node, new_node, k)
return new_node
def copy_clean(node, preserve_annos=None):
"""Creates a deep copy of an AST.
The copy will not include fields that are prefixed by '__', with the
exception of user-specified annotations.
Args:
node: ast.AST
preserve_annos: Optional[Set[Hashable]], annotation keys to include in the
copy
Returns:
ast.AST
"""
return CleanCopier(preserve_annos).copy(node)
class SymbolRenamer(gast.NodeTransformer):
"""Transformer that can rename symbols to a simple names."""
def __init__(self, name_map):
self.name_map = name_map
def _process_name_node(self, node):
qn = anno.getanno(node, anno.Basic.QN)
if qn in self.name_map:
new_node = gast.Name(
str(self.name_map[qn]),
ctx=node.ctx,
annotation=None,
type_comment=None)
# All annotations get carried over.
for k in anno.keys(node):
anno.copyanno(node, new_node, k)
return new_node
return self.generic_visit(node)
def _process_list_of_strings(self, names):
for i in range(len(names)):
qn = qual_names.QN(names[i])
if qn in self.name_map:
names[i] = str(self.name_map[qn])
return names
def visit_Nonlocal(self, node):
node.names = self._process_list_of_strings(node.names)
return node
def visit_Global(self, node):
node.names = self._process_list_of_strings(node.names)
return node
def visit_Name(self, node):
return self._process_name_node(node)
def visit_Attribute(self, node):
if anno.hasanno(node, anno.Basic.QN):
return self._process_name_node(node)
# Renaming attributes is not supported.
return self.generic_visit(node)
def visit_FunctionDef(self, node):
qn = qual_names.QN(node.name)
if qn in self.name_map:
node.name = str(self.name_map[qn])
return self.generic_visit(node)
def rename_symbols(node, name_map):
"""Renames symbols in an AST. Requires qual_names annotations."""
renamer = SymbolRenamer(name_map)
if isinstance(node, list):
return [renamer.visit(n) for n in node]
elif isinstance(node, tuple):
return tuple(renamer.visit(n) for n in node)
return renamer.visit(node)
def keywords_to_dict(keywords):
"""Converts a list of ast.keyword objects to a dict."""
keys = []
values = []
for kw in keywords:
keys.append(gast.Constant(kw.arg, kind=None))
values.append(kw.value)
return gast.Dict(keys=keys, values=values)
class PatternMatcher(gast.NodeVisitor):
"""Matches a node against a pattern represented by a node."""
def __init__(self, pattern):
self.pattern = pattern
self.pattern_stack = []
self.matches = True
def compare_and_visit(self, node, pattern):
self.pattern_stack.append(self.pattern)
self.pattern = pattern
self.generic_visit(node)
self.pattern = self.pattern_stack.pop()
def no_match(self):
self.matches = False
return False
def is_wildcard(self, p):
if isinstance(p, (list, tuple)) and len(p) == 1:
p, = p
if isinstance(p, gast.Name) and p.id == '_':
return True
if p == '_':
return True
return False
def generic_visit(self, node):
if not self.matches:
return
pattern = self.pattern
for f in node._fields:
if f.startswith('__'):
continue
if not hasattr(node, f):
if hasattr(pattern, f) and getattr(pattern, f):
return self.no_match()
else:
continue
if not hasattr(pattern, f):
return self.no_match()
v = getattr(node, f)
p = getattr(pattern, f)
if self.is_wildcard(p):
continue
if isinstance(v, (list, tuple)):
if not isinstance(p, (list, tuple)) or len(v) != len(p):
return self.no_match()
for v_item, p_item in zip(v, p):
self.compare_and_visit(v_item, p_item)
elif isinstance(v, (gast.AST, ast.AST)):
if not isinstance(v, type(p)) and not isinstance(p, type(v)):
return self.no_match()
self.compare_and_visit(v, p)
else:
# Assume everything else is a value type.
if v != p:
return self.no_match()
def matches(node, pattern):
"""Basic pattern matcher for AST.
The pattern may contain wildcards represented by the symbol '_'. A node
matches a pattern if for every node in the tree, either there is a node of
the same type in pattern, or a Name node with id='_'.
Args:
node: ast.AST
pattern: ast.AST
Returns:
bool
"""
if isinstance(pattern, str):
pattern = parser.parse_str(pattern)
matcher = PatternMatcher(pattern)
matcher.visit(node)
return matcher.matches
# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
def apply_to_single_assignments(targets, values, apply_fn):
"""Applies a function to each individual assignment.
This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
It tries to break down the unpacking if possible. In effect, it has the same
effect as passing the assigned values in SSA form to apply_fn.
Examples:
The following will result in apply_fn(a, c), apply_fn(b, d):
a, b = c, d
The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
a, b = c
The following will result in apply_fn(a, (b, c)):
a = b, c
It uses the visitor pattern to allow subclasses to process single
assignments individually.
Args:
targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
used with the targets field of an ast.Assign node
values: ast.AST
apply_fn: Callable[[ast.AST, ast.AST], None], called with the
respective nodes of each single assignment
"""
if not isinstance(targets, (list, tuple)):
targets = (targets,)
for target in targets:
if isinstance(target, (gast.Tuple, gast.List)):
for i in range(len(target.elts)):
target_el = target.elts[i]
if isinstance(values, (gast.Tuple, gast.List)):
value_el = values.elts[i]
else:
idx = parser.parse_expression(str(i))
value_el = gast.Subscript(values, idx, ctx=gast.Load())
apply_to_single_assignments(target_el, value_el, apply_fn)
else:
apply_fn(target, values)
def parallel_walk(node, other):
"""Walks two ASTs in parallel.
The two trees must have identical structure.
Args:
node: Union[ast.AST, Iterable[ast.AST]]
other: Union[ast.AST, Iterable[ast.AST]]
Yields:
Tuple[ast.AST, ast.AST]
Raises:
ValueError: if the two trees don't have identical structure.
"""
if isinstance(node, (list, tuple)):
node_stack = list(node)
else:
node_stack = [node]
if isinstance(other, (list, tuple)):
other_stack = list(other)
else:
other_stack = [other]
while node_stack and other_stack:
assert len(node_stack) == len(other_stack)
n = node_stack.pop()
o = other_stack.pop()
if ((not isinstance(n, (ast.AST, gast.AST, str)) and n is not None) or
(not isinstance(o, (ast.AST, gast.AST, str)) and n is not None) or
n.__class__.__name__ != o.__class__.__name__):
raise ValueError('inconsistent nodes: {} ({}) and {} ({})'.format(
n, n.__class__.__name__, o, o.__class__.__name__))
yield n, o
if isinstance(n, str):
assert isinstance(o, str), 'The check above should have ensured this'
continue
if n is None:
assert o is None, 'The check above should have ensured this'
continue
for f in n._fields:
n_child = getattr(n, f, None)
o_child = getattr(o, f, None)
if f.startswith('__') or n_child is None or o_child is None:
continue
if isinstance(n_child, (list, tuple)):
if (not isinstance(o_child, (list, tuple)) or
len(n_child) != len(o_child)):
raise ValueError(
'inconsistent values for field {}: {} and {}'.format(
f, n_child, o_child))
node_stack.extend(n_child)
other_stack.extend(o_child)
elif isinstance(n_child, (gast.AST, ast.AST)):
node_stack.append(n_child)
other_stack.append(o_child)
elif n_child != o_child:
raise ValueError(
'inconsistent values for field {}: {} and {}'.format(
f, n_child, o_child))