3RNN/Lib/site-packages/tensorflow/python/autograph/converters/directives.py

178 lines
6.6 KiB
Python
Raw Normal View History

2024-05-26 19:49:15 +02:00
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Handles directives.
This converter removes the directive functions from the code and moves the
information they specify into AST annotations. It is a specialized form of
static analysis, one that is specific to AutoGraph.
Note that this requires that the actual directive functions are static - that
is, they do not change at runtime. So if you do something like this:
tf.autograph.set_loop_options = <new function>
Then the directive will may no longer be recognized. Furthermore, if the
converted function is cached, such an action may be irreversible.
"""
import inspect
import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.util import tf_inspect
STATIC_VALUE = 'static_value'
"""Used for AST annotations, see visit_Name."""
class _LoopScope(object):
def __init__(self):
self.ast_node = None
self.statements_visited = 0
def _map_args(call_node, function):
"""Maps AST call nodes to the actual function's arguments.
Args:
call_node: ast.Call
function: Callable[..., Any], the actual function matching call_node
Returns:
Dict[Text, ast.AST], mapping each of the function's argument names to
the respective AST node.
Raises:
ValueError: if the default arguments are not correctly set
"""
args = call_node.args
kwds = {kwd.arg: kwd.value for kwd in call_node.keywords}
call_args = tf_inspect.getcallargs(function, *args, **kwds)
# Keyword arguments not specified in kwds will be mapped to their defaults,
# which are Python values. Since we don't currently have a way to transform
# those into AST references, we simply remove them. By convention, directives
# use UNSPECIFIED as default value for optional arguments. No other
# defaults should be present.
unexpected_defaults = []
for k in call_args:
if (k not in kwds
and call_args[k] not in args
and call_args[k] is not directives.UNSPECIFIED):
unexpected_defaults.append(k)
if unexpected_defaults:
raise ValueError('Unexpected keyword argument values, %s, for function %s'
% (zip(unexpected_defaults,
[call_args[k] for k in unexpected_defaults]),
function))
return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED}
class DirectivesTransformer(converter.Base):
"""Parses compiler directives and converts them into AST annotations."""
def _process_symbol_directive(self, call_node, directive):
if len(call_node.args) < 1:
raise ValueError('"%s" requires a positional first argument'
' as the target' % directive.__name__)
target = call_node.args[0]
defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS)
for def_ in defs:
def_.directives[directive] = _map_args(call_node, directive)
return call_node
def _process_statement_directive(self, call_node, directive):
if self.state[_LoopScope].statements_visited > 1:
raise ValueError(
'"%s" must be the first statement in the loop block' % (
directive.__name__))
if self.state[_LoopScope].level < 2:
raise ValueError(
'"%s" must be used inside a statement' % directive.__name__)
target = self.state[_LoopScope].ast_node
node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {})
node_anno[directive] = _map_args(call_node, directive)
anno.setanno(target, anno.Basic.DIRECTIVES, node_anno)
return call_node
def visit_Name(self, node):
node = self.generic_visit(node)
if isinstance(node.ctx, gast.Load):
defs = anno.getanno(node, anno.Static.DEFINITIONS, ())
is_defined = bool(defs)
if not is_defined and node.id in self.ctx.info.namespace:
anno.setanno(node, STATIC_VALUE, self.ctx.info.namespace[node.id])
return node
def visit_Attribute(self, node):
node = self.generic_visit(node)
parent_val = anno.getanno(node.value, STATIC_VALUE, default=None)
if parent_val is not None and inspect.ismodule(parent_val):
if hasattr(parent_val, node.attr):
anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr))
return node
def visit_Assign(self, node):
self.state[_LoopScope].statements_visited += 1
return self.generic_visit(node)
def visit_AugAssign(self, node):
self.state[_LoopScope].statements_visited += 1
return self.generic_visit(node)
def visit_Expr(self, node):
self.state[_LoopScope].statements_visited += 1
node = self.generic_visit(node)
if isinstance(node.value, gast.Call):
call_node = node.value
static_val = anno.getanno(call_node.func, STATIC_VALUE, default=None)
if static_val is not None:
# Note: directive calls are not output in the generated code, hence
# the removal from the code by returning None.
if static_val is directives.set_element_type:
self._process_symbol_directive(call_node, static_val)
return None
elif static_val is directives.set_loop_options:
self._process_statement_directive(call_node, static_val)
return None
return node
# TODO(mdan): This will be insufficient for other control flow.
# That means that if we ever have a directive that affects things other than
# loops, we'll need support for parallel scopes, or have multiple converters.
def _track_and_visit_loop(self, node):
self.state[_LoopScope].enter()
self.state[_LoopScope].ast_node = node
node = self.generic_visit(node)
# Edge case: a loop with just one directive statement would become empty.
if not node.body:
node.body = [gast.Pass()]
self.state[_LoopScope].exit()
return node
def visit_While(self, node):
return self._track_and_visit_loop(node)
def visit_For(self, node):
return self._track_and_visit_loop(node)
def transform(node, ctx):
return DirectivesTransformer(ctx).visit(node)