222 lines
7.2 KiB
Python
222 lines
7.2 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Handles function calls, by generating compiled function names and calls.
|
|
|
|
Note: this transformer does not rename the top level object being converted;
|
|
that is the caller's responsibility.
|
|
|
|
Requires function_scopes.
|
|
"""
|
|
|
|
import gast
|
|
|
|
from tensorflow.python.autograph.core import converter
|
|
from tensorflow.python.autograph.pyct import anno
|
|
from tensorflow.python.autograph.pyct import parser
|
|
from tensorflow.python.autograph.pyct import qual_names
|
|
from tensorflow.python.autograph.pyct import templates
|
|
from tensorflow.python.autograph.utils import ag_logging
|
|
|
|
|
|
# TODO(mdan): Rename to FunctionCallsTransformer.
|
|
|
|
|
|
class _Function(object):
|
|
|
|
no_root = True
|
|
|
|
def __init__(self):
|
|
self.context_name = None
|
|
|
|
|
|
set_trace_warned = False
|
|
|
|
|
|
class _ArgTemplateBuilder(object):
|
|
"""Constructs a tuple representing the positional arguments in a call.
|
|
|
|
Example (yes, it's legal Python 3):
|
|
|
|
f(*args1, b, *args2, c, d) -> args1 + (b,) + args2 + (c, d)
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._arg_accumulator = []
|
|
self._argspec = []
|
|
self._finalized = False
|
|
|
|
def _consume_args(self):
|
|
if self._arg_accumulator:
|
|
self._argspec.append(
|
|
gast.Tuple(elts=self._arg_accumulator, ctx=gast.Load()))
|
|
self._arg_accumulator = []
|
|
|
|
def add_arg(self, a):
|
|
self._arg_accumulator.append(a)
|
|
|
|
def add_stararg(self, a):
|
|
self._consume_args()
|
|
self._argspec.append(
|
|
gast.Call(
|
|
gast.Name(
|
|
'tuple', ctx=gast.Load(), annotation=None, type_comment=None),
|
|
args=[a],
|
|
keywords=()))
|
|
|
|
def finalize(self):
|
|
self._consume_args()
|
|
self._finalized = True
|
|
|
|
def to_ast(self):
|
|
assert self._finalized
|
|
if self._argspec:
|
|
result = self._argspec[0]
|
|
for i in range(1, len(self._argspec)):
|
|
result = gast.BinOp(result, gast.Add(), self._argspec[i])
|
|
return result
|
|
return gast.Tuple([], gast.Load())
|
|
|
|
|
|
class CallTreeTransformer(converter.Base):
|
|
"""Transforms the call tree by renaming transformed symbols."""
|
|
|
|
def visit_Lambda(self, node):
|
|
if not anno.hasanno(node, 'function_context_name'):
|
|
# Lambda functions created during the conversion process have no
|
|
# context manager.
|
|
return self.generic_visit(node)
|
|
with self.state[_Function] as fn_scope:
|
|
fn_scope.context_name = anno.getanno(node, 'function_context_name')
|
|
return self.generic_visit(node)
|
|
|
|
def visit_FunctionDef(self, node):
|
|
# Decorators and arg defaults are part of the outer scope.
|
|
node.decorator_list = self.visit_block(node.decorator_list)
|
|
node.args.defaults = self.visit_block(node.args.defaults)
|
|
for i, d in enumerate(node.args.kw_defaults):
|
|
if d is not None:
|
|
node.args.kw_defaults[i] = self.visit(d)
|
|
with self.state[_Function] as fn_scope:
|
|
# Note: if the conversion process ever creates helper functions, this
|
|
# assumption will no longer hold.
|
|
assert anno.hasanno(node, 'function_context_name'), (
|
|
'The function_scopes converter always creates a scope for functions.')
|
|
fn_scope.context_name = anno.getanno(node, 'function_context_name')
|
|
node.body = self.visit_block(node.body)
|
|
if node.returns:
|
|
node.returns = self.visit(node.returns)
|
|
return node
|
|
|
|
def visit_With(self, node):
|
|
# Context manager calls (in node.items) are not converted.
|
|
node.body = self.visit_block(node.body)
|
|
return node
|
|
|
|
def _args_to_tuple(self, node):
|
|
"""Ties together all positional and *arg arguments in a single tuple."""
|
|
# TODO(mdan): We could rewrite this to just a call to tuple(). Maybe better?
|
|
# For example for
|
|
# f(a, b, *args)
|
|
# instead of writing:
|
|
# (a, b) + args
|
|
# just write this?
|
|
# tuple(a, b, *args)
|
|
builder = _ArgTemplateBuilder()
|
|
for a in node.args:
|
|
if isinstance(a, gast.Starred):
|
|
builder.add_stararg(a.value)
|
|
else:
|
|
builder.add_arg(a)
|
|
builder.finalize()
|
|
return builder.to_ast()
|
|
|
|
def _kwargs_to_dict(self, node):
|
|
"""Ties together all keyword and **kwarg arguments in a single dict."""
|
|
if node.keywords:
|
|
return gast.Call(
|
|
gast.Name(
|
|
'dict', ctx=gast.Load(), annotation=None, type_comment=None),
|
|
args=(),
|
|
keywords=node.keywords)
|
|
else:
|
|
return parser.parse_expression('None')
|
|
|
|
def visit_Call(self, node):
|
|
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
|
|
function_context_name = self.state[_Function].context_name
|
|
node = self.generic_visit(node)
|
|
|
|
# TODO(mdan): Refactor converted_call as a 'Call' operator.
|
|
|
|
# Calls to the internal 'ag__' module are never converted (though their
|
|
# arguments might be).
|
|
if full_name.startswith('ag__.'):
|
|
return node
|
|
|
|
# Calls to the function context manager (inserted by function_scopes) are
|
|
# also safe.
|
|
if full_name.startswith(function_context_name + '.'):
|
|
return node
|
|
|
|
# Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
|
|
# the normal mechanisms to bypass these literals because they are sensitive
|
|
# to the frame they are being called from.
|
|
# TODO(mdan): Generalize this to a "static allowlist" config.
|
|
if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'):
|
|
global set_trace_warned
|
|
if not set_trace_warned:
|
|
# TODO(mdan): Update and shorten once available on tensorflow.org.
|
|
ag_logging.warning(
|
|
'Detected `pdb.set_trace()` in user code. The code'
|
|
' generated by AutoGraph is not optimized for step-by-step'
|
|
' debugging. See https://github.com/tensorflow/tensorflow/'
|
|
'blob/master/tensorflow/python/autograph/g3doc/reference/'
|
|
'debugging.md.')
|
|
set_trace_warned = True
|
|
return node
|
|
|
|
if (full_name == 'print' and
|
|
not self.ctx.user.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
|
|
return node
|
|
|
|
template = """
|
|
ag__.converted_call(func, args, kwargs, function_ctx)
|
|
"""
|
|
new_call = templates.replace_as_expression(
|
|
template,
|
|
func=node.func,
|
|
args=self._args_to_tuple(node),
|
|
kwargs=self._kwargs_to_dict(node),
|
|
function_ctx=function_context_name)
|
|
|
|
return new_call
|
|
|
|
|
|
def transform(node, ctx):
|
|
"""Transform function call to the compiled counterparts.
|
|
|
|
Args:
|
|
node: AST
|
|
ctx: EntityContext
|
|
Returns:
|
|
A tuple (node, new_names):
|
|
node: The transformed AST
|
|
new_names: set(string), containing any newly-generated names
|
|
"""
|
|
node = qual_names.resolve(node)
|
|
|
|
node = CallTreeTransformer(ctx).visit(node)
|
|
return node
|