178 lines
6.6 KiB
Python
178 lines
6.6 KiB
Python
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Handles directives.
|
||
|
|
||
|
This converter removes the directive functions from the code and moves the
|
||
|
information they specify into AST annotations. It is a specialized form of
|
||
|
static analysis, one that is specific to AutoGraph.
|
||
|
|
||
|
Note that this requires that the actual directive functions are static - that
|
||
|
is, they do not change at runtime. So if you do something like this:
|
||
|
|
||
|
tf.autograph.set_loop_options = <new function>
|
||
|
|
||
|
Then the directive will may no longer be recognized. Furthermore, if the
|
||
|
converted function is cached, such an action may be irreversible.
|
||
|
"""
|
||
|
|
||
|
import inspect
|
||
|
|
||
|
import gast
|
||
|
|
||
|
from tensorflow.python.autograph.core import converter
|
||
|
from tensorflow.python.autograph.lang import directives
|
||
|
from tensorflow.python.autograph.pyct import anno
|
||
|
from tensorflow.python.util import tf_inspect
|
||
|
|
||
|
|
||
|
STATIC_VALUE = 'static_value'
|
||
|
"""Used for AST annotations, see visit_Name."""
|
||
|
|
||
|
|
||
|
class _LoopScope(object):
|
||
|
|
||
|
def __init__(self):
|
||
|
self.ast_node = None
|
||
|
self.statements_visited = 0
|
||
|
|
||
|
|
||
|
def _map_args(call_node, function):
|
||
|
"""Maps AST call nodes to the actual function's arguments.
|
||
|
|
||
|
Args:
|
||
|
call_node: ast.Call
|
||
|
function: Callable[..., Any], the actual function matching call_node
|
||
|
Returns:
|
||
|
Dict[Text, ast.AST], mapping each of the function's argument names to
|
||
|
the respective AST node.
|
||
|
Raises:
|
||
|
ValueError: if the default arguments are not correctly set
|
||
|
"""
|
||
|
args = call_node.args
|
||
|
kwds = {kwd.arg: kwd.value for kwd in call_node.keywords}
|
||
|
call_args = tf_inspect.getcallargs(function, *args, **kwds)
|
||
|
|
||
|
# Keyword arguments not specified in kwds will be mapped to their defaults,
|
||
|
# which are Python values. Since we don't currently have a way to transform
|
||
|
# those into AST references, we simply remove them. By convention, directives
|
||
|
# use UNSPECIFIED as default value for optional arguments. No other
|
||
|
# defaults should be present.
|
||
|
unexpected_defaults = []
|
||
|
for k in call_args:
|
||
|
if (k not in kwds
|
||
|
and call_args[k] not in args
|
||
|
and call_args[k] is not directives.UNSPECIFIED):
|
||
|
unexpected_defaults.append(k)
|
||
|
if unexpected_defaults:
|
||
|
raise ValueError('Unexpected keyword argument values, %s, for function %s'
|
||
|
% (zip(unexpected_defaults,
|
||
|
[call_args[k] for k in unexpected_defaults]),
|
||
|
function))
|
||
|
return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED}
|
||
|
|
||
|
|
||
|
class DirectivesTransformer(converter.Base):
|
||
|
"""Parses compiler directives and converts them into AST annotations."""
|
||
|
|
||
|
def _process_symbol_directive(self, call_node, directive):
|
||
|
if len(call_node.args) < 1:
|
||
|
raise ValueError('"%s" requires a positional first argument'
|
||
|
' as the target' % directive.__name__)
|
||
|
target = call_node.args[0]
|
||
|
defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS)
|
||
|
for def_ in defs:
|
||
|
def_.directives[directive] = _map_args(call_node, directive)
|
||
|
return call_node
|
||
|
|
||
|
def _process_statement_directive(self, call_node, directive):
|
||
|
if self.state[_LoopScope].statements_visited > 1:
|
||
|
raise ValueError(
|
||
|
'"%s" must be the first statement in the loop block' % (
|
||
|
directive.__name__))
|
||
|
if self.state[_LoopScope].level < 2:
|
||
|
raise ValueError(
|
||
|
'"%s" must be used inside a statement' % directive.__name__)
|
||
|
target = self.state[_LoopScope].ast_node
|
||
|
node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {})
|
||
|
node_anno[directive] = _map_args(call_node, directive)
|
||
|
anno.setanno(target, anno.Basic.DIRECTIVES, node_anno)
|
||
|
return call_node
|
||
|
|
||
|
def visit_Name(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
if isinstance(node.ctx, gast.Load):
|
||
|
defs = anno.getanno(node, anno.Static.DEFINITIONS, ())
|
||
|
is_defined = bool(defs)
|
||
|
if not is_defined and node.id in self.ctx.info.namespace:
|
||
|
anno.setanno(node, STATIC_VALUE, self.ctx.info.namespace[node.id])
|
||
|
return node
|
||
|
|
||
|
def visit_Attribute(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
parent_val = anno.getanno(node.value, STATIC_VALUE, default=None)
|
||
|
if parent_val is not None and inspect.ismodule(parent_val):
|
||
|
if hasattr(parent_val, node.attr):
|
||
|
anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr))
|
||
|
return node
|
||
|
|
||
|
def visit_Assign(self, node):
|
||
|
self.state[_LoopScope].statements_visited += 1
|
||
|
return self.generic_visit(node)
|
||
|
|
||
|
def visit_AugAssign(self, node):
|
||
|
self.state[_LoopScope].statements_visited += 1
|
||
|
return self.generic_visit(node)
|
||
|
|
||
|
def visit_Expr(self, node):
|
||
|
self.state[_LoopScope].statements_visited += 1
|
||
|
node = self.generic_visit(node)
|
||
|
if isinstance(node.value, gast.Call):
|
||
|
call_node = node.value
|
||
|
static_val = anno.getanno(call_node.func, STATIC_VALUE, default=None)
|
||
|
if static_val is not None:
|
||
|
# Note: directive calls are not output in the generated code, hence
|
||
|
# the removal from the code by returning None.
|
||
|
|
||
|
if static_val is directives.set_element_type:
|
||
|
self._process_symbol_directive(call_node, static_val)
|
||
|
return None
|
||
|
elif static_val is directives.set_loop_options:
|
||
|
self._process_statement_directive(call_node, static_val)
|
||
|
return None
|
||
|
return node
|
||
|
|
||
|
# TODO(mdan): This will be insufficient for other control flow.
|
||
|
# That means that if we ever have a directive that affects things other than
|
||
|
# loops, we'll need support for parallel scopes, or have multiple converters.
|
||
|
def _track_and_visit_loop(self, node):
|
||
|
self.state[_LoopScope].enter()
|
||
|
self.state[_LoopScope].ast_node = node
|
||
|
node = self.generic_visit(node)
|
||
|
# Edge case: a loop with just one directive statement would become empty.
|
||
|
if not node.body:
|
||
|
node.body = [gast.Pass()]
|
||
|
self.state[_LoopScope].exit()
|
||
|
return node
|
||
|
|
||
|
def visit_While(self, node):
|
||
|
return self._track_and_visit_loop(node)
|
||
|
|
||
|
def visit_For(self, node):
|
||
|
return self._track_and_visit_loop(node)
|
||
|
|
||
|
|
||
|
def transform(node, ctx):
|
||
|
return DirectivesTransformer(ctx).visit(node)
|