3RNN/Lib/site-packages/tensorflow/python/autograph/pyct/transformer.py
2024-05-26 19:49:15 +02:00

539 lines
16 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A node transformer that includes utilities for SCT."""
import collections
import enum
import gast
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import templates
class AnalysisLevel(enum.IntEnum):
NONE = 0
ACTIVITY = 1
DEFINEDNESS = 2
LIVENESS = 3
# TODO(znado): Use namedtuple.
class Context(object):
"""Contains information about a source code transformation.
This object is mutable, and is updated during conversion. Not thread safe.
Attributes:
info: EntityInfo, immutable.
namer: naming.Namer.
current_origin: origin_info.OriginInfo, holds the OriginInfo of the last
AST node to be processed successfully. Useful for error handling.
user: An user-supplied context object. The object is opaque to the
infrastructure, but will pe passed through to all custom transformations.
"""
def __init__(self, info, namer, user_context):
self.info = info
self.namer = namer
self.current_origin = None
self.user = user_context
# TODO(mdan): Move to a standalone file.
class EntityInfo(
collections.namedtuple(
'EntityInfo',
('name', 'source_code', 'source_file', 'future_features', 'namespace'))
):
"""Contains information about a Python entity.
Immutable.
Examples of entities include functions and classes.
Attributes:
name: The name that identifies this entity.
source_code: The entity's source code.
source_file: The entity's source file.
future_features: Tuple[Text], the future features that this entity was
compiled with. See
https://docs.python.org/2/reference/simple_stmts.html#future.
namespace: Dict[str, ], containing symbols visible to the entity (excluding
parameters).
"""
pass
class _StateStack(object):
"""Templated context manager.
This class provides syntactic sugar for a stack of objects of known
type. It allows accessing attributes of the object at the top of the stack
directly against this object, which allows for very terse syntax.
For example, this code:
stack = _StateStack(Foo)
stack.enter()
stack.bar
Is equivalent to:
stack = []
stack.append(Foo())
foo = stack[-1]
foo.bar
See _State for more on how this is used.
Attributes:
type: Any, the type of objects that this stack holds
level: int, the current stack depth
stack: List[Any], the actual stack
value: Any, the instance of the object at the top of the stack
"""
def __init__(self, type_):
# Because we override __setattr__, we need to attach these attributes using
# the superclass' setattr.
object.__setattr__(self, 'type', type_)
object.__setattr__(self, '_stack', [])
if not hasattr(type_, 'no_root'):
self.enter()
def __enter__(self):
self.enter()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.exit()
def enter(self):
self._stack.append(self.type())
def exit(self):
self._stack.pop()
@property
def stack(self):
return self._stack
@property
def level(self):
return len(self._stack)
@property
def value(self):
return self._stack[-1]
def __iter__(self):
return iter(self._stack)
def __getattr__(self, key):
return getattr(self._stack[-1], key)
def __setattr__(self, key, value):
setattr(self._stack[-1], key, value)
class _State(object):
"""Syntactic sugar for accessing an instance of a StateStack context manager.
This structure offers syntactic sugar over a dict of stacks of objects
of known type. These structures are useful to keep state during AST walks.
Multiple different scopes can be tracked in parallel. For example:
s = _State()
s[foo].enter()
s[bar].enter() # this will not affect s[foo]
Element access has special semantics:
* keys are a data type
* element values are _StateStack(type=key) objects
* missing elements are automatically added, similarly to defaultdict
For example, the following block :
_State s
s[Foo]
Is equivalent to:
s = {}
if Foo not in s:
s[Foo] = Foo()
s[Foo]
See Base for how it's used.
"""
def __init__(self):
self._value = {}
def __getitem__(self, key):
if key not in self._value:
self._value[key] = _StateStack(key)
return self._value[key]
class NodeStateTracker(object):
"""Base class for general-purpose Python code transformation.
This abstract class provides helpful functions, like state tracking within
the scope of arbitrary node, helpers for processing code blocks, debugging,
mapping of transformed code to original code, and others.
Scope-local state tracking: to keep state across nodes, at the level of
(possibly nested) scopes, use enter/exit_local_scope and set/get_local.
You must call enter/exit_local_scope manually, but the transformer detects
when they are not properly paired.
The transformer allows keeping state across calls that is local
to arbitrary nodes and their descendants, using the self.state attribute.
Multiple independent scopes are allowed and automatically constructed.
For example, to keep track of the `If` node that encloses any `Name` node,
one can write:
```
class FooType(object):
def __init__(self):
self.foo_property = None
class DummyTransformer(NodeStateTracker, ast.NodeTransformer):
def visit_If(self, node):
self.state[FooType].enter()
self.state[FooType].foo_property = node
node = self.veneric_visit(node)
self.state[FooType].exit()
return node
def visit_Name(self, node):
self.state[FooType].foo_property # will hold the innermost enclosing if
```
Alternatively, the `enter()`/`exit()` calls can be managed by a `with`
statement:
```
def visit_If(self, node):
with self.state[FooType] as foo:
foo.foo_property = node
return self.generic_visit(node)
```
"""
# TODO(mdan): Document all extra features.
def __init__(self, ctx):
"""Initialize the transformer.
Subclasses should call this.
Args:
ctx: A Context object.
"""
self._lineno = 0
self._col_offset = 0
self.ctx = ctx
# Allows scoping of local variables to keep state across calls to visit_*
# methods. Multiple scope hierarchies may exist and are keyed by tag. A
# scope is valid at one or more nodes and all its children. Scopes created
# in child nodes supersede their parent. Scopes are isolated from one
# another.
self.state = _State()
def debug_print(self, node):
"""Helper method useful for debugging. Prints the AST."""
if __debug__:
print(pretty_printer.fmt(node))
return node
def debug_print_src(self, node):
"""Helper method useful for debugging. Prints the AST as code."""
if __debug__:
print(parser.unparse(node))
return node
def visit_block(self, nodes, before_visit=None, after_visit=None):
"""A more powerful version of generic_visit for statement blocks.
An example of a block is the body of an if statement.
This function allows specifying a postprocessing callback (the
after_visit argument) argument which can be used to move nodes to a new
destination. This is done by after_visit by returning a non-null
second return value, e.g. return new_node, new_destination.
For example, a transformer could perform the following move:
foo()
bar()
baz()
foo()
if cond:
bar()
baz()
The above could be done with a postprocessor of this kind:
def after_visit(node):
if node_is_function_call(bar):
new_container_node = build_cond()
new_container_node.body.append(node)
return new_container_node, new_container_node.body
else:
# Once we set a new destination, all subsequent items will be
# moved to it, so we don't need to explicitly handle baz.
return node, None
Args:
nodes: enumerable of AST node objects. If None, the function returns None.
before_visit: optional callable that is called before visiting each item
in nodes
after_visit: optional callable that takes in an AST node and returns a
tuple (new_node, new_destination). It is called after visiting each item
in nodes. Is used in the same was as the
visit_* methods: new_node will replace the node; if not None,
new_destination must be a list, and subsequent nodes will be placed
in this list instead of the list returned by visit_block.
Returns:
A list of AST node objects containing the transformed items fron nodes,
except those nodes that have been relocated using after_visit.
"""
if nodes is None:
return None
results = []
node_destination = results
for node in nodes:
if before_visit:
# TODO(mdan): We can modify node here too, if ever needed.
before_visit()
replacement = self.visit(node)
if after_visit and replacement:
replacement, new_destination = after_visit(replacement)
else:
new_destination = None
if replacement:
if isinstance(replacement, (list, tuple)):
node_destination.extend(replacement)
else:
node_destination.append(replacement)
# Allow the postprocessor to reroute the remaining nodes to a new list.
if new_destination is not None:
node_destination = new_destination
return results
# TODO(mdan): Rename to PythonCodeTransformer.
class Base(NodeStateTracker, gast.NodeTransformer):
"""Base class for general-purpose Python-to-Python code transformation.
This is an extension of ast.NodeTransformer that provides the additional
functions offered by NodeStateTracker.
"""
def create_assignment(self, target, expression):
template = """
target = expression
"""
return templates.replace(template, target=target, expression=expression)
# TODO(mdan): Remove.
def apply_to_single_assignments(self, targets, values, apply_fn):
"""Applies a function to each individual assignment.
This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
It tries to break down the unpacking if possible. In effect, it has the same
effect as passing the assigned values in SSA form to apply_fn.
Examples:
The following will result in apply_fn(a, c), apply_fn(b, d):
a, b = c, d
The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
a, b = c
The following will result in apply_fn(a, (b, c)):
a = b, c
It uses the visitor pattern to allow subclasses to process single
assignments individually.
Args:
targets: list, tuple of or individual AST node. Should be used with the
targets field of an ast.Assign node.
values: an AST node.
apply_fn: a function of a single argument, which will be called with the
respective nodes of each single assignment. The signature is
apply_fn(target, value), no return value.
"""
if not isinstance(targets, (list, tuple)):
targets = (targets,)
for target in targets:
if isinstance(target, (gast.Tuple, gast.List)):
for i in range(len(target.elts)):
target_el = target.elts[i]
if isinstance(values, (gast.Tuple, gast.List)):
value_el = values.elts[i]
else:
value_el = gast.Subscript(values, i, ctx=gast.Store())
self.apply_to_single_assignments(target_el, value_el, apply_fn)
else:
# TODO(mdan): Look into allowing to rewrite the AST here.
apply_fn(target, values)
def visit(self, node):
if not isinstance(node, gast.AST):
# This is not that uncommon a mistake: various node bodies are lists, for
# example, posing a land mine for transformers that need to recursively
# call `visit`. The error needs to be raised before the exception handler
# below is installed, because said handler will mess up if `node` is not,
# in fact, a node.
msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
' visit lists of nodes, use "visit_block" instead').format(
type(node))
raise ValueError(msg)
if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
return node
parent_origin = self.ctx.current_origin
if anno.hasanno(node, anno.Basic.ORIGIN):
self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)
try:
processing_expr_node = isinstance(node, gast.Expr)
if processing_expr_node:
entry_expr_value = node.value
result = super(Base, self).visit(node)
# Adjust for consistency: replacing the value of an Expr with
# an Assign node removes the need for the Expr node.
if (processing_expr_node and isinstance(result, gast.Expr) and
(result.value is not entry_expr_value)):
# When the replacement is a list, it is assumed that the list came
# from a template that contained a number of statements, which
# themselves are standalone and don't require an enclosing Expr.
if isinstance(result.value,
(list, tuple, gast.Assign, gast.AugAssign)):
result = result.value
# By default, all replacements receive the origin info of the replaced
# node.
if result is not node and result is not None:
inherited_origin = anno.getanno(
node, anno.Basic.ORIGIN, default=parent_origin)
if inherited_origin is not None:
nodes_to_adjust = result
if isinstance(result, (list, tuple)):
nodes_to_adjust = result
else:
nodes_to_adjust = (result,)
for n in nodes_to_adjust:
if not anno.hasanno(n, anno.Basic.ORIGIN):
anno.setanno(n, anno.Basic.ORIGIN, inherited_origin)
finally:
self.ctx.current_origin = parent_origin
return result
class CodeGenerator(NodeStateTracker, gast.NodeVisitor):
"""Base class for general-purpose Python-to-string code transformation.
Similar to Base, but outputs arbitrary strings instead of a Python AST.
This uses the same visitor mechanism that the standard NodeVisitor uses,
meaning that subclasses write handlers for the different kinds of nodes.
New code is generated using the emit method, which appends to a code buffer
that can be afterwards obtained from code_buffer.
Example:
class SimpleCodeGen(CodeGenerator):
def visitIf(self, node):
self.emit('if ')
self.visit(node.test)
self.emit(' { ')
self.visit(node.body)
self.emit(' } else { ')
self.visit(node.orelse)
self.emit(' } ')
node = ast.parse(...)
gen = SimpleCodeGen()
gen.visit(node)
# gen.code_buffer contains the resulting code
"""
def __init__(self, ctx):
super(CodeGenerator, self).__init__(ctx)
self._output_code = ''
self.source_map = {}
def emit(self, code):
self._output_code += code
@property
def code_buffer(self):
return self._output_code
def visit(self, node):
if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
return
parent_origin = self.ctx.current_origin
eof_before = len(self._output_code)
if anno.hasanno(node, anno.Basic.ORIGIN):
self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)
try:
ret = super(CodeGenerator, self).visit(node)
# By default, all replacements receive the origin info of the replaced
# node.
eof_after = len(self._output_code)
if eof_before - eof_after:
inherited_origin = anno.getanno(
node, anno.Basic.ORIGIN, default=parent_origin)
if inherited_origin is not None:
self.source_map[(eof_before, eof_after)] = inherited_origin
return ret
finally:
self.ctx.current_origin = parent_origin