Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/autograph/pyct/common_transformers/anf.py

621 lines
23 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Conversion to A-normal form.
The general idea of A-normal form is that every intermediate value is
explicitly named with a variable. For more, see
https://en.wikipedia.org/wiki/A-normal_form.
The specific converters used here are based on Python AST semantics as
documented at https://greentreesnakes.readthedocs.io/en/latest/.
"""
import collections
import gast
from tensorflow.python.autograph.pyct import gast_util
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct import transformer
# TODO(mdan): Replace with naming.Namer.
class DummyGensym:
"""A dumb gensym that suffixes a stem by sequential numbers from 1000."""
def __init__(self):
# A proper implementation needs to account for:
# * ctx.info.namespace
# * all the symbols defined in the AST
# * the symbols generated so far
self._idx = 0
def new_name(self, stem='tmp'):
self._idx += 1
return stem + '_' + str(1000 + self._idx)
REPLACE = lambda _1, _2, _3: True
LEAVE = lambda _1, _2, _3: False
ANY = object()
class ASTEdgePattern(collections.namedtuple(
'ASTEdgePattern', ['parent', 'field', 'child'])):
"""A pattern defining a type of AST edge.
This consists of three components:
- The type of the parent node, checked with isinstance,
- The name of the field, checked with string equality, and
- The type of the child node, also checked with isinstance.
If all three match, the whole pattern is considered to match.
In all three slots, the special value `anf.ANY` is treated as "match
anything". The internal nodes are produced from the `gast` library rather
than the standard `ast` module, which may affect `isinstance` checks.
"""
__slots__ = ()
def matches(self, parent, field, child):
"""Computes whether this pattern matches the given edge."""
if self.parent is ANY or isinstance(parent, self.parent):
pass # OK
else:
return False
if self.field is ANY or field == self.field:
pass # OK
else:
return False
return self.child is ANY or isinstance(child, self.child)
class AnfTransformer(transformer.Base):
"""Performs the conversion to A-normal form (ANF)."""
# The algorithm is a postorder recursive tree walk. Any given node A may, in
# general, require creation of a series B of Assign statements, which compute
# and explicitly name the intermediate values needed to compute the value of
# A. If A was already a statement, it can be replaced with the sequence B +
# [A]. If A was an expression, B needs to be propagated up the tree until a
# statement is encountered. Since the `ast.NodeTransformer` framework makes
# no provision for subtraversals returning side information, this class
# accumulates the sequence B in an instance variable.
# The only other subtlety is that some Python statements (like `if`) have both
# expression fields (`test`) and statement list fields (`body` and `orelse`).
# Any additional assignments needed to name all the intermediate values in the
# `test` can be prepended to the `if` node, but assignments produced by
# processing the `body` and the `orelse` need to be kept together with them,
# and not accidentally lifted out of the `if`.
def __init__(self, ctx, config):
"""Creates an ANF transformer.
Args:
ctx: transformer.Context
config: Configuration
"""
super(AnfTransformer, self).__init__(ctx)
if config is None:
# These could be pulled out, but are generally considered to already be in
# A-normal form. Thus they are left in by default, but could be pulled
# out if the configuration calls for it.
if gast_util.GAST2:
literal_node_types = (
gast.Num, gast.Str, gast.Bytes, gast.NameConstant,
gast.Name # Name is here to cover True, False, and None in Python 2
)
elif gast_util.GAST3:
literal_node_types = (
gast.Constant,
gast.Name # Name is here to cover True, False, and None in Python 2
)
else:
assert False
self._overrides = [
(ASTEdgePattern(ANY, ANY, literal_node_types), LEAVE),
(ASTEdgePattern(ANY, ANY, gast.expr), REPLACE)]
else:
self._overrides = config
self._gensym = DummyGensym()
self._pending_statements = []
def _consume_pending_statements(self):
ans = self._pending_statements
self._pending_statements = []
return ans
def _add_pending_statement(self, stmt):
self._pending_statements.append(stmt)
def _match(self, pattern, parent, field, child):
if pattern is ANY:
return True
else:
return pattern.matches(parent, field, child)
def _should_transform(self, parent, field, child):
for pat, result in self._overrides:
if self._match(pat, parent, field, child):
return result(parent, field, child)
# Fell off the end of the pattern list: do not transform
return False
def _do_transform_node(self, node):
temp_name = self._gensym.new_name()
temp_assign = templates.replace(
'temp_name = expr', temp_name=temp_name, expr=node)[0]
self._add_pending_statement(temp_assign)
answer = templates.replace('temp_name', temp_name=temp_name)[0]
return answer
def _ensure_node_in_anf(self, parent, field, node):
"""Puts `node` in A-normal form, by replacing it with a variable if needed.
The exact definition of A-normal form is given by the configuration. The
parent and the incoming field name are only needed because the configuration
may be context-dependent.
Args:
parent: An AST node, the parent of `node`.
field: The field name under which `node` is the child of `parent`.
node: An AST node, potentially to be replaced with a variable reference.
Returns:
node: An AST node; the argument if transformation was not necessary,
or the new variable reference if it was.
"""
if node is None:
return node
if _is_trivial(node):
return node
if isinstance(node, list):
# If something's field was actually a list, e.g., variadic arguments.
return [self._ensure_node_in_anf(parent, field, n) for n in node]
if isinstance(node, gast.keyword):
node.value = self._ensure_node_in_anf(parent, field, node.value)
return node
if isinstance(node, (gast.Starred, gast.withitem, gast.slice)):
# These nodes aren't really extractable in their own right, but their
# subnodes might be. Propagate the parent and field name to the child
# nodes, instead of querying the configuration for children of, e.g.,
# gast.Starred.
return self._ensure_fields_in_anf(node, parent, field)
if self._should_transform(parent, field, node):
return self._do_transform_node(node)
else:
return node
def _ensure_fields_in_anf(self, node, parent=None, super_field=None):
for field in node._fields:
if field.startswith('__'):
continue
parent_supplied = node if parent is None else parent
field_supplied = field if super_field is None else super_field
setattr(node, field, self._ensure_node_in_anf(
parent_supplied, field_supplied, getattr(node, field)))
return node
def _visit_strict_statement(self, node, children_ok_to_transform=True):
assert not self._pending_statements
node = self.generic_visit(node)
if children_ok_to_transform:
self._ensure_fields_in_anf(node)
results = self._consume_pending_statements()
results.append(node)
return results
def _visit_trivial_only_statement(self, node, msg):
assert not self._pending_statements
node = self.generic_visit(node)
self._ensure_fields_in_anf(node)
if self._pending_statements:
raise ValueError(msg)
else:
return node
def _visit_strict_expression(self, node):
node = self.generic_visit(node)
self._ensure_fields_in_anf(node)
return node
def _visit_trivial_only_expression(self, node, msg):
k = len(self._pending_statements)
node = self.generic_visit(node)
self._ensure_fields_in_anf(node)
# This check relies on there being no opportunities to consume pending
# statements while traversing children of an expression.
if len(self._pending_statements) != k:
raise ValueError(msg)
else:
return node
# Note on code order: These are listed in the same order as the grammar
# elements on https://github.com/serge-sans-paille/gast
# FunctionDef, AsyncFunctionDef, and ClassDef should be correct by default.
def visit_Return(self, node):
return self._visit_strict_statement(node)
def visit_Delete(self, node):
return self._visit_strict_statement(node, children_ok_to_transform=False)
def visit_Assign(self, node):
return self._visit_strict_statement(node, children_ok_to_transform=False)
def visit_AugAssign(self, node):
return self._visit_strict_statement(node, children_ok_to_transform=False)
def visit_Print(self, node):
return self._visit_strict_statement(node)
def visit_For(self, node):
assert not self._pending_statements
# It's important to visit node.iter first, because any statements created
# thereby need to live outside the body.
self.visit(node.iter)
node.iter = self._ensure_node_in_anf(node, 'iter', node.iter)
iter_stmts = self._consume_pending_statements()
# This generic_visit will revisit node.iter, but that is correct because by
# this point the node.iter link has been checked. It may be somewhat
# expensive if the configuration didn't call for transforming node.iter, as
# then it may be large and will be uselessly transformed again. This
# behavior is what causes the documented effect that configuration callables
# may be invoked more than once of the same links; if the code is rewritten
# not to do that (anywhere), the docstring of `transform` should be updated.
node = self.generic_visit(node)
assert not self._pending_statements
iter_stmts.append(node)
return iter_stmts
def visit_AsyncFor(self, node):
msg = ('Nontrivial AsyncFor nodes not supported yet '
'(need to think through the semantics).')
return self._visit_trivial_only_statement(node, msg)
def visit_While(self, node):
assert not self._pending_statements
self.visit(node.test)
node.test = self._ensure_node_in_anf(node, 'test', node.test)
if self._pending_statements:
msg = ('While with nontrivial test not supported yet '
'(need to avoid precomputing the test).')
raise ValueError(msg)
# If traversing node.test yielded no statements extracted, the generic visit
# will do the right thing.
return self.generic_visit(node)
def visit_If(self, node):
assert not self._pending_statements
# It's important to visit node.test first, because any statements created
# thereby need to live outside the body.
self.visit(node.test)
node.test = self._ensure_node_in_anf(node, 'test', node.test)
condition_stmts = self._consume_pending_statements()
# This generic_visit will revisit node.test, but that is correct because by
# this point the node.test link has been checked. It may be somewhat
# expensive if the configuration didn't call for transforming node.test, as
# then it may be large and will be uselessly transformed again. This
# happens in several places.
node = self.generic_visit(node)
assert not self._pending_statements
condition_stmts.append(node)
return condition_stmts
def visit_With(self, node):
assert not self._pending_statements
# It's important to visit node.items first, because any statements created
# thereby need to live outside the body.
for item in node.items:
self.visit(item)
node.items = [self._ensure_node_in_anf(node, 'items', n)
for n in node.items]
contexts_stmts = self._consume_pending_statements()
# This generic_visit will revisit node.items, but that is correct because by
# this point the node.items link has been checked. It may be somewhat
# expensive if the configuration didn't call for transforming node.items, as
# then it may be large and will be uselessly transformed again. This
# happens in several places.
node = self.generic_visit(node)
assert not self._pending_statements
contexts_stmts.append(node)
return contexts_stmts
def visit_AsyncWith(self, node):
msg = ('Nontrivial AsyncWith nodes not supported yet '
'(need to think through the semantics).')
return self._visit_trivial_only_statement(node, msg)
def visit_Raise(self, node):
return self._visit_strict_statement(node)
# Try should be correct by default.
def visit_Assert(self, node):
msg = ('Nontrivial Assert nodes not supported yet '
'(need to avoid computing the test when assertions are off, and '
'avoid computing the irritant when the assertion does not fire).')
return self._visit_trivial_only_statement(node, msg)
# Import and ImportFrom should be correct by default.
def visit_Exec(self, node):
return self._visit_strict_statement(node)
# Global and Nonlocal should be correct by default.
def visit_Expr(self, node):
return self._visit_strict_statement(node, children_ok_to_transform=False)
# Pass, Break, and Continue should be correct by default.
def visit_BoolOp(self, node):
msg = ('Nontrivial BoolOp nodes not supported yet '
'(need to preserve short-circuiting semantics).')
return self._visit_trivial_only_expression(node, msg)
def visit_BinOp(self, node):
return self._visit_strict_expression(node)
def visit_UnaryOp(self, node):
return self._visit_strict_expression(node)
def visit_Lambda(self, node):
msg = ('Nontrivial Lambda nodes not supported '
'(cannot insert statements into lambda bodies).')
return self._visit_trivial_only_expression(node, msg)
def visit_IfExp(self, node):
msg = ('Nontrivial IfExp nodes not supported yet '
'(need to convert to If statement, to evaluate branches lazily '
'and insert statements into them).')
return self._visit_trivial_only_expression(node, msg)
def visit_Dict(self, node):
return self._visit_strict_expression(node)
def visit_Set(self, node):
return self._visit_strict_expression(node)
def visit_ListComp(self, node):
msg = ('ListComp nodes not supported '
'(need to convert to a form that tolerates '
'assignment statements in clause bodies).')
raise ValueError(msg)
def visit_SetComp(self, node):
msg = ('SetComp nodes not supported '
'(need to convert to a form that tolerates '
'assignment statements in clause bodies).')
raise ValueError(msg)
def visit_DictComp(self, node):
msg = ('DictComp nodes not supported '
'(need to convert to a form that tolerates '
'assignment statements in clause bodies).')
raise ValueError(msg)
def visit_GeneratorExp(self, node):
msg = ('GeneratorExp nodes not supported '
'(need to convert to a form that tolerates '
'assignment statements in clause bodies).')
raise ValueError(msg)
def visit_Await(self, node):
msg = ('Nontrivial Await nodes not supported yet '
'(need to think through the semantics).')
return self._visit_trivial_only_expression(node, msg)
def visit_Yield(self, node):
return self._visit_strict_expression(node)
def visit_YieldFrom(self, node):
msg = ('Nontrivial YieldFrom nodes not supported yet '
'(need to unit-test them in Python 2).')
return self._visit_trivial_only_expression(node, msg)
def visit_Compare(self, node):
if len(node.ops) > 1:
msg = ('Multi-ary compare nodes not supported yet '
'(need to preserve short-circuiting semantics).')
raise ValueError(msg)
return self._visit_strict_expression(node)
def visit_Call(self, node):
return self._visit_strict_expression(node)
def visit_Repr(self, node):
msg = ('Nontrivial Repr nodes not supported yet '
'(need to research their syntax and semantics).')
return self._visit_trivial_only_expression(node, msg)
def visit_FormattedValue(self, node):
msg = ('Nontrivial FormattedValue nodes not supported yet '
'(need to unit-test them in Python 2).')
return self._visit_trivial_only_expression(node, msg)
def visit_JoinedStr(self, node):
msg = ('Nontrivial JoinedStr nodes not supported yet '
'(need to unit-test them in Python 2).')
return self._visit_trivial_only_expression(node, msg)
def visit_Attribute(self, node):
return self._visit_strict_expression(node)
def visit_Subscript(self, node):
return self._visit_strict_expression(node)
# Starred and Name are correct by default, because the right thing to do is to
# just recur.
def visit_List(self, node):
node = self.generic_visit(node)
if not isinstance(node.ctx, gast.Store):
self._ensure_fields_in_anf(node)
return node
def visit_Tuple(self, node):
node = self.generic_visit(node)
if not isinstance(node.ctx, gast.Store):
self._ensure_fields_in_anf(node)
return node
def _is_py2_name_constant(node):
return isinstance(node, gast.Name) and node.id in ['True', 'False', 'None']
def _is_trivial(node):
"""Returns whether to consider the given node 'trivial'.
The definition of 'trivial' is a node that can't meaningfully be pulled out
into its own assignment statement.
This is surprisingly difficult to do robustly across versions of Python and
gast, as the parsing of constants has changed, if I may, constantly.
Args:
node: An AST node to check for triviality
Returns:
trivial: A Python `bool` indicating whether the node is trivial.
"""
trivial_node_types = (
# Variable names
gast.Name,
# Non-nodes that show up as AST fields
bool,
str,
# Binary operators
gast.Add,
gast.Sub,
gast.Mult,
gast.Div,
gast.Mod,
gast.Pow,
gast.LShift,
gast.RShift,
gast.BitOr,
gast.BitXor,
gast.BitAnd,
gast.FloorDiv,
# Unary operators
gast.Invert,
gast.Not,
gast.UAdd,
gast.USub,
# Comparison operators
gast.Eq,
gast.NotEq,
gast.Lt,
gast.LtE,
gast.Gt,
gast.GtE,
gast.Is,
gast.IsNot,
gast.In,
gast.NotIn,
# Other leaf nodes that don't make sense standalone.
gast.expr_context,
)
if isinstance(node, trivial_node_types) and not _is_py2_name_constant(node):
return True
if gast_util.is_ellipsis(node):
return True
return False
def transform(node, ctx, config=None):
"""Converts the given node to A-normal form (ANF).
The general idea of A-normal form: https://en.wikipedia.org/wiki/A-normal_form
The specific converters used here are based on Python AST semantics as
documented at https://greentreesnakes.readthedocs.io/en/latest/.
What exactly should be considered A-normal form for any given programming
language is not completely obvious. The transformation defined here is
therefore configurable as to which syntax to replace with a fresh variable and
which to leave be. The configuration is intentionally flexible enough to
define very precise variable insertion transformations, should that be
desired.
The configuration is a list of syntax rules, each of which is a 2-tuple:
- An `ASTEdgePattern` (which see) defining a type of AST edge, and
- Whether to transform children of such edges.
The special object `anf.ANY` may be used as a pattern that matches all edges.
Each replacement directive is one of three possible things:
- The object `anf.REPLACE`, meaning "Replace this child node with a variable",
- The object `anf.LEAVE`, meaning "Do not replace this child node with a
variable", or
- A Python callable. If a callable, it is called with the parent node, the
field name, and the child node, and must compute a boolean indicating
whether to transform the child node or not. The callable is free to use
whatever context information it chooses. The callable may be invoked more
than once on the same link, and must produce the same answer each time.
The syntax rules are tested in order, and the first match governs. If no rule
matches, the node is not transformed.
The above rules notwithstanding,
- Variable references are never replaced with (fresh) variables, as that would
accomplish nothing.
- The left-hand children of Assign and AugAssign nodes, and the children of
Del nodes, are never replaced with variables, as that would break their
semantics.
- The right-hand children of Assign nodes are never replaced with variables,
as the original assignment would still have to be present in the result
to define the new variable. (That is, there's no point in transforming
`x = sin(y)` into `tmp = sin(y); x = tmp`.)
- The right-hand children of AugAssign nodes are never replaced with variables
either, but only because the difference from Assign was considered a
potential source of confusion (and it would have been slightly awkward in
the code to treat the RHS differently than the LHS).
- Various special-purpose AST nodes are not exposed to the configuration, lest
the transform produce invalid syntax like, e.g., `tmp = +; x = 1 tmp 2`.
For example, the configuration
```python
[(anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.expr), anf.REPLACE)]
```
gives explicit fresh names to all expressions regardless of context (except as
outlined above), whereas
```python
[(anf.ASTEdgePattern(gast.If, "test", anf.ANY), anf.REPLACE)]
```
only transforms the conditionals of `if` statements (but not, e.g., `while`).
If no configuration is supplied, the default behavior is to transform all
expressions except literal constants, which is defined as a configuration as
```python
# For Python 3, and gast library versions before 0.3
literals = (gast.Num, gast.Str, gast.Bytes, gast.NameConstant)
[(anf.ASTEdgePattern(anf.ANY, anf.ANY, literals), anf.LEAVE),
(anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.expr), anf.REPLACE)]
```
Args:
node: The node to transform.
ctx: transformer.EntityInfo. TODO(mdan): What information does this
argument provide?
config: Optional ANF configuration. If omitted, ANF replaces all expression
expect literal constants.
"""
return AnfTransformer(ctx, config).visit(node)